diff --git a/server.py b/server.py index 4991ba50..40e2239e 100644 --- a/server.py +++ b/server.py @@ -142,11 +142,12 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok input_ids = encode(question, 1) preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1') cuda = "" if args.cpu else ".cuda()" + if eos_token is not None: + n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1] for i in range(tokens): if eos_token is None: output = eval(f"model.generate(input_ids, {preset}){cuda}") else: - n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1] output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}") reply = tokenizer.decode(output[0], skip_special_tokens=True) @@ -240,8 +241,8 @@ elif args.chat or args.cai_chat: return question def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): - history.append(['', '']) question = generate_chat_prompt(text, tokens, name1, name2, context) + history.append(['', '']) eos_token = '\n' if check else None for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token): reply = i[0]