mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Add Classifier Free Guidance (CFG) for Transformers/ExLlama (#3325)
This commit is contained in:
parent
5134878344
commit
0af10ab49b
@ -63,6 +63,8 @@ async def run(user_input, history):
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5,
|
||||
'mirostat_eta': 0.1,
|
||||
'guidance_scale': 1,
|
||||
'negative_prompt': '',
|
||||
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
|
@ -57,6 +57,8 @@ def run(user_input, history):
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5,
|
||||
'mirostat_eta': 0.1,
|
||||
'guidance_scale': 1,
|
||||
'negative_prompt': '',
|
||||
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
|
@ -45,6 +45,8 @@ async def run(context):
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5,
|
||||
'mirostat_eta': 0.1,
|
||||
'guidance_scale': 1,
|
||||
'negative_prompt': '',
|
||||
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
|
@ -37,6 +37,8 @@ def run(prompt):
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5,
|
||||
'mirostat_eta': 0.1,
|
||||
'guidance_scale': 1,
|
||||
'negative_prompt': '',
|
||||
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
|
@ -43,6 +43,8 @@ def build_parameters(body, chat=False):
|
||||
'mirostat_mode': int(body.get('mirostat_mode', 0)),
|
||||
'mirostat_tau': float(body.get('mirostat_tau', 5)),
|
||||
'mirostat_eta': float(body.get('mirostat_eta', 0.1)),
|
||||
'guidance_scale': float(body.get('guidance_scale', 1)),
|
||||
'negative_prompt': str(body.get('negative_prompt', '')),
|
||||
'seed': int(body.get('seed', -1)),
|
||||
'add_bos_token': bool(body.get('add_bos_token', True)),
|
||||
'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
|
||||
|
@ -33,6 +33,8 @@ default_req_params = {
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5.0,
|
||||
'mirostat_eta': 0.1,
|
||||
'guidance_scale': 1,
|
||||
'negative_prompt': '',
|
||||
'ban_eos_token': False,
|
||||
'skip_special_tokens': True,
|
||||
'custom_stopping_strings': '',
|
||||
|
@ -1,9 +1,11 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import version as torch_version
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import clear_torch_cache
|
||||
from modules.text_generation import get_max_prompt_length
|
||||
|
||||
try:
|
||||
@ -78,6 +80,21 @@ class ExllamaModel:
|
||||
return result, result
|
||||
|
||||
def generate_with_streaming(self, prompt, state):
|
||||
|
||||
# The cache batch size must be 2 for CFG and 1 otherwise
|
||||
if state['guidance_scale'] == 1:
|
||||
if self.cache.batch_size == 2:
|
||||
del self.cache
|
||||
clear_torch_cache()
|
||||
self.cache = ExLlamaCache(self.model)
|
||||
self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
|
||||
else:
|
||||
if self.cache.batch_size == 1:
|
||||
del self.cache
|
||||
clear_torch_cache()
|
||||
self.cache = ExLlamaCache(self.model, batch_size=2)
|
||||
self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
|
||||
|
||||
self.generator.settings.temperature = state['temperature']
|
||||
self.generator.settings.top_p = state['top_p']
|
||||
self.generator.settings.top_k = state['top_k']
|
||||
@ -89,31 +106,71 @@ class ExllamaModel:
|
||||
else:
|
||||
self.generator.disallow_tokens(None)
|
||||
|
||||
self.generator.end_beam_search()
|
||||
# Case 1: no CFG
|
||||
if state['guidance_scale'] == 1:
|
||||
self.generator.end_beam_search()
|
||||
|
||||
# Tokenizing the input
|
||||
ids = self.generator.tokenizer.encode(prompt)
|
||||
ids = ids[:, -get_max_prompt_length(state):]
|
||||
if state['auto_max_new_tokens']:
|
||||
max_new_tokens = state['truncation_length'] - ids.shape[-1]
|
||||
# Tokenizing the input
|
||||
ids = self.generator.tokenizer.encode(prompt)
|
||||
ids = ids[:, -get_max_prompt_length(state):]
|
||||
if state['auto_max_new_tokens']:
|
||||
max_new_tokens = state['truncation_length'] - ids.shape[-1]
|
||||
else:
|
||||
max_new_tokens = state['max_new_tokens']
|
||||
|
||||
self.generator.gen_begin_reuse(ids)
|
||||
initial_len = self.generator.sequence[0].shape[0]
|
||||
has_leading_space = False
|
||||
|
||||
for i in range(max_new_tokens):
|
||||
token = self.generator.gen_single_token()
|
||||
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
|
||||
has_leading_space = True
|
||||
|
||||
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
|
||||
if has_leading_space:
|
||||
decoded_text = ' ' + decoded_text
|
||||
|
||||
yield decoded_text
|
||||
if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything:
|
||||
break
|
||||
|
||||
# Case 2: CFG
|
||||
else:
|
||||
max_new_tokens = state['max_new_tokens']
|
||||
alpha = state['guidance_scale']
|
||||
prompts = [prompt, state['negative_prompt'] or '']
|
||||
|
||||
self.generator.gen_begin_reuse(ids)
|
||||
initial_len = self.generator.sequence[0].shape[0]
|
||||
has_leading_space = False
|
||||
for i in range(max_new_tokens):
|
||||
token = self.generator.gen_single_token()
|
||||
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
|
||||
has_leading_space = True
|
||||
ids, mask = self.tokenizer.encode(prompts, return_mask=True)
|
||||
if state['auto_max_new_tokens']:
|
||||
max_new_tokens = state['truncation_length'] - ids[0].shape[-1]
|
||||
else:
|
||||
max_new_tokens = state['max_new_tokens']
|
||||
|
||||
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
|
||||
if has_leading_space:
|
||||
decoded_text = ' ' + decoded_text
|
||||
self.generator.gen_begin(ids, mask=mask)
|
||||
initial_len = self.generator.sequence[0].shape[0]
|
||||
has_leading_space = False
|
||||
|
||||
yield decoded_text
|
||||
if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything:
|
||||
break
|
||||
for i in range(max_new_tokens):
|
||||
logits = self.model.forward(self.generator.sequence[:, -1:], self.cache, input_mask=mask)
|
||||
self.generator.apply_rep_penalty(logits)
|
||||
|
||||
logits = F.log_softmax(logits, dim=-1)
|
||||
logits_mixed = alpha * logits[0] + (1 - alpha) * logits[1]
|
||||
|
||||
token, _ = self.generator.sample_current(logits_mixed)
|
||||
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
|
||||
has_leading_space = True
|
||||
|
||||
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
|
||||
if has_leading_space:
|
||||
decoded_text = ' ' + decoded_text
|
||||
|
||||
yield decoded_text
|
||||
if token.item() == self.tokenizer.eos_token_id or shared.stop_everything:
|
||||
break
|
||||
|
||||
batch_token = token.repeat(2, 1)
|
||||
self.generator.gen_accept_token(batch_token)
|
||||
|
||||
def generate(self, prompt, state):
|
||||
output = ''
|
||||
|
@ -47,12 +47,11 @@ class ExllamaHF(PreTrainedModel):
|
||||
return torch.device(0)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# TODO: Some decoding methods (such as Contrastive Search) may not work at this time
|
||||
assert len(args) == 0, 'no *args should be passed to forward'
|
||||
input_ids = args[0] if len(args) > 0 else kwargs['input_ids']
|
||||
use_cache = kwargs.get('use_cache', True)
|
||||
labels = kwargs.get('labels', None)
|
||||
seq = kwargs['input_ids'][0].tolist()
|
||||
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
|
||||
cache = kwargs.get('past_key_values', None)
|
||||
seq = input_ids[0].tolist()
|
||||
|
||||
if labels is None:
|
||||
if cache is None:
|
||||
@ -60,7 +59,7 @@ class ExllamaHF(PreTrainedModel):
|
||||
cache = self.ex_cache
|
||||
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True, lora=self.lora)
|
||||
|
||||
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device)
|
||||
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(input_ids.device)
|
||||
else:
|
||||
if cache is None:
|
||||
self.ex_cache.current_seq_len = 0
|
||||
|
@ -49,12 +49,11 @@ class LlamacppHF(PreTrainedModel):
|
||||
return torch.device(0)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# TODO: Some decoding methods (such as Contrastive Search) may not work at this time
|
||||
assert len(args) == 0, 'no *args should be passed to forward'
|
||||
input_ids = args[0] if len(args) > 0 else kwargs['input_ids']
|
||||
use_cache = kwargs.get('use_cache', True)
|
||||
labels = kwargs.get('labels', None)
|
||||
seq = kwargs['input_ids'][0].tolist()
|
||||
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
|
||||
cache = kwargs.get('past_key_values', None)
|
||||
seq = input_ids[0].tolist()
|
||||
|
||||
# Make the forward call
|
||||
seq_tensor = torch.tensor(seq)
|
||||
@ -70,7 +69,7 @@ class LlamacppHF(PreTrainedModel):
|
||||
self.model.reset()
|
||||
self.model.eval(seq)
|
||||
logits = torch.tensor(self.model.eval_logits)
|
||||
logits = logits.view(1, logits.shape[0], logits.shape[1]).to(kwargs['input_ids'].device)
|
||||
logits = logits.view(1, logits.shape[0], logits.shape[1]).to(input_ids.device)
|
||||
|
||||
self.cache = seq_tensor
|
||||
|
||||
|
@ -115,6 +115,8 @@ loaders_samplers = {
|
||||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'guidance_scale',
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
@ -152,6 +154,8 @@ loaders_samplers = {
|
||||
'repetition_penalty',
|
||||
'repetition_penalty_range',
|
||||
'seed',
|
||||
'guidance_scale',
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
'auto_max_new_tokens',
|
||||
},
|
||||
@ -178,6 +182,8 @@ loaders_samplers = {
|
||||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'guidance_scale',
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
@ -206,6 +212,8 @@ loaders_samplers = {
|
||||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'guidance_scale',
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
|
@ -9,6 +9,7 @@ def default_preset():
|
||||
'do_sample': True,
|
||||
'temperature': 1,
|
||||
'top_p': 1,
|
||||
'top_k': 0,
|
||||
'typical_p': 1,
|
||||
'epsilon_cutoff': 0,
|
||||
'eta_cutoff': 0,
|
||||
@ -17,19 +18,23 @@ def default_preset():
|
||||
'repetition_penalty': 1,
|
||||
'repetition_penalty_range': 0,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
'min_length': 0,
|
||||
'length_penalty': 1,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'early_stopping': False,
|
||||
'min_length': 0,
|
||||
'guidance_scale': 1,
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5.0,
|
||||
'mirostat_eta': 0.1,
|
||||
'penalty_alpha': 0,
|
||||
'num_beams': 1,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False,
|
||||
}
|
||||
|
||||
|
||||
def presets_params():
|
||||
return [k for k in default_preset()]
|
||||
|
||||
|
||||
def load_preset(name):
|
||||
generate_params = default_preset()
|
||||
if name not in ['None', None, '']:
|
||||
@ -51,12 +56,12 @@ def load_preset_memoized(name):
|
||||
def load_preset_for_ui(name, state):
|
||||
generate_params = load_preset(name)
|
||||
state.update(generate_params)
|
||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
|
||||
return state, *[generate_params[k] for k in presets_params()]
|
||||
|
||||
|
||||
def generate_preset_yaml(state):
|
||||
defaults = default_preset()
|
||||
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
|
||||
data = {k: state[k] for k in presets_params()}
|
||||
|
||||
# Remove entries that are identical to the defaults
|
||||
for k in list(data.keys()):
|
||||
|
@ -42,6 +42,7 @@ settings = {
|
||||
'max_new_tokens_max': 4096,
|
||||
'auto_max_new_tokens': False,
|
||||
'seed': -1,
|
||||
'negative_prompt': '',
|
||||
'character': 'None',
|
||||
'name1': 'You',
|
||||
'name2': 'Assistant',
|
||||
|
@ -226,9 +226,12 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
|
||||
|
||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||
generate_params = {}
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
if state['negative_prompt'] != '':
|
||||
generate_params['negative_prompt_ids'] = encode(state['negative_prompt'])
|
||||
|
||||
for k in ['epsilon_cutoff', 'eta_cutoff']:
|
||||
if state[k] > 0:
|
||||
generate_params[k] = state[k] * 1e-4
|
||||
|
@ -100,6 +100,8 @@ def list_interface_input_elements():
|
||||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'negative_prompt',
|
||||
'guidance_scale',
|
||||
'add_bos_token',
|
||||
'ban_eos_token',
|
||||
'truncation_length',
|
||||
|
@ -15,10 +15,10 @@ safetensors==0.3.1
|
||||
scipy
|
||||
sentencepiece
|
||||
tensorboard
|
||||
transformers==4.31.*
|
||||
tqdm
|
||||
wandb
|
||||
git+https://github.com/huggingface/peft@96c0277a1b9a381b10ab34dbf84917f9b3b992e6
|
||||
git+https://github.com/huggingface/transformers@d533465150532b0c5de167b574e59f64c68b1154
|
||||
bitsandbytes==0.41.1; platform_system != "Windows"
|
||||
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.3.0/auto_gptq-0.3.0+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||
|
@ -229,7 +229,7 @@ def create_model_menus():
|
||||
shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0)
|
||||
shared.gradio['autogptq_info'] = gr.Markdown('* ExLlama_HF is recommended over AutoGPTQ for models derived from LLaMA.')
|
||||
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
|
||||
shared.gradio['max_seq_len'] = gr.Slider(label='max_seq_len', minimum=2048, maximum=16384, step=256, info='Maximum sequence length.', value=shared.args.max_seq_len)
|
||||
shared.gradio['max_seq_len'] = gr.Slider(label='max_seq_len', minimum=0, maximum=16384, step=256, info='Maximum sequence length.', value=shared.args.max_seq_len)
|
||||
shared.gradio['compress_pos_emb'] = gr.Slider(label='compress_pos_emb', minimum=1, maximum=8, step=1, info='Positional embeddings compression factor. Should typically be set to max_seq_len / 2048.', value=shared.args.compress_pos_emb)
|
||||
shared.gradio['alpha_value'] = gr.Slider(label='alpha_value', minimum=1, maximum=32, step=1, info='Positional embeddings alpha factor for NTK RoPE scaling. Scaling is not identical to embedding compression. Use either this or compress_pos_emb, not both.', value=shared.args.alpha_value)
|
||||
|
||||
@ -408,6 +408,8 @@ def create_settings_menus(default_preset):
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['guidance_scale'] = gr.Slider(-0.5, 2.5, step=0.05, value=generate_params['guidance_scale'], label='guidance_scale', info='For CFG. 1.5 is a good value.')
|
||||
shared.gradio['negative_prompt'] = gr.Textbox(value=shared.settings['negative_prompt'], label='Negative prompt')
|
||||
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
|
||||
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
|
||||
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
|
||||
@ -433,7 +435,7 @@ def create_settings_menus(default_preset):
|
||||
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
|
||||
|
||||
filter_by_loader.change(loaders.blacklist_samplers, filter_by_loader, gradio(loaders.list_all_samplers()), show_progress=False)
|
||||
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a'))
|
||||
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state') + gradio(presets.presets_params()))
|
||||
|
||||
|
||||
def create_file_saving_menus():
|
||||
|
@ -5,6 +5,7 @@ max_new_tokens_min: 1
|
||||
max_new_tokens_max: 4096
|
||||
auto_max_new_tokens: false
|
||||
seed: -1
|
||||
negative_prompt: ''
|
||||
character: None
|
||||
name1: You
|
||||
name2: Assistant
|
||||
|
Loading…
Reference in New Issue
Block a user