From cfe47a5cdd2166083761c68081a9f89a48d1b92f Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Sun, 24 Dec 2023 19:31:17 -0800 Subject: [PATCH] use ollama client --- discollama.py | 270 ++++++++++++++++++++++++++----------------------- poetry.lock | 116 ++++++++++++++++++++- pyproject.toml | 1 + 3 files changed, 260 insertions(+), 127 deletions(-) diff --git a/discollama.py b/discollama.py index 24b85d5..076e3f0 100644 --- a/discollama.py +++ b/discollama.py @@ -1,160 +1,180 @@ +import io import os import json -import aiohttp -import discord +import asyncio import argparse -from redis import Redis +from datetime import datetime, timedelta -import logging +import ollama +import discord +import redis -intents = discord.Intents.default() -intents.message_content = True +from logging import getLogger -client = discord.Client(intents=intents) +# piggy back on the logger discord.py set up +logging = getLogger('discord.discollama') -@client.event -async def on_ready(): - logging.info( - 'Ready! Invite URL: %s', - discord.utils.oauth_url( - client.application_id, - permissions=discord.Permissions(read_messages=True, send_messages=True), - scopes=['bot'], - )) +class Response: + def __init__(self, message): + self.message = message + self.channel = message.channel + + self.r = None + self.sb = io.StringIO() + + async def write(self, s, end=''): + if self.sb.seek(0, io.SEEK_END) + len(s) + len(end) > 2000: + self.r = None + self.sb.seek(0, io.SEEK_SET) + self.sb.truncate() + + self.sb.write(s) + + if self.r: + await self.r.edit(content=self.sb.getvalue() + end) + return + + if self.channel.type == discord.ChannelType.text: + self.channel = await self.channel.create_thread(name='Discollama Says', message=self.message, auto_archive_duration=60) + + self.r = await self.channel.send(self.sb.getvalue()) -async def generate_response(prompt, context=[]): - body = { - key: value - for key, value in { - 'model': args.ollama_model, - 'prompt': prompt, - 'context': context, - }.items() if value - } +class Discollama: + def __init__(self, ollama, discord, redis): + self.ollama = ollama + self.discord = discord + self.redis = redis - async with aiohttp.ClientSession() as session: - async with session.post( - f'http://{args.ollama_host}:{args.ollama_port}/api/generate', - json=body) as r: - async for line in r.content: - yield json.loads(line) + # register event handlers + self.discord.event(self.on_ready) + self.discord.event(self.on_message) + async def on_ready(self): + activity = discord.Activity(name='Discollama', state='Ask me anything!', type=discord.ActivityType.custom) + await self.discord.change_presence(activity=activity) -async def buffered_generate_response(prompt, context=[]): - buffer = '' - async for part in generate_response(prompt, context): - if error := part.get('error'): - raise Exception(error) - - if part['done']: - yield buffer, part - break - - buffer += part['response'] - if len(buffer) >= args.buffer_size: - yield buffer, part - buffer = '' - - -def save_session(response, part): - context = part.get('context', []) - redis.json().set(f'ollama:{response.id}', '$', {'context': context}) - - redis.expire(f'ollama:{response.id}', 60 * 60 * 24 * 7) - logging.info('saving message=%s: len(context)=%d', response.id, len(context)) - - -def load_session(reference): - kwargs = {} - if reference: - context = redis.json().get(f'ollama:{reference.message_id}', '.context') - kwargs['context'] = context or [] - - if kwargs.get('context'): logging.info( - 'loading message=%s: len(context)=%d', - reference.message_id, - len(kwargs['context']), + 'Ready! Invite URL: %s', + discord.utils.oauth_url( + self.discord.application_id, + permissions=discord.Permissions( + read_messages=True, + send_messages=True, + create_public_threads=True, + ), + scopes=['bot'], + ), ) - return kwargs + async def on_message(self, message): + if self.discord.user == message.author: + # don't respond to ourselves + return + if not self.discord.user.mentioned_in(message): + # don't respond to messages that don't mention us + return -@client.event -async def on_message(message): - if message.author == client.user: - return + content = message.content.replace(f'<@{self.discord.user.id}>', '').strip() + if not content: + content = 'Hi!' - if client.user.id in message.raw_mentions: - raw_content = message.content.replace(f'<@{client.user.id}>', '').strip() - if raw_content.strip() == '': - raw_content = 'Tell me about yourself.' + channel = message.channel - response = None - response_content = '' - async with message.channel.typing(): + context = [] + if reference := message.reference: + context = await self.load(message_id=reference.message_id) + if not context: + reference_message = await message.channel.fetch_message(reference.message_id) + content = '\n'.join( + [ + content, + 'Use this to answer the question if it is relevant, otherwise ignore it:', + reference_message.content, + ] + ) + + if not context: + context = await self.load(channel_id=channel.id) + + r = Response(message) + task = asyncio.create_task(self.thinking(message)) + async for part in self.generate(content, context): + task.cancel() + + await r.write(part['response'], end='...') + + await r.write('') + await self.save(r.channel.id, message.id, part['context']) + + async def thinking(self, message, timeout=999): + try: await message.add_reaction('🤔') + async with message.channel.typing(): + await asyncio.sleep(timeout) + except Exception: + pass + finally: + await message.remove_reaction('🤔', self.discord.user) - context = [] - if reference := message.reference: - if session := load_session(message.reference): - context = session.get('context', []) - else: - reference_message = await message.channel.fetch_message( - reference.message_id) - reference_content = reference_message.content - raw_content = '\n'.join([ - raw_content, - 'Use it to answer the prompt:', - reference_content, - ]) + async def generate(self, content, context): + sb = io.StringIO() - async for buffer, part in buffered_generate_response( - raw_content, - context=context, - ): - response_content += buffer - if part['done']: - save_session(response, part) - break + t = datetime.now() + async for part in await self.ollama.generate(model='llama2', prompt=content, context=context, stream=True): + sb.write(part['response']) - if not response: - response = await message.reply(response_content) - await message.remove_reaction('🤔', client.user) - continue + if part['done'] or datetime.now() - t > timedelta(seconds=1): + part['response'] = sb.getvalue() + yield part + t = datetime.now() + sb.seek(0, io.SEEK_SET) + sb.truncate() - if len(response_content) + 3 >= 2000: - response = await response.reply(buffer) - response_content = buffer - continue + async def save(self, channel_id, message_id, ctx: list[int]): + self.redis.set(f'discollama:channel:{channel_id}', message_id, ex=60 * 60 * 24 * 7) + self.redis.set(f'discollama:message:{message_id}', json.dumps(ctx), ex=60 * 60 * 24 * 7) - await response.edit(content=response_content + '...') + async def load(self, channel_id=None, message_id=None) -> list[int]: + if channel_id: + message_id = self.redis.get(f'discollama:channel:{channel_id}') - await response.edit(content=response_content) + ctx = self.redis.get(f'discollama:message:{message_id}') + return json.loads(ctx) if ctx else [] + + def run(self, token): + try: + self.discord.run(token) + except Exception: + self.redis.close() -default_ollama_host = os.getenv('OLLAMA_HOST', '127.0.0.1') -default_ollama_port = os.getenv('OLLAMA_PORT', 11434) -default_ollama_model = os.getenv('OLLAMA_MODEL', 'llama2') +def main(): + parser = argparse.ArgumentParser() -parser = argparse.ArgumentParser() -parser.add_argument('--ollama-host', default=default_ollama_host) -parser.add_argument('--ollama-port', default=default_ollama_port, type=int) -parser.add_argument('--ollama-model', default=default_ollama_model, type=str) + parser.add_argument('--ollama-scheme', default=os.getenv('OLLAMA_SCHEME', 'http'), choices=['http', 'https']) + parser.add_argument('--ollama-host', default=os.getenv('OLLAMA_HOST', '127.0.0.1'), type=str) + parser.add_argument('--ollama-port', default=os.getenv('OLLAMA_PORT', 11434), type=int) + parser.add_argument('--ollama-model', default=os.getenv('OLLAMA_MODEL', 'llama2'), type=str) -parser.add_argument('--redis-host', default='localhost') -parser.add_argument('--redis-port', default=6379) + parser.add_argument('--redis-host', default=os.getenv('REDIS_HOST', '127.0.0.1'), type=str) + parser.add_argument('--redis-port', default=os.getenv('REDIS_PORT', 6379), type=int) -parser.add_argument('--buffer-size', default=32, type=int) + parser.add_argument('--buffer-size', default=32, type=int) -args = parser.parse_args() + args = parser.parse_args() -try: - redis = Redis(host=args.redis_host, port=args.redis_port) - client.run(os.getenv('DISCORD_TOKEN'), root_logger=True) -except KeyboardInterrupt: - pass + intents = discord.Intents.default() + intents.message_content = True -redis.close() + Discollama( + ollama.AsyncClient(base_url=f'{args.ollama_scheme}://{args.ollama_host}:{args.ollama_port}'), + discord.Client(intents=intents), + redis.Redis(host=args.redis_host, port=args.redis_port, db=0, decode_responses=True), + ).run(os.environ['DISCORD_TOKEN']) + + +if __name__ == '__main__': + main() diff --git a/poetry.lock b/poetry.lock index 1f33bc7..d7dc4b4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -109,6 +109,26 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "anyio" +version = "4.2.0" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +optional = false +python-versions = ">=3.8" +files = [ + {file = "anyio-4.2.0-py3-none-any.whl", hash = "sha256:745843b39e829e108e518c489b31dc757de7d2131d53fac32bd8df268227bfee"}, + {file = "anyio-4.2.0.tar.gz", hash = "sha256:e1875bb4b4e2de1669f4bc7869b6d3f54231cdced71605e6e64c9be77e3be50f"}, +] + +[package.dependencies] +idna = ">=2.8" +sniffio = ">=1.1" + +[package.extras] +doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +trio = ["trio (>=0.23)"] + [[package]] name = "async-timeout" version = "4.0.3" @@ -138,6 +158,17 @@ docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib- tests = ["attrs[tests-no-zope]", "zope-interface"] tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +[[package]] +name = "certifi" +version = "2023.11.17" +description = "Python package for providing Mozilla's CA Bundle." +optional = false +python-versions = ">=3.6" +files = [ + {file = "certifi-2023.11.17-py3-none-any.whl", hash = "sha256:e036ab49d5b79556f99cfc2d9320b34cfbe5be05c5871b51de9329f0603b0474"}, + {file = "certifi-2023.11.17.tar.gz", hash = "sha256:9b469f3a900bf28dc19b8cfbf8019bf47f7fdd1a65a1d4ffb98fc14166beb4d1"}, +] + [[package]] name = "discord-py" version = "2.3.2" @@ -244,6 +275,62 @@ files = [ {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, ] +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + +[[package]] +name = "httpcore" +version = "1.0.2" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"}, + {file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.23.0)"] + +[[package]] +name = "httpx" +version = "0.25.2" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-0.25.2-py3-none-any.whl", hash = "sha256:a05d3d052d9b2dfce0e3896636467f8a5342fb2b902c819428e1ac65413ca118"}, + {file = "httpx-0.25.2.tar.gz", hash = "sha256:8b8fcaa0c8ea7b05edd69a094e63a2094c4efcb48129fb757361bc423c0ad9e8"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] + [[package]] name = "idna" version = "3.6" @@ -338,6 +425,20 @@ files = [ {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, ] +[[package]] +name = "ollama" +version = "0.1.0" +description = "The official Python client for Ollama." +optional = false +python-versions = ">=3.8,<4.0" +files = [ + {file = "ollama-0.1.0-py3-none-any.whl", hash = "sha256:5c810773f13ee7c0078fe4ae3f5ba4939b27057ddd7e63d7d35fdd6cb8e6978f"}, + {file = "ollama-0.1.0.tar.gz", hash = "sha256:701d6824653f3827a1ee80bb605d0560e784cde55fc1dfd17ada438636ee4fed"}, +] + +[package.dependencies] +httpx = ">=0.25.2,<0.26.0" + [[package]] name = "redis" version = "5.0.1" @@ -356,6 +457,17 @@ async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2 hiredis = ["hiredis (>=1.0.0)"] ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] +[[package]] +name = "sniffio" +version = "1.3.0" +description = "Sniff out which async library your code is running under" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"}, + {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, +] + [[package]] name = "yarl" version = "1.9.4" @@ -462,4 +574,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "4102f42c28a97587f87680d10290146c0a4a03f767feb1acfa9f8c0c8ca6962d" +content-hash = "b2e530f606719d9f929f04b71536c228ba331f0e9b0976288ff0e4fcc55788f4" diff --git a/pyproject.toml b/pyproject.toml index f77790d..f022125 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ readme = "README.md" python = "^3.11" discord-py = "^2.3.1" redis = "^5.0.1" +ollama = "^0.1.0" [build-system] requires = ["poetry-core"]