Compare commits
3 Commits
main
...
multimodal
Author | SHA1 | Date |
---|---|---|
Michael Yang | a6a2866c43 | |
Michael Yang | eb955228e2 | |
Michael Yang | 5ba591fbfb |
310
discollama.py
310
discollama.py
|
@ -1,160 +1,204 @@
|
||||||
|
import io
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import aiohttp
|
import asyncio
|
||||||
import discord
|
|
||||||
import argparse
|
import argparse
|
||||||
from redis import Redis
|
from datetime import datetime, timedelta
|
||||||
|
from base64 import b64encode, b64decode
|
||||||
|
|
||||||
import logging
|
import ollama
|
||||||
|
import discord
|
||||||
|
import redis
|
||||||
|
|
||||||
intents = discord.Intents.default()
|
from logging import getLogger
|
||||||
intents.message_content = True
|
|
||||||
|
|
||||||
client = discord.Client(intents=intents)
|
# piggy back on the logger discord.py set up
|
||||||
|
logging = getLogger('discord.discollama')
|
||||||
|
|
||||||
|
|
||||||
@client.event
|
class Response:
|
||||||
async def on_ready():
|
def __init__(self, message):
|
||||||
|
self.message = message
|
||||||
|
self.channel = message.channel
|
||||||
|
|
||||||
|
self.r = None
|
||||||
|
self.sb = io.StringIO()
|
||||||
|
|
||||||
|
async def write(self, s, end=''):
|
||||||
|
if self.sb.seek(0, io.SEEK_END) + len(s) + len(end) > 2000:
|
||||||
|
self.r = None
|
||||||
|
self.sb.seek(0, io.SEEK_SET)
|
||||||
|
self.sb.truncate()
|
||||||
|
|
||||||
|
self.sb.write(s)
|
||||||
|
|
||||||
|
if self.sb.seek(0, io.SEEK_END) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.r:
|
||||||
|
await self.r.edit(content=self.sb.getvalue() + end)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.channel.type == discord.ChannelType.text:
|
||||||
|
self.channel = await self.channel.create_thread(name='Discollama Says', message=self.message, auto_archive_duration=60)
|
||||||
|
|
||||||
|
self.r = await self.channel.send(self.sb.getvalue())
|
||||||
|
|
||||||
|
|
||||||
|
class Discollama:
|
||||||
|
def __init__(self, ollama, discord, redis, models):
|
||||||
|
self.ollama = ollama
|
||||||
|
self.discord = discord
|
||||||
|
self.redis = redis
|
||||||
|
|
||||||
|
self.models = models
|
||||||
|
|
||||||
|
# registry setup hook
|
||||||
|
self.discord.setup_hook = self.setup_hook
|
||||||
|
|
||||||
|
# register event handlers
|
||||||
|
self.discord.event(self.on_ready)
|
||||||
|
self.discord.event(self.on_message)
|
||||||
|
|
||||||
|
async def on_ready(self):
|
||||||
|
activity = discord.Activity(name='Discollama', state='Ask me anything!', type=discord.ActivityType.custom)
|
||||||
|
await self.discord.change_presence(activity=activity)
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
'Ready! Invite URL: %s',
|
'Ready! Invite URL: %s',
|
||||||
discord.utils.oauth_url(
|
discord.utils.oauth_url(
|
||||||
client.application_id,
|
self.discord.application_id,
|
||||||
permissions=discord.Permissions(read_messages=True, send_messages=True),
|
permissions=discord.Permissions(
|
||||||
|
read_messages=True,
|
||||||
|
send_messages=True,
|
||||||
|
create_public_threads=True,
|
||||||
|
),
|
||||||
scopes=['bot'],
|
scopes=['bot'],
|
||||||
))
|
),
|
||||||
|
|
||||||
|
|
||||||
async def generate_response(prompt, context=[]):
|
|
||||||
body = {
|
|
||||||
key: value
|
|
||||||
for key, value in {
|
|
||||||
'model': args.ollama_model,
|
|
||||||
'prompt': prompt,
|
|
||||||
'context': context,
|
|
||||||
}.items() if value
|
|
||||||
}
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
f'http://{args.ollama_host}:{args.ollama_port}/api/generate',
|
|
||||||
json=body) as r:
|
|
||||||
async for line in r.content:
|
|
||||||
yield json.loads(line)
|
|
||||||
|
|
||||||
|
|
||||||
async def buffered_generate_response(prompt, context=[]):
|
|
||||||
buffer = ''
|
|
||||||
async for part in generate_response(prompt, context):
|
|
||||||
if error := part.get('error'):
|
|
||||||
raise Exception(error)
|
|
||||||
|
|
||||||
if part['done']:
|
|
||||||
yield buffer, part
|
|
||||||
break
|
|
||||||
|
|
||||||
buffer += part['response']
|
|
||||||
if len(buffer) >= args.buffer_size:
|
|
||||||
yield buffer, part
|
|
||||||
buffer = ''
|
|
||||||
|
|
||||||
|
|
||||||
def save_session(response, part):
|
|
||||||
context = part.get('context', [])
|
|
||||||
redis.json().set(f'ollama:{response.id}', '$', {'context': context})
|
|
||||||
|
|
||||||
redis.expire(f'ollama:{response.id}', 60 * 60 * 24 * 7)
|
|
||||||
logging.info('saving message=%s: len(context)=%d', response.id, len(context))
|
|
||||||
|
|
||||||
|
|
||||||
def load_session(reference):
|
|
||||||
kwargs = {}
|
|
||||||
if reference:
|
|
||||||
context = redis.json().get(f'ollama:{reference.message_id}', '.context')
|
|
||||||
kwargs['context'] = context or []
|
|
||||||
|
|
||||||
if kwargs.get('context'):
|
|
||||||
logging.info(
|
|
||||||
'loading message=%s: len(context)=%d',
|
|
||||||
reference.message_id,
|
|
||||||
len(kwargs['context']),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return kwargs
|
async def on_message(self, message):
|
||||||
|
if self.discord.user == message.author:
|
||||||
|
# don't respond to ourselves
|
||||||
@client.event
|
|
||||||
async def on_message(message):
|
|
||||||
if message.author == client.user:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if client.user.id in message.raw_mentions:
|
if not self.discord.user.mentioned_in(message):
|
||||||
raw_content = message.content.replace(f'<@{client.user.id}>', '').strip()
|
# don't respond to messages that don't mention us
|
||||||
if raw_content.strip() == '':
|
return
|
||||||
raw_content = 'Tell me about yourself.'
|
|
||||||
|
|
||||||
response = None
|
content = message.content.replace(f'<@{self.discord.user.id}>', '').strip()
|
||||||
response_content = ''
|
if not content:
|
||||||
async with message.channel.typing():
|
content = 'Hi!'
|
||||||
await message.add_reaction('🤔')
|
|
||||||
|
|
||||||
context = []
|
channel = message.channel
|
||||||
|
|
||||||
|
context, images = [], []
|
||||||
if reference := message.reference:
|
if reference := message.reference:
|
||||||
if session := load_session(message.reference):
|
context, images = await self.load(message_id=reference.message_id)
|
||||||
context = session.get('context', [])
|
if not context:
|
||||||
else:
|
reference_message = await message.channel.fetch_message(reference.message_id)
|
||||||
reference_message = await message.channel.fetch_message(
|
content = '\n'.join(
|
||||||
reference.message_id)
|
[
|
||||||
reference_content = reference_message.content
|
content,
|
||||||
raw_content = '\n'.join([
|
'Use this to answer the question if it is relevant, otherwise ignore it:',
|
||||||
raw_content,
|
reference_message.content,
|
||||||
'Use it to answer the prompt:',
|
]
|
||||||
reference_content,
|
)
|
||||||
])
|
|
||||||
|
|
||||||
async for buffer, part in buffered_generate_response(
|
if not context:
|
||||||
raw_content,
|
context, images = await self.load(channel_id=channel.id)
|
||||||
context=context,
|
|
||||||
):
|
|
||||||
response_content += buffer
|
|
||||||
if part['done']:
|
|
||||||
save_session(response, part)
|
|
||||||
break
|
|
||||||
|
|
||||||
if not response:
|
images.extend([await attachment.read() for attachment in message.attachments if attachment.content_type.startswith('image/')])
|
||||||
response = await message.reply(response_content)
|
|
||||||
await message.remove_reaction('🤔', client.user)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if len(response_content) + 3 >= 2000:
|
r = Response(message)
|
||||||
response = await response.reply(buffer)
|
task = asyncio.create_task(self.thinking(message))
|
||||||
response_content = buffer
|
async for part in self.generate(content, context, images=images):
|
||||||
continue
|
task.cancel()
|
||||||
|
|
||||||
await response.edit(content=response_content + '...')
|
await r.write(part['response'], end='...')
|
||||||
|
|
||||||
await response.edit(content=response_content)
|
await r.write('')
|
||||||
|
await self.save(r.channel.id, message.id, part['context'], images)
|
||||||
|
|
||||||
|
async def thinking(self, message, timeout=999):
|
||||||
default_ollama_host = os.getenv('OLLAMA_HOST', '127.0.0.1')
|
try:
|
||||||
default_ollama_port = os.getenv('OLLAMA_PORT', 11434)
|
await message.add_reaction('🤔')
|
||||||
default_ollama_model = os.getenv('OLLAMA_MODEL', 'llama2')
|
async with message.channel.typing():
|
||||||
|
await asyncio.sleep(timeout)
|
||||||
parser = argparse.ArgumentParser()
|
except Exception:
|
||||||
parser.add_argument('--ollama-host', default=default_ollama_host)
|
|
||||||
parser.add_argument('--ollama-port', default=default_ollama_port, type=int)
|
|
||||||
parser.add_argument('--ollama-model', default=default_ollama_model, type=str)
|
|
||||||
|
|
||||||
parser.add_argument('--redis-host', default='localhost')
|
|
||||||
parser.add_argument('--redis-port', default=6379)
|
|
||||||
|
|
||||||
parser.add_argument('--buffer-size', default=32, type=int)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
|
||||||
redis = Redis(host=args.redis_host, port=args.redis_port)
|
|
||||||
client.run(os.getenv('DISCORD_TOKEN'), root_logger=True)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
pass
|
||||||
|
finally:
|
||||||
|
await message.remove_reaction('🤔', self.discord.user)
|
||||||
|
|
||||||
redis.close()
|
async def generate(self, content, context, images=None):
|
||||||
|
model = self.models['images' if images else '']
|
||||||
|
|
||||||
|
sb = io.StringIO()
|
||||||
|
|
||||||
|
t = datetime.now()
|
||||||
|
async for part in await self.ollama.generate(model=model, prompt=content, context=context, images=images, stream=True):
|
||||||
|
sb.write(part['response'])
|
||||||
|
|
||||||
|
if part['done'] or datetime.now() - t > timedelta(seconds=1):
|
||||||
|
part['response'] = sb.getvalue()
|
||||||
|
yield part
|
||||||
|
t = datetime.now()
|
||||||
|
sb.seek(0, io.SEEK_SET)
|
||||||
|
sb.truncate()
|
||||||
|
|
||||||
|
async def save(self, channel_id, message_id, context, images):
|
||||||
|
self.redis.set(f'discollama:channel:{channel_id}', message_id, ex=60 * 60 * 24 * 7)
|
||||||
|
self.redis.set(f'discollama:message:{message_id}', json.dumps(context), ex=60 * 60 * 24 * 7)
|
||||||
|
|
||||||
|
images = [b64encode(image).decode('utf-8') for image in images]
|
||||||
|
self.redis.set(f'discollama:images:{message_id}', json.dumps(images), ex=60 * 60 * 24 * 7)
|
||||||
|
|
||||||
|
async def load(self, channel_id=None, message_id=None):
|
||||||
|
if channel_id:
|
||||||
|
message_id = self.redis.get(f'discollama:channel:{channel_id}')
|
||||||
|
|
||||||
|
context = self.redis.get(f'discollama:message:{message_id}')
|
||||||
|
images = self.redis.get(f'discollama:images:{message_id}')
|
||||||
|
return json.loads(context) if context else [], [b64decode(image) for image in json.loads(images)] if images else []
|
||||||
|
|
||||||
|
def run(self, token):
|
||||||
|
try:
|
||||||
|
self.discord.run(token)
|
||||||
|
except Exception:
|
||||||
|
self.redis.close()
|
||||||
|
|
||||||
|
async def setup_hook(self):
|
||||||
|
for key, value in self.models.items():
|
||||||
|
logging.info('Downloading %s model %s...', key, value)
|
||||||
|
await self.ollama.pull(value)
|
||||||
|
logging.info('Downloading %s model %s... done', key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument('--ollama-scheme', default=os.getenv('OLLAMA_SCHEME', 'http'), choices=['http', 'https'])
|
||||||
|
parser.add_argument('--ollama-host', default=os.getenv('OLLAMA_HOST', '127.0.0.1'), type=str)
|
||||||
|
parser.add_argument('--ollama-port', default=os.getenv('OLLAMA_PORT', 11434), type=int)
|
||||||
|
|
||||||
|
parser.add_argument('--ollama-model', default=os.getenv('OLLAMA_MODEL', 'llama2'), type=str)
|
||||||
|
parser.add_argument('--ollama-images-model', default=os.getenv('OLLAMA_IMAGES_MODEL', 'llava'), type=str)
|
||||||
|
|
||||||
|
parser.add_argument('--redis-host', default=os.getenv('REDIS_HOST', '127.0.0.1'), type=str)
|
||||||
|
parser.add_argument('--redis-port', default=os.getenv('REDIS_PORT', 6379), type=int)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
Discollama(
|
||||||
|
ollama.AsyncClient(base_url=f'{args.ollama_scheme}://{args.ollama_host}:{args.ollama_port}'),
|
||||||
|
discord.Client(intents=discord.Intents.default()),
|
||||||
|
redis.Redis(host=args.redis_host, port=args.redis_port, db=0, decode_responses=True),
|
||||||
|
{
|
||||||
|
'': args.ollama_model,
|
||||||
|
'images': args.ollama_images_model,
|
||||||
|
},
|
||||||
|
).run(os.environ['DISCORD_TOKEN'])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
Loading…
Reference in New Issue