commit
e730c6a648
|
@ -42,10 +42,11 @@ class Response:
|
||||||
|
|
||||||
|
|
||||||
class Discollama:
|
class Discollama:
|
||||||
def __init__(self, ollama, discord, redis):
|
def __init__(self, ollama, discord, redis, model):
|
||||||
self.ollama = ollama
|
self.ollama = ollama
|
||||||
self.discord = discord
|
self.discord = discord
|
||||||
self.redis = redis
|
self.redis = redis
|
||||||
|
self.model = model
|
||||||
|
|
||||||
# register event handlers
|
# register event handlers
|
||||||
self.discord.event(self.on_ready)
|
self.discord.event(self.on_ready)
|
||||||
|
@ -123,7 +124,7 @@ class Discollama:
|
||||||
sb = io.StringIO()
|
sb = io.StringIO()
|
||||||
|
|
||||||
t = datetime.now()
|
t = datetime.now()
|
||||||
async for part in await self.ollama.generate(model='llama2', prompt=content, context=context, stream=True):
|
async for part in await self.ollama.generate(model=self.model, prompt=content, context=context, stream=True):
|
||||||
sb.write(part['response'])
|
sb.write(part['response'])
|
||||||
|
|
||||||
if part['done'] or datetime.now() - t > timedelta(seconds=1):
|
if part['done'] or datetime.now() - t > timedelta(seconds=1):
|
||||||
|
@ -173,6 +174,7 @@ def main():
|
||||||
ollama.AsyncClient(base_url=f'{args.ollama_scheme}://{args.ollama_host}:{args.ollama_port}'),
|
ollama.AsyncClient(base_url=f'{args.ollama_scheme}://{args.ollama_host}:{args.ollama_port}'),
|
||||||
discord.Client(intents=intents),
|
discord.Client(intents=intents),
|
||||||
redis.Redis(host=args.redis_host, port=args.redis_port, db=0, decode_responses=True),
|
redis.Redis(host=args.redis_host, port=args.redis_port, db=0, decode_responses=True),
|
||||||
|
model=args.ollama_model,
|
||||||
).run(os.environ['DISCORD_TOKEN'])
|
).run(os.environ['DISCORD_TOKEN'])
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue