diff --git a/discollama.py b/discollama.py index b8a52c5..33f0e80 100644 --- a/discollama.py +++ b/discollama.py @@ -43,6 +43,19 @@ async def generate_response(prompt, context=[]): yield json.loads(line) +async def buffered_generate_response(prompt, context=[]): + buffer = '' + async for chunk in generate_response(prompt, context): + if chunk['done']: + yield buffer, chunk + break + + buffer += chunk['response'] + if len(buffer) >= args.buffer_size: + yield buffer, chunk + buffer = '' + + def save_session(response, chunk): context = msgpack.packb(chunk['context']) redis.hset(f'ollama:{response.id}', 'context', context) @@ -77,29 +90,21 @@ async def on_message(message): # TODO: discord has a 2000 character limit, so we need to split the response response = None - buffer = '' response_content = '' async with message.channel.typing(): await message.add_reaction('🤔') - async for chunk in generate_response(raw_content, **load_session(message.reference)): + + async for buffer, chunk in buffered_generate_response(raw_content, **load_session(message.reference)): + response_content += buffer if chunk['done']: - response_content += buffer save_session(response, chunk) break - buffer += chunk['response'] - - if len(buffer) >= args.buffer_size: - # buffer the edit so as to not call Discord API too often - response_content += buffer - - if response: - await response.edit(content=response_content + '...') - else: - response = await message.reply(response_content) - await message.remove_reaction('🤔', client.user) - - buffer = '' + if response: + await response.edit(content=response_content + '...') + else: + response = await message.reply(response_content) + await message.remove_reaction('🤔', client.user) await response.edit(content=response_content)