mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-28 15:18:33 +01:00
103 lines
4.4 KiB
Python
103 lines
4.4 KiB
Python
|
import torch
|
||
|
import transformers
|
||
|
from transformers import LogitsWarper
|
||
|
from transformers.generation.logits_process import LogitNormalization, LogitsProcessorList
|
||
|
|
||
|
|
||
|
class TailFreeLogitsWarper(LogitsWarper):
|
||
|
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||
|
tfs = float(tfs)
|
||
|
if tfs < 0 or tfs > 1.0:
|
||
|
raise ValueError(f"`tfs` has to be a float >= 0 and <= 1, but is {tfs}")
|
||
|
self.tfs = tfs
|
||
|
self.filter_value = filter_value
|
||
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||
|
|
||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||
|
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||
|
probs = sorted_logits.softmax(dim=-1)
|
||
|
|
||
|
# Compute second derivative normalized CDF
|
||
|
d2 = probs.diff().diff().abs()
|
||
|
normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)
|
||
|
normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
|
||
|
|
||
|
# Remove tokens with CDF value above the threshold (token with 0 are kept)
|
||
|
sorted_indices_to_remove = normalized_d2_cdf > self.tfs
|
||
|
|
||
|
# Centre the distribution around the cutoff as in the original implementation of the algorithm
|
||
|
sorted_indices_to_remove = torch.cat(
|
||
|
(
|
||
|
torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
|
||
|
sorted_indices_to_remove,
|
||
|
torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
|
||
|
),
|
||
|
dim=-1,
|
||
|
)
|
||
|
|
||
|
if self.min_tokens_to_keep > 1:
|
||
|
# Keep at least min_tokens_to_keep
|
||
|
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||
|
|
||
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||
|
return scores
|
||
|
|
||
|
|
||
|
class TopALogitsWarper(LogitsWarper):
|
||
|
def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||
|
top_a = float(top_a)
|
||
|
if top_a < 0 or top_a > 1.0:
|
||
|
raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}")
|
||
|
self.top_a = top_a
|
||
|
self.filter_value = filter_value
|
||
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||
|
|
||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||
|
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||
|
probs = sorted_logits.softmax(dim=-1)
|
||
|
|
||
|
# Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
|
||
|
probs_max = probs[..., 0, None]
|
||
|
sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a
|
||
|
|
||
|
if self.min_tokens_to_keep > 1:
|
||
|
# Keep at least min_tokens_to_keep
|
||
|
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||
|
|
||
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||
|
return scores
|
||
|
|
||
|
|
||
|
def get_logits_warper_patch(self, generation_config):
|
||
|
warpers = self._get_logits_warper_old(generation_config)
|
||
|
warpers_to_add = LogitsProcessorList()
|
||
|
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
|
||
|
|
||
|
if generation_config.tfs is not None and 0.0 <= generation_config.tfs <= 1.0:
|
||
|
warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep))
|
||
|
if generation_config.top_a is not None and 0.0 <= generation_config.top_a <= 1.0:
|
||
|
warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep))
|
||
|
|
||
|
if warpers and isinstance(warpers[-1], LogitNormalization):
|
||
|
warpers = warpers[:-1] + warpers_to_add + [warpers[-1]]
|
||
|
else:
|
||
|
warpers += warpers_to_add
|
||
|
|
||
|
return warpers
|
||
|
|
||
|
|
||
|
def generation_config_init_patch(self, **kwargs):
|
||
|
self.__init___old(**kwargs)
|
||
|
self.tfs = kwargs.pop("tfs", 1.0)
|
||
|
self.top_a = kwargs.pop("top_a", 0.0)
|
||
|
|
||
|
|
||
|
def hijack_samplers():
|
||
|
transformers.GenerationMixin._get_logits_warper_old = transformers.GenerationMixin._get_logits_warper
|
||
|
transformers.GenerationMixin._get_logits_warper = get_logits_warper_patch
|
||
|
|
||
|
transformers.GenerationConfig.__init___old = transformers.GenerationConfig.__init__
|
||
|
transformers.GenerationConfig.__init__ = generation_config_init_patch
|