diff --git a/extensions/softprompt/script.py b/extensions/character_bias/script.py similarity index 82% rename from extensions/softprompt/script.py rename to extensions/character_bias/script.py index 1892cbf6..b7bdaa9d 100644 --- a/extensions/softprompt/script.py +++ b/extensions/character_bias/script.py @@ -1,5 +1,5 @@ params = { - "soft prompt": " *I speak in an annoyingly cute way*", + "bias string": " *I speak in an annoyingly cute way*", } def input_modifier(string): @@ -24,4 +24,4 @@ def bot_prefix_modifier(string): behavior. """ - return string + params["soft prompt"] + return string + params["bias string"] diff --git a/requirements.txt b/requirements.txt index 90744ad1..7420966f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ accelerate==0.15.0 +beautifulsoup4 bitsandbytes==0.37.0 gradio==3.15.0 -transformers==4.25.1 +numpy safetensors==0.2.8 -beautifulsoup4 +git+https://github.com/huggingface/transformers diff --git a/server.py b/server.py index 49ca5ddf..07bcd220 100644 --- a/server.py +++ b/server.py @@ -10,10 +10,12 @@ import re import sys import time import warnings +import zipfile from datetime import datetime from pathlib import Path import gradio as gr +import numpy as np import torch import transformers from PIL import Image @@ -157,6 +159,37 @@ def load_model(model_name): print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") return model, tokenizer +def load_soft_prompt(name): + global soft_prompt, soft_prompt_tensor + + if name == 'None': + soft_prompt = False + soft_prompt_tensor = None + else: + with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf: + zf.extract('tensor.npy') + tensor = np.load('tensor.npy') + tensor = torch.Tensor(tensor).to(device=model.device, dtype=model.dtype) + tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1])) + + soft_prompt = True + soft_prompt_tensor = tensor + + return name + +def upload_softprompt_event(file): + with zipfile.ZipFile(io.BytesIO(file)) as zf: + zf.extract('meta.json') + j = json.loads(open('meta.json', 'r').read()) + name = j['name'] + + with open(Path(f'softprompts/{name}.zip'), 'wb') as f: + f.write(file) + + load_soft_prompt(name) + + return name + def load_model_wrapper(selected_model): global model_name, model, tokenizer @@ -244,7 +277,7 @@ def formatted_outputs(reply, model_name): return reply def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None): - global model_name, model, tokenizer + global model_name, model, tokenizer, soft_prompt, soft_prompt_tensor original_question = question if not (args.chat or args.cai_chat): @@ -292,14 +325,29 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top else: generate_params.append(f"max_new_tokens=8") + if soft_prompt: + inputs_embeds = model.transformer.wte(input_ids) + inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1) + filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device) + filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens + generate_params.insert(0, "inputs_embeds=inputs_embeds") + generate_params.insert(0, "filler_input_ids") + else: + filler_input_ids = None + generate_params.insert(0, "input_ids") + # Generate the entire reply at once if args.no_stream: t0 = time.time() with torch.no_grad(): - output = eval(f"model.generate(input_ids, {','.join(generate_params)}){cuda}") - reply = decode(output[0]) + output = eval(f"model.generate({','.join(generate_params)}){cuda}") + if soft_prompt: + output = torch.cat((input_ids[0], output[0][filler_input_ids.shape[1]:])) + else: + output = output[0] + reply = decode(output) t1 = time.time() - print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output[0])-len(input_ids[0])} tokens)") + print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)") if not (args.chat or args.cai_chat): reply = original_question + apply_extensions(reply[len(question):], "output") yield formatted_outputs(reply, model_name) @@ -309,13 +357,26 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top yield formatted_outputs(original_question, model_name) for i in tqdm(range(tokens//8+1)): with torch.no_grad(): - output = eval(f"model.generate(input_ids, {','.join(generate_params)}){cuda}") - reply = decode(output[0]) + output = eval(f"model.generate({','.join(generate_params)}){cuda}") + + if soft_prompt: + output = torch.cat((input_ids[0], output[0][filler_input_ids.shape[1]:])) + else: + output = output[0] + + reply = decode(output) if not (args.chat or args.cai_chat): reply = original_question + apply_extensions(reply[len(question):], "output") yield formatted_outputs(reply, model_name) - input_ids = output - if output[0][-1] == n: + + input_ids = torch.reshape(output, (1, output.shape[0])) + if soft_prompt: + inputs_embeds = model.transformer.wte(input_ids) + inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1) + filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device) + filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens + + if output[-1] == n: break def apply_extensions(text, typ): @@ -353,6 +414,9 @@ def get_available_characters(): def get_available_extensions(): return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower) +def get_available_softprompts(): + return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower) + def create_extensions_block(): extensions_ui_elements = [] default_values = [] @@ -410,8 +474,19 @@ def create_settings_menus(): min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if args.no_stream else 0, label="min_length", interactive=args.no_stream) early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping") + with gr.Accordion("Soft prompt", open=False): + with gr.Row(): + softprompts_menu = gr.Dropdown(choices=available_softprompts, value="None", label='Soft prompt') + create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button") + + gr.Markdown('Upload a soft prompt:') + with gr.Row(): + upload_softprompt = gr.File(type='binary') + model_menu.change(load_model_wrapper, [model_menu], [model_menu], show_progress=True) preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping]) + softprompts_menu.change(load_soft_prompt, [softprompts_menu], [softprompts_menu], show_progress=True) + upload_softprompt.upload(upload_softprompt_event, [upload_softprompt], [softprompts_menu]) return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping # This gets the new line characters right. @@ -718,6 +793,7 @@ available_models = get_available_models() available_presets = get_available_presets() available_characters = get_available_characters() available_extensions = get_available_extensions() +available_softprompts = get_available_softprompts() extension_state = {} if args.extensions is not None: for i,ext in enumerate(args.extensions.split(',')): @@ -746,7 +822,8 @@ else: print() model_name = available_models[i] model, tokenizer = load_model(model_name) -loaded_preset = None +loaded_preset = soft_prompt_tensor = None +soft_prompt = False # UI settings if model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): diff --git a/softprompts/place-your-softprompts-here.txt b/softprompts/place-your-softprompts-here.txt new file mode 100644 index 00000000..e69de29b