discollama/discollama.py

158 lines
4.1 KiB
Python
Raw Normal View History

2023-07-28 20:00:40 -07:00
import os
import json
import aiohttp
import discord
import argparse
2023-10-09 16:59:22 -07:00
from redis import Redis
2023-07-28 20:00:40 -07:00
import logging
intents = discord.Intents.default()
intents.message_content = True
client = discord.Client(intents=intents)
@client.event
async def on_ready():
2023-09-20 20:15:09 -07:00
logging.info(
'Ready! Invite URL: %s',
discord.utils.oauth_url(
client.application_id,
permissions=discord.Permissions(read_messages=True, send_messages=True),
scopes=['bot'],
))
2023-07-28 20:00:40 -07:00
2023-08-01 19:58:36 -07:00
async def generate_response(prompt, context=[]):
2023-09-20 20:15:09 -07:00
body = {
key: value
for key, value in {
'model': args.ollama_model,
'prompt': prompt,
'context': context,
}.items() if value
}
async with aiohttp.ClientSession() as session:
async with session.post(
f'http://{args.ollama_host}:{args.ollama_port}/api/generate',
json=body) as r:
async for line in r.content:
yield json.loads(line)
2023-07-28 20:00:40 -07:00
2023-08-04 17:58:44 -07:00
async def buffered_generate_response(prompt, context=[]):
2023-09-20 20:15:09 -07:00
buffer = ''
2023-10-14 13:25:31 -07:00
async for part in generate_response(prompt, context):
if part['done']:
yield buffer, part
2023-09-20 20:15:09 -07:00
break
2023-08-04 17:58:44 -07:00
2023-10-14 13:25:31 -07:00
buffer += part['response']
2023-09-20 20:15:09 -07:00
if len(buffer) >= args.buffer_size:
2023-10-14 13:25:31 -07:00
yield buffer, part
2023-09-20 20:15:09 -07:00
buffer = ''
2023-08-04 17:58:44 -07:00
2023-10-14 13:25:31 -07:00
def save_session(response, part):
context = part.get('context', [])
redis.json().set(f'ollama:{response.id}', '$', {'context': context})
2023-09-20 20:15:09 -07:00
redis.expire(f'ollama:{response.id}', 60 * 60 * 24 * 7)
2023-10-14 13:25:31 -07:00
logging.info('saving message=%s: len(context)=%d', response.id, len(context))
2023-07-28 20:00:40 -07:00
def load_session(reference):
2023-09-20 20:15:09 -07:00
kwargs = {}
if reference:
2023-10-14 13:25:31 -07:00
context = redis.json().get(f'ollama:{reference.message_id}', '.context')
kwargs['context'] = context or []
2023-07-28 20:00:40 -07:00
2023-09-20 20:15:09 -07:00
if kwargs.get('context'):
logging.info(
'loading message=%s: len(context)=%d',
reference.message_id,
len(kwargs['context']),
)
2023-07-28 20:00:40 -07:00
2023-09-20 20:15:09 -07:00
return kwargs
2023-07-28 20:00:40 -07:00
@client.event
async def on_message(message):
2023-09-20 20:15:09 -07:00
if message.author == client.user:
return
if client.user.id in message.raw_mentions:
raw_content = message.content.replace(f'<@{client.user.id}>', '').strip()
if raw_content.strip() == '':
raw_content = 'Tell me about yourself.'
response = None
response_content = ''
async with message.channel.typing():
await message.add_reaction('🤔')
context = []
if reference := message.reference:
if session := load_session(message.reference):
context = session.get('context', [])
else:
reference_message = await message.channel.fetch_message(
reference.message_id)
reference_content = reference_message.content
raw_content = '\n'.join([
raw_content,
'Use it to answer the prompt:',
reference_content,
])
2023-09-20 20:15:09 -07:00
2023-10-14 13:25:31 -07:00
async for buffer, part in buffered_generate_response(
2023-09-20 20:15:09 -07:00
raw_content,
context=context,
):
response_content += buffer
2023-10-14 13:25:31 -07:00
if part['done']:
save_session(response, part)
2023-09-20 20:15:09 -07:00
break
2023-08-04 00:42:08 -07:00
2023-09-20 20:15:09 -07:00
if not response:
response = await message.reply(response_content)
await message.remove_reaction('🤔', client.user)
continue
2023-09-20 20:15:09 -07:00
if len(response_content) + 3 >= 2000:
response = await response.reply(buffer)
response_content = buffer
continue
2023-09-20 20:15:09 -07:00
await response.edit(content=response_content + '...')
2023-07-28 20:00:40 -07:00
2023-09-20 20:15:09 -07:00
await response.edit(content=response_content)
2023-07-28 20:00:40 -07:00
2023-10-09 16:59:22 -07:00
default_ollama_host = os.getenv('OLLAMA_HOST', '127.0.0.1')
default_ollama_port = os.getenv('OLLAMA_PORT', 11434)
default_ollama_model = os.getenv('OLLAMA_MODEL', 'llama2')
2023-07-28 20:00:40 -07:00
parser = argparse.ArgumentParser()
2023-10-09 16:59:22 -07:00
parser.add_argument('--ollama-host', default=default_ollama_host)
parser.add_argument('--ollama-port', default=default_ollama_port, type=int)
parser.add_argument('--ollama-model', default=default_ollama_model, type=str)
2023-07-28 20:00:40 -07:00
2023-10-09 16:59:22 -07:00
parser.add_argument('--redis-host', default='localhost')
parser.add_argument('--redis-port', default=6379)
2023-07-28 20:00:40 -07:00
2023-08-01 19:58:36 -07:00
parser.add_argument('--buffer-size', default=32, type=int)
2023-07-28 20:00:40 -07:00
2023-07-31 10:25:09 -07:00
args = parser.parse_args()
2023-07-28 20:00:40 -07:00
2023-08-03 09:35:28 -07:00
try:
2023-10-09 16:59:22 -07:00
redis = Redis(host=args.redis_host, port=args.redis_port)
2023-09-20 20:15:09 -07:00
client.run(os.getenv('DISCORD_TOKEN'), root_logger=True)
2023-08-03 09:35:28 -07:00
except KeyboardInterrupt:
2023-09-20 20:15:09 -07:00
pass
2023-08-03 09:35:28 -07:00
redis.close()