use ollama client

pull/2/head
Michael Yang 2023-12-24 19:31:17 -08:00
parent 35c6f1a270
commit cfe47a5cdd
3 changed files with 260 additions and 127 deletions

View File

@ -1,160 +1,180 @@
import io
import os import os
import json import json
import aiohttp import asyncio
import discord
import argparse import argparse
from redis import Redis from datetime import datetime, timedelta
import logging import ollama
import discord
import redis
intents = discord.Intents.default() from logging import getLogger
intents.message_content = True
client = discord.Client(intents=intents) # piggy back on the logger discord.py set up
logging = getLogger('discord.discollama')
@client.event class Response:
async def on_ready(): def __init__(self, message):
logging.info( self.message = message
'Ready! Invite URL: %s', self.channel = message.channel
discord.utils.oauth_url(
client.application_id, self.r = None
permissions=discord.Permissions(read_messages=True, send_messages=True), self.sb = io.StringIO()
scopes=['bot'],
)) async def write(self, s, end=''):
if self.sb.seek(0, io.SEEK_END) + len(s) + len(end) > 2000:
self.r = None
self.sb.seek(0, io.SEEK_SET)
self.sb.truncate()
self.sb.write(s)
if self.r:
await self.r.edit(content=self.sb.getvalue() + end)
return
if self.channel.type == discord.ChannelType.text:
self.channel = await self.channel.create_thread(name='Discollama Says', message=self.message, auto_archive_duration=60)
self.r = await self.channel.send(self.sb.getvalue())
async def generate_response(prompt, context=[]): class Discollama:
body = { def __init__(self, ollama, discord, redis):
key: value self.ollama = ollama
for key, value in { self.discord = discord
'model': args.ollama_model, self.redis = redis
'prompt': prompt,
'context': context,
}.items() if value
}
async with aiohttp.ClientSession() as session: # register event handlers
async with session.post( self.discord.event(self.on_ready)
f'http://{args.ollama_host}:{args.ollama_port}/api/generate', self.discord.event(self.on_message)
json=body) as r:
async for line in r.content:
yield json.loads(line)
async def on_ready(self):
activity = discord.Activity(name='Discollama', state='Ask me anything!', type=discord.ActivityType.custom)
await self.discord.change_presence(activity=activity)
async def buffered_generate_response(prompt, context=[]):
buffer = ''
async for part in generate_response(prompt, context):
if error := part.get('error'):
raise Exception(error)
if part['done']:
yield buffer, part
break
buffer += part['response']
if len(buffer) >= args.buffer_size:
yield buffer, part
buffer = ''
def save_session(response, part):
context = part.get('context', [])
redis.json().set(f'ollama:{response.id}', '$', {'context': context})
redis.expire(f'ollama:{response.id}', 60 * 60 * 24 * 7)
logging.info('saving message=%s: len(context)=%d', response.id, len(context))
def load_session(reference):
kwargs = {}
if reference:
context = redis.json().get(f'ollama:{reference.message_id}', '.context')
kwargs['context'] = context or []
if kwargs.get('context'):
logging.info( logging.info(
'loading message=%s: len(context)=%d', 'Ready! Invite URL: %s',
reference.message_id, discord.utils.oauth_url(
len(kwargs['context']), self.discord.application_id,
permissions=discord.Permissions(
read_messages=True,
send_messages=True,
create_public_threads=True,
),
scopes=['bot'],
),
) )
return kwargs async def on_message(self, message):
if self.discord.user == message.author:
# don't respond to ourselves
return
if not self.discord.user.mentioned_in(message):
# don't respond to messages that don't mention us
return
@client.event content = message.content.replace(f'<@{self.discord.user.id}>', '').strip()
async def on_message(message): if not content:
if message.author == client.user: content = 'Hi!'
return
if client.user.id in message.raw_mentions: channel = message.channel
raw_content = message.content.replace(f'<@{client.user.id}>', '').strip()
if raw_content.strip() == '':
raw_content = 'Tell me about yourself.'
response = None context = []
response_content = '' if reference := message.reference:
async with message.channel.typing(): context = await self.load(message_id=reference.message_id)
if not context:
reference_message = await message.channel.fetch_message(reference.message_id)
content = '\n'.join(
[
content,
'Use this to answer the question if it is relevant, otherwise ignore it:',
reference_message.content,
]
)
if not context:
context = await self.load(channel_id=channel.id)
r = Response(message)
task = asyncio.create_task(self.thinking(message))
async for part in self.generate(content, context):
task.cancel()
await r.write(part['response'], end='...')
await r.write('')
await self.save(r.channel.id, message.id, part['context'])
async def thinking(self, message, timeout=999):
try:
await message.add_reaction('🤔') await message.add_reaction('🤔')
async with message.channel.typing():
await asyncio.sleep(timeout)
except Exception:
pass
finally:
await message.remove_reaction('🤔', self.discord.user)
context = [] async def generate(self, content, context):
if reference := message.reference: sb = io.StringIO()
if session := load_session(message.reference):
context = session.get('context', [])
else:
reference_message = await message.channel.fetch_message(
reference.message_id)
reference_content = reference_message.content
raw_content = '\n'.join([
raw_content,
'Use it to answer the prompt:',
reference_content,
])
async for buffer, part in buffered_generate_response( t = datetime.now()
raw_content, async for part in await self.ollama.generate(model='llama2', prompt=content, context=context, stream=True):
context=context, sb.write(part['response'])
):
response_content += buffer
if part['done']:
save_session(response, part)
break
if not response: if part['done'] or datetime.now() - t > timedelta(seconds=1):
response = await message.reply(response_content) part['response'] = sb.getvalue()
await message.remove_reaction('🤔', client.user) yield part
continue t = datetime.now()
sb.seek(0, io.SEEK_SET)
sb.truncate()
if len(response_content) + 3 >= 2000: async def save(self, channel_id, message_id, ctx: list[int]):
response = await response.reply(buffer) self.redis.set(f'discollama:channel:{channel_id}', message_id, ex=60 * 60 * 24 * 7)
response_content = buffer self.redis.set(f'discollama:message:{message_id}', json.dumps(ctx), ex=60 * 60 * 24 * 7)
continue
await response.edit(content=response_content + '...') async def load(self, channel_id=None, message_id=None) -> list[int]:
if channel_id:
message_id = self.redis.get(f'discollama:channel:{channel_id}')
await response.edit(content=response_content) ctx = self.redis.get(f'discollama:message:{message_id}')
return json.loads(ctx) if ctx else []
def run(self, token):
try:
self.discord.run(token)
except Exception:
self.redis.close()
default_ollama_host = os.getenv('OLLAMA_HOST', '127.0.0.1') def main():
default_ollama_port = os.getenv('OLLAMA_PORT', 11434) parser = argparse.ArgumentParser()
default_ollama_model = os.getenv('OLLAMA_MODEL', 'llama2')
parser = argparse.ArgumentParser() parser.add_argument('--ollama-scheme', default=os.getenv('OLLAMA_SCHEME', 'http'), choices=['http', 'https'])
parser.add_argument('--ollama-host', default=default_ollama_host) parser.add_argument('--ollama-host', default=os.getenv('OLLAMA_HOST', '127.0.0.1'), type=str)
parser.add_argument('--ollama-port', default=default_ollama_port, type=int) parser.add_argument('--ollama-port', default=os.getenv('OLLAMA_PORT', 11434), type=int)
parser.add_argument('--ollama-model', default=default_ollama_model, type=str) parser.add_argument('--ollama-model', default=os.getenv('OLLAMA_MODEL', 'llama2'), type=str)
parser.add_argument('--redis-host', default='localhost') parser.add_argument('--redis-host', default=os.getenv('REDIS_HOST', '127.0.0.1'), type=str)
parser.add_argument('--redis-port', default=6379) parser.add_argument('--redis-port', default=os.getenv('REDIS_PORT', 6379), type=int)
parser.add_argument('--buffer-size', default=32, type=int) parser.add_argument('--buffer-size', default=32, type=int)
args = parser.parse_args() args = parser.parse_args()
try: intents = discord.Intents.default()
redis = Redis(host=args.redis_host, port=args.redis_port) intents.message_content = True
client.run(os.getenv('DISCORD_TOKEN'), root_logger=True)
except KeyboardInterrupt:
pass
redis.close() Discollama(
ollama.AsyncClient(base_url=f'{args.ollama_scheme}://{args.ollama_host}:{args.ollama_port}'),
discord.Client(intents=intents),
redis.Redis(host=args.redis_host, port=args.redis_port, db=0, decode_responses=True),
).run(os.environ['DISCORD_TOKEN'])
if __name__ == '__main__':
main()

116
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. # This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
[[package]] [[package]]
name = "aiohttp" name = "aiohttp"
@ -109,6 +109,26 @@ files = [
[package.dependencies] [package.dependencies]
frozenlist = ">=1.1.0" frozenlist = ">=1.1.0"
[[package]]
name = "anyio"
version = "4.2.0"
description = "High level compatibility layer for multiple asynchronous event loop implementations"
optional = false
python-versions = ">=3.8"
files = [
{file = "anyio-4.2.0-py3-none-any.whl", hash = "sha256:745843b39e829e108e518c489b31dc757de7d2131d53fac32bd8df268227bfee"},
{file = "anyio-4.2.0.tar.gz", hash = "sha256:e1875bb4b4e2de1669f4bc7869b6d3f54231cdced71605e6e64c9be77e3be50f"},
]
[package.dependencies]
idna = ">=2.8"
sniffio = ">=1.1"
[package.extras]
doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"]
test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"]
trio = ["trio (>=0.23)"]
[[package]] [[package]]
name = "async-timeout" name = "async-timeout"
version = "4.0.3" version = "4.0.3"
@ -138,6 +158,17 @@ docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-
tests = ["attrs[tests-no-zope]", "zope-interface"] tests = ["attrs[tests-no-zope]", "zope-interface"]
tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
[[package]]
name = "certifi"
version = "2023.11.17"
description = "Python package for providing Mozilla's CA Bundle."
optional = false
python-versions = ">=3.6"
files = [
{file = "certifi-2023.11.17-py3-none-any.whl", hash = "sha256:e036ab49d5b79556f99cfc2d9320b34cfbe5be05c5871b51de9329f0603b0474"},
{file = "certifi-2023.11.17.tar.gz", hash = "sha256:9b469f3a900bf28dc19b8cfbf8019bf47f7fdd1a65a1d4ffb98fc14166beb4d1"},
]
[[package]] [[package]]
name = "discord-py" name = "discord-py"
version = "2.3.2" version = "2.3.2"
@ -244,6 +275,62 @@ files = [
{file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"},
] ]
[[package]]
name = "h11"
version = "0.14.0"
description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1"
optional = false
python-versions = ">=3.7"
files = [
{file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"},
{file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"},
]
[[package]]
name = "httpcore"
version = "1.0.2"
description = "A minimal low-level HTTP client."
optional = false
python-versions = ">=3.8"
files = [
{file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"},
{file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"},
]
[package.dependencies]
certifi = "*"
h11 = ">=0.13,<0.15"
[package.extras]
asyncio = ["anyio (>=4.0,<5.0)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
trio = ["trio (>=0.22.0,<0.23.0)"]
[[package]]
name = "httpx"
version = "0.25.2"
description = "The next generation HTTP client."
optional = false
python-versions = ">=3.8"
files = [
{file = "httpx-0.25.2-py3-none-any.whl", hash = "sha256:a05d3d052d9b2dfce0e3896636467f8a5342fb2b902c819428e1ac65413ca118"},
{file = "httpx-0.25.2.tar.gz", hash = "sha256:8b8fcaa0c8ea7b05edd69a094e63a2094c4efcb48129fb757361bc423c0ad9e8"},
]
[package.dependencies]
anyio = "*"
certifi = "*"
httpcore = "==1.*"
idna = "*"
sniffio = "*"
[package.extras]
brotli = ["brotli", "brotlicffi"]
cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
[[package]] [[package]]
name = "idna" name = "idna"
version = "3.6" version = "3.6"
@ -338,6 +425,20 @@ files = [
{file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"},
] ]
[[package]]
name = "ollama"
version = "0.1.0"
description = "The official Python client for Ollama."
optional = false
python-versions = ">=3.8,<4.0"
files = [
{file = "ollama-0.1.0-py3-none-any.whl", hash = "sha256:5c810773f13ee7c0078fe4ae3f5ba4939b27057ddd7e63d7d35fdd6cb8e6978f"},
{file = "ollama-0.1.0.tar.gz", hash = "sha256:701d6824653f3827a1ee80bb605d0560e784cde55fc1dfd17ada438636ee4fed"},
]
[package.dependencies]
httpx = ">=0.25.2,<0.26.0"
[[package]] [[package]]
name = "redis" name = "redis"
version = "5.0.1" version = "5.0.1"
@ -356,6 +457,17 @@ async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2
hiredis = ["hiredis (>=1.0.0)"] hiredis = ["hiredis (>=1.0.0)"]
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"]
[[package]]
name = "sniffio"
version = "1.3.0"
description = "Sniff out which async library your code is running under"
optional = false
python-versions = ">=3.7"
files = [
{file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"},
{file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"},
]
[[package]] [[package]]
name = "yarl" name = "yarl"
version = "1.9.4" version = "1.9.4"
@ -462,4 +574,4 @@ multidict = ">=4.0"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "4102f42c28a97587f87680d10290146c0a4a03f767feb1acfa9f8c0c8ca6962d" content-hash = "b2e530f606719d9f929f04b71536c228ba331f0e9b0976288ff0e4fcc55788f4"

View File

@ -9,6 +9,7 @@ readme = "README.md"
python = "^3.11" python = "^3.11"
discord-py = "^2.3.1" discord-py = "^2.3.1"
redis = "^5.0.1" redis = "^5.0.1"
ollama = "^0.1.0"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]