diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index f02bea4f..447c8782 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -1,7 +1,11 @@ +import math + import torch import transformers from transformers import LogitsWarper -from transformers.generation.logits_process import LogitNormalization, LogitsProcessorList +from transformers.generation.logits_process import (LogitNormalization, + LogitsProcessorList, + TemperatureLogitsWarper) class TailFreeLogitsWarper(LogitsWarper): @@ -70,15 +74,67 @@ class TopALogitsWarper(LogitsWarper): return scores +class MirostatLogitsWarper(LogitsWarper): + def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + if mirostat_mode not in [2]: + raise ValueError(f"`mirostat` has to be a an integer 2, but is {mirostat_mode}") + self.mirostat_mode = mirostat_mode + self.mirostat_eta = mirostat_eta + self.mirostat_tau = mirostat_tau + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + self.mu = 2 * self.mirostat_tau + self.e = 0 + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + logits = scores[0] + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + prob_original = torch.softmax(sorted_logits, dim=-1).tolist() # candidates + + # Truncate the words with surprise values greater than mu + for i, candidate in enumerate(prob_original): + if candidate > 0 and -math.log2(candidate) > self.mu: + if (i == 0): + sorted_logits = sorted_logits[:1] + else: + sorted_logits = sorted_logits[:i] + break + + # Normalize the probabilities of the remaining words + prob_topk = torch.softmax(sorted_logits, dim=0) + + prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to('cuda') + + observed_surprise = -math.log2(prob_topk[prev_i]) + self.e = observed_surprise - self.mirostat_tau + + # Update mu using the learning rate and error + self.mu -= self.mirostat_eta * self.e + + sorted_indices_to_remove = torch.ones_like(scores[0], dtype=torch.bool) + sorted_indices_to_remove[prev_i] = False + + indices_to_remove = sorted_indices_to_remove.unsqueeze(0).scatter(1, sorted_indices.unsqueeze(0), sorted_indices_to_remove.unsqueeze(0)) + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + + def get_logits_warper_patch(self, generation_config): warpers = self._get_logits_warper_old(generation_config) warpers_to_add = LogitsProcessorList() min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1 - if generation_config.tfs is not None and 0.0 <= generation_config.tfs <= 1.0: - warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep)) - if generation_config.top_a is not None and 0.0 <= generation_config.top_a <= 1.0: - warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep)) + if generation_config.mirostat_mode is not None and generation_config.mirostat_mode == 2: + warpers_to_add.append(MirostatLogitsWarper(mirostat_mode=generation_config.mirostat_mode, mirostat_eta=generation_config.mirostat_eta, mirostat_tau=generation_config.mirostat_tau, min_tokens_to_keep=min_tokens_to_keep)) + # We need to disable samplers other than temperature + for warper in warpers: + if not isinstance(warper, TemperatureLogitsWarper): + warpers.remove(warper) + else: + if generation_config.tfs is not None and 0.0 <= generation_config.tfs <= 1.0: + warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep)) + if generation_config.top_a is not None and 0.0 <= generation_config.top_a <= 1.0: + warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep)) if warpers and isinstance(warpers[-1], LogitNormalization): warpers = warpers[:-1] + warpers_to_add + [warpers[-1]] @@ -92,6 +148,9 @@ def generation_config_init_patch(self, **kwargs): self.__init___old(**kwargs) self.tfs = kwargs.pop("tfs", 1.0) self.top_a = kwargs.pop("top_a", 0.0) + self.mirostat_mode = kwargs.pop("mirostat_mode", 0) + self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1) + self.mirostat_tau = kwargs.pop("mirostat_tau", 5) def hijack_samplers(): diff --git a/modules/text_generation.py b/modules/text_generation.py index 00b7cc7b..2dd7d38a 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -193,7 +193,7 @@ def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_c def generate_reply_HF(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False): generate_params = {} - for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a']: + for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']: generate_params[k] = state[k] for k in ['epsilon_cutoff', 'eta_cutoff']: diff --git a/server.py b/server.py index 25dbae8e..a507cf77 100644 --- a/server.py +++ b/server.py @@ -504,7 +504,7 @@ def create_settings_menus(default_preset): shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') with gr.Column(): - gr.Markdown('Mirostat (for llama.cpp)') + gr.Markdown('Mirostat (mode=1 is only for llama.cpp)') shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode') shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau') shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')