diff --git a/discollama.py b/discollama.py index 076e3f0..5124bf1 100644 --- a/discollama.py +++ b/discollama.py @@ -4,6 +4,7 @@ import json import asyncio import argparse from datetime import datetime, timedelta +from base64 import b64encode, b64decode import ollama import discord @@ -31,6 +32,9 @@ class Response: self.sb.write(s) + if self.sb.seek(0, io.SEEK_END) == 0: + return + if self.r: await self.r.edit(content=self.sb.getvalue() + end) return @@ -42,11 +46,16 @@ class Response: class Discollama: - def __init__(self, ollama, discord, redis): + def __init__(self, ollama, discord, redis, models): self.ollama = ollama self.discord = discord self.redis = redis + self.models = models + + # registry setup hook + self.discord.setup_hook = self.setup_hook + # register event handlers self.discord.event(self.on_ready) self.discord.event(self.on_message) @@ -83,9 +92,9 @@ class Discollama: channel = message.channel - context = [] + context, images = [], [] if reference := message.reference: - context = await self.load(message_id=reference.message_id) + context, images = await self.load(message_id=reference.message_id) if not context: reference_message = await message.channel.fetch_message(reference.message_id) content = '\n'.join( @@ -97,17 +106,19 @@ class Discollama: ) if not context: - context = await self.load(channel_id=channel.id) + context, images = await self.load(channel_id=channel.id) + + images.extend([await attachment.read() for attachment in message.attachments if attachment.content_type.startswith('image/')]) r = Response(message) task = asyncio.create_task(self.thinking(message)) - async for part in self.generate(content, context): + async for part in self.generate(content, context, images=images): task.cancel() await r.write(part['response'], end='...') await r.write('') - await self.save(r.channel.id, message.id, part['context']) + await self.save(r.channel.id, message.id, part['context'], images) async def thinking(self, message, timeout=999): try: @@ -119,11 +130,13 @@ class Discollama: finally: await message.remove_reaction('🤔', self.discord.user) - async def generate(self, content, context): + async def generate(self, content, context, images=None): + model = self.models['images' if images else ''] + sb = io.StringIO() t = datetime.now() - async for part in await self.ollama.generate(model='llama2', prompt=content, context=context, stream=True): + async for part in await self.ollama.generate(model=model, prompt=content, context=context, images=images, stream=True): sb.write(part['response']) if part['done'] or datetime.now() - t > timedelta(seconds=1): @@ -133,16 +146,20 @@ class Discollama: sb.seek(0, io.SEEK_SET) sb.truncate() - async def save(self, channel_id, message_id, ctx: list[int]): + async def save(self, channel_id, message_id, context, images): self.redis.set(f'discollama:channel:{channel_id}', message_id, ex=60 * 60 * 24 * 7) - self.redis.set(f'discollama:message:{message_id}', json.dumps(ctx), ex=60 * 60 * 24 * 7) + self.redis.set(f'discollama:message:{message_id}', json.dumps(context), ex=60 * 60 * 24 * 7) - async def load(self, channel_id=None, message_id=None) -> list[int]: + images = [b64encode(image).decode('utf-8') for image in images] + self.redis.set(f'discollama:images:{message_id}', json.dumps(images), ex=60 * 60 * 24 * 7) + + async def load(self, channel_id=None, message_id=None): if channel_id: message_id = self.redis.get(f'discollama:channel:{channel_id}') - ctx = self.redis.get(f'discollama:message:{message_id}') - return json.loads(ctx) if ctx else [] + context = self.redis.get(f'discollama:message:{message_id}') + images = self.redis.get(f'discollama:images:{message_id}') + return json.loads(context) if context else [], [b64decode(image) for image in json.loads(images)] if images else [] def run(self, token): try: @@ -150,6 +167,12 @@ class Discollama: except Exception: self.redis.close() + async def setup_hook(self): + for key, value in self.models.items(): + logging.info('Downloading %s model %s...', key, value) + await self.ollama.pull(value) + logging.info('Downloading %s model %s... done', key, value) + def main(): parser = argparse.ArgumentParser() @@ -157,22 +180,23 @@ def main(): parser.add_argument('--ollama-scheme', default=os.getenv('OLLAMA_SCHEME', 'http'), choices=['http', 'https']) parser.add_argument('--ollama-host', default=os.getenv('OLLAMA_HOST', '127.0.0.1'), type=str) parser.add_argument('--ollama-port', default=os.getenv('OLLAMA_PORT', 11434), type=int) + parser.add_argument('--ollama-model', default=os.getenv('OLLAMA_MODEL', 'llama2'), type=str) + parser.add_argument('--ollama-images-model', default=os.getenv('OLLAMA_IMAGES_MODEL', 'llava'), type=str) parser.add_argument('--redis-host', default=os.getenv('REDIS_HOST', '127.0.0.1'), type=str) parser.add_argument('--redis-port', default=os.getenv('REDIS_PORT', 6379), type=int) - parser.add_argument('--buffer-size', default=32, type=int) - args = parser.parse_args() - intents = discord.Intents.default() - intents.message_content = True - Discollama( ollama.AsyncClient(base_url=f'{args.ollama_scheme}://{args.ollama_host}:{args.ollama_port}'), - discord.Client(intents=intents), + discord.Client(intents=discord.Intents.default()), redis.Redis(host=args.redis_host, port=args.redis_port, db=0, decode_responses=True), + { + '': args.ollama_model, + 'images': args.ollama_images_model, + }, ).run(os.environ['DISCORD_TOKEN'])