mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Remove softprompt support
This commit is contained in:
parent
ccb4c9f178
commit
00b94847da
@ -160,11 +160,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
min_rows = 3
|
||||
|
||||
# Finding the maximum prompt size
|
||||
chat_prompt_size = state['chat_prompt_size']
|
||||
if shared.soft_prompt:
|
||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||
|
||||
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
||||
max_length = min(get_max_prompt_length(state), state['chat_prompt_size'])
|
||||
|
||||
# Building the turn templates
|
||||
if 'turn_template' not in state or state['turn_template'] == '':
|
||||
|
@ -55,11 +55,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||
is_instruct = state['mode'] == 'instruct'
|
||||
|
||||
# Find the maximum prompt size
|
||||
chat_prompt_size = state['chat_prompt_size']
|
||||
if shared.soft_prompt:
|
||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||
|
||||
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
||||
max_length = min(get_max_prompt_length(state), state['chat_prompt_size'])
|
||||
all_substrings = {
|
||||
'chat': get_turn_substrings(state, instruct=False),
|
||||
'instruct': get_turn_substrings(state, instruct=True)
|
||||
|
@ -1,12 +1,9 @@
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from accelerate import infer_auto_device_map, init_empty_weights
|
||||
@ -338,32 +335,3 @@ def unload_model():
|
||||
def reload_model():
|
||||
unload_model()
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
|
||||
|
||||
def load_soft_prompt(name):
|
||||
if name == 'None':
|
||||
shared.soft_prompt = False
|
||||
shared.soft_prompt_tensor = None
|
||||
else:
|
||||
with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
|
||||
zf.extract('tensor.npy')
|
||||
zf.extract('meta.json')
|
||||
j = json.loads(open('meta.json', 'r').read())
|
||||
logger.info(f"\nLoading the softprompt \"{name}\".")
|
||||
for field in j:
|
||||
if field != 'name':
|
||||
if type(j[field]) is list:
|
||||
logger.info(f"{field}: {', '.join(j[field])}")
|
||||
else:
|
||||
logger.info(f"{field}: {j[field]}")
|
||||
|
||||
tensor = np.load('tensor.npy')
|
||||
Path('tensor.npy').unlink()
|
||||
Path('meta.json').unlink()
|
||||
|
||||
tensor = torch.Tensor(tensor).to(device=shared.model.device, dtype=shared.model.dtype)
|
||||
tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
|
||||
shared.soft_prompt = True
|
||||
shared.soft_prompt_tensor = tensor
|
||||
|
||||
return name
|
||||
|
@ -12,8 +12,6 @@ tokenizer = None
|
||||
model_name = "None"
|
||||
model_type = None
|
||||
lora_names = []
|
||||
soft_prompt_tensor = None
|
||||
soft_prompt = False
|
||||
|
||||
# Chat variables
|
||||
history = {'internal': [], 'visible': []}
|
||||
|
@ -27,11 +27,7 @@ def generate_reply(*args, **kwargs):
|
||||
|
||||
|
||||
def get_max_prompt_length(state):
|
||||
max_length = state['truncation_length'] - state['max_new_tokens']
|
||||
if shared.soft_prompt:
|
||||
max_length -= shared.soft_prompt_tensor.shape[1]
|
||||
|
||||
return max_length
|
||||
return state['truncation_length'] - state['max_new_tokens']
|
||||
|
||||
|
||||
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
||||
@ -80,14 +76,6 @@ def decode(output_ids, skip_special_tokens=True):
|
||||
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
||||
|
||||
|
||||
def generate_softprompt_input_tensors(input_ids):
|
||||
inputs_embeds = shared.model.transformer.wte(input_ids)
|
||||
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
||||
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
|
||||
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||
return inputs_embeds, filler_input_ids
|
||||
|
||||
|
||||
# Removes empty replies from gpt4chan outputs
|
||||
def fix_gpt4chan(s):
|
||||
for i in range(10):
|
||||
@ -232,18 +220,11 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
||||
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
||||
|
||||
# Add the encoded tokens to generate_params
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds)
|
||||
original_input_ids = input_ids
|
||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||
original_input_ids = input_ids
|
||||
generate_params.update({'inputs': input_ids})
|
||||
if inputs_embeds is not None:
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
generate_params.update({'inputs': filler_input_ids})
|
||||
else:
|
||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||
original_input_ids = input_ids
|
||||
generate_params.update({'inputs': input_ids})
|
||||
if inputs_embeds is not None:
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
|
||||
# Create the StoppingCriteriaList with the stopping strings (needs to be done after tokenizer extensions)
|
||||
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||
@ -269,9 +250,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
||||
if cuda:
|
||||
output = output.cuda()
|
||||
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
|
||||
|
||||
# Stream the reply 1 token at a time.
|
||||
@ -289,9 +267,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
||||
|
||||
with generate_with_streaming(**generate_params) as generator:
|
||||
for output in generator:
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
|
||||
if output[-1] in eos_token_ids:
|
||||
break
|
||||
|
@ -60,10 +60,6 @@ def get_available_extensions():
|
||||
return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys)
|
||||
|
||||
|
||||
def get_available_softprompts():
|
||||
return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=natural_keys)
|
||||
|
||||
|
||||
def get_available_loras():
|
||||
return sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=natural_keys)
|
||||
|
||||
|
29
server.py
29
server.py
@ -26,7 +26,6 @@ import matplotlib
|
||||
matplotlib.use('Agg') # This fixes LaTeX rendering on some systems
|
||||
|
||||
import importlib
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
@ -34,7 +33,6 @@ import re
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
@ -50,7 +48,7 @@ from modules import chat, shared, training, ui, utils
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.html_generator import chat_html_wrapper
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, load_soft_prompt, unload_model
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.text_generation import (generate_reply_wrapper,
|
||||
get_encoded_length, stop_everything_event)
|
||||
|
||||
@ -119,19 +117,6 @@ def load_preset_values(preset_menu, state, return_dict=False):
|
||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
|
||||
|
||||
|
||||
def upload_soft_prompt(file):
|
||||
with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
||||
zf.extract('meta.json')
|
||||
j = json.loads(open('meta.json', 'r').read())
|
||||
name = j['name']
|
||||
Path('meta.json').unlink()
|
||||
|
||||
with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
|
||||
f.write(file)
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def open_save_prompt():
|
||||
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
|
||||
return gr.update(value=fname, visible=True), gr.update(visible=False), gr.update(visible=True)
|
||||
@ -510,16 +495,6 @@ def create_settings_menus(default_preset):
|
||||
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
|
||||
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
|
||||
|
||||
gr.Markdown('Other')
|
||||
with gr.Accordion('Soft prompt', open=False):
|
||||
with gr.Row():
|
||||
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=utils.get_available_softprompts(), value='None', label='Soft prompt')
|
||||
ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda: None, lambda: {'choices': utils.get_available_softprompts()}, 'refresh-button')
|
||||
|
||||
gr.Markdown('Upload a soft prompt (.zip format):')
|
||||
with gr.Row():
|
||||
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
|
||||
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
@ -535,8 +510,6 @@ def create_settings_menus(default_preset):
|
||||
gr.Markdown('[Click here for more information.](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Generation-parameters.md)')
|
||||
|
||||
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']])
|
||||
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
|
||||
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
|
||||
|
||||
|
||||
def set_interface_arguments(interface_mode, extensions, bool_active):
|
||||
|
Loading…
Reference in New Issue
Block a user