mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Add Mirostat v2 sampling to transformer models (#2571)
This commit is contained in:
parent
aff3e04df4
commit
b04e18d10c
@ -1,7 +1,11 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import LogitsWarper
|
from transformers import LogitsWarper
|
||||||
from transformers.generation.logits_process import LogitNormalization, LogitsProcessorList
|
from transformers.generation.logits_process import (LogitNormalization,
|
||||||
|
LogitsProcessorList,
|
||||||
|
TemperatureLogitsWarper)
|
||||||
|
|
||||||
|
|
||||||
class TailFreeLogitsWarper(LogitsWarper):
|
class TailFreeLogitsWarper(LogitsWarper):
|
||||||
@ -70,15 +74,67 @@ class TopALogitsWarper(LogitsWarper):
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class MirostatLogitsWarper(LogitsWarper):
|
||||||
|
def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
|
if mirostat_mode not in [2]:
|
||||||
|
raise ValueError(f"`mirostat` has to be a an integer 2, but is {mirostat_mode}")
|
||||||
|
self.mirostat_mode = mirostat_mode
|
||||||
|
self.mirostat_eta = mirostat_eta
|
||||||
|
self.mirostat_tau = mirostat_tau
|
||||||
|
self.filter_value = filter_value
|
||||||
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
self.mu = 2 * self.mirostat_tau
|
||||||
|
self.e = 0
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
logits = scores[0]
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
|
prob_original = torch.softmax(sorted_logits, dim=-1).tolist() # candidates
|
||||||
|
|
||||||
|
# Truncate the words with surprise values greater than mu
|
||||||
|
for i, candidate in enumerate(prob_original):
|
||||||
|
if candidate > 0 and -math.log2(candidate) > self.mu:
|
||||||
|
if (i == 0):
|
||||||
|
sorted_logits = sorted_logits[:1]
|
||||||
|
else:
|
||||||
|
sorted_logits = sorted_logits[:i]
|
||||||
|
break
|
||||||
|
|
||||||
|
# Normalize the probabilities of the remaining words
|
||||||
|
prob_topk = torch.softmax(sorted_logits, dim=0)
|
||||||
|
|
||||||
|
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to('cuda')
|
||||||
|
|
||||||
|
observed_surprise = -math.log2(prob_topk[prev_i])
|
||||||
|
self.e = observed_surprise - self.mirostat_tau
|
||||||
|
|
||||||
|
# Update mu using the learning rate and error
|
||||||
|
self.mu -= self.mirostat_eta * self.e
|
||||||
|
|
||||||
|
sorted_indices_to_remove = torch.ones_like(scores[0], dtype=torch.bool)
|
||||||
|
sorted_indices_to_remove[prev_i] = False
|
||||||
|
|
||||||
|
indices_to_remove = sorted_indices_to_remove.unsqueeze(0).scatter(1, sorted_indices.unsqueeze(0), sorted_indices_to_remove.unsqueeze(0))
|
||||||
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
def get_logits_warper_patch(self, generation_config):
|
def get_logits_warper_patch(self, generation_config):
|
||||||
warpers = self._get_logits_warper_old(generation_config)
|
warpers = self._get_logits_warper_old(generation_config)
|
||||||
warpers_to_add = LogitsProcessorList()
|
warpers_to_add = LogitsProcessorList()
|
||||||
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
|
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
|
||||||
|
|
||||||
if generation_config.tfs is not None and 0.0 <= generation_config.tfs <= 1.0:
|
if generation_config.mirostat_mode is not None and generation_config.mirostat_mode == 2:
|
||||||
warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep))
|
warpers_to_add.append(MirostatLogitsWarper(mirostat_mode=generation_config.mirostat_mode, mirostat_eta=generation_config.mirostat_eta, mirostat_tau=generation_config.mirostat_tau, min_tokens_to_keep=min_tokens_to_keep))
|
||||||
if generation_config.top_a is not None and 0.0 <= generation_config.top_a <= 1.0:
|
# We need to disable samplers other than temperature
|
||||||
warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep))
|
for warper in warpers:
|
||||||
|
if not isinstance(warper, TemperatureLogitsWarper):
|
||||||
|
warpers.remove(warper)
|
||||||
|
else:
|
||||||
|
if generation_config.tfs is not None and 0.0 <= generation_config.tfs <= 1.0:
|
||||||
|
warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep))
|
||||||
|
if generation_config.top_a is not None and 0.0 <= generation_config.top_a <= 1.0:
|
||||||
|
warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep))
|
||||||
|
|
||||||
if warpers and isinstance(warpers[-1], LogitNormalization):
|
if warpers and isinstance(warpers[-1], LogitNormalization):
|
||||||
warpers = warpers[:-1] + warpers_to_add + [warpers[-1]]
|
warpers = warpers[:-1] + warpers_to_add + [warpers[-1]]
|
||||||
@ -92,6 +148,9 @@ def generation_config_init_patch(self, **kwargs):
|
|||||||
self.__init___old(**kwargs)
|
self.__init___old(**kwargs)
|
||||||
self.tfs = kwargs.pop("tfs", 1.0)
|
self.tfs = kwargs.pop("tfs", 1.0)
|
||||||
self.top_a = kwargs.pop("top_a", 0.0)
|
self.top_a = kwargs.pop("top_a", 0.0)
|
||||||
|
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
|
||||||
|
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
|
||||||
|
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
|
||||||
|
|
||||||
|
|
||||||
def hijack_samplers():
|
def hijack_samplers():
|
||||||
|
@ -193,7 +193,7 @@ def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_c
|
|||||||
|
|
||||||
def generate_reply_HF(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False):
|
def generate_reply_HF(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False):
|
||||||
generate_params = {}
|
generate_params = {}
|
||||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a']:
|
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
|
||||||
generate_params[k] = state[k]
|
generate_params[k] = state[k]
|
||||||
|
|
||||||
for k in ['epsilon_cutoff', 'eta_cutoff']:
|
for k in ['epsilon_cutoff', 'eta_cutoff']:
|
||||||
|
@ -504,7 +504,7 @@ def create_settings_menus(default_preset):
|
|||||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
gr.Markdown('Mirostat (for llama.cpp)')
|
gr.Markdown('Mirostat (mode=1 is only for llama.cpp)')
|
||||||
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode')
|
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode')
|
||||||
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
|
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
|
||||||
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
|
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
|
||||||
|
Loading…
Reference in New Issue
Block a user