formatting
parent
924017de97
commit
c564d05de5
175
discollama.py
175
discollama.py
|
@ -17,114 +17,125 @@ client = discord.Client(intents=intents)
|
||||||
|
|
||||||
@client.event
|
@client.event
|
||||||
async def on_ready():
|
async def on_ready():
|
||||||
logging.info(
|
logging.info(
|
||||||
'Ready! Invite URL: %s',
|
'Ready! Invite URL: %s',
|
||||||
discord.utils.oauth_url(
|
discord.utils.oauth_url(
|
||||||
client.application_id,
|
client.application_id,
|
||||||
permissions=discord.Permissions(read_messages=True, send_messages=True),
|
permissions=discord.Permissions(read_messages=True, send_messages=True),
|
||||||
scopes=['bot']))
|
scopes=['bot'],
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
async def generate_response(prompt, context=[]):
|
async def generate_response(prompt, context=[]):
|
||||||
body = {
|
body = {
|
||||||
key: value
|
key: value
|
||||||
for key, value in {
|
for key, value in {
|
||||||
'model': args.ollama_model,
|
'model': args.ollama_model,
|
||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
'context': context,
|
'context': context,
|
||||||
}.items() if value
|
}.items() if value
|
||||||
}
|
}
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f'http://{args.ollama_host}:{args.ollama_port}/api/generate',
|
f'http://{args.ollama_host}:{args.ollama_port}/api/generate',
|
||||||
json=body) as r:
|
json=body) as r:
|
||||||
async for line in r.content:
|
async for line in r.content:
|
||||||
yield json.loads(line)
|
yield json.loads(line)
|
||||||
|
|
||||||
|
|
||||||
async def buffered_generate_response(prompt, context=[]):
|
async def buffered_generate_response(prompt, context=[]):
|
||||||
buffer = ''
|
buffer = ''
|
||||||
async for chunk in generate_response(prompt, context):
|
async for chunk in generate_response(prompt, context):
|
||||||
if chunk['done']:
|
if chunk['done']:
|
||||||
yield buffer, chunk
|
yield buffer, chunk
|
||||||
break
|
break
|
||||||
|
|
||||||
buffer += chunk['response']
|
buffer += chunk['response']
|
||||||
if len(buffer) >= args.buffer_size:
|
if len(buffer) >= args.buffer_size:
|
||||||
yield buffer, chunk
|
yield buffer, chunk
|
||||||
buffer = ''
|
buffer = ''
|
||||||
|
|
||||||
|
|
||||||
def save_session(response, chunk):
|
def save_session(response, chunk):
|
||||||
context = msgpack.packb(chunk['context'])
|
context = msgpack.packb(chunk['context'])
|
||||||
redis.hset(f'ollama:{response.id}', 'context', context)
|
redis.hset(f'ollama:{response.id}', 'context', context)
|
||||||
|
|
||||||
|
redis.expire(f'ollama:{response.id}', 60 * 60 * 24 * 7)
|
||||||
|
logging.info(
|
||||||
|
'saving message=%s: len(context)=%d',
|
||||||
|
response.id,
|
||||||
|
len(chunk['context']),
|
||||||
|
)
|
||||||
|
|
||||||
redis.expire(f'ollama:{response.id}', 60 * 60 * 24 * 7)
|
|
||||||
logging.info('saving message=%s: len(context)=%d', response.id, len(chunk['context']))
|
|
||||||
|
|
||||||
def load_session(reference):
|
def load_session(reference):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if reference:
|
if reference:
|
||||||
context = redis.hget(f'ollama:{reference.message_id}', 'context')
|
context = redis.hget(f'ollama:{reference.message_id}', 'context')
|
||||||
kwargs['context'] = msgpack.unpackb(context) if context else []
|
kwargs['context'] = msgpack.unpackb(context) if context else []
|
||||||
|
|
||||||
if kwargs.get('context'):
|
if kwargs.get('context'):
|
||||||
logging.info(
|
logging.info(
|
||||||
'loading message=%s: len(context)=%d',
|
'loading message=%s: len(context)=%d',
|
||||||
reference.message_id,
|
reference.message_id,
|
||||||
len(kwargs['context']))
|
len(kwargs['context']),
|
||||||
|
)
|
||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
@client.event
|
@client.event
|
||||||
async def on_message(message):
|
async def on_message(message):
|
||||||
if message.author == client.user:
|
if message.author == client.user:
|
||||||
return
|
return
|
||||||
|
|
||||||
if client.user.id in message.raw_mentions:
|
if client.user.id in message.raw_mentions:
|
||||||
raw_content = message.content.replace(f'<@{client.user.id}>', '').strip()
|
raw_content = message.content.replace(f'<@{client.user.id}>', '').strip()
|
||||||
if raw_content.strip() == '':
|
if raw_content.strip() == '':
|
||||||
raw_content = 'Tell me about yourself.'
|
raw_content = 'Tell me about yourself.'
|
||||||
|
|
||||||
response = None
|
response = None
|
||||||
response_content = ''
|
response_content = ''
|
||||||
async with message.channel.typing():
|
async with message.channel.typing():
|
||||||
await message.add_reaction('🤔')
|
await message.add_reaction('🤔')
|
||||||
|
|
||||||
context = []
|
context = []
|
||||||
if reference := message.reference:
|
if reference := message.reference:
|
||||||
if session := load_session(message.reference):
|
if session := load_session(message.reference):
|
||||||
context = session.get('context', [])
|
context = session.get('context', [])
|
||||||
|
|
||||||
reference_message = await message.channel.fetch_message(reference.message_id)
|
reference_message = await message.channel.fetch_message(
|
||||||
reference_content = reference_message.content
|
reference.message_id)
|
||||||
raw_content = '\n'.join([
|
reference_content = reference_message.content
|
||||||
raw_content,
|
raw_content = '\n'.join([
|
||||||
'Use it to answer the prompt:',
|
raw_content,
|
||||||
reference_content,
|
'Use it to answer the prompt:',
|
||||||
])
|
reference_content,
|
||||||
|
])
|
||||||
|
|
||||||
async for buffer, chunk in buffered_generate_response(raw_content, context=context):
|
async for buffer, chunk in buffered_generate_response(
|
||||||
response_content += buffer
|
raw_content,
|
||||||
if chunk['done']:
|
context=context,
|
||||||
save_session(response, chunk)
|
):
|
||||||
break
|
response_content += buffer
|
||||||
|
if chunk['done']:
|
||||||
|
save_session(response, chunk)
|
||||||
|
break
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
response = await message.reply(response_content)
|
response = await message.reply(response_content)
|
||||||
await message.remove_reaction('🤔', client.user)
|
await message.remove_reaction('🤔', client.user)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if len(response_content) + 3 >= 2000:
|
if len(response_content) + 3 >= 2000:
|
||||||
response = await response.reply(buffer)
|
response = await response.reply(buffer)
|
||||||
response_content = buffer
|
response_content = buffer
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await response.edit(content=response_content + '...')
|
await response.edit(content=response_content + '...')
|
||||||
|
|
||||||
await response.edit(content=response_content)
|
await response.edit(content=response_content)
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -142,9 +153,9 @@ args = parser.parse_args()
|
||||||
args.redis.parent.mkdir(parents=True, exist_ok=True)
|
args.redis.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
redis = Redis(args.redis)
|
redis = Redis(args.redis)
|
||||||
client.run(os.getenv('DISCORD_TOKEN'), root_logger=True)
|
client.run(os.getenv('DISCORD_TOKEN'), root_logger=True)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
redis.close()
|
redis.close()
|
||||||
|
|
Loading…
Reference in New Issue