mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Remove softprompt support
This commit is contained in:
parent
ccb4c9f178
commit
00b94847da
@ -160,11 +160,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
min_rows = 3
|
min_rows = 3
|
||||||
|
|
||||||
# Finding the maximum prompt size
|
# Finding the maximum prompt size
|
||||||
chat_prompt_size = state['chat_prompt_size']
|
max_length = min(get_max_prompt_length(state), state['chat_prompt_size'])
|
||||||
if shared.soft_prompt:
|
|
||||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
|
||||||
|
|
||||||
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
|
||||||
|
|
||||||
# Building the turn templates
|
# Building the turn templates
|
||||||
if 'turn_template' not in state or state['turn_template'] == '':
|
if 'turn_template' not in state or state['turn_template'] == '':
|
||||||
|
@ -55,11 +55,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
is_instruct = state['mode'] == 'instruct'
|
is_instruct = state['mode'] == 'instruct'
|
||||||
|
|
||||||
# Find the maximum prompt size
|
# Find the maximum prompt size
|
||||||
chat_prompt_size = state['chat_prompt_size']
|
max_length = min(get_max_prompt_length(state), state['chat_prompt_size'])
|
||||||
if shared.soft_prompt:
|
|
||||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
|
||||||
|
|
||||||
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
|
||||||
all_substrings = {
|
all_substrings = {
|
||||||
'chat': get_turn_substrings(state, instruct=False),
|
'chat': get_turn_substrings(state, instruct=False),
|
||||||
'instruct': get_turn_substrings(state, instruct=True)
|
'instruct': get_turn_substrings(state, instruct=True)
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
import gc
|
import gc
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import zipfile
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import infer_auto_device_map, init_empty_weights
|
from accelerate import infer_auto_device_map, init_empty_weights
|
||||||
@ -338,32 +335,3 @@ def unload_model():
|
|||||||
def reload_model():
|
def reload_model():
|
||||||
unload_model()
|
unload_model()
|
||||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||||
|
|
||||||
|
|
||||||
def load_soft_prompt(name):
|
|
||||||
if name == 'None':
|
|
||||||
shared.soft_prompt = False
|
|
||||||
shared.soft_prompt_tensor = None
|
|
||||||
else:
|
|
||||||
with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
|
|
||||||
zf.extract('tensor.npy')
|
|
||||||
zf.extract('meta.json')
|
|
||||||
j = json.loads(open('meta.json', 'r').read())
|
|
||||||
logger.info(f"\nLoading the softprompt \"{name}\".")
|
|
||||||
for field in j:
|
|
||||||
if field != 'name':
|
|
||||||
if type(j[field]) is list:
|
|
||||||
logger.info(f"{field}: {', '.join(j[field])}")
|
|
||||||
else:
|
|
||||||
logger.info(f"{field}: {j[field]}")
|
|
||||||
|
|
||||||
tensor = np.load('tensor.npy')
|
|
||||||
Path('tensor.npy').unlink()
|
|
||||||
Path('meta.json').unlink()
|
|
||||||
|
|
||||||
tensor = torch.Tensor(tensor).to(device=shared.model.device, dtype=shared.model.dtype)
|
|
||||||
tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
|
|
||||||
shared.soft_prompt = True
|
|
||||||
shared.soft_prompt_tensor = tensor
|
|
||||||
|
|
||||||
return name
|
|
||||||
|
@ -12,8 +12,6 @@ tokenizer = None
|
|||||||
model_name = "None"
|
model_name = "None"
|
||||||
model_type = None
|
model_type = None
|
||||||
lora_names = []
|
lora_names = []
|
||||||
soft_prompt_tensor = None
|
|
||||||
soft_prompt = False
|
|
||||||
|
|
||||||
# Chat variables
|
# Chat variables
|
||||||
history = {'internal': [], 'visible': []}
|
history = {'internal': [], 'visible': []}
|
||||||
|
@ -27,11 +27,7 @@ def generate_reply(*args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def get_max_prompt_length(state):
|
def get_max_prompt_length(state):
|
||||||
max_length = state['truncation_length'] - state['max_new_tokens']
|
return state['truncation_length'] - state['max_new_tokens']
|
||||||
if shared.soft_prompt:
|
|
||||||
max_length -= shared.soft_prompt_tensor.shape[1]
|
|
||||||
|
|
||||||
return max_length
|
|
||||||
|
|
||||||
|
|
||||||
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
||||||
@ -80,14 +76,6 @@ def decode(output_ids, skip_special_tokens=True):
|
|||||||
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
||||||
|
|
||||||
|
|
||||||
def generate_softprompt_input_tensors(input_ids):
|
|
||||||
inputs_embeds = shared.model.transformer.wte(input_ids)
|
|
||||||
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
|
||||||
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
|
|
||||||
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
|
||||||
return inputs_embeds, filler_input_ids
|
|
||||||
|
|
||||||
|
|
||||||
# Removes empty replies from gpt4chan outputs
|
# Removes empty replies from gpt4chan outputs
|
||||||
def fix_gpt4chan(s):
|
def fix_gpt4chan(s):
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
@ -232,13 +220,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||||||
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
||||||
|
|
||||||
# Add the encoded tokens to generate_params
|
# Add the encoded tokens to generate_params
|
||||||
if shared.soft_prompt:
|
|
||||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
|
||||||
question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds)
|
|
||||||
original_input_ids = input_ids
|
|
||||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
|
||||||
generate_params.update({'inputs': filler_input_ids})
|
|
||||||
else:
|
|
||||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||||
original_input_ids = input_ids
|
original_input_ids = input_ids
|
||||||
generate_params.update({'inputs': input_ids})
|
generate_params.update({'inputs': input_ids})
|
||||||
@ -269,9 +250,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||||||
if cuda:
|
if cuda:
|
||||||
output = output.cuda()
|
output = output.cuda()
|
||||||
|
|
||||||
if shared.soft_prompt:
|
|
||||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
|
||||||
|
|
||||||
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
|
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
|
||||||
|
|
||||||
# Stream the reply 1 token at a time.
|
# Stream the reply 1 token at a time.
|
||||||
@ -289,9 +267,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||||||
|
|
||||||
with generate_with_streaming(**generate_params) as generator:
|
with generate_with_streaming(**generate_params) as generator:
|
||||||
for output in generator:
|
for output in generator:
|
||||||
if shared.soft_prompt:
|
|
||||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
|
||||||
|
|
||||||
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
|
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
|
||||||
if output[-1] in eos_token_ids:
|
if output[-1] in eos_token_ids:
|
||||||
break
|
break
|
||||||
|
@ -60,10 +60,6 @@ def get_available_extensions():
|
|||||||
return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys)
|
return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys)
|
||||||
|
|
||||||
|
|
||||||
def get_available_softprompts():
|
|
||||||
return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=natural_keys)
|
|
||||||
|
|
||||||
|
|
||||||
def get_available_loras():
|
def get_available_loras():
|
||||||
return sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=natural_keys)
|
return sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=natural_keys)
|
||||||
|
|
||||||
|
29
server.py
29
server.py
@ -26,7 +26,6 @@ import matplotlib
|
|||||||
matplotlib.use('Agg') # This fixes LaTeX rendering on some systems
|
matplotlib.use('Agg') # This fixes LaTeX rendering on some systems
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import io
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -34,7 +33,6 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import zipfile
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -50,7 +48,7 @@ from modules import chat, shared, training, ui, utils
|
|||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.html_generator import chat_html_wrapper
|
from modules.html_generator import chat_html_wrapper
|
||||||
from modules.LoRA import add_lora_to_model
|
from modules.LoRA import add_lora_to_model
|
||||||
from modules.models import load_model, load_soft_prompt, unload_model
|
from modules.models import load_model, unload_model
|
||||||
from modules.text_generation import (generate_reply_wrapper,
|
from modules.text_generation import (generate_reply_wrapper,
|
||||||
get_encoded_length, stop_everything_event)
|
get_encoded_length, stop_everything_event)
|
||||||
|
|
||||||
@ -119,19 +117,6 @@ def load_preset_values(preset_menu, state, return_dict=False):
|
|||||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
|
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
|
||||||
|
|
||||||
|
|
||||||
def upload_soft_prompt(file):
|
|
||||||
with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
|
||||||
zf.extract('meta.json')
|
|
||||||
j = json.loads(open('meta.json', 'r').read())
|
|
||||||
name = j['name']
|
|
||||||
Path('meta.json').unlink()
|
|
||||||
|
|
||||||
with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
|
|
||||||
f.write(file)
|
|
||||||
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
def open_save_prompt():
|
def open_save_prompt():
|
||||||
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
|
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
|
||||||
return gr.update(value=fname, visible=True), gr.update(visible=False), gr.update(visible=True)
|
return gr.update(value=fname, visible=True), gr.update(visible=False), gr.update(visible=True)
|
||||||
@ -510,16 +495,6 @@ def create_settings_menus(default_preset):
|
|||||||
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
|
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
|
||||||
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
|
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
|
||||||
|
|
||||||
gr.Markdown('Other')
|
|
||||||
with gr.Accordion('Soft prompt', open=False):
|
|
||||||
with gr.Row():
|
|
||||||
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=utils.get_available_softprompts(), value='None', label='Soft prompt')
|
|
||||||
ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda: None, lambda: {'choices': utils.get_available_softprompts()}, 'refresh-button')
|
|
||||||
|
|
||||||
gr.Markdown('Upload a soft prompt (.zip format):')
|
|
||||||
with gr.Row():
|
|
||||||
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
|
|
||||||
|
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@ -535,8 +510,6 @@ def create_settings_menus(default_preset):
|
|||||||
gr.Markdown('[Click here for more information.](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Generation-parameters.md)')
|
gr.Markdown('[Click here for more information.](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Generation-parameters.md)')
|
||||||
|
|
||||||
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']])
|
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']])
|
||||||
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
|
|
||||||
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
|
|
||||||
|
|
||||||
|
|
||||||
def set_interface_arguments(interface_mode, extensions, bool_active):
|
def set_interface_arguments(interface_mode, extensions, bool_active):
|
||||||
|
Loading…
Reference in New Issue
Block a user