mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-21 23:57:58 +01:00
Add softprompt support (for real this time)
Is this too much voodoo for our purposes?
This commit is contained in:
parent
aa1177ff15
commit
3277b751f5
@ -1,5 +1,5 @@
|
|||||||
params = {
|
params = {
|
||||||
"soft prompt": " *I speak in an annoyingly cute way*",
|
"bias string": " *I speak in an annoyingly cute way*",
|
||||||
}
|
}
|
||||||
|
|
||||||
def input_modifier(string):
|
def input_modifier(string):
|
||||||
@ -24,4 +24,4 @@ def bot_prefix_modifier(string):
|
|||||||
behavior.
|
behavior.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return string + params["soft prompt"]
|
return string + params["bias string"]
|
@ -1,6 +1,7 @@
|
|||||||
accelerate==0.15.0
|
accelerate==0.15.0
|
||||||
|
beautifulsoup4
|
||||||
bitsandbytes==0.37.0
|
bitsandbytes==0.37.0
|
||||||
gradio==3.15.0
|
gradio==3.15.0
|
||||||
transformers==4.25.1
|
numpy
|
||||||
safetensors==0.2.8
|
safetensors==0.2.8
|
||||||
beautifulsoup4
|
git+https://github.com/huggingface/transformers
|
||||||
|
95
server.py
95
server.py
@ -10,10 +10,12 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
import zipfile
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -157,6 +159,37 @@ def load_model(model_name):
|
|||||||
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
def load_soft_prompt(name):
|
||||||
|
global soft_prompt, soft_prompt_tensor
|
||||||
|
|
||||||
|
if name == 'None':
|
||||||
|
soft_prompt = False
|
||||||
|
soft_prompt_tensor = None
|
||||||
|
else:
|
||||||
|
with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
|
||||||
|
zf.extract('tensor.npy')
|
||||||
|
tensor = np.load('tensor.npy')
|
||||||
|
tensor = torch.Tensor(tensor).to(device=model.device, dtype=model.dtype)
|
||||||
|
tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
|
||||||
|
|
||||||
|
soft_prompt = True
|
||||||
|
soft_prompt_tensor = tensor
|
||||||
|
|
||||||
|
return name
|
||||||
|
|
||||||
|
def upload_softprompt_event(file):
|
||||||
|
with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
||||||
|
zf.extract('meta.json')
|
||||||
|
j = json.loads(open('meta.json', 'r').read())
|
||||||
|
name = j['name']
|
||||||
|
|
||||||
|
with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
|
||||||
|
f.write(file)
|
||||||
|
|
||||||
|
load_soft_prompt(name)
|
||||||
|
|
||||||
|
return name
|
||||||
|
|
||||||
def load_model_wrapper(selected_model):
|
def load_model_wrapper(selected_model):
|
||||||
global model_name, model, tokenizer
|
global model_name, model, tokenizer
|
||||||
|
|
||||||
@ -244,7 +277,7 @@ def formatted_outputs(reply, model_name):
|
|||||||
return reply
|
return reply
|
||||||
|
|
||||||
def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
|
def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
|
||||||
global model_name, model, tokenizer
|
global model_name, model, tokenizer, soft_prompt, soft_prompt_tensor
|
||||||
|
|
||||||
original_question = question
|
original_question = question
|
||||||
if not (args.chat or args.cai_chat):
|
if not (args.chat or args.cai_chat):
|
||||||
@ -292,14 +325,29 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
|
|||||||
else:
|
else:
|
||||||
generate_params.append(f"max_new_tokens=8")
|
generate_params.append(f"max_new_tokens=8")
|
||||||
|
|
||||||
|
if soft_prompt:
|
||||||
|
inputs_embeds = model.transformer.wte(input_ids)
|
||||||
|
inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1)
|
||||||
|
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device)
|
||||||
|
filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||||
|
generate_params.insert(0, "inputs_embeds=inputs_embeds")
|
||||||
|
generate_params.insert(0, "filler_input_ids")
|
||||||
|
else:
|
||||||
|
filler_input_ids = None
|
||||||
|
generate_params.insert(0, "input_ids")
|
||||||
|
|
||||||
# Generate the entire reply at once
|
# Generate the entire reply at once
|
||||||
if args.no_stream:
|
if args.no_stream:
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = eval(f"model.generate(input_ids, {','.join(generate_params)}){cuda}")
|
output = eval(f"model.generate({','.join(generate_params)}){cuda}")
|
||||||
reply = decode(output[0])
|
if soft_prompt:
|
||||||
|
output = torch.cat((input_ids[0], output[0][filler_input_ids.shape[1]:]))
|
||||||
|
else:
|
||||||
|
output = output[0]
|
||||||
|
reply = decode(output)
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output[0])-len(input_ids[0])} tokens)")
|
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
|
||||||
if not (args.chat or args.cai_chat):
|
if not (args.chat or args.cai_chat):
|
||||||
reply = original_question + apply_extensions(reply[len(question):], "output")
|
reply = original_question + apply_extensions(reply[len(question):], "output")
|
||||||
yield formatted_outputs(reply, model_name)
|
yield formatted_outputs(reply, model_name)
|
||||||
@ -309,13 +357,26 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
|
|||||||
yield formatted_outputs(original_question, model_name)
|
yield formatted_outputs(original_question, model_name)
|
||||||
for i in tqdm(range(tokens//8+1)):
|
for i in tqdm(range(tokens//8+1)):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = eval(f"model.generate(input_ids, {','.join(generate_params)}){cuda}")
|
output = eval(f"model.generate({','.join(generate_params)}){cuda}")
|
||||||
reply = decode(output[0])
|
|
||||||
|
if soft_prompt:
|
||||||
|
output = torch.cat((input_ids[0], output[0][filler_input_ids.shape[1]:]))
|
||||||
|
else:
|
||||||
|
output = output[0]
|
||||||
|
|
||||||
|
reply = decode(output)
|
||||||
if not (args.chat or args.cai_chat):
|
if not (args.chat or args.cai_chat):
|
||||||
reply = original_question + apply_extensions(reply[len(question):], "output")
|
reply = original_question + apply_extensions(reply[len(question):], "output")
|
||||||
yield formatted_outputs(reply, model_name)
|
yield formatted_outputs(reply, model_name)
|
||||||
input_ids = output
|
|
||||||
if output[0][-1] == n:
|
input_ids = torch.reshape(output, (1, output.shape[0]))
|
||||||
|
if soft_prompt:
|
||||||
|
inputs_embeds = model.transformer.wte(input_ids)
|
||||||
|
inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1)
|
||||||
|
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device)
|
||||||
|
filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||||
|
|
||||||
|
if output[-1] == n:
|
||||||
break
|
break
|
||||||
|
|
||||||
def apply_extensions(text, typ):
|
def apply_extensions(text, typ):
|
||||||
@ -353,6 +414,9 @@ def get_available_characters():
|
|||||||
def get_available_extensions():
|
def get_available_extensions():
|
||||||
return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
|
return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
|
||||||
|
|
||||||
|
def get_available_softprompts():
|
||||||
|
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
|
||||||
|
|
||||||
def create_extensions_block():
|
def create_extensions_block():
|
||||||
extensions_ui_elements = []
|
extensions_ui_elements = []
|
||||||
default_values = []
|
default_values = []
|
||||||
@ -410,8 +474,19 @@ def create_settings_menus():
|
|||||||
min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if args.no_stream else 0, label="min_length", interactive=args.no_stream)
|
min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if args.no_stream else 0, label="min_length", interactive=args.no_stream)
|
||||||
early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping")
|
early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping")
|
||||||
|
|
||||||
|
with gr.Accordion("Soft prompt", open=False):
|
||||||
|
with gr.Row():
|
||||||
|
softprompts_menu = gr.Dropdown(choices=available_softprompts, value="None", label='Soft prompt')
|
||||||
|
create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button")
|
||||||
|
|
||||||
|
gr.Markdown('Upload a soft prompt:')
|
||||||
|
with gr.Row():
|
||||||
|
upload_softprompt = gr.File(type='binary')
|
||||||
|
|
||||||
model_menu.change(load_model_wrapper, [model_menu], [model_menu], show_progress=True)
|
model_menu.change(load_model_wrapper, [model_menu], [model_menu], show_progress=True)
|
||||||
preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping])
|
preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping])
|
||||||
|
softprompts_menu.change(load_soft_prompt, [softprompts_menu], [softprompts_menu], show_progress=True)
|
||||||
|
upload_softprompt.upload(upload_softprompt_event, [upload_softprompt], [softprompts_menu])
|
||||||
return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping
|
return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping
|
||||||
|
|
||||||
# This gets the new line characters right.
|
# This gets the new line characters right.
|
||||||
@ -718,6 +793,7 @@ available_models = get_available_models()
|
|||||||
available_presets = get_available_presets()
|
available_presets = get_available_presets()
|
||||||
available_characters = get_available_characters()
|
available_characters = get_available_characters()
|
||||||
available_extensions = get_available_extensions()
|
available_extensions = get_available_extensions()
|
||||||
|
available_softprompts = get_available_softprompts()
|
||||||
extension_state = {}
|
extension_state = {}
|
||||||
if args.extensions is not None:
|
if args.extensions is not None:
|
||||||
for i,ext in enumerate(args.extensions.split(',')):
|
for i,ext in enumerate(args.extensions.split(',')):
|
||||||
@ -746,7 +822,8 @@ else:
|
|||||||
print()
|
print()
|
||||||
model_name = available_models[i]
|
model_name = available_models[i]
|
||||||
model, tokenizer = load_model(model_name)
|
model, tokenizer = load_model(model_name)
|
||||||
loaded_preset = None
|
loaded_preset = soft_prompt_tensor = None
|
||||||
|
soft_prompt = False
|
||||||
|
|
||||||
# UI settings
|
# UI settings
|
||||||
if model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
|
if model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
|
||||||
|
0
softprompts/place-your-softprompts-here.txt
Normal file
0
softprompts/place-your-softprompts-here.txt
Normal file
Loading…
Reference in New Issue
Block a user