discollama/discollama.py

247 lines
7.7 KiB
Python

import io
import os
import json
import asyncio
import argparse
from datetime import datetime, timedelta
import ollama
import chromadb
import discord
import redis
from logging import getLogger
# piggy back on the logger discord.py set up
logging = getLogger('discord.discollama')
class Response:
def __init__(self, message):
self.message = message
self.channel = message.channel
self.r = None
self.sb = io.StringIO()
async def write(self, s, end=''):
if self.sb.seek(0, io.SEEK_END) + len(s) + len(end) > 2000:
self.r = None
self.sb.seek(0, io.SEEK_SET)
self.sb.truncate()
self.sb.write(s)
value = self.sb.getvalue().strip()
if not value:
return
if self.r:
await self.r.edit(content=value + end)
return
if self.channel.type == discord.ChannelType.text:
self.channel = await self.channel.create_thread(name='Discollama Says', message=self.message, auto_archive_duration=60)
self.r = await self.channel.send(value)
class Discollama:
def __init__(self, ollama, discord, redis, model, collection):
self.ollama = ollama
self.discord = discord
self.redis = redis
self.model = model
self.collection = collection
# register event handlers
self.discord.event(self.on_ready)
self.discord.event(self.on_message)
async def on_ready(self):
activity = discord.Activity(name='Discollama', state='Ask me anything!', type=discord.ActivityType.custom)
await self.discord.change_presence(activity=activity)
logging.info(
'Ready! Invite URL: %s',
discord.utils.oauth_url(
self.discord.application_id,
permissions=discord.Permissions(
read_messages=True,
send_messages=True,
create_public_threads=True,
),
scopes=['bot'],
),
)
async def on_message(self, message):
if self.discord.user == message.author:
# don't respond to ourselves
return
if not self.discord.user.mentioned_in(message):
# don't respond to messages that don't mention us
return
content = message.content.replace(f'<@{self.discord.user.id}>', '').strip()
if not content:
content = 'Hi!'
channel = message.channel
context = []
if reference := message.reference:
context = await self.load(message_id=reference.message_id)
if not context:
reference_message = await message.channel.fetch_message(reference.message_id)
content = '\n'.join(
[
content,
'Use this to answer the question if it is relevant, otherwise ignore it:',
reference_message.content,
]
)
# retrieve relevant context from vector store
knowledge = self.collection.query(
query_texts=[content],
n_results=2
)
# directly unpack the first list of documents if it exists, or use an empty list
documents = knowledge.get('documents', [[]])[0]
content = '\n'.join(
[
'Using the provided document, answer the user question to the best of your ability. You must try to use information from the provided document. Combine information in the document into a coherent answer.',
'If there is nothing in the document relevant to the user question, say \'Hmm, I don\'t know about that, try referencing the docs.\', before providing any other information you know.',
'Anything between the following `document` html blocks is retrieved from a knowledge bank, not part of the conversation with the user.',
'<document>',
'\n'.join(documents) if documents else '',
'</document>',
'Anything between the following `user` html blocks is part of the conversation with the user.',
'<user>',
content,
'</user>',
]
)
if not context:
context = await self.load(channel_id=channel.id)
r = Response(message)
task = asyncio.create_task(self.thinking(message))
async for part in self.generate(content, context):
task.cancel()
await r.write(part['response'], end='...')
await r.write('')
await self.save(r.channel.id, message.id, part['context'])
async def thinking(self, message, timeout=999):
try:
await message.add_reaction('🤔')
async with message.channel.typing():
await asyncio.sleep(timeout)
except Exception:
pass
finally:
await message.remove_reaction('🤔', self.discord.user)
async def generate(self, content, context):
sb = io.StringIO()
t = datetime.now()
async for part in await self.ollama.generate(model=self.model, prompt=content, context=context, keep_alive=-1, stream=True):
sb.write(part['response'])
if part['done'] or datetime.now() - t > timedelta(seconds=1):
part['response'] = sb.getvalue()
yield part
t = datetime.now()
sb.seek(0, io.SEEK_SET)
sb.truncate()
async def save(self, channel_id, message_id, ctx: list[int]):
self.redis.set(f'discollama:channel:{channel_id}', message_id, ex=60 * 60 * 24 * 7)
self.redis.set(f'discollama:message:{message_id}', json.dumps(ctx), ex=60 * 60 * 24 * 7)
async def load(self, channel_id=None, message_id=None) -> list[int]:
if channel_id:
message_id = self.redis.get(f'discollama:channel:{channel_id}')
ctx = self.redis.get(f'discollama:message:{message_id}')
return json.loads(ctx) if ctx else []
def run(self, token):
try:
self.discord.run(token)
except Exception as e:
logging.exception("An error occurred while running the bot: %s", e)
self.redis.close()
def embed_data(collection):
logging.info('embedding data...')
documents = []
ids = []
# read all data from the data folder
for filename in os.listdir('data'):
if filename.endswith('.json'):
filepath = os.path.join('data', filename)
with open(filepath, 'r') as file:
try:
data = json.load(file)
if isinstance(data, list):
for index, item in enumerate(data):
documents.append(item)
file_id = f"{filename.rsplit('.', 1)[0]}-{index}"
ids.append(file_id)
else:
logging.warning("The file {filename} is not a JSON array.")
except json.JSONDecodeError as e:
logging.exception(f"Error decoding JSON from file {filename}: {e}")
except Exception as e:
logging.exception(f"An error occurred while processing file {filename}: {e}")
# store the data in chroma for look-up
collection.add(
documents=documents,
ids=ids,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--ollama-scheme', default=os.getenv('OLLAMA_SCHEME', 'http'), choices=['http', 'https'])
parser.add_argument('--ollama-host', default=os.getenv('OLLAMA_HOST', '127.0.0.1'), type=str)
parser.add_argument('--ollama-port', default=os.getenv('OLLAMA_PORT', 11434), type=int)
parser.add_argument('--ollama-model', default=os.getenv('OLLAMA_MODEL', 'llama2'), type=str)
parser.add_argument('--redis-host', default=os.getenv('REDIS_HOST', '127.0.0.1'), type=str)
parser.add_argument('--redis-port', default=os.getenv('REDIS_PORT', 6379), type=int)
parser.add_argument('--buffer-size', default=32, type=int)
args = parser.parse_args()
intents = discord.Intents.default()
intents.message_content = True
chroma = chromadb.Client()
collection = chroma.get_or_create_collection(name='discollama')
embed_data(collection)
Discollama(
ollama.AsyncClient(host=f'{args.ollama_scheme}://{args.ollama_host}:{args.ollama_port}'),
discord.Client(intents=intents),
redis.Redis(host=args.redis_host, port=args.redis_port, db=0, decode_responses=True),
model=args.ollama_model,
collection=collection,
).run(os.environ['DISCORD_TOKEN'])
if __name__ == '__main__':
main()