remove session_id

pull/3/head
Michael Yang 2023-08-01 19:58:36 -07:00
parent df7cc830e3
commit 193c83fbd3
1 changed files with 5 additions and 13 deletions

View File

@ -21,14 +21,13 @@ async def on_ready():
logging.info('ready') logging.info('ready')
async def generate_response(prompt, context=[], session=None): async def generate_response(prompt, context=[]):
body = { body = {
key: value key: value
for key, value in { for key, value in {
'model': args.ollama_model, 'model': args.ollama_model,
'prompt': prompt, 'prompt': prompt,
'context': context, 'context': context,
'session_id': session,
}.items() if value }.items() if value
} }
@ -41,29 +40,22 @@ async def generate_response(prompt, context=[], session=None):
def save_session(response, chunk): def save_session(response, chunk):
session = msgpack.packb(chunk['session_id'])
redis.hset(f'ollama:{response.id}', 'session', session)
context = msgpack.packb(chunk['context']) context = msgpack.packb(chunk['context'])
redis.hset(f'ollama:{response.id}', 'context', context) redis.hset(f'ollama:{response.id}', 'context', context)
redis.expire(f'ollama:{response.id}', 60 * 60 * 24 * 7) redis.expire(f'ollama:{response.id}', 60 * 60 * 24 * 7)
logging.info('[%s] saving session %s: len(context)=%d', response.id, chunk['session_id'], len(chunk['context'])) logging.info('saving message=%s: len(context)=%d', response.id, len(chunk['context']))
def load_session(reference): def load_session(reference):
kwargs = {} kwargs = {}
if reference: if reference:
session = redis.hget(f'ollama:{reference.message_id}', 'session')
kwargs['session'] = msgpack.unpackb(session) if session else None
context = redis.hget(f'ollama:{reference.message_id}', 'context') context = redis.hget(f'ollama:{reference.message_id}', 'context')
kwargs['context'] = msgpack.unpackb(context) if context else [] kwargs['context'] = msgpack.unpackb(context) if context else []
if kwargs.get('session'): if kwargs.get('context'):
logging.info( logging.info(
'[%s] loading session %s: len(context)=%d', 'loading message=%s: len(context)=%d',
reference.message_id, reference.message_id,
kwargs['session'],
len(kwargs['context'])) len(kwargs['context']))
return kwargs return kwargs
@ -111,7 +103,7 @@ parser.add_argument('--ollama-model', default='llama2', type=str)
default_redis = Path.home() / '.cache' / 'discollama' / 'brain.db' default_redis = Path.home() / '.cache' / 'discollama' / 'brain.db'
parser.add_argument('--redis', default=default_redis, type=Path) parser.add_argument('--redis', default=default_redis, type=Path)
parser.add_argument('--buffer-size', default=64, type=int) parser.add_argument('--buffer-size', default=32, type=int)
args = parser.parse_args() args = parser.parse_args()