From 8b3bb512ef69f4a1dc3117645ebb5a73c6d48050 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 13 Feb 2023 23:34:04 -0300 Subject: [PATCH] Minor bug fix (soft prompt was being loaded twice) --- server.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/server.py b/server.py index 943fb5d2..c079a2aa 100644 --- a/server.py +++ b/server.py @@ -146,7 +146,7 @@ def load_model(model_name): if args.disk: params.append(f"offload_folder='{args.disk_cache_dir or 'cache'}'") - command = f"{command}(Path(f'models/{model_name}'), {','.join(set(params))})" + command = f"{command}(Path(f'models/{model_name}'), {', '.join(set(params))})" model = eval(command) # Loading the tokenizer @@ -186,8 +186,6 @@ def upload_soft_prompt(file): with open(Path(f'softprompts/{name}.zip'), 'wb') as f: f.write(file) - load_soft_prompt(name) - return name def load_model_wrapper(selected_model): @@ -343,7 +341,7 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top if args.no_stream: t0 = time.time() with torch.no_grad(): - output = eval(f"model.generate({','.join(generate_params)}){cuda}")[0] + output = eval(f"model.generate({', '.join(generate_params)}){cuda}")[0] if soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) @@ -360,7 +358,7 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top yield formatted_outputs(original_question, model_name) for i in tqdm(range(tokens//8+1)): with torch.no_grad(): - output = eval(f"model.generate({','.join(generate_params)}){cuda}")[0] + output = eval(f"model.generate({', '.join(generate_params)}){cuda}")[0] if soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) @@ -476,7 +474,7 @@ def create_settings_menus(): softprompts_menu = gr.Dropdown(choices=available_softprompts, value="None", label='Soft prompt') create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button") - gr.Markdown('Upload a soft prompt:') + gr.Markdown('Upload a soft prompt (.zip format):') with gr.Row(): upload_softprompt = gr.File(type='binary')