buffer calls to discord

pull/3/head
Michael Yang 2023-07-31 10:25:09 -07:00
parent 2c32d1d90c
commit c7eb790389
1 changed files with 14 additions and 8 deletions

View File

@ -25,7 +25,7 @@ async def generate_response(prompt, context=[], session=None):
body = { body = {
key: value key: value
for key, value in { for key, value in {
'model': ollama_model, 'model': args.ollama_model,
'prompt': prompt, 'prompt': prompt,
'context': context, 'context': context,
'session_id': session, 'session_id': session,
@ -34,7 +34,7 @@ async def generate_response(prompt, context=[], session=None):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post( async with session.post(
f'http://{ollama_host}:{ollama_port}/api/generate', f'http://{args.ollama_host}:{args.ollama_port}/api/generate',
json=body) as r: json=body) as r:
async for line in r.content: async for line in r.content:
yield json.loads(line) yield json.loads(line)
@ -83,15 +83,23 @@ async def on_message(message):
response = await message.channel.send(':thinking:', reference=message) response = await message.channel.send(':thinking:', reference=message)
# TODO: discord has a 2000 character limit, so we need to split the response # TODO: discord has a 2000 character limit, so we need to split the response
buffer = ''
response_content = '' response_content = ''
async for chunk in generate_response(raw_content, **load_session(message.reference)): async for chunk in generate_response(raw_content, **load_session(message.reference)):
if chunk['done']: if chunk['done']:
response_content += buffer
save_session(response, chunk) save_session(response, chunk)
break break
response_content += chunk['response'] buffer += chunk['response']
if len(buffer) >= args.buffer_size:
# buffer the edit so as to not call Discord API too often
response_content += buffer
await response.edit(content=response_content + '...') await response.edit(content=response_content + '...')
buffer = ''
await response.edit(content=response_content) await response.edit(content=response_content)
@ -103,11 +111,9 @@ parser.add_argument('--ollama-model', default='llama2', type=str)
default_redis = Path.home() / '.cache' / 'discollama' / 'brain.db' default_redis = Path.home() / '.cache' / 'discollama' / 'brain.db'
parser.add_argument('--redis', default=default_redis, type=Path) parser.add_argument('--redis', default=default_redis, type=Path)
args = parser.parse_args() parser.add_argument('--buffer-size', default=64, type=int)
ollama_host = args.ollama_host args = parser.parse_args()
ollama_port = args.ollama_port
ollama_model = args.ollama_model
args.redis.parent.mkdir(parents=True, exist_ok=True) args.redis.parent.mkdir(parents=True, exist_ok=True)
redis = Redis(args.redis) redis = Redis(args.redis)