From 022960a0874b43d5f5f8181778fb9fad2cd07185 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 18 Jan 2023 21:37:21 -0300 Subject: [PATCH] This is the correct way of sampling 1 token at a time --- server.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/server.py b/server.py index c838cc0a..79fe6628 100644 --- a/server.py +++ b/server.py @@ -139,11 +139,11 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok preset = infile.read() loaded_preset = inference_settings + input_ids = encode(question, 1) + preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1') + cuda = ".cuda()" if args.cpu else "" for i in range(tokens): - input_ids = encode(question, 1) - preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1') - cuda = ".cuda()" if args.cpu else "" if eos_token is None: output = eval(f"model.generate(input_ids, {preset}){cuda}") else: @@ -152,7 +152,6 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok reply = tokenizer.decode(output[0], skip_special_tokens=True) reply = reply.replace(r'<|endoftext|>', '') - question = reply if model_name.lower().startswith('galactica'): reply = fix_galactica(reply) yield reply, reply, generate_basic_html(reply) @@ -162,6 +161,8 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok else: yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply) + input_ids = output + # Choosing the default model if args.model is not None: model_name = args.model