Add softprompt support (for real this time)

Is this too much voodoo for our purposes?
This commit is contained in:
oobabooga 2023-02-13 15:25:16 -03:00
parent aa1177ff15
commit 3277b751f5
4 changed files with 91 additions and 13 deletions

View File

@ -1,5 +1,5 @@
params = { params = {
"soft prompt": " *I speak in an annoyingly cute way*", "bias string": " *I speak in an annoyingly cute way*",
} }
def input_modifier(string): def input_modifier(string):
@ -24,4 +24,4 @@ def bot_prefix_modifier(string):
behavior. behavior.
""" """
return string + params["soft prompt"] return string + params["bias string"]

View File

@ -1,6 +1,7 @@
accelerate==0.15.0 accelerate==0.15.0
beautifulsoup4
bitsandbytes==0.37.0 bitsandbytes==0.37.0
gradio==3.15.0 gradio==3.15.0
transformers==4.25.1 numpy
safetensors==0.2.8 safetensors==0.2.8
beautifulsoup4 git+https://github.com/huggingface/transformers

View File

@ -10,10 +10,12 @@ import re
import sys import sys
import time import time
import warnings import warnings
import zipfile
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
import numpy as np
import torch import torch
import transformers import transformers
from PIL import Image from PIL import Image
@ -157,6 +159,37 @@ def load_model(model_name):
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer return model, tokenizer
def load_soft_prompt(name):
global soft_prompt, soft_prompt_tensor
if name == 'None':
soft_prompt = False
soft_prompt_tensor = None
else:
with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
zf.extract('tensor.npy')
tensor = np.load('tensor.npy')
tensor = torch.Tensor(tensor).to(device=model.device, dtype=model.dtype)
tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
soft_prompt = True
soft_prompt_tensor = tensor
return name
def upload_softprompt_event(file):
with zipfile.ZipFile(io.BytesIO(file)) as zf:
zf.extract('meta.json')
j = json.loads(open('meta.json', 'r').read())
name = j['name']
with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
f.write(file)
load_soft_prompt(name)
return name
def load_model_wrapper(selected_model): def load_model_wrapper(selected_model):
global model_name, model, tokenizer global model_name, model, tokenizer
@ -244,7 +277,7 @@ def formatted_outputs(reply, model_name):
return reply return reply
def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None): def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
global model_name, model, tokenizer global model_name, model, tokenizer, soft_prompt, soft_prompt_tensor
original_question = question original_question = question
if not (args.chat or args.cai_chat): if not (args.chat or args.cai_chat):
@ -292,14 +325,29 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
else: else:
generate_params.append(f"max_new_tokens=8") generate_params.append(f"max_new_tokens=8")
if soft_prompt:
inputs_embeds = model.transformer.wte(input_ids)
inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1)
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device)
filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens
generate_params.insert(0, "inputs_embeds=inputs_embeds")
generate_params.insert(0, "filler_input_ids")
else:
filler_input_ids = None
generate_params.insert(0, "input_ids")
# Generate the entire reply at once # Generate the entire reply at once
if args.no_stream: if args.no_stream:
t0 = time.time() t0 = time.time()
with torch.no_grad(): with torch.no_grad():
output = eval(f"model.generate(input_ids, {','.join(generate_params)}){cuda}") output = eval(f"model.generate({','.join(generate_params)}){cuda}")
reply = decode(output[0]) if soft_prompt:
output = torch.cat((input_ids[0], output[0][filler_input_ids.shape[1]:]))
else:
output = output[0]
reply = decode(output)
t1 = time.time() t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output[0])-len(input_ids[0])} tokens)") print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
if not (args.chat or args.cai_chat): if not (args.chat or args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output") reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, model_name) yield formatted_outputs(reply, model_name)
@ -309,13 +357,26 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
yield formatted_outputs(original_question, model_name) yield formatted_outputs(original_question, model_name)
for i in tqdm(range(tokens//8+1)): for i in tqdm(range(tokens//8+1)):
with torch.no_grad(): with torch.no_grad():
output = eval(f"model.generate(input_ids, {','.join(generate_params)}){cuda}") output = eval(f"model.generate({','.join(generate_params)}){cuda}")
reply = decode(output[0])
if soft_prompt:
output = torch.cat((input_ids[0], output[0][filler_input_ids.shape[1]:]))
else:
output = output[0]
reply = decode(output)
if not (args.chat or args.cai_chat): if not (args.chat or args.cai_chat):
reply = original_question + apply_extensions(reply[len(question):], "output") reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, model_name) yield formatted_outputs(reply, model_name)
input_ids = output
if output[0][-1] == n: input_ids = torch.reshape(output, (1, output.shape[0]))
if soft_prompt:
inputs_embeds = model.transformer.wte(input_ids)
inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1)
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device)
filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens
if output[-1] == n:
break break
def apply_extensions(text, typ): def apply_extensions(text, typ):
@ -353,6 +414,9 @@ def get_available_characters():
def get_available_extensions(): def get_available_extensions():
return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower) return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
def get_available_softprompts():
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
def create_extensions_block(): def create_extensions_block():
extensions_ui_elements = [] extensions_ui_elements = []
default_values = [] default_values = []
@ -410,8 +474,19 @@ def create_settings_menus():
min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if args.no_stream else 0, label="min_length", interactive=args.no_stream) min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if args.no_stream else 0, label="min_length", interactive=args.no_stream)
early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping") early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping")
with gr.Accordion("Soft prompt", open=False):
with gr.Row():
softprompts_menu = gr.Dropdown(choices=available_softprompts, value="None", label='Soft prompt')
create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button")
gr.Markdown('Upload a soft prompt:')
with gr.Row():
upload_softprompt = gr.File(type='binary')
model_menu.change(load_model_wrapper, [model_menu], [model_menu], show_progress=True) model_menu.change(load_model_wrapper, [model_menu], [model_menu], show_progress=True)
preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping]) preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping])
softprompts_menu.change(load_soft_prompt, [softprompts_menu], [softprompts_menu], show_progress=True)
upload_softprompt.upload(upload_softprompt_event, [upload_softprompt], [softprompts_menu])
return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping
# This gets the new line characters right. # This gets the new line characters right.
@ -718,6 +793,7 @@ available_models = get_available_models()
available_presets = get_available_presets() available_presets = get_available_presets()
available_characters = get_available_characters() available_characters = get_available_characters()
available_extensions = get_available_extensions() available_extensions = get_available_extensions()
available_softprompts = get_available_softprompts()
extension_state = {} extension_state = {}
if args.extensions is not None: if args.extensions is not None:
for i,ext in enumerate(args.extensions.split(',')): for i,ext in enumerate(args.extensions.split(',')):
@ -746,7 +822,8 @@ else:
print() print()
model_name = available_models[i] model_name = available_models[i]
model, tokenizer = load_model(model_name) model, tokenizer = load_model(model_name)
loaded_preset = None loaded_preset = soft_prompt_tensor = None
soft_prompt = False
# UI settings # UI settings
if model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): if model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):