discollama/discollama.py

205 lines
6.3 KiB
Python
Raw Normal View History

2023-12-24 19:31:17 -08:00
import io
2023-07-28 20:00:40 -07:00
import os
import json
2023-12-24 19:31:17 -08:00
import asyncio
2023-07-28 20:00:40 -07:00
import argparse
2023-12-24 19:31:17 -08:00
from datetime import datetime, timedelta
2023-12-24 21:09:20 -08:00
from base64 import b64encode, b64decode
2023-07-28 20:00:40 -07:00
2023-12-24 19:31:17 -08:00
import ollama
import discord
import redis
2023-07-28 20:00:40 -07:00
2023-12-24 19:31:17 -08:00
from logging import getLogger
2023-07-28 20:00:40 -07:00
2023-12-24 19:31:17 -08:00
# piggy back on the logger discord.py set up
logging = getLogger('discord.discollama')
2023-07-28 20:00:40 -07:00
2023-12-24 19:31:17 -08:00
class Response:
def __init__(self, message):
self.message = message
self.channel = message.channel
2023-09-20 20:15:09 -07:00
2023-12-24 19:31:17 -08:00
self.r = None
self.sb = io.StringIO()
2023-07-28 20:00:40 -07:00
2023-12-24 19:31:17 -08:00
async def write(self, s, end=''):
if self.sb.seek(0, io.SEEK_END) + len(s) + len(end) > 2000:
self.r = None
self.sb.seek(0, io.SEEK_SET)
self.sb.truncate()
2023-07-28 20:00:40 -07:00
2023-12-24 19:31:17 -08:00
self.sb.write(s)
2023-10-14 13:46:00 -07:00
2023-12-24 21:09:20 -08:00
if self.sb.seek(0, io.SEEK_END) == 0:
return
2023-12-24 19:31:17 -08:00
if self.r:
await self.r.edit(content=self.sb.getvalue() + end)
return
2023-08-04 17:58:44 -07:00
2023-12-24 19:31:17 -08:00
if self.channel.type == discord.ChannelType.text:
self.channel = await self.channel.create_thread(name='Discollama Says', message=self.message, auto_archive_duration=60)
2023-08-04 17:58:44 -07:00
2023-12-24 19:31:17 -08:00
self.r = await self.channel.send(self.sb.getvalue())
2023-08-04 17:58:44 -07:00
2023-09-20 20:15:09 -07:00
2023-12-24 19:31:17 -08:00
class Discollama:
2023-12-24 21:09:20 -08:00
def __init__(self, ollama, discord, redis, models):
2023-12-24 19:31:17 -08:00
self.ollama = ollama
self.discord = discord
self.redis = redis
2023-07-28 20:00:40 -07:00
2023-12-24 21:09:20 -08:00
self.models = models
2024-01-03 14:05:36 -08:00
# registry setup hook
self.discord.setup_hook = self.setup_hook
2023-12-24 19:31:17 -08:00
# register event handlers
self.discord.event(self.on_ready)
self.discord.event(self.on_message)
2023-07-28 20:00:40 -07:00
2023-12-24 19:31:17 -08:00
async def on_ready(self):
activity = discord.Activity(name='Discollama', state='Ask me anything!', type=discord.ActivityType.custom)
await self.discord.change_presence(activity=activity)
2023-07-28 20:00:40 -07:00
2023-09-20 20:15:09 -07:00
logging.info(
2023-12-24 19:31:17 -08:00
'Ready! Invite URL: %s',
discord.utils.oauth_url(
self.discord.application_id,
permissions=discord.Permissions(
read_messages=True,
send_messages=True,
create_public_threads=True,
),
scopes=['bot'],
),
2023-09-20 20:15:09 -07:00
)
2023-07-28 20:00:40 -07:00
2023-12-24 19:31:17 -08:00
async def on_message(self, message):
if self.discord.user == message.author:
# don't respond to ourselves
return
if not self.discord.user.mentioned_in(message):
# don't respond to messages that don't mention us
return
content = message.content.replace(f'<@{self.discord.user.id}>', '').strip()
if not content:
content = 'Hi!'
channel = message.channel
2023-12-24 21:09:20 -08:00
context, images = [], []
2023-12-24 19:31:17 -08:00
if reference := message.reference:
2023-12-24 21:09:20 -08:00
context, images = await self.load(message_id=reference.message_id)
2023-12-24 19:31:17 -08:00
if not context:
reference_message = await message.channel.fetch_message(reference.message_id)
content = '\n'.join(
[
content,
'Use this to answer the question if it is relevant, otherwise ignore it:',
reference_message.content,
]
)
if not context:
2023-12-24 21:09:20 -08:00
context, images = await self.load(channel_id=channel.id)
images.extend([await attachment.read() for attachment in message.attachments if attachment.content_type.startswith('image/')])
2023-12-24 19:31:17 -08:00
r = Response(message)
task = asyncio.create_task(self.thinking(message))
2023-12-24 21:09:20 -08:00
async for part in self.generate(content, context, images=images):
2023-12-24 19:31:17 -08:00
task.cancel()
await r.write(part['response'], end='...')
await r.write('')
2023-12-24 21:09:20 -08:00
await self.save(r.channel.id, message.id, part['context'], images)
2023-12-24 19:31:17 -08:00
async def thinking(self, message, timeout=999):
try:
await message.add_reaction('🤔')
async with message.channel.typing():
await asyncio.sleep(timeout)
except Exception:
pass
finally:
await message.remove_reaction('🤔', self.discord.user)
2023-07-28 20:00:40 -07:00
2023-12-24 21:09:20 -08:00
async def generate(self, content, context, images=None):
2024-01-03 14:05:36 -08:00
model = self.models['images' if images else '']
2023-12-24 21:09:20 -08:00
2023-12-24 19:31:17 -08:00
sb = io.StringIO()
2023-07-28 20:00:40 -07:00
2023-12-24 19:31:17 -08:00
t = datetime.now()
2023-12-24 21:09:20 -08:00
async for part in await self.ollama.generate(model=model, prompt=content, context=context, images=images, stream=True):
2023-12-24 19:31:17 -08:00
sb.write(part['response'])
2023-09-20 20:15:09 -07:00
2023-12-24 19:31:17 -08:00
if part['done'] or datetime.now() - t > timedelta(seconds=1):
part['response'] = sb.getvalue()
yield part
t = datetime.now()
sb.seek(0, io.SEEK_SET)
sb.truncate()
2023-12-24 21:09:20 -08:00
async def save(self, channel_id, message_id, context, images):
2023-12-24 19:31:17 -08:00
self.redis.set(f'discollama:channel:{channel_id}', message_id, ex=60 * 60 * 24 * 7)
2023-12-24 21:09:20 -08:00
self.redis.set(f'discollama:message:{message_id}', json.dumps(context), ex=60 * 60 * 24 * 7)
2023-12-24 19:31:17 -08:00
2023-12-24 21:09:20 -08:00
images = [b64encode(image).decode('utf-8') for image in images]
self.redis.set(f'discollama:images:{message_id}', json.dumps(images), ex=60 * 60 * 24 * 7)
async def load(self, channel_id=None, message_id=None):
2023-12-24 19:31:17 -08:00
if channel_id:
message_id = self.redis.get(f'discollama:channel:{channel_id}')
2023-12-24 21:09:20 -08:00
context = self.redis.get(f'discollama:message:{message_id}')
images = self.redis.get(f'discollama:images:{message_id}')
return json.loads(context) if context else [], [b64decode(image) for image in json.loads(images)] if images else []
2023-12-24 19:31:17 -08:00
def run(self, token):
try:
self.discord.run(token)
except Exception:
self.redis.close()
2024-01-03 14:05:36 -08:00
async def setup_hook(self):
for key, value in self.models.items():
logging.info('Downloading %s model %s...', key, value)
await self.ollama.pull(value)
logging.info('Downloading %s model %s... done', key, value)
2023-12-24 19:31:17 -08:00
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--ollama-scheme', default=os.getenv('OLLAMA_SCHEME', 'http'), choices=['http', 'https'])
parser.add_argument('--ollama-host', default=os.getenv('OLLAMA_HOST', '127.0.0.1'), type=str)
parser.add_argument('--ollama-port', default=os.getenv('OLLAMA_PORT', 11434), type=int)
2023-12-24 21:09:20 -08:00
2023-12-24 19:31:17 -08:00
parser.add_argument('--ollama-model', default=os.getenv('OLLAMA_MODEL', 'llama2'), type=str)
2023-12-24 21:09:20 -08:00
parser.add_argument('--ollama-images-model', default=os.getenv('OLLAMA_IMAGES_MODEL', 'llava'), type=str)
2023-12-24 19:31:17 -08:00
parser.add_argument('--redis-host', default=os.getenv('REDIS_HOST', '127.0.0.1'), type=str)
parser.add_argument('--redis-port', default=os.getenv('REDIS_PORT', 6379), type=int)
args = parser.parse_args()
Discollama(
ollama.AsyncClient(base_url=f'{args.ollama_scheme}://{args.ollama_host}:{args.ollama_port}'),
2023-12-24 21:09:20 -08:00
discord.Client(intents=discord.Intents.default()),
2023-12-24 19:31:17 -08:00
redis.Redis(host=args.redis_host, port=args.redis_port, db=0, decode_responses=True),
2023-12-24 21:09:20 -08:00
{
2024-01-03 14:05:36 -08:00
'': args.ollama_model,
2023-12-24 21:09:20 -08:00
'images': args.ollama_images_model,
},
2023-12-24 19:31:17 -08:00
).run(os.environ['DISCORD_TOKEN'])
2023-09-20 20:15:09 -07:00
2023-12-24 19:31:17 -08:00
if __name__ == '__main__':
main()