refactor buffer generate response

pull/3/head
Michael Yang 2023-08-04 17:58:44 -07:00
parent 0d4f5792de
commit f32a4b09b3
1 changed files with 21 additions and 16 deletions

View File

@ -43,6 +43,19 @@ async def generate_response(prompt, context=[]):
yield json.loads(line) yield json.loads(line)
async def buffered_generate_response(prompt, context=[]):
buffer = ''
async for chunk in generate_response(prompt, context):
if chunk['done']:
yield buffer, chunk
break
buffer += chunk['response']
if len(buffer) >= args.buffer_size:
yield buffer, chunk
buffer = ''
def save_session(response, chunk): def save_session(response, chunk):
context = msgpack.packb(chunk['context']) context = msgpack.packb(chunk['context'])
redis.hset(f'ollama:{response.id}', 'context', context) redis.hset(f'ollama:{response.id}', 'context', context)
@ -77,29 +90,21 @@ async def on_message(message):
# TODO: discord has a 2000 character limit, so we need to split the response # TODO: discord has a 2000 character limit, so we need to split the response
response = None response = None
buffer = ''
response_content = '' response_content = ''
async with message.channel.typing(): async with message.channel.typing():
await message.add_reaction('🤔') await message.add_reaction('🤔')
async for chunk in generate_response(raw_content, **load_session(message.reference)):
async for buffer, chunk in buffered_generate_response(raw_content, **load_session(message.reference)):
response_content += buffer
if chunk['done']: if chunk['done']:
response_content += buffer
save_session(response, chunk) save_session(response, chunk)
break break
buffer += chunk['response'] if response:
await response.edit(content=response_content + '...')
if len(buffer) >= args.buffer_size: else:
# buffer the edit so as to not call Discord API too often response = await message.reply(response_content)
response_content += buffer await message.remove_reaction('🤔', client.user)
if response:
await response.edit(content=response_content + '...')
else:
response = await message.reply(response_content)
await message.remove_reaction('🤔', client.user)
buffer = ''
await response.edit(content=response_content) await response.edit(content=response_content)