refactor buffer generate response

pull/3/head
Michael Yang 2023-08-04 17:58:44 -07:00
parent 0d4f5792de
commit f32a4b09b3
1 changed files with 21 additions and 16 deletions

View File

@ -43,6 +43,19 @@ async def generate_response(prompt, context=[]):
yield json.loads(line) yield json.loads(line)
async def buffered_generate_response(prompt, context=[]):
buffer = ''
async for chunk in generate_response(prompt, context):
if chunk['done']:
yield buffer, chunk
break
buffer += chunk['response']
if len(buffer) >= args.buffer_size:
yield buffer, chunk
buffer = ''
def save_session(response, chunk): def save_session(response, chunk):
context = msgpack.packb(chunk['context']) context = msgpack.packb(chunk['context'])
redis.hset(f'ollama:{response.id}', 'context', context) redis.hset(f'ollama:{response.id}', 'context', context)
@ -77,30 +90,22 @@ async def on_message(message):
# TODO: discord has a 2000 character limit, so we need to split the response # TODO: discord has a 2000 character limit, so we need to split the response
response = None response = None
buffer = ''
response_content = '' response_content = ''
async with message.channel.typing(): async with message.channel.typing():
await message.add_reaction('🤔') await message.add_reaction('🤔')
async for chunk in generate_response(raw_content, **load_session(message.reference)):
if chunk['done']: async for buffer, chunk in buffered_generate_response(raw_content, **load_session(message.reference)):
response_content += buffer response_content += buffer
if chunk['done']:
save_session(response, chunk) save_session(response, chunk)
break break
buffer += chunk['response']
if len(buffer) >= args.buffer_size:
# buffer the edit so as to not call Discord API too often
response_content += buffer
if response: if response:
await response.edit(content=response_content + '...') await response.edit(content=response_content + '...')
else: else:
response = await message.reply(response_content) response = await message.reply(response_content)
await message.remove_reaction('🤔', client.user) await message.remove_reaction('🤔', client.user)
buffer = ''
await response.edit(content=response_content) await response.edit(content=response_content)