use ollama client
parent
35c6f1a270
commit
cfe47a5cdd
270
discollama.py
270
discollama.py
|
@ -1,160 +1,180 @@
|
|||
import io
|
||||
import os
|
||||
import json
|
||||
import aiohttp
|
||||
import discord
|
||||
import asyncio
|
||||
import argparse
|
||||
from redis import Redis
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import logging
|
||||
import ollama
|
||||
import discord
|
||||
import redis
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
from logging import getLogger
|
||||
|
||||
client = discord.Client(intents=intents)
|
||||
# piggy back on the logger discord.py set up
|
||||
logging = getLogger('discord.discollama')
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
class Response:
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
self.channel = message.channel
|
||||
|
||||
self.r = None
|
||||
self.sb = io.StringIO()
|
||||
|
||||
async def write(self, s, end=''):
|
||||
if self.sb.seek(0, io.SEEK_END) + len(s) + len(end) > 2000:
|
||||
self.r = None
|
||||
self.sb.seek(0, io.SEEK_SET)
|
||||
self.sb.truncate()
|
||||
|
||||
self.sb.write(s)
|
||||
|
||||
if self.r:
|
||||
await self.r.edit(content=self.sb.getvalue() + end)
|
||||
return
|
||||
|
||||
if self.channel.type == discord.ChannelType.text:
|
||||
self.channel = await self.channel.create_thread(name='Discollama Says', message=self.message, auto_archive_duration=60)
|
||||
|
||||
self.r = await self.channel.send(self.sb.getvalue())
|
||||
|
||||
|
||||
class Discollama:
|
||||
def __init__(self, ollama, discord, redis):
|
||||
self.ollama = ollama
|
||||
self.discord = discord
|
||||
self.redis = redis
|
||||
|
||||
# register event handlers
|
||||
self.discord.event(self.on_ready)
|
||||
self.discord.event(self.on_message)
|
||||
|
||||
async def on_ready(self):
|
||||
activity = discord.Activity(name='Discollama', state='Ask me anything!', type=discord.ActivityType.custom)
|
||||
await self.discord.change_presence(activity=activity)
|
||||
|
||||
logging.info(
|
||||
'Ready! Invite URL: %s',
|
||||
discord.utils.oauth_url(
|
||||
client.application_id,
|
||||
permissions=discord.Permissions(read_messages=True, send_messages=True),
|
||||
self.discord.application_id,
|
||||
permissions=discord.Permissions(
|
||||
read_messages=True,
|
||||
send_messages=True,
|
||||
create_public_threads=True,
|
||||
),
|
||||
scopes=['bot'],
|
||||
))
|
||||
|
||||
|
||||
async def generate_response(prompt, context=[]):
|
||||
body = {
|
||||
key: value
|
||||
for key, value in {
|
||||
'model': args.ollama_model,
|
||||
'prompt': prompt,
|
||||
'context': context,
|
||||
}.items() if value
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f'http://{args.ollama_host}:{args.ollama_port}/api/generate',
|
||||
json=body) as r:
|
||||
async for line in r.content:
|
||||
yield json.loads(line)
|
||||
|
||||
|
||||
async def buffered_generate_response(prompt, context=[]):
|
||||
buffer = ''
|
||||
async for part in generate_response(prompt, context):
|
||||
if error := part.get('error'):
|
||||
raise Exception(error)
|
||||
|
||||
if part['done']:
|
||||
yield buffer, part
|
||||
break
|
||||
|
||||
buffer += part['response']
|
||||
if len(buffer) >= args.buffer_size:
|
||||
yield buffer, part
|
||||
buffer = ''
|
||||
|
||||
|
||||
def save_session(response, part):
|
||||
context = part.get('context', [])
|
||||
redis.json().set(f'ollama:{response.id}', '$', {'context': context})
|
||||
|
||||
redis.expire(f'ollama:{response.id}', 60 * 60 * 24 * 7)
|
||||
logging.info('saving message=%s: len(context)=%d', response.id, len(context))
|
||||
|
||||
|
||||
def load_session(reference):
|
||||
kwargs = {}
|
||||
if reference:
|
||||
context = redis.json().get(f'ollama:{reference.message_id}', '.context')
|
||||
kwargs['context'] = context or []
|
||||
|
||||
if kwargs.get('context'):
|
||||
logging.info(
|
||||
'loading message=%s: len(context)=%d',
|
||||
reference.message_id,
|
||||
len(kwargs['context']),
|
||||
),
|
||||
)
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_message(message):
|
||||
if message.author == client.user:
|
||||
async def on_message(self, message):
|
||||
if self.discord.user == message.author:
|
||||
# don't respond to ourselves
|
||||
return
|
||||
|
||||
if client.user.id in message.raw_mentions:
|
||||
raw_content = message.content.replace(f'<@{client.user.id}>', '').strip()
|
||||
if raw_content.strip() == '':
|
||||
raw_content = 'Tell me about yourself.'
|
||||
if not self.discord.user.mentioned_in(message):
|
||||
# don't respond to messages that don't mention us
|
||||
return
|
||||
|
||||
response = None
|
||||
response_content = ''
|
||||
async with message.channel.typing():
|
||||
await message.add_reaction('🤔')
|
||||
content = message.content.replace(f'<@{self.discord.user.id}>', '').strip()
|
||||
if not content:
|
||||
content = 'Hi!'
|
||||
|
||||
channel = message.channel
|
||||
|
||||
context = []
|
||||
if reference := message.reference:
|
||||
if session := load_session(message.reference):
|
||||
context = session.get('context', [])
|
||||
else:
|
||||
reference_message = await message.channel.fetch_message(
|
||||
reference.message_id)
|
||||
reference_content = reference_message.content
|
||||
raw_content = '\n'.join([
|
||||
raw_content,
|
||||
'Use it to answer the prompt:',
|
||||
reference_content,
|
||||
])
|
||||
context = await self.load(message_id=reference.message_id)
|
||||
if not context:
|
||||
reference_message = await message.channel.fetch_message(reference.message_id)
|
||||
content = '\n'.join(
|
||||
[
|
||||
content,
|
||||
'Use this to answer the question if it is relevant, otherwise ignore it:',
|
||||
reference_message.content,
|
||||
]
|
||||
)
|
||||
|
||||
async for buffer, part in buffered_generate_response(
|
||||
raw_content,
|
||||
context=context,
|
||||
):
|
||||
response_content += buffer
|
||||
if part['done']:
|
||||
save_session(response, part)
|
||||
break
|
||||
if not context:
|
||||
context = await self.load(channel_id=channel.id)
|
||||
|
||||
if not response:
|
||||
response = await message.reply(response_content)
|
||||
await message.remove_reaction('🤔', client.user)
|
||||
continue
|
||||
r = Response(message)
|
||||
task = asyncio.create_task(self.thinking(message))
|
||||
async for part in self.generate(content, context):
|
||||
task.cancel()
|
||||
|
||||
if len(response_content) + 3 >= 2000:
|
||||
response = await response.reply(buffer)
|
||||
response_content = buffer
|
||||
continue
|
||||
await r.write(part['response'], end='...')
|
||||
|
||||
await response.edit(content=response_content + '...')
|
||||
await r.write('')
|
||||
await self.save(r.channel.id, message.id, part['context'])
|
||||
|
||||
await response.edit(content=response_content)
|
||||
async def thinking(self, message, timeout=999):
|
||||
try:
|
||||
await message.add_reaction('🤔')
|
||||
async with message.channel.typing():
|
||||
await asyncio.sleep(timeout)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
await message.remove_reaction('🤔', self.discord.user)
|
||||
|
||||
async def generate(self, content, context):
|
||||
sb = io.StringIO()
|
||||
|
||||
t = datetime.now()
|
||||
async for part in await self.ollama.generate(model='llama2', prompt=content, context=context, stream=True):
|
||||
sb.write(part['response'])
|
||||
|
||||
if part['done'] or datetime.now() - t > timedelta(seconds=1):
|
||||
part['response'] = sb.getvalue()
|
||||
yield part
|
||||
t = datetime.now()
|
||||
sb.seek(0, io.SEEK_SET)
|
||||
sb.truncate()
|
||||
|
||||
async def save(self, channel_id, message_id, ctx: list[int]):
|
||||
self.redis.set(f'discollama:channel:{channel_id}', message_id, ex=60 * 60 * 24 * 7)
|
||||
self.redis.set(f'discollama:message:{message_id}', json.dumps(ctx), ex=60 * 60 * 24 * 7)
|
||||
|
||||
async def load(self, channel_id=None, message_id=None) -> list[int]:
|
||||
if channel_id:
|
||||
message_id = self.redis.get(f'discollama:channel:{channel_id}')
|
||||
|
||||
ctx = self.redis.get(f'discollama:message:{message_id}')
|
||||
return json.loads(ctx) if ctx else []
|
||||
|
||||
def run(self, token):
|
||||
try:
|
||||
self.discord.run(token)
|
||||
except Exception:
|
||||
self.redis.close()
|
||||
|
||||
|
||||
default_ollama_host = os.getenv('OLLAMA_HOST', '127.0.0.1')
|
||||
default_ollama_port = os.getenv('OLLAMA_PORT', 11434)
|
||||
default_ollama_model = os.getenv('OLLAMA_MODEL', 'llama2')
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--ollama-host', default=default_ollama_host)
|
||||
parser.add_argument('--ollama-port', default=default_ollama_port, type=int)
|
||||
parser.add_argument('--ollama-model', default=default_ollama_model, type=str)
|
||||
|
||||
parser.add_argument('--redis-host', default='localhost')
|
||||
parser.add_argument('--redis-port', default=6379)
|
||||
parser.add_argument('--ollama-scheme', default=os.getenv('OLLAMA_SCHEME', 'http'), choices=['http', 'https'])
|
||||
parser.add_argument('--ollama-host', default=os.getenv('OLLAMA_HOST', '127.0.0.1'), type=str)
|
||||
parser.add_argument('--ollama-port', default=os.getenv('OLLAMA_PORT', 11434), type=int)
|
||||
parser.add_argument('--ollama-model', default=os.getenv('OLLAMA_MODEL', 'llama2'), type=str)
|
||||
|
||||
parser.add_argument('--redis-host', default=os.getenv('REDIS_HOST', '127.0.0.1'), type=str)
|
||||
parser.add_argument('--redis-port', default=os.getenv('REDIS_PORT', 6379), type=int)
|
||||
|
||||
parser.add_argument('--buffer-size', default=32, type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
redis = Redis(host=args.redis_host, port=args.redis_port)
|
||||
client.run(os.getenv('DISCORD_TOKEN'), root_logger=True)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
|
||||
redis.close()
|
||||
Discollama(
|
||||
ollama.AsyncClient(base_url=f'{args.ollama_scheme}://{args.ollama_host}:{args.ollama_port}'),
|
||||
discord.Client(intents=intents),
|
||||
redis.Redis(host=args.redis_host, port=args.redis_port, db=0, decode_responses=True),
|
||||
).run(os.environ['DISCORD_TOKEN'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohttp"
|
||||
|
@ -109,6 +109,26 @@ files = [
|
|||
[package.dependencies]
|
||||
frozenlist = ">=1.1.0"
|
||||
|
||||
[[package]]
|
||||
name = "anyio"
|
||||
version = "4.2.0"
|
||||
description = "High level compatibility layer for multiple asynchronous event loop implementations"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "anyio-4.2.0-py3-none-any.whl", hash = "sha256:745843b39e829e108e518c489b31dc757de7d2131d53fac32bd8df268227bfee"},
|
||||
{file = "anyio-4.2.0.tar.gz", hash = "sha256:e1875bb4b4e2de1669f4bc7869b6d3f54231cdced71605e6e64c9be77e3be50f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
idna = ">=2.8"
|
||||
sniffio = ">=1.1"
|
||||
|
||||
[package.extras]
|
||||
doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"]
|
||||
test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"]
|
||||
trio = ["trio (>=0.23)"]
|
||||
|
||||
[[package]]
|
||||
name = "async-timeout"
|
||||
version = "4.0.3"
|
||||
|
@ -138,6 +158,17 @@ docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-
|
|||
tests = ["attrs[tests-no-zope]", "zope-interface"]
|
||||
tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2023.11.17"
|
||||
description = "Python package for providing Mozilla's CA Bundle."
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "certifi-2023.11.17-py3-none-any.whl", hash = "sha256:e036ab49d5b79556f99cfc2d9320b34cfbe5be05c5871b51de9329f0603b0474"},
|
||||
{file = "certifi-2023.11.17.tar.gz", hash = "sha256:9b469f3a900bf28dc19b8cfbf8019bf47f7fdd1a65a1d4ffb98fc14166beb4d1"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "discord-py"
|
||||
version = "2.3.2"
|
||||
|
@ -244,6 +275,62 @@ files = [
|
|||
{file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h11"
|
||||
version = "0.14.0"
|
||||
description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"},
|
||||
{file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httpcore"
|
||||
version = "1.0.2"
|
||||
description = "A minimal low-level HTTP client."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"},
|
||||
{file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
certifi = "*"
|
||||
h11 = ">=0.13,<0.15"
|
||||
|
||||
[package.extras]
|
||||
asyncio = ["anyio (>=4.0,<5.0)"]
|
||||
http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
trio = ["trio (>=0.22.0,<0.23.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "httpx"
|
||||
version = "0.25.2"
|
||||
description = "The next generation HTTP client."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "httpx-0.25.2-py3-none-any.whl", hash = "sha256:a05d3d052d9b2dfce0e3896636467f8a5342fb2b902c819428e1ac65413ca118"},
|
||||
{file = "httpx-0.25.2.tar.gz", hash = "sha256:8b8fcaa0c8ea7b05edd69a094e63a2094c4efcb48129fb757361bc423c0ad9e8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = "*"
|
||||
certifi = "*"
|
||||
httpcore = "==1.*"
|
||||
idna = "*"
|
||||
sniffio = "*"
|
||||
|
||||
[package.extras]
|
||||
brotli = ["brotli", "brotlicffi"]
|
||||
cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
|
||||
http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.6"
|
||||
|
@ -338,6 +425,20 @@ files = [
|
|||
{file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ollama"
|
||||
version = "0.1.0"
|
||||
description = "The official Python client for Ollama."
|
||||
optional = false
|
||||
python-versions = ">=3.8,<4.0"
|
||||
files = [
|
||||
{file = "ollama-0.1.0-py3-none-any.whl", hash = "sha256:5c810773f13ee7c0078fe4ae3f5ba4939b27057ddd7e63d7d35fdd6cb8e6978f"},
|
||||
{file = "ollama-0.1.0.tar.gz", hash = "sha256:701d6824653f3827a1ee80bb605d0560e784cde55fc1dfd17ada438636ee4fed"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = ">=0.25.2,<0.26.0"
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "5.0.1"
|
||||
|
@ -356,6 +457,17 @@ async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2
|
|||
hiredis = ["hiredis (>=1.0.0)"]
|
||||
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "sniffio"
|
||||
version = "1.3.0"
|
||||
description = "Sniff out which async library your code is running under"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"},
|
||||
{file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yarl"
|
||||
version = "1.9.4"
|
||||
|
@ -462,4 +574,4 @@ multidict = ">=4.0"
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "4102f42c28a97587f87680d10290146c0a4a03f767feb1acfa9f8c0c8ca6962d"
|
||||
content-hash = "b2e530f606719d9f929f04b71536c228ba331f0e9b0976288ff0e4fcc55788f4"
|
||||
|
|
|
@ -9,6 +9,7 @@ readme = "README.md"
|
|||
python = "^3.11"
|
||||
discord-py = "^2.3.1"
|
||||
redis = "^5.0.1"
|
||||
ollama = "^0.1.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
|
Loading…
Reference in New Issue