pull models on start

multimodal
Michael Yang 2024-01-03 15:05:36 -07:00
parent eb955228e2
commit a6a2866c43
1 changed files with 11 additions and 2 deletions

View File

@ -53,6 +53,9 @@ class Discollama:
self.models = models
# registry setup hook
self.discord.setup_hook = self.setup_hook
# register event handlers
self.discord.event(self.on_ready)
self.discord.event(self.on_message)
@ -128,7 +131,7 @@ class Discollama:
await message.remove_reaction('🤔', self.discord.user)
async def generate(self, content, context, images=None):
model = self.models['images' if images else 'default']
model = self.models['images' if images else '']
sb = io.StringIO()
@ -164,6 +167,12 @@ class Discollama:
except Exception:
self.redis.close()
async def setup_hook(self):
for key, value in self.models.items():
logging.info('Downloading %s model %s...', key, value)
await self.ollama.pull(value)
logging.info('Downloading %s model %s... done', key, value)
def main():
parser = argparse.ArgumentParser()
@ -185,7 +194,7 @@ def main():
discord.Client(intents=discord.Intents.default()),
redis.Redis(host=args.redis_host, port=args.redis_port, db=0, decode_responses=True),
{
'default': args.ollama_model,
'': args.ollama_model,
'images': args.ollama_images_model,
},
).run(os.environ['DISCORD_TOKEN'])