Remove softprompt support

This commit is contained in:
oobabooga 2023-06-06 07:42:23 -03:00
parent ccb4c9f178
commit 00b94847da
7 changed files with 8 additions and 106 deletions

View File

@ -160,11 +160,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
min_rows = 3 min_rows = 3
# Finding the maximum prompt size # Finding the maximum prompt size
chat_prompt_size = state['chat_prompt_size'] max_length = min(get_max_prompt_length(state), state['chat_prompt_size'])
if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(state), chat_prompt_size)
# Building the turn templates # Building the turn templates
if 'turn_template' not in state or state['turn_template'] == '': if 'turn_template' not in state or state['turn_template'] == '':

View File

@ -55,11 +55,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
is_instruct = state['mode'] == 'instruct' is_instruct = state['mode'] == 'instruct'
# Find the maximum prompt size # Find the maximum prompt size
chat_prompt_size = state['chat_prompt_size'] max_length = min(get_max_prompt_length(state), state['chat_prompt_size'])
if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(state), chat_prompt_size)
all_substrings = { all_substrings = {
'chat': get_turn_substrings(state, instruct=False), 'chat': get_turn_substrings(state, instruct=False),
'instruct': get_turn_substrings(state, instruct=True) 'instruct': get_turn_substrings(state, instruct=True)

View File

@ -1,12 +1,9 @@
import gc import gc
import json
import os import os
import re import re
import time import time
import zipfile
from pathlib import Path from pathlib import Path
import numpy as np
import torch import torch
import transformers import transformers
from accelerate import infer_auto_device_map, init_empty_weights from accelerate import infer_auto_device_map, init_empty_weights
@ -338,32 +335,3 @@ def unload_model():
def reload_model(): def reload_model():
unload_model() unload_model()
shared.model, shared.tokenizer = load_model(shared.model_name) shared.model, shared.tokenizer = load_model(shared.model_name)
def load_soft_prompt(name):
if name == 'None':
shared.soft_prompt = False
shared.soft_prompt_tensor = None
else:
with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
zf.extract('tensor.npy')
zf.extract('meta.json')
j = json.loads(open('meta.json', 'r').read())
logger.info(f"\nLoading the softprompt \"{name}\".")
for field in j:
if field != 'name':
if type(j[field]) is list:
logger.info(f"{field}: {', '.join(j[field])}")
else:
logger.info(f"{field}: {j[field]}")
tensor = np.load('tensor.npy')
Path('tensor.npy').unlink()
Path('meta.json').unlink()
tensor = torch.Tensor(tensor).to(device=shared.model.device, dtype=shared.model.dtype)
tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
shared.soft_prompt = True
shared.soft_prompt_tensor = tensor
return name

View File

@ -12,8 +12,6 @@ tokenizer = None
model_name = "None" model_name = "None"
model_type = None model_type = None
lora_names = [] lora_names = []
soft_prompt_tensor = None
soft_prompt = False
# Chat variables # Chat variables
history = {'internal': [], 'visible': []} history = {'internal': [], 'visible': []}

View File

@ -27,11 +27,7 @@ def generate_reply(*args, **kwargs):
def get_max_prompt_length(state): def get_max_prompt_length(state):
max_length = state['truncation_length'] - state['max_new_tokens'] return state['truncation_length'] - state['max_new_tokens']
if shared.soft_prompt:
max_length -= shared.soft_prompt_tensor.shape[1]
return max_length
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None): def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
@ -80,14 +76,6 @@ def decode(output_ids, skip_special_tokens=True):
return shared.tokenizer.decode(output_ids, skip_special_tokens) return shared.tokenizer.decode(output_ids, skip_special_tokens)
def generate_softprompt_input_tensors(input_ids):
inputs_embeds = shared.model.transformer.wte(input_ids)
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
return inputs_embeds, filler_input_ids
# Removes empty replies from gpt4chan outputs # Removes empty replies from gpt4chan outputs
def fix_gpt4chan(s): def fix_gpt4chan(s):
for i in range(10): for i in range(10):
@ -232,18 +220,11 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
eos_token_ids.append(int(encode(eos_token)[0][-1])) eos_token_ids.append(int(encode(eos_token)[0][-1]))
# Add the encoded tokens to generate_params # Add the encoded tokens to generate_params
if shared.soft_prompt: question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) original_input_ids = input_ids
question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds) generate_params.update({'inputs': input_ids})
original_input_ids = input_ids if inputs_embeds is not None:
generate_params.update({'inputs_embeds': inputs_embeds}) generate_params.update({'inputs_embeds': inputs_embeds})
generate_params.update({'inputs': filler_input_ids})
else:
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
original_input_ids = input_ids
generate_params.update({'inputs': input_ids})
if inputs_embeds is not None:
generate_params.update({'inputs_embeds': inputs_embeds})
# Create the StoppingCriteriaList with the stopping strings (needs to be done after tokenizer extensions) # Create the StoppingCriteriaList with the stopping strings (needs to be done after tokenizer extensions)
stopping_criteria_list = transformers.StoppingCriteriaList() stopping_criteria_list = transformers.StoppingCriteriaList()
@ -269,9 +250,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
if cuda: if cuda:
output = output.cuda() output = output.cuda()
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat) yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
# Stream the reply 1 token at a time. # Stream the reply 1 token at a time.
@ -289,9 +267,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
with generate_with_streaming(**generate_params) as generator: with generate_with_streaming(**generate_params) as generator:
for output in generator: for output in generator:
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat) yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
if output[-1] in eos_token_ids: if output[-1] in eos_token_ids:
break break

View File

@ -60,10 +60,6 @@ def get_available_extensions():
return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys) return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys)
def get_available_softprompts():
return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=natural_keys)
def get_available_loras(): def get_available_loras():
return sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=natural_keys) return sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=natural_keys)

View File

@ -26,7 +26,6 @@ import matplotlib
matplotlib.use('Agg') # This fixes LaTeX rendering on some systems matplotlib.use('Agg') # This fixes LaTeX rendering on some systems
import importlib import importlib
import io
import json import json
import math import math
import os import os
@ -34,7 +33,6 @@ import re
import sys import sys
import time import time
import traceback import traceback
import zipfile
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
@ -50,7 +48,7 @@ from modules import chat, shared, training, ui, utils
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.html_generator import chat_html_wrapper from modules.html_generator import chat_html_wrapper
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt, unload_model from modules.models import load_model, unload_model
from modules.text_generation import (generate_reply_wrapper, from modules.text_generation import (generate_reply_wrapper,
get_encoded_length, stop_everything_event) get_encoded_length, stop_everything_event)
@ -119,19 +117,6 @@ def load_preset_values(preset_menu, state, return_dict=False):
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']] return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
def upload_soft_prompt(file):
with zipfile.ZipFile(io.BytesIO(file)) as zf:
zf.extract('meta.json')
j = json.loads(open('meta.json', 'r').read())
name = j['name']
Path('meta.json').unlink()
with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
f.write(file)
return name
def open_save_prompt(): def open_save_prompt():
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}" fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
return gr.update(value=fname, visible=True), gr.update(visible=False), gr.update(visible=True) return gr.update(value=fname, visible=True), gr.update(visible=False), gr.update(visible=True)
@ -510,16 +495,6 @@ def create_settings_menus(default_preset):
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau') shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta') shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
gr.Markdown('Other')
with gr.Accordion('Soft prompt', open=False):
with gr.Row():
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=utils.get_available_softprompts(), value='None', label='Soft prompt')
ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda: None, lambda: {'choices': utils.get_available_softprompts()}, 'refresh-button')
gr.Markdown('Upload a soft prompt (.zip format):')
with gr.Row():
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
with gr.Box(): with gr.Box():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -535,8 +510,6 @@ def create_settings_menus(default_preset):
gr.Markdown('[Click here for more information.](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Generation-parameters.md)') gr.Markdown('[Click here for more information.](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Generation-parameters.md)')
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]) shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']])
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
def set_interface_arguments(interface_mode, extensions, bool_active): def set_interface_arguments(interface_mode, extensions, bool_active):