diff --git a/docs/Extensions.md b/docs/Extensions.md index b80e1a26..724733b8 100644 --- a/docs/Extensions.md +++ b/docs/Extensions.md @@ -160,11 +160,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs): min_rows = 3 # Finding the maximum prompt size - chat_prompt_size = state['chat_prompt_size'] - if shared.soft_prompt: - chat_prompt_size -= shared.soft_prompt_tensor.shape[1] - - max_length = min(get_max_prompt_length(state), chat_prompt_size) + max_length = min(get_max_prompt_length(state), state['chat_prompt_size']) # Building the turn templates if 'turn_template' not in state or state['turn_template'] == '': diff --git a/modules/chat.py b/modules/chat.py index e78800fd..a62767cb 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -55,11 +55,7 @@ def generate_chat_prompt(user_input, state, **kwargs): is_instruct = state['mode'] == 'instruct' # Find the maximum prompt size - chat_prompt_size = state['chat_prompt_size'] - if shared.soft_prompt: - chat_prompt_size -= shared.soft_prompt_tensor.shape[1] - - max_length = min(get_max_prompt_length(state), chat_prompt_size) + max_length = min(get_max_prompt_length(state), state['chat_prompt_size']) all_substrings = { 'chat': get_turn_substrings(state, instruct=False), 'instruct': get_turn_substrings(state, instruct=True) diff --git a/modules/models.py b/modules/models.py index 3972133a..1a4eb5a0 100644 --- a/modules/models.py +++ b/modules/models.py @@ -1,12 +1,9 @@ import gc -import json import os import re import time -import zipfile from pathlib import Path -import numpy as np import torch import transformers from accelerate import infer_auto_device_map, init_empty_weights @@ -338,32 +335,3 @@ def unload_model(): def reload_model(): unload_model() shared.model, shared.tokenizer = load_model(shared.model_name) - - -def load_soft_prompt(name): - if name == 'None': - shared.soft_prompt = False - shared.soft_prompt_tensor = None - else: - with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf: - zf.extract('tensor.npy') - zf.extract('meta.json') - j = json.loads(open('meta.json', 'r').read()) - logger.info(f"\nLoading the softprompt \"{name}\".") - for field in j: - if field != 'name': - if type(j[field]) is list: - logger.info(f"{field}: {', '.join(j[field])}") - else: - logger.info(f"{field}: {j[field]}") - - tensor = np.load('tensor.npy') - Path('tensor.npy').unlink() - Path('meta.json').unlink() - - tensor = torch.Tensor(tensor).to(device=shared.model.device, dtype=shared.model.dtype) - tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1])) - shared.soft_prompt = True - shared.soft_prompt_tensor = tensor - - return name diff --git a/modules/shared.py b/modules/shared.py index 3df0cc08..9f4f720c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -12,8 +12,6 @@ tokenizer = None model_name = "None" model_type = None lora_names = [] -soft_prompt_tensor = None -soft_prompt = False # Chat variables history = {'internal': [], 'visible': []} diff --git a/modules/text_generation.py b/modules/text_generation.py index 6e7c0166..00b7cc7b 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -27,11 +27,7 @@ def generate_reply(*args, **kwargs): def get_max_prompt_length(state): - max_length = state['truncation_length'] - state['max_new_tokens'] - if shared.soft_prompt: - max_length -= shared.soft_prompt_tensor.shape[1] - - return max_length + return state['truncation_length'] - state['max_new_tokens'] def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None): @@ -80,14 +76,6 @@ def decode(output_ids, skip_special_tokens=True): return shared.tokenizer.decode(output_ids, skip_special_tokens) -def generate_softprompt_input_tensors(input_ids): - inputs_embeds = shared.model.transformer.wte(input_ids) - inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1) - filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device) - # filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens - return inputs_embeds, filler_input_ids - - # Removes empty replies from gpt4chan outputs def fix_gpt4chan(s): for i in range(10): @@ -232,18 +220,11 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None, eos_token_ids.append(int(encode(eos_token)[0][-1])) # Add the encoded tokens to generate_params - if shared.soft_prompt: - inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) - question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds) - original_input_ids = input_ids + question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None) + original_input_ids = input_ids + generate_params.update({'inputs': input_ids}) + if inputs_embeds is not None: generate_params.update({'inputs_embeds': inputs_embeds}) - generate_params.update({'inputs': filler_input_ids}) - else: - question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None) - original_input_ids = input_ids - generate_params.update({'inputs': input_ids}) - if inputs_embeds is not None: - generate_params.update({'inputs_embeds': inputs_embeds}) # Create the StoppingCriteriaList with the stopping strings (needs to be done after tokenizer extensions) stopping_criteria_list = transformers.StoppingCriteriaList() @@ -269,9 +250,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None, if cuda: output = output.cuda() - if shared.soft_prompt: - output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat) # Stream the reply 1 token at a time. @@ -289,9 +267,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None, with generate_with_streaming(**generate_params) as generator: for output in generator: - if shared.soft_prompt: - output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat) if output[-1] in eos_token_ids: break diff --git a/modules/utils.py b/modules/utils.py index 84ca997f..4fa9f868 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -60,10 +60,6 @@ def get_available_extensions(): return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys) -def get_available_softprompts(): - return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=natural_keys) - - def get_available_loras(): return sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=natural_keys) diff --git a/server.py b/server.py index ce7086a5..d047801d 100644 --- a/server.py +++ b/server.py @@ -26,7 +26,6 @@ import matplotlib matplotlib.use('Agg') # This fixes LaTeX rendering on some systems import importlib -import io import json import math import os @@ -34,7 +33,6 @@ import re import sys import time import traceback -import zipfile from datetime import datetime from functools import partial from pathlib import Path @@ -50,7 +48,7 @@ from modules import chat, shared, training, ui, utils from modules.extensions import apply_extensions from modules.html_generator import chat_html_wrapper from modules.LoRA import add_lora_to_model -from modules.models import load_model, load_soft_prompt, unload_model +from modules.models import load_model, unload_model from modules.text_generation import (generate_reply_wrapper, get_encoded_length, stop_everything_event) @@ -119,19 +117,6 @@ def load_preset_values(preset_menu, state, return_dict=False): return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']] -def upload_soft_prompt(file): - with zipfile.ZipFile(io.BytesIO(file)) as zf: - zf.extract('meta.json') - j = json.loads(open('meta.json', 'r').read()) - name = j['name'] - Path('meta.json').unlink() - - with open(Path(f'softprompts/{name}.zip'), 'wb') as f: - f.write(file) - - return name - - def open_save_prompt(): fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}" return gr.update(value=fname, visible=True), gr.update(visible=False), gr.update(visible=True) @@ -510,16 +495,6 @@ def create_settings_menus(default_preset): shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau') shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta') - gr.Markdown('Other') - with gr.Accordion('Soft prompt', open=False): - with gr.Row(): - shared.gradio['softprompts_menu'] = gr.Dropdown(choices=utils.get_available_softprompts(), value='None', label='Soft prompt') - ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda: None, lambda: {'choices': utils.get_available_softprompts()}, 'refresh-button') - - gr.Markdown('Upload a soft prompt (.zip format):') - with gr.Row(): - shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip']) - with gr.Box(): with gr.Row(): with gr.Column(): @@ -535,8 +510,6 @@ def create_settings_menus(default_preset): gr.Markdown('[Click here for more information.](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Generation-parameters.md)') shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]) - shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True) - shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu']) def set_interface_arguments(interface_mode, extensions, bool_active):