add multimodal response
parent
5ba591fbfb
commit
eb955228e2
|
@ -4,6 +4,7 @@ import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import argparse
|
import argparse
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from base64 import b64encode, b64decode
|
||||||
|
|
||||||
import ollama
|
import ollama
|
||||||
import discord
|
import discord
|
||||||
|
@ -31,6 +32,9 @@ class Response:
|
||||||
|
|
||||||
self.sb.write(s)
|
self.sb.write(s)
|
||||||
|
|
||||||
|
if self.sb.seek(0, io.SEEK_END) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
if self.r:
|
if self.r:
|
||||||
await self.r.edit(content=self.sb.getvalue() + end)
|
await self.r.edit(content=self.sb.getvalue() + end)
|
||||||
return
|
return
|
||||||
|
@ -42,11 +46,13 @@ class Response:
|
||||||
|
|
||||||
|
|
||||||
class Discollama:
|
class Discollama:
|
||||||
def __init__(self, ollama, discord, redis):
|
def __init__(self, ollama, discord, redis, models):
|
||||||
self.ollama = ollama
|
self.ollama = ollama
|
||||||
self.discord = discord
|
self.discord = discord
|
||||||
self.redis = redis
|
self.redis = redis
|
||||||
|
|
||||||
|
self.models = models
|
||||||
|
|
||||||
# register event handlers
|
# register event handlers
|
||||||
self.discord.event(self.on_ready)
|
self.discord.event(self.on_ready)
|
||||||
self.discord.event(self.on_message)
|
self.discord.event(self.on_message)
|
||||||
|
@ -83,9 +89,9 @@ class Discollama:
|
||||||
|
|
||||||
channel = message.channel
|
channel = message.channel
|
||||||
|
|
||||||
context = []
|
context, images = [], []
|
||||||
if reference := message.reference:
|
if reference := message.reference:
|
||||||
context = await self.load(message_id=reference.message_id)
|
context, images = await self.load(message_id=reference.message_id)
|
||||||
if not context:
|
if not context:
|
||||||
reference_message = await message.channel.fetch_message(reference.message_id)
|
reference_message = await message.channel.fetch_message(reference.message_id)
|
||||||
content = '\n'.join(
|
content = '\n'.join(
|
||||||
|
@ -97,17 +103,19 @@ class Discollama:
|
||||||
)
|
)
|
||||||
|
|
||||||
if not context:
|
if not context:
|
||||||
context = await self.load(channel_id=channel.id)
|
context, images = await self.load(channel_id=channel.id)
|
||||||
|
|
||||||
|
images.extend([await attachment.read() for attachment in message.attachments if attachment.content_type.startswith('image/')])
|
||||||
|
|
||||||
r = Response(message)
|
r = Response(message)
|
||||||
task = asyncio.create_task(self.thinking(message))
|
task = asyncio.create_task(self.thinking(message))
|
||||||
async for part in self.generate(content, context):
|
async for part in self.generate(content, context, images=images):
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
await r.write(part['response'], end='...')
|
await r.write(part['response'], end='...')
|
||||||
|
|
||||||
await r.write('')
|
await r.write('')
|
||||||
await self.save(r.channel.id, message.id, part['context'])
|
await self.save(r.channel.id, message.id, part['context'], images)
|
||||||
|
|
||||||
async def thinking(self, message, timeout=999):
|
async def thinking(self, message, timeout=999):
|
||||||
try:
|
try:
|
||||||
|
@ -119,11 +127,13 @@ class Discollama:
|
||||||
finally:
|
finally:
|
||||||
await message.remove_reaction('🤔', self.discord.user)
|
await message.remove_reaction('🤔', self.discord.user)
|
||||||
|
|
||||||
async def generate(self, content, context):
|
async def generate(self, content, context, images=None):
|
||||||
|
model = self.models['images' if images else 'default']
|
||||||
|
|
||||||
sb = io.StringIO()
|
sb = io.StringIO()
|
||||||
|
|
||||||
t = datetime.now()
|
t = datetime.now()
|
||||||
async for part in await self.ollama.generate(model='llama2', prompt=content, context=context, stream=True):
|
async for part in await self.ollama.generate(model=model, prompt=content, context=context, images=images, stream=True):
|
||||||
sb.write(part['response'])
|
sb.write(part['response'])
|
||||||
|
|
||||||
if part['done'] or datetime.now() - t > timedelta(seconds=1):
|
if part['done'] or datetime.now() - t > timedelta(seconds=1):
|
||||||
|
@ -133,16 +143,20 @@ class Discollama:
|
||||||
sb.seek(0, io.SEEK_SET)
|
sb.seek(0, io.SEEK_SET)
|
||||||
sb.truncate()
|
sb.truncate()
|
||||||
|
|
||||||
async def save(self, channel_id, message_id, ctx: list[int]):
|
async def save(self, channel_id, message_id, context, images):
|
||||||
self.redis.set(f'discollama:channel:{channel_id}', message_id, ex=60 * 60 * 24 * 7)
|
self.redis.set(f'discollama:channel:{channel_id}', message_id, ex=60 * 60 * 24 * 7)
|
||||||
self.redis.set(f'discollama:message:{message_id}', json.dumps(ctx), ex=60 * 60 * 24 * 7)
|
self.redis.set(f'discollama:message:{message_id}', json.dumps(context), ex=60 * 60 * 24 * 7)
|
||||||
|
|
||||||
async def load(self, channel_id=None, message_id=None) -> list[int]:
|
images = [b64encode(image).decode('utf-8') for image in images]
|
||||||
|
self.redis.set(f'discollama:images:{message_id}', json.dumps(images), ex=60 * 60 * 24 * 7)
|
||||||
|
|
||||||
|
async def load(self, channel_id=None, message_id=None):
|
||||||
if channel_id:
|
if channel_id:
|
||||||
message_id = self.redis.get(f'discollama:channel:{channel_id}')
|
message_id = self.redis.get(f'discollama:channel:{channel_id}')
|
||||||
|
|
||||||
ctx = self.redis.get(f'discollama:message:{message_id}')
|
context = self.redis.get(f'discollama:message:{message_id}')
|
||||||
return json.loads(ctx) if ctx else []
|
images = self.redis.get(f'discollama:images:{message_id}')
|
||||||
|
return json.loads(context) if context else [], [b64decode(image) for image in json.loads(images)] if images else []
|
||||||
|
|
||||||
def run(self, token):
|
def run(self, token):
|
||||||
try:
|
try:
|
||||||
|
@ -157,22 +171,23 @@ def main():
|
||||||
parser.add_argument('--ollama-scheme', default=os.getenv('OLLAMA_SCHEME', 'http'), choices=['http', 'https'])
|
parser.add_argument('--ollama-scheme', default=os.getenv('OLLAMA_SCHEME', 'http'), choices=['http', 'https'])
|
||||||
parser.add_argument('--ollama-host', default=os.getenv('OLLAMA_HOST', '127.0.0.1'), type=str)
|
parser.add_argument('--ollama-host', default=os.getenv('OLLAMA_HOST', '127.0.0.1'), type=str)
|
||||||
parser.add_argument('--ollama-port', default=os.getenv('OLLAMA_PORT', 11434), type=int)
|
parser.add_argument('--ollama-port', default=os.getenv('OLLAMA_PORT', 11434), type=int)
|
||||||
|
|
||||||
parser.add_argument('--ollama-model', default=os.getenv('OLLAMA_MODEL', 'llama2'), type=str)
|
parser.add_argument('--ollama-model', default=os.getenv('OLLAMA_MODEL', 'llama2'), type=str)
|
||||||
|
parser.add_argument('--ollama-images-model', default=os.getenv('OLLAMA_IMAGES_MODEL', 'llava'), type=str)
|
||||||
|
|
||||||
parser.add_argument('--redis-host', default=os.getenv('REDIS_HOST', '127.0.0.1'), type=str)
|
parser.add_argument('--redis-host', default=os.getenv('REDIS_HOST', '127.0.0.1'), type=str)
|
||||||
parser.add_argument('--redis-port', default=os.getenv('REDIS_PORT', 6379), type=int)
|
parser.add_argument('--redis-port', default=os.getenv('REDIS_PORT', 6379), type=int)
|
||||||
|
|
||||||
parser.add_argument('--buffer-size', default=32, type=int)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
intents = discord.Intents.default()
|
|
||||||
intents.message_content = True
|
|
||||||
|
|
||||||
Discollama(
|
Discollama(
|
||||||
ollama.AsyncClient(base_url=f'{args.ollama_scheme}://{args.ollama_host}:{args.ollama_port}'),
|
ollama.AsyncClient(base_url=f'{args.ollama_scheme}://{args.ollama_host}:{args.ollama_port}'),
|
||||||
discord.Client(intents=intents),
|
discord.Client(intents=discord.Intents.default()),
|
||||||
redis.Redis(host=args.redis_host, port=args.redis_port, db=0, decode_responses=True),
|
redis.Redis(host=args.redis_host, port=args.redis_port, db=0, decode_responses=True),
|
||||||
|
{
|
||||||
|
'default': args.ollama_model,
|
||||||
|
'images': args.ollama_images_model,
|
||||||
|
},
|
||||||
).run(os.environ['DISCORD_TOKEN'])
|
).run(os.environ['DISCORD_TOKEN'])
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue