add multimodal response

multimodal
Michael Yang 2023-12-24 21:09:20 -08:00
parent 5ba591fbfb
commit eb955228e2
1 changed files with 34 additions and 19 deletions

View File

@ -4,6 +4,7 @@ import json
import asyncio import asyncio
import argparse import argparse
from datetime import datetime, timedelta from datetime import datetime, timedelta
from base64 import b64encode, b64decode
import ollama import ollama
import discord import discord
@ -31,6 +32,9 @@ class Response:
self.sb.write(s) self.sb.write(s)
if self.sb.seek(0, io.SEEK_END) == 0:
return
if self.r: if self.r:
await self.r.edit(content=self.sb.getvalue() + end) await self.r.edit(content=self.sb.getvalue() + end)
return return
@ -42,11 +46,13 @@ class Response:
class Discollama: class Discollama:
def __init__(self, ollama, discord, redis): def __init__(self, ollama, discord, redis, models):
self.ollama = ollama self.ollama = ollama
self.discord = discord self.discord = discord
self.redis = redis self.redis = redis
self.models = models
# register event handlers # register event handlers
self.discord.event(self.on_ready) self.discord.event(self.on_ready)
self.discord.event(self.on_message) self.discord.event(self.on_message)
@ -83,9 +89,9 @@ class Discollama:
channel = message.channel channel = message.channel
context = [] context, images = [], []
if reference := message.reference: if reference := message.reference:
context = await self.load(message_id=reference.message_id) context, images = await self.load(message_id=reference.message_id)
if not context: if not context:
reference_message = await message.channel.fetch_message(reference.message_id) reference_message = await message.channel.fetch_message(reference.message_id)
content = '\n'.join( content = '\n'.join(
@ -97,17 +103,19 @@ class Discollama:
) )
if not context: if not context:
context = await self.load(channel_id=channel.id) context, images = await self.load(channel_id=channel.id)
images.extend([await attachment.read() for attachment in message.attachments if attachment.content_type.startswith('image/')])
r = Response(message) r = Response(message)
task = asyncio.create_task(self.thinking(message)) task = asyncio.create_task(self.thinking(message))
async for part in self.generate(content, context): async for part in self.generate(content, context, images=images):
task.cancel() task.cancel()
await r.write(part['response'], end='...') await r.write(part['response'], end='...')
await r.write('') await r.write('')
await self.save(r.channel.id, message.id, part['context']) await self.save(r.channel.id, message.id, part['context'], images)
async def thinking(self, message, timeout=999): async def thinking(self, message, timeout=999):
try: try:
@ -119,11 +127,13 @@ class Discollama:
finally: finally:
await message.remove_reaction('🤔', self.discord.user) await message.remove_reaction('🤔', self.discord.user)
async def generate(self, content, context): async def generate(self, content, context, images=None):
model = self.models['images' if images else 'default']
sb = io.StringIO() sb = io.StringIO()
t = datetime.now() t = datetime.now()
async for part in await self.ollama.generate(model='llama2', prompt=content, context=context, stream=True): async for part in await self.ollama.generate(model=model, prompt=content, context=context, images=images, stream=True):
sb.write(part['response']) sb.write(part['response'])
if part['done'] or datetime.now() - t > timedelta(seconds=1): if part['done'] or datetime.now() - t > timedelta(seconds=1):
@ -133,16 +143,20 @@ class Discollama:
sb.seek(0, io.SEEK_SET) sb.seek(0, io.SEEK_SET)
sb.truncate() sb.truncate()
async def save(self, channel_id, message_id, ctx: list[int]): async def save(self, channel_id, message_id, context, images):
self.redis.set(f'discollama:channel:{channel_id}', message_id, ex=60 * 60 * 24 * 7) self.redis.set(f'discollama:channel:{channel_id}', message_id, ex=60 * 60 * 24 * 7)
self.redis.set(f'discollama:message:{message_id}', json.dumps(ctx), ex=60 * 60 * 24 * 7) self.redis.set(f'discollama:message:{message_id}', json.dumps(context), ex=60 * 60 * 24 * 7)
async def load(self, channel_id=None, message_id=None) -> list[int]: images = [b64encode(image).decode('utf-8') for image in images]
self.redis.set(f'discollama:images:{message_id}', json.dumps(images), ex=60 * 60 * 24 * 7)
async def load(self, channel_id=None, message_id=None):
if channel_id: if channel_id:
message_id = self.redis.get(f'discollama:channel:{channel_id}') message_id = self.redis.get(f'discollama:channel:{channel_id}')
ctx = self.redis.get(f'discollama:message:{message_id}') context = self.redis.get(f'discollama:message:{message_id}')
return json.loads(ctx) if ctx else [] images = self.redis.get(f'discollama:images:{message_id}')
return json.loads(context) if context else [], [b64decode(image) for image in json.loads(images)] if images else []
def run(self, token): def run(self, token):
try: try:
@ -157,22 +171,23 @@ def main():
parser.add_argument('--ollama-scheme', default=os.getenv('OLLAMA_SCHEME', 'http'), choices=['http', 'https']) parser.add_argument('--ollama-scheme', default=os.getenv('OLLAMA_SCHEME', 'http'), choices=['http', 'https'])
parser.add_argument('--ollama-host', default=os.getenv('OLLAMA_HOST', '127.0.0.1'), type=str) parser.add_argument('--ollama-host', default=os.getenv('OLLAMA_HOST', '127.0.0.1'), type=str)
parser.add_argument('--ollama-port', default=os.getenv('OLLAMA_PORT', 11434), type=int) parser.add_argument('--ollama-port', default=os.getenv('OLLAMA_PORT', 11434), type=int)
parser.add_argument('--ollama-model', default=os.getenv('OLLAMA_MODEL', 'llama2'), type=str) parser.add_argument('--ollama-model', default=os.getenv('OLLAMA_MODEL', 'llama2'), type=str)
parser.add_argument('--ollama-images-model', default=os.getenv('OLLAMA_IMAGES_MODEL', 'llava'), type=str)
parser.add_argument('--redis-host', default=os.getenv('REDIS_HOST', '127.0.0.1'), type=str) parser.add_argument('--redis-host', default=os.getenv('REDIS_HOST', '127.0.0.1'), type=str)
parser.add_argument('--redis-port', default=os.getenv('REDIS_PORT', 6379), type=int) parser.add_argument('--redis-port', default=os.getenv('REDIS_PORT', 6379), type=int)
parser.add_argument('--buffer-size', default=32, type=int)
args = parser.parse_args() args = parser.parse_args()
intents = discord.Intents.default()
intents.message_content = True
Discollama( Discollama(
ollama.AsyncClient(base_url=f'{args.ollama_scheme}://{args.ollama_host}:{args.ollama_port}'), ollama.AsyncClient(base_url=f'{args.ollama_scheme}://{args.ollama_host}:{args.ollama_port}'),
discord.Client(intents=intents), discord.Client(intents=discord.Intents.default()),
redis.Redis(host=args.redis_host, port=args.redis_port, db=0, decode_responses=True), redis.Redis(host=args.redis_host, port=args.redis_port, db=0, decode_responses=True),
{
'default': args.ollama_model,
'images': args.ollama_images_model,
},
).run(os.environ['DISCORD_TOKEN']) ).run(os.environ['DISCORD_TOKEN'])