Compare commits

...

3 Commits

Author SHA1 Message Date
Michael Yang a6a2866c43 pull models on start 2024-01-03 15:05:36 -07:00
Michael Yang eb955228e2 add multimodal response 2023-12-24 22:11:30 -08:00
Michael Yang 5ba591fbfb use ollama client 2023-12-24 22:11:30 -08:00
1 changed files with 167 additions and 123 deletions

View File

@ -1,160 +1,204 @@
import io
import os import os
import json import json
import aiohttp import asyncio
import discord
import argparse import argparse
from redis import Redis from datetime import datetime, timedelta
from base64 import b64encode, b64decode
import logging import ollama
import discord
import redis
intents = discord.Intents.default() from logging import getLogger
intents.message_content = True
client = discord.Client(intents=intents) # piggy back on the logger discord.py set up
logging = getLogger('discord.discollama')
@client.event class Response:
async def on_ready(): def __init__(self, message):
logging.info( self.message = message
'Ready! Invite URL: %s', self.channel = message.channel
discord.utils.oauth_url(
client.application_id, self.r = None
permissions=discord.Permissions(read_messages=True, send_messages=True), self.sb = io.StringIO()
scopes=['bot'],
)) async def write(self, s, end=''):
if self.sb.seek(0, io.SEEK_END) + len(s) + len(end) > 2000:
self.r = None
self.sb.seek(0, io.SEEK_SET)
self.sb.truncate()
self.sb.write(s)
if self.sb.seek(0, io.SEEK_END) == 0:
return
if self.r:
await self.r.edit(content=self.sb.getvalue() + end)
return
if self.channel.type == discord.ChannelType.text:
self.channel = await self.channel.create_thread(name='Discollama Says', message=self.message, auto_archive_duration=60)
self.r = await self.channel.send(self.sb.getvalue())
async def generate_response(prompt, context=[]): class Discollama:
body = { def __init__(self, ollama, discord, redis, models):
key: value self.ollama = ollama
for key, value in { self.discord = discord
'model': args.ollama_model, self.redis = redis
'prompt': prompt,
'context': context,
}.items() if value
}
async with aiohttp.ClientSession() as session: self.models = models
async with session.post(
f'http://{args.ollama_host}:{args.ollama_port}/api/generate',
json=body) as r:
async for line in r.content:
yield json.loads(line)
# registry setup hook
self.discord.setup_hook = self.setup_hook
async def buffered_generate_response(prompt, context=[]): # register event handlers
buffer = '' self.discord.event(self.on_ready)
async for part in generate_response(prompt, context): self.discord.event(self.on_message)
if error := part.get('error'):
raise Exception(error)
if part['done']: async def on_ready(self):
yield buffer, part activity = discord.Activity(name='Discollama', state='Ask me anything!', type=discord.ActivityType.custom)
break await self.discord.change_presence(activity=activity)
buffer += part['response']
if len(buffer) >= args.buffer_size:
yield buffer, part
buffer = ''
def save_session(response, part):
context = part.get('context', [])
redis.json().set(f'ollama:{response.id}', '$', {'context': context})
redis.expire(f'ollama:{response.id}', 60 * 60 * 24 * 7)
logging.info('saving message=%s: len(context)=%d', response.id, len(context))
def load_session(reference):
kwargs = {}
if reference:
context = redis.json().get(f'ollama:{reference.message_id}', '.context')
kwargs['context'] = context or []
if kwargs.get('context'):
logging.info( logging.info(
'loading message=%s: len(context)=%d', 'Ready! Invite URL: %s',
reference.message_id, discord.utils.oauth_url(
len(kwargs['context']), self.discord.application_id,
permissions=discord.Permissions(
read_messages=True,
send_messages=True,
create_public_threads=True,
),
scopes=['bot'],
),
) )
return kwargs async def on_message(self, message):
if self.discord.user == message.author:
# don't respond to ourselves
return
if not self.discord.user.mentioned_in(message):
# don't respond to messages that don't mention us
return
@client.event content = message.content.replace(f'<@{self.discord.user.id}>', '').strip()
async def on_message(message): if not content:
if message.author == client.user: content = 'Hi!'
return
if client.user.id in message.raw_mentions: channel = message.channel
raw_content = message.content.replace(f'<@{client.user.id}>', '').strip()
if raw_content.strip() == '':
raw_content = 'Tell me about yourself.'
response = None context, images = [], []
response_content = '' if reference := message.reference:
async with message.channel.typing(): context, images = await self.load(message_id=reference.message_id)
if not context:
reference_message = await message.channel.fetch_message(reference.message_id)
content = '\n'.join(
[
content,
'Use this to answer the question if it is relevant, otherwise ignore it:',
reference_message.content,
]
)
if not context:
context, images = await self.load(channel_id=channel.id)
images.extend([await attachment.read() for attachment in message.attachments if attachment.content_type.startswith('image/')])
r = Response(message)
task = asyncio.create_task(self.thinking(message))
async for part in self.generate(content, context, images=images):
task.cancel()
await r.write(part['response'], end='...')
await r.write('')
await self.save(r.channel.id, message.id, part['context'], images)
async def thinking(self, message, timeout=999):
try:
await message.add_reaction('🤔') await message.add_reaction('🤔')
async with message.channel.typing():
await asyncio.sleep(timeout)
except Exception:
pass
finally:
await message.remove_reaction('🤔', self.discord.user)
context = [] async def generate(self, content, context, images=None):
if reference := message.reference: model = self.models['images' if images else '']
if session := load_session(message.reference):
context = session.get('context', [])
else:
reference_message = await message.channel.fetch_message(
reference.message_id)
reference_content = reference_message.content
raw_content = '\n'.join([
raw_content,
'Use it to answer the prompt:',
reference_content,
])
async for buffer, part in buffered_generate_response( sb = io.StringIO()
raw_content,
context=context,
):
response_content += buffer
if part['done']:
save_session(response, part)
break
if not response: t = datetime.now()
response = await message.reply(response_content) async for part in await self.ollama.generate(model=model, prompt=content, context=context, images=images, stream=True):
await message.remove_reaction('🤔', client.user) sb.write(part['response'])
continue
if len(response_content) + 3 >= 2000: if part['done'] or datetime.now() - t > timedelta(seconds=1):
response = await response.reply(buffer) part['response'] = sb.getvalue()
response_content = buffer yield part
continue t = datetime.now()
sb.seek(0, io.SEEK_SET)
sb.truncate()
await response.edit(content=response_content + '...') async def save(self, channel_id, message_id, context, images):
self.redis.set(f'discollama:channel:{channel_id}', message_id, ex=60 * 60 * 24 * 7)
self.redis.set(f'discollama:message:{message_id}', json.dumps(context), ex=60 * 60 * 24 * 7)
await response.edit(content=response_content) images = [b64encode(image).decode('utf-8') for image in images]
self.redis.set(f'discollama:images:{message_id}', json.dumps(images), ex=60 * 60 * 24 * 7)
async def load(self, channel_id=None, message_id=None):
if channel_id:
message_id = self.redis.get(f'discollama:channel:{channel_id}')
context = self.redis.get(f'discollama:message:{message_id}')
images = self.redis.get(f'discollama:images:{message_id}')
return json.loads(context) if context else [], [b64decode(image) for image in json.loads(images)] if images else []
def run(self, token):
try:
self.discord.run(token)
except Exception:
self.redis.close()
async def setup_hook(self):
for key, value in self.models.items():
logging.info('Downloading %s model %s...', key, value)
await self.ollama.pull(value)
logging.info('Downloading %s model %s... done', key, value)
default_ollama_host = os.getenv('OLLAMA_HOST', '127.0.0.1') def main():
default_ollama_port = os.getenv('OLLAMA_PORT', 11434) parser = argparse.ArgumentParser()
default_ollama_model = os.getenv('OLLAMA_MODEL', 'llama2')
parser = argparse.ArgumentParser() parser.add_argument('--ollama-scheme', default=os.getenv('OLLAMA_SCHEME', 'http'), choices=['http', 'https'])
parser.add_argument('--ollama-host', default=default_ollama_host) parser.add_argument('--ollama-host', default=os.getenv('OLLAMA_HOST', '127.0.0.1'), type=str)
parser.add_argument('--ollama-port', default=default_ollama_port, type=int) parser.add_argument('--ollama-port', default=os.getenv('OLLAMA_PORT', 11434), type=int)
parser.add_argument('--ollama-model', default=default_ollama_model, type=str)
parser.add_argument('--redis-host', default='localhost') parser.add_argument('--ollama-model', default=os.getenv('OLLAMA_MODEL', 'llama2'), type=str)
parser.add_argument('--redis-port', default=6379) parser.add_argument('--ollama-images-model', default=os.getenv('OLLAMA_IMAGES_MODEL', 'llava'), type=str)
parser.add_argument('--buffer-size', default=32, type=int) parser.add_argument('--redis-host', default=os.getenv('REDIS_HOST', '127.0.0.1'), type=str)
parser.add_argument('--redis-port', default=os.getenv('REDIS_PORT', 6379), type=int)
args = parser.parse_args() args = parser.parse_args()
try: Discollama(
redis = Redis(host=args.redis_host, port=args.redis_port) ollama.AsyncClient(base_url=f'{args.ollama_scheme}://{args.ollama_host}:{args.ollama_port}'),
client.run(os.getenv('DISCORD_TOKEN'), root_logger=True) discord.Client(intents=discord.Intents.default()),
except KeyboardInterrupt: redis.Redis(host=args.redis_host, port=args.redis_port, db=0, decode_responses=True),
pass {
'': args.ollama_model,
'images': args.ollama_images_model,
},
).run(os.environ['DISCORD_TOKEN'])
redis.close()
if __name__ == '__main__':
main()