From c7eb7903898915c11cf5bb77324f2cf7dbd9252c Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 31 Jul 2023 10:25:09 -0700 Subject: [PATCH] buffer calls to discord --- discollama.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/discollama.py b/discollama.py index 7027515..50da1d3 100644 --- a/discollama.py +++ b/discollama.py @@ -25,7 +25,7 @@ async def generate_response(prompt, context=[], session=None): body = { key: value for key, value in { - 'model': ollama_model, + 'model': args.ollama_model, 'prompt': prompt, 'context': context, 'session_id': session, @@ -34,7 +34,7 @@ async def generate_response(prompt, context=[], session=None): async with aiohttp.ClientSession() as session: async with session.post( - f'http://{ollama_host}:{ollama_port}/api/generate', + f'http://{args.ollama_host}:{args.ollama_port}/api/generate', json=body) as r: async for line in r.content: yield json.loads(line) @@ -83,14 +83,22 @@ async def on_message(message): response = await message.channel.send(':thinking:', reference=message) # TODO: discord has a 2000 character limit, so we need to split the response + buffer = '' response_content = '' async for chunk in generate_response(raw_content, **load_session(message.reference)): if chunk['done']: + response_content += buffer save_session(response, chunk) break - response_content += chunk['response'] - await response.edit(content=response_content + '...') + buffer += chunk['response'] + + if len(buffer) >= args.buffer_size: + # buffer the edit so as to not call Discord API too often + response_content += buffer + await response.edit(content=response_content + '...') + + buffer = '' await response.edit(content=response_content) @@ -103,11 +111,9 @@ parser.add_argument('--ollama-model', default='llama2', type=str) default_redis = Path.home() / '.cache' / 'discollama' / 'brain.db' parser.add_argument('--redis', default=default_redis, type=Path) -args = parser.parse_args() +parser.add_argument('--buffer-size', default=64, type=int) -ollama_host = args.ollama_host -ollama_port = args.ollama_port -ollama_model = args.ollama_model +args = parser.parse_args() args.redis.parent.mkdir(parents=True, exist_ok=True) redis = Redis(args.redis)