pull/8/merge
Michael Yang 2024-01-05 09:58:02 -08:00 committed by GitHub
commit da06488732
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 43 additions and 19 deletions

View File

@ -4,6 +4,7 @@ import json
import asyncio import asyncio
import argparse import argparse
from datetime import datetime, timedelta from datetime import datetime, timedelta
from base64 import b64encode, b64decode
import ollama import ollama
import discord import discord
@ -31,6 +32,9 @@ class Response:
self.sb.write(s) self.sb.write(s)
if self.sb.seek(0, io.SEEK_END) == 0:
return
if self.r: if self.r:
await self.r.edit(content=self.sb.getvalue() + end) await self.r.edit(content=self.sb.getvalue() + end)
return return
@ -42,11 +46,16 @@ class Response:
class Discollama: class Discollama:
def __init__(self, ollama, discord, redis): def __init__(self, ollama, discord, redis, models):
self.ollama = ollama self.ollama = ollama
self.discord = discord self.discord = discord
self.redis = redis self.redis = redis
self.models = models
# registry setup hook
self.discord.setup_hook = self.setup_hook
# register event handlers # register event handlers
self.discord.event(self.on_ready) self.discord.event(self.on_ready)
self.discord.event(self.on_message) self.discord.event(self.on_message)
@ -83,9 +92,9 @@ class Discollama:
channel = message.channel channel = message.channel
context = [] context, images = [], []
if reference := message.reference: if reference := message.reference:
context = await self.load(message_id=reference.message_id) context, images = await self.load(message_id=reference.message_id)
if not context: if not context:
reference_message = await message.channel.fetch_message(reference.message_id) reference_message = await message.channel.fetch_message(reference.message_id)
content = '\n'.join( content = '\n'.join(
@ -97,17 +106,19 @@ class Discollama:
) )
if not context: if not context:
context = await self.load(channel_id=channel.id) context, images = await self.load(channel_id=channel.id)
images.extend([await attachment.read() for attachment in message.attachments if attachment.content_type.startswith('image/')])
r = Response(message) r = Response(message)
task = asyncio.create_task(self.thinking(message)) task = asyncio.create_task(self.thinking(message))
async for part in self.generate(content, context): async for part in self.generate(content, context, images=images):
task.cancel() task.cancel()
await r.write(part['response'], end='...') await r.write(part['response'], end='...')
await r.write('') await r.write('')
await self.save(r.channel.id, message.id, part['context']) await self.save(r.channel.id, message.id, part['context'], images)
async def thinking(self, message, timeout=999): async def thinking(self, message, timeout=999):
try: try:
@ -119,11 +130,13 @@ class Discollama:
finally: finally:
await message.remove_reaction('🤔', self.discord.user) await message.remove_reaction('🤔', self.discord.user)
async def generate(self, content, context): async def generate(self, content, context, images=None):
model = self.models['images' if images else '']
sb = io.StringIO() sb = io.StringIO()
t = datetime.now() t = datetime.now()
async for part in await self.ollama.generate(model='llama2', prompt=content, context=context, stream=True): async for part in await self.ollama.generate(model=model, prompt=content, context=context, images=images, stream=True):
sb.write(part['response']) sb.write(part['response'])
if part['done'] or datetime.now() - t > timedelta(seconds=1): if part['done'] or datetime.now() - t > timedelta(seconds=1):
@ -133,16 +146,20 @@ class Discollama:
sb.seek(0, io.SEEK_SET) sb.seek(0, io.SEEK_SET)
sb.truncate() sb.truncate()
async def save(self, channel_id, message_id, ctx: list[int]): async def save(self, channel_id, message_id, context, images):
self.redis.set(f'discollama:channel:{channel_id}', message_id, ex=60 * 60 * 24 * 7) self.redis.set(f'discollama:channel:{channel_id}', message_id, ex=60 * 60 * 24 * 7)
self.redis.set(f'discollama:message:{message_id}', json.dumps(ctx), ex=60 * 60 * 24 * 7) self.redis.set(f'discollama:message:{message_id}', json.dumps(context), ex=60 * 60 * 24 * 7)
async def load(self, channel_id=None, message_id=None) -> list[int]: images = [b64encode(image).decode('utf-8') for image in images]
self.redis.set(f'discollama:images:{message_id}', json.dumps(images), ex=60 * 60 * 24 * 7)
async def load(self, channel_id=None, message_id=None):
if channel_id: if channel_id:
message_id = self.redis.get(f'discollama:channel:{channel_id}') message_id = self.redis.get(f'discollama:channel:{channel_id}')
ctx = self.redis.get(f'discollama:message:{message_id}') context = self.redis.get(f'discollama:message:{message_id}')
return json.loads(ctx) if ctx else [] images = self.redis.get(f'discollama:images:{message_id}')
return json.loads(context) if context else [], [b64decode(image) for image in json.loads(images)] if images else []
def run(self, token): def run(self, token):
try: try:
@ -150,6 +167,12 @@ class Discollama:
except Exception: except Exception:
self.redis.close() self.redis.close()
async def setup_hook(self):
for key, value in self.models.items():
logging.info('Downloading %s model %s...', key, value)
await self.ollama.pull(value)
logging.info('Downloading %s model %s... done', key, value)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -157,22 +180,23 @@ def main():
parser.add_argument('--ollama-scheme', default=os.getenv('OLLAMA_SCHEME', 'http'), choices=['http', 'https']) parser.add_argument('--ollama-scheme', default=os.getenv('OLLAMA_SCHEME', 'http'), choices=['http', 'https'])
parser.add_argument('--ollama-host', default=os.getenv('OLLAMA_HOST', '127.0.0.1'), type=str) parser.add_argument('--ollama-host', default=os.getenv('OLLAMA_HOST', '127.0.0.1'), type=str)
parser.add_argument('--ollama-port', default=os.getenv('OLLAMA_PORT', 11434), type=int) parser.add_argument('--ollama-port', default=os.getenv('OLLAMA_PORT', 11434), type=int)
parser.add_argument('--ollama-model', default=os.getenv('OLLAMA_MODEL', 'llama2'), type=str) parser.add_argument('--ollama-model', default=os.getenv('OLLAMA_MODEL', 'llama2'), type=str)
parser.add_argument('--ollama-images-model', default=os.getenv('OLLAMA_IMAGES_MODEL', 'llava'), type=str)
parser.add_argument('--redis-host', default=os.getenv('REDIS_HOST', '127.0.0.1'), type=str) parser.add_argument('--redis-host', default=os.getenv('REDIS_HOST', '127.0.0.1'), type=str)
parser.add_argument('--redis-port', default=os.getenv('REDIS_PORT', 6379), type=int) parser.add_argument('--redis-port', default=os.getenv('REDIS_PORT', 6379), type=int)
parser.add_argument('--buffer-size', default=32, type=int)
args = parser.parse_args() args = parser.parse_args()
intents = discord.Intents.default()
intents.message_content = True
Discollama( Discollama(
ollama.AsyncClient(base_url=f'{args.ollama_scheme}://{args.ollama_host}:{args.ollama_port}'), ollama.AsyncClient(base_url=f'{args.ollama_scheme}://{args.ollama_host}:{args.ollama_port}'),
discord.Client(intents=intents), discord.Client(intents=discord.Intents.default()),
redis.Redis(host=args.redis_host, port=args.redis_port, db=0, decode_responses=True), redis.Redis(host=args.redis_host, port=args.redis_port, db=0, decode_responses=True),
{
'': args.ollama_model,
'images': args.ollama_images_model,
},
).run(os.environ['DISCORD_TOKEN']) ).run(os.environ['DISCORD_TOKEN'])