mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-11 21:10:40 +01:00
Add softprompt support (for real this time)
Is this too much voodoo for our purposes?
This commit is contained in:
parent
aa1177ff15
commit
3277b751f5
@ -1,5 +1,5 @@
|
||||
params = {
|
||||
"soft prompt": " *I speak in an annoyingly cute way*",
|
||||
"bias string": " *I speak in an annoyingly cute way*",
|
||||
}
|
||||
|
||||
def input_modifier(string):
|
||||
@ -24,4 +24,4 @@ def bot_prefix_modifier(string):
|
||||
behavior.
|
||||
"""
|
||||
|
||||
return string + params["soft prompt"]
|
||||
return string + params["bias string"]
|
@ -1,6 +1,7 @@
|
||||
accelerate==0.15.0
|
||||
beautifulsoup4
|
||||
bitsandbytes==0.37.0
|
||||
gradio==3.15.0
|
||||
transformers==4.25.1
|
||||
numpy
|
||||
safetensors==0.2.8
|
||||
beautifulsoup4
|
||||
git+https://github.com/huggingface/transformers
|
||||
|
95
server.py
95
server.py
@ -10,10 +10,12 @@ import re
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from PIL import Image
|
||||
@ -157,6 +159,37 @@ def load_model(model_name):
|
||||
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||
return model, tokenizer
|
||||
|
||||
def load_soft_prompt(name):
|
||||
global soft_prompt, soft_prompt_tensor
|
||||
|
||||
if name == 'None':
|
||||
soft_prompt = False
|
||||
soft_prompt_tensor = None
|
||||
else:
|
||||
with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
|
||||
zf.extract('tensor.npy')
|
||||
tensor = np.load('tensor.npy')
|
||||
tensor = torch.Tensor(tensor).to(device=model.device, dtype=model.dtype)
|
||||
tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
|
||||
|
||||
soft_prompt = True
|
||||
soft_prompt_tensor = tensor
|
||||
|
||||
return name
|
||||
|
||||
def upload_softprompt_event(file):
|
||||
with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
||||
zf.extract('meta.json')
|
||||
j = json.loads(open('meta.json', 'r').read())
|
||||
name = j['name']
|
||||
|
||||
with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
|
||||
f.write(file)
|
||||
|
||||
load_soft_prompt(name)
|
||||
|
||||
return name
|
||||
|
||||
def load_model_wrapper(selected_model):
|
||||
global model_name, model, tokenizer
|
||||
|
||||
@ -244,7 +277,7 @@ def formatted_outputs(reply, model_name):
|
||||
return reply
|
||||
|
||||
def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
|
||||
global model_name, model, tokenizer
|
||||
global model_name, model, tokenizer, soft_prompt, soft_prompt_tensor
|
||||
|
||||
original_question = question
|
||||
if not (args.chat or args.cai_chat):
|
||||
@ -292,14 +325,29 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
|
||||
else:
|
||||
generate_params.append(f"max_new_tokens=8")
|
||||
|
||||
if soft_prompt:
|
||||
inputs_embeds = model.transformer.wte(input_ids)
|
||||
inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1)
|
||||
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device)
|
||||
filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||
generate_params.insert(0, "inputs_embeds=inputs_embeds")
|
||||
generate_params.insert(0, "filler_input_ids")
|
||||
else:
|
||||
filler_input_ids = None
|
||||
generate_params.insert(0, "input_ids")
|
||||
|
||||
# Generate the entire reply at once
|
||||
if args.no_stream:
|
||||
t0 = time.time()
|
||||
with torch.no_grad():
|
||||
output = eval(f"model.generate(input_ids, {','.join(generate_params)}){cuda}")
|
||||
reply = decode(output[0])
|
||||
output = eval(f"model.generate({','.join(generate_params)}){cuda}")
|
||||
if soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[0][filler_input_ids.shape[1]:]))
|
||||
else:
|
||||
output = output[0]
|
||||
reply = decode(output)
|
||||
t1 = time.time()
|
||||
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output[0])-len(input_ids[0])} tokens)")
|
||||
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
|
||||
if not (args.chat or args.cai_chat):
|
||||
reply = original_question + apply_extensions(reply[len(question):], "output")
|
||||
yield formatted_outputs(reply, model_name)
|
||||
@ -309,13 +357,26 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
|
||||
yield formatted_outputs(original_question, model_name)
|
||||
for i in tqdm(range(tokens//8+1)):
|
||||
with torch.no_grad():
|
||||
output = eval(f"model.generate(input_ids, {','.join(generate_params)}){cuda}")
|
||||
reply = decode(output[0])
|
||||
output = eval(f"model.generate({','.join(generate_params)}){cuda}")
|
||||
|
||||
if soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[0][filler_input_ids.shape[1]:]))
|
||||
else:
|
||||
output = output[0]
|
||||
|
||||
reply = decode(output)
|
||||
if not (args.chat or args.cai_chat):
|
||||
reply = original_question + apply_extensions(reply[len(question):], "output")
|
||||
yield formatted_outputs(reply, model_name)
|
||||
input_ids = output
|
||||
if output[0][-1] == n:
|
||||
|
||||
input_ids = torch.reshape(output, (1, output.shape[0]))
|
||||
if soft_prompt:
|
||||
inputs_embeds = model.transformer.wte(input_ids)
|
||||
inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1)
|
||||
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device)
|
||||
filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||
|
||||
if output[-1] == n:
|
||||
break
|
||||
|
||||
def apply_extensions(text, typ):
|
||||
@ -353,6 +414,9 @@ def get_available_characters():
|
||||
def get_available_extensions():
|
||||
return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
|
||||
|
||||
def get_available_softprompts():
|
||||
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
|
||||
|
||||
def create_extensions_block():
|
||||
extensions_ui_elements = []
|
||||
default_values = []
|
||||
@ -410,8 +474,19 @@ def create_settings_menus():
|
||||
min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if args.no_stream else 0, label="min_length", interactive=args.no_stream)
|
||||
early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping")
|
||||
|
||||
with gr.Accordion("Soft prompt", open=False):
|
||||
with gr.Row():
|
||||
softprompts_menu = gr.Dropdown(choices=available_softprompts, value="None", label='Soft prompt')
|
||||
create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button")
|
||||
|
||||
gr.Markdown('Upload a soft prompt:')
|
||||
with gr.Row():
|
||||
upload_softprompt = gr.File(type='binary')
|
||||
|
||||
model_menu.change(load_model_wrapper, [model_menu], [model_menu], show_progress=True)
|
||||
preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping])
|
||||
softprompts_menu.change(load_soft_prompt, [softprompts_menu], [softprompts_menu], show_progress=True)
|
||||
upload_softprompt.upload(upload_softprompt_event, [upload_softprompt], [softprompts_menu])
|
||||
return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping
|
||||
|
||||
# This gets the new line characters right.
|
||||
@ -718,6 +793,7 @@ available_models = get_available_models()
|
||||
available_presets = get_available_presets()
|
||||
available_characters = get_available_characters()
|
||||
available_extensions = get_available_extensions()
|
||||
available_softprompts = get_available_softprompts()
|
||||
extension_state = {}
|
||||
if args.extensions is not None:
|
||||
for i,ext in enumerate(args.extensions.split(',')):
|
||||
@ -746,7 +822,8 @@ else:
|
||||
print()
|
||||
model_name = available_models[i]
|
||||
model, tokenizer = load_model(model_name)
|
||||
loaded_preset = None
|
||||
loaded_preset = soft_prompt_tensor = None
|
||||
soft_prompt = False
|
||||
|
||||
# UI settings
|
||||
if model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
|
||||
|
0
softprompts/place-your-softprompts-here.txt
Normal file
0
softprompts/place-your-softprompts-here.txt
Normal file
Loading…
x
Reference in New Issue
Block a user