pull models on start
parent
eb955228e2
commit
a6a2866c43
|
@ -53,6 +53,9 @@ class Discollama:
|
|||
|
||||
self.models = models
|
||||
|
||||
# registry setup hook
|
||||
self.discord.setup_hook = self.setup_hook
|
||||
|
||||
# register event handlers
|
||||
self.discord.event(self.on_ready)
|
||||
self.discord.event(self.on_message)
|
||||
|
@ -128,7 +131,7 @@ class Discollama:
|
|||
await message.remove_reaction('🤔', self.discord.user)
|
||||
|
||||
async def generate(self, content, context, images=None):
|
||||
model = self.models['images' if images else 'default']
|
||||
model = self.models['images' if images else '']
|
||||
|
||||
sb = io.StringIO()
|
||||
|
||||
|
@ -164,6 +167,12 @@ class Discollama:
|
|||
except Exception:
|
||||
self.redis.close()
|
||||
|
||||
async def setup_hook(self):
|
||||
for key, value in self.models.items():
|
||||
logging.info('Downloading %s model %s...', key, value)
|
||||
await self.ollama.pull(value)
|
||||
logging.info('Downloading %s model %s... done', key, value)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -185,7 +194,7 @@ def main():
|
|||
discord.Client(intents=discord.Intents.default()),
|
||||
redis.Redis(host=args.redis_host, port=args.redis_port, db=0, decode_responses=True),
|
||||
{
|
||||
'default': args.ollama_model,
|
||||
'': args.ollama_model,
|
||||
'images': args.ollama_images_model,
|
||||
},
|
||||
).run(os.environ['DISCORD_TOKEN'])
|
||||
|
|
Loading…
Reference in New Issue