formatting

pull/3/head
Michael Yang 2023-09-20 20:15:09 -07:00
parent 924017de97
commit c564d05de5
1 changed files with 93 additions and 82 deletions

View File

@ -17,114 +17,125 @@ client = discord.Client(intents=intents)
@client.event @client.event
async def on_ready(): async def on_ready():
logging.info( logging.info(
'Ready! Invite URL: %s', 'Ready! Invite URL: %s',
discord.utils.oauth_url( discord.utils.oauth_url(
client.application_id, client.application_id,
permissions=discord.Permissions(read_messages=True, send_messages=True), permissions=discord.Permissions(read_messages=True, send_messages=True),
scopes=['bot'])) scopes=['bot'],
))
async def generate_response(prompt, context=[]): async def generate_response(prompt, context=[]):
body = { body = {
key: value key: value
for key, value in { for key, value in {
'model': args.ollama_model, 'model': args.ollama_model,
'prompt': prompt, 'prompt': prompt,
'context': context, 'context': context,
}.items() if value }.items() if value
} }
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post( async with session.post(
f'http://{args.ollama_host}:{args.ollama_port}/api/generate', f'http://{args.ollama_host}:{args.ollama_port}/api/generate',
json=body) as r: json=body) as r:
async for line in r.content: async for line in r.content:
yield json.loads(line) yield json.loads(line)
async def buffered_generate_response(prompt, context=[]): async def buffered_generate_response(prompt, context=[]):
buffer = '' buffer = ''
async for chunk in generate_response(prompt, context): async for chunk in generate_response(prompt, context):
if chunk['done']: if chunk['done']:
yield buffer, chunk yield buffer, chunk
break break
buffer += chunk['response'] buffer += chunk['response']
if len(buffer) >= args.buffer_size: if len(buffer) >= args.buffer_size:
yield buffer, chunk yield buffer, chunk
buffer = '' buffer = ''
def save_session(response, chunk): def save_session(response, chunk):
context = msgpack.packb(chunk['context']) context = msgpack.packb(chunk['context'])
redis.hset(f'ollama:{response.id}', 'context', context) redis.hset(f'ollama:{response.id}', 'context', context)
redis.expire(f'ollama:{response.id}', 60 * 60 * 24 * 7)
logging.info(
'saving message=%s: len(context)=%d',
response.id,
len(chunk['context']),
)
redis.expire(f'ollama:{response.id}', 60 * 60 * 24 * 7)
logging.info('saving message=%s: len(context)=%d', response.id, len(chunk['context']))
def load_session(reference): def load_session(reference):
kwargs = {} kwargs = {}
if reference: if reference:
context = redis.hget(f'ollama:{reference.message_id}', 'context') context = redis.hget(f'ollama:{reference.message_id}', 'context')
kwargs['context'] = msgpack.unpackb(context) if context else [] kwargs['context'] = msgpack.unpackb(context) if context else []
if kwargs.get('context'): if kwargs.get('context'):
logging.info( logging.info(
'loading message=%s: len(context)=%d', 'loading message=%s: len(context)=%d',
reference.message_id, reference.message_id,
len(kwargs['context'])) len(kwargs['context']),
)
return kwargs return kwargs
@client.event @client.event
async def on_message(message): async def on_message(message):
if message.author == client.user: if message.author == client.user:
return return
if client.user.id in message.raw_mentions: if client.user.id in message.raw_mentions:
raw_content = message.content.replace(f'<@{client.user.id}>', '').strip() raw_content = message.content.replace(f'<@{client.user.id}>', '').strip()
if raw_content.strip() == '': if raw_content.strip() == '':
raw_content = 'Tell me about yourself.' raw_content = 'Tell me about yourself.'
response = None response = None
response_content = '' response_content = ''
async with message.channel.typing(): async with message.channel.typing():
await message.add_reaction('🤔') await message.add_reaction('🤔')
context = [] context = []
if reference := message.reference: if reference := message.reference:
if session := load_session(message.reference): if session := load_session(message.reference):
context = session.get('context', []) context = session.get('context', [])
reference_message = await message.channel.fetch_message(reference.message_id) reference_message = await message.channel.fetch_message(
reference_content = reference_message.content reference.message_id)
raw_content = '\n'.join([ reference_content = reference_message.content
raw_content, raw_content = '\n'.join([
'Use it to answer the prompt:', raw_content,
reference_content, 'Use it to answer the prompt:',
]) reference_content,
])
async for buffer, chunk in buffered_generate_response(raw_content, context=context): async for buffer, chunk in buffered_generate_response(
response_content += buffer raw_content,
if chunk['done']: context=context,
save_session(response, chunk) ):
break response_content += buffer
if chunk['done']:
save_session(response, chunk)
break
if not response: if not response:
response = await message.reply(response_content) response = await message.reply(response_content)
await message.remove_reaction('🤔', client.user) await message.remove_reaction('🤔', client.user)
continue continue
if len(response_content) + 3 >= 2000: if len(response_content) + 3 >= 2000:
response = await response.reply(buffer) response = await response.reply(buffer)
response_content = buffer response_content = buffer
continue continue
await response.edit(content=response_content + '...') await response.edit(content=response_content + '...')
await response.edit(content=response_content) await response.edit(content=response_content)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -142,9 +153,9 @@ args = parser.parse_args()
args.redis.parent.mkdir(parents=True, exist_ok=True) args.redis.parent.mkdir(parents=True, exist_ok=True)
try: try:
redis = Redis(args.redis) redis = Redis(args.redis)
client.run(os.getenv('DISCORD_TOKEN'), root_logger=True) client.run(os.getenv('DISCORD_TOKEN'), root_logger=True)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
redis.close() redis.close()