Add softprompt support (for real this time)

Is this too much voodoo for our purposes?
This commit is contained in:
oobabooga 2023-02-13 15:25:16 -03:00
parent aa1177ff15
commit 3277b751f5
4 changed files with 91 additions and 13 deletions

View File

@ -1,5 +1,5 @@
params = {
"soft prompt": " *I speak in an annoyingly cute way*",
"bias string": " *I speak in an annoyingly cute way*",
}
def input_modifier(string):
@ -24,4 +24,4 @@ def bot_prefix_modifier(string):
behavior.
"""
return string + params["soft prompt"]
return string + params["bias string"]

View File

@ -1,6 +1,7 @@
accelerate==0.15.0
beautifulsoup4
bitsandbytes==0.37.0
gradio==3.15.0
transformers==4.25.1
numpy
safetensors==0.2.8
beautifulsoup4
git+https://github.com/huggingface/transformers

View File

@ -10,10 +10,12 @@ import re
import sys
import time
import warnings
import zipfile
from datetime import datetime
from pathlib import Path
import gradio as gr
import numpy as np
import torch
import transformers
from PIL import Image
@ -157,6 +159,37 @@ def load_model(model_name):
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer
def load_soft_prompt(name):
global soft_prompt, soft_prompt_tensor
if name == 'None':
soft_prompt = False
soft_prompt_tensor = None
else:
with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
zf.extract('tensor.npy')
tensor = np.load('tensor.npy')
tensor = torch.Tensor(tensor).to(device=model.device, dtype=model.dtype)
tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
soft_prompt = True
soft_prompt_tensor = tensor
return name
def upload_softprompt_event(file):
with zipfile.ZipFile(io.BytesIO(file)) as zf:
zf.extract('meta.json')
j = json.loads(open('meta.json', 'r').read())
name = j['name']
with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
f.write(file)
load_soft_prompt(name)
return name
def load_model_wrapper(selected_model):
global model_name, model, tokenizer
@ -244,7 +277,7 @@ def formatted_outputs(reply, model_name):
return reply
def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
global model_name, model, tokenizer
global model_name, model, tokenizer, soft_prompt, soft_prompt_tensor
original_question = question
if not (args.chat or args.cai_chat):
@ -292,14 +325,29 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
else:
generate_params.append(f"max_new_tokens=8")
if soft_prompt:
inputs_embeds = model.transformer.wte(input_ids)
inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1)
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device)
filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens
generate_params.insert(0, "inputs_embeds=inputs_embeds")
generate_params.insert(0, "filler_input_ids")
else:
filler_input_ids = None
generate_params.insert(0, "input_ids")
# Generate the entire reply at once
if args.no_stream:
t0 = time.time()
with torch.no_grad():
output = eval(f"model.generate(input_ids, {','.join(generate_params)}){cuda}")
reply = decode(output[0])
output = eval(f"model.generate({','.join(generate_params)}){cuda}")
if soft_prompt:
output = torch.cat((input_ids[0], output[0][filler_input_ids.shape[1]:]))
else:
output = output[0]
reply = decode(output)
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output[0])-len(input_ids[0])} tokens)")
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
if not (args.chat or args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, model_name)
@ -309,13 +357,26 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
yield formatted_outputs(original_question, model_name)
for i in tqdm(range(tokens//8+1)):
with torch.no_grad():
output = eval(f"model.generate(input_ids, {','.join(generate_params)}){cuda}")
reply = decode(output[0])
output = eval(f"model.generate({','.join(generate_params)}){cuda}")
if soft_prompt:
output = torch.cat((input_ids[0], output[0][filler_input_ids.shape[1]:]))
else:
output = output[0]
reply = decode(output)
if not (args.chat or args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, model_name)
input_ids = output
if output[0][-1] == n:
input_ids = torch.reshape(output, (1, output.shape[0]))
if soft_prompt:
inputs_embeds = model.transformer.wte(input_ids)
inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1)
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device)
filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens
if output[-1] == n:
break
def apply_extensions(text, typ):
@ -353,6 +414,9 @@ def get_available_characters():
def get_available_extensions():
return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
def get_available_softprompts():
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
def create_extensions_block():
extensions_ui_elements = []
default_values = []
@ -410,8 +474,19 @@ def create_settings_menus():
min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if args.no_stream else 0, label="min_length", interactive=args.no_stream)
early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping")
with gr.Accordion("Soft prompt", open=False):
with gr.Row():
softprompts_menu = gr.Dropdown(choices=available_softprompts, value="None", label='Soft prompt')
create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button")
gr.Markdown('Upload a soft prompt:')
with gr.Row():
upload_softprompt = gr.File(type='binary')
model_menu.change(load_model_wrapper, [model_menu], [model_menu], show_progress=True)
preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping])
softprompts_menu.change(load_soft_prompt, [softprompts_menu], [softprompts_menu], show_progress=True)
upload_softprompt.upload(upload_softprompt_event, [upload_softprompt], [softprompts_menu])
return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping
# This gets the new line characters right.
@ -718,6 +793,7 @@ available_models = get_available_models()
available_presets = get_available_presets()
available_characters = get_available_characters()
available_extensions = get_available_extensions()
available_softprompts = get_available_softprompts()
extension_state = {}
if args.extensions is not None:
for i,ext in enumerate(args.extensions.split(',')):
@ -746,7 +822,8 @@ else:
print()
model_name = available_models[i]
model, tokenizer = load_model(model_name)
loaded_preset = None
loaded_preset = soft_prompt_tensor = None
soft_prompt = False
# UI settings
if model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):