2024-05-20 08:23:47 +05:30
|
|
|
import json
|
2023-06-09 19:26:31 -05:00
|
|
|
import math
|
2024-02-06 11:20:10 -03:00
|
|
|
import pprint
|
2024-09-28 07:20:12 +05:30
|
|
|
import random
|
2023-06-09 19:26:31 -05:00
|
|
|
|
2023-05-30 08:40:01 +08:00
|
|
|
import torch
|
|
|
|
import transformers
|
2023-10-27 08:09:51 +05:30
|
|
|
from transformers import LogitsWarper, is_torch_xpu_available
|
2023-06-25 01:44:36 -03:00
|
|
|
from transformers.generation.logits_process import (
|
|
|
|
LogitNormalization,
|
2023-06-29 13:40:13 -03:00
|
|
|
LogitsProcessor,
|
2024-02-06 11:20:10 -03:00
|
|
|
LogitsProcessorList
|
2023-06-25 01:44:36 -03:00
|
|
|
)
|
2023-05-30 08:40:01 +08:00
|
|
|
|
2024-01-07 07:36:26 -06:00
|
|
|
from modules import shared
|
2024-02-06 11:20:10 -03:00
|
|
|
from modules.logging_colors import logger
|
2024-01-07 07:36:26 -06:00
|
|
|
|
2023-08-22 20:18:16 -07:00
|
|
|
global_scores = None
|
|
|
|
|
2023-11-04 13:09:07 -03:00
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
class TemperatureLogitsWarperCustom(LogitsWarper):
|
2024-02-03 21:20:02 -06:00
|
|
|
'''
|
2024-02-06 11:20:10 -03:00
|
|
|
A copy of the original Transformers temperature logits warper.
|
2024-02-03 21:20:02 -06:00
|
|
|
'''
|
2024-02-06 11:20:10 -03:00
|
|
|
|
|
|
|
def __init__(self, temperature: float):
|
2024-01-07 07:36:26 -06:00
|
|
|
if not isinstance(temperature, float) or not (temperature > 0):
|
|
|
|
except_msg = (
|
|
|
|
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
|
|
|
|
"scores will be invalid."
|
|
|
|
)
|
|
|
|
if isinstance(temperature, float) and temperature == 0.0:
|
|
|
|
except_msg += " If you're looking for greedy decoding strategies, set `do_sample=False`."
|
|
|
|
|
|
|
|
raise ValueError(except_msg)
|
|
|
|
|
|
|
|
self.temperature = temperature
|
2024-02-06 11:20:10 -03:00
|
|
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
|
|
scores = scores / self.temperature
|
|
|
|
return scores
|
|
|
|
|
|
|
|
|
|
|
|
class DynamicTemperatureLogitsWarper(LogitsWarper):
|
|
|
|
'''
|
|
|
|
Dynamic temperature.
|
|
|
|
'''
|
|
|
|
|
|
|
|
def __init__(self, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float):
|
2024-01-08 23:28:35 -03:00
|
|
|
self.dynatemp_low = dynatemp_low
|
|
|
|
self.dynatemp_high = dynatemp_high
|
|
|
|
self.dynatemp_exponent = dynatemp_exponent
|
2024-01-07 07:36:26 -06:00
|
|
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
2024-02-06 11:20:10 -03:00
|
|
|
min_temp = self.dynatemp_low
|
|
|
|
max_temp = self.dynatemp_high
|
|
|
|
exponent_val = self.dynatemp_exponent
|
2024-01-07 07:36:26 -06:00
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
# Convert logits to probabilities
|
|
|
|
probs = torch.softmax(scores, dim=-1)
|
2024-02-03 21:20:02 -06:00
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
# Calculate entropy of the softmax probabilities
|
|
|
|
entropy = -1.0 * torch.where(probs > 0, probs * torch.log(probs), torch.zeros_like(probs)).sum()
|
2024-02-03 21:20:02 -06:00
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
# Guard against future possible division by zero
|
|
|
|
entropy = max(entropy, torch.tensor(1e-10)) # Ensures entropy is slightly greater than 0
|
2024-02-03 21:20:02 -06:00
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
# Any logits which are not -Infinity will be considered for calculating max entropy.
|
|
|
|
num_valid_tokens = torch.sum(scores > -float('inf')).item()
|
2024-02-03 21:20:02 -06:00
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
# Now, calculate the max entropy by using only the valid tokens' count
|
|
|
|
max_entropy = math.log(num_valid_tokens)
|
2024-01-07 07:36:26 -06:00
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
# Guard against future possible division by zero
|
|
|
|
max_entropy = max_entropy if max_entropy > 0.0 else 1e-10
|
2024-01-07 07:36:26 -06:00
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
# Normalize the entropy
|
|
|
|
normalized_entropy = entropy / max_entropy
|
2024-01-07 07:36:26 -06:00
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
# Map the normalized entropy to the desired temperature range using the power function
|
|
|
|
dyn_temp = min_temp + (max_temp - min_temp) * (normalized_entropy.pow(exponent_val))
|
2024-01-07 07:36:26 -06:00
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
# Apply the dynamically calculated temperature scaling
|
|
|
|
scores = scores / dyn_temp
|
2024-01-07 07:36:26 -06:00
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
# print("----------------------\nTemperature from generation_config:", self.temperature)
|
|
|
|
# print("min_temp:", min_temp)
|
|
|
|
# print("max_temp:", max_temp)
|
|
|
|
# print("Entropy:", entropy.item())
|
|
|
|
# print("Max Possible Entropy considering valid tokens only:", max_entropy)
|
|
|
|
# print("Normalized Entropy:", normalized_entropy.item())
|
|
|
|
# print("Dynamic Temperature (dyn_temp):", dyn_temp.item())
|
|
|
|
# print("----------------------")
|
2024-01-07 07:36:26 -06:00
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
# max_prob_token_id = torch.argmax(scores, dim=-1) # Get the token ID with the highest probability
|
|
|
|
# max_prob_token = shared.tokenizer.convert_ids_to_tokens(int(max_prob_token_id)) # Convert ID to token
|
|
|
|
# print("--- T=", float(dyn_temp), "token=", max_prob_token, "min=", min_temp, "max=", max_temp, "exponent=", exponent_val)
|
2024-01-07 07:36:26 -06:00
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
return scores
|
2024-01-07 07:36:26 -06:00
|
|
|
|
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
class QuadraticSamplingLogitsWarper(LogitsWarper):
|
|
|
|
'''
|
2024-03-03 10:22:21 -06:00
|
|
|
Quadratic sampling with smoothing factor and smoothing curve parameters.
|
2024-02-06 11:20:10 -03:00
|
|
|
'''
|
2024-01-07 07:36:26 -06:00
|
|
|
|
2024-03-03 10:22:21 -06:00
|
|
|
def __init__(self, smoothing_factor, smoothing_curve):
|
2024-02-06 11:20:10 -03:00
|
|
|
self.smoothing_factor = smoothing_factor
|
2024-03-03 10:22:21 -06:00
|
|
|
self.smoothing_curve = smoothing_curve
|
2024-01-07 07:36:26 -06:00
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
2024-01-07 07:36:26 -06:00
|
|
|
|
2024-03-03 10:22:21 -06:00
|
|
|
# Compute necessary values
|
|
|
|
max_logit = scores.max()
|
|
|
|
diff = scores - max_logit
|
|
|
|
k = (3 - self.smoothing_curve) / 2
|
|
|
|
s = (self.smoothing_curve - 1) / 2
|
|
|
|
|
|
|
|
# Apply transformation to non-negative infinity values
|
|
|
|
transformed_logits = torch.where(
|
|
|
|
scores != float('-inf'),
|
|
|
|
-(k * self.smoothing_factor * diff**2) + (s * self.smoothing_factor * diff**3) + max_logit,
|
|
|
|
scores
|
|
|
|
)
|
2024-01-07 07:36:26 -06:00
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
return transformed_logits
|
2024-02-03 21:20:02 -06:00
|
|
|
|
2024-01-07 07:36:26 -06:00
|
|
|
|
2023-05-30 08:40:01 +08:00
|
|
|
class TailFreeLogitsWarper(LogitsWarper):
|
|
|
|
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
|
|
|
tfs = float(tfs)
|
|
|
|
if tfs < 0 or tfs > 1.0:
|
|
|
|
raise ValueError(f"`tfs` has to be a float >= 0 and <= 1, but is {tfs}")
|
|
|
|
self.tfs = tfs
|
|
|
|
self.filter_value = filter_value
|
|
|
|
self.min_tokens_to_keep = min_tokens_to_keep
|
|
|
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
|
|
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
|
|
|
probs = sorted_logits.softmax(dim=-1)
|
|
|
|
|
|
|
|
# Compute second derivative normalized CDF
|
|
|
|
d2 = probs.diff().diff().abs()
|
|
|
|
normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)
|
|
|
|
normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
|
|
|
|
|
|
|
|
# Remove tokens with CDF value above the threshold (token with 0 are kept)
|
|
|
|
sorted_indices_to_remove = normalized_d2_cdf > self.tfs
|
|
|
|
|
|
|
|
# Centre the distribution around the cutoff as in the original implementation of the algorithm
|
|
|
|
sorted_indices_to_remove = torch.cat(
|
|
|
|
(
|
|
|
|
torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
|
|
|
|
sorted_indices_to_remove,
|
|
|
|
torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
|
|
|
|
),
|
|
|
|
dim=-1,
|
|
|
|
)
|
|
|
|
|
|
|
|
if self.min_tokens_to_keep > 1:
|
|
|
|
# Keep at least min_tokens_to_keep
|
|
|
|
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
|
|
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
|
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
|
|
|
return scores
|
|
|
|
|
|
|
|
|
|
|
|
class TopALogitsWarper(LogitsWarper):
|
|
|
|
def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
|
|
|
top_a = float(top_a)
|
|
|
|
if top_a < 0 or top_a > 1.0:
|
|
|
|
raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}")
|
|
|
|
self.top_a = top_a
|
|
|
|
self.filter_value = filter_value
|
|
|
|
self.min_tokens_to_keep = min_tokens_to_keep
|
|
|
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
|
|
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
|
|
|
probs = sorted_logits.softmax(dim=-1)
|
|
|
|
|
|
|
|
# Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
|
|
|
|
probs_max = probs[..., 0, None]
|
|
|
|
sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a
|
|
|
|
|
|
|
|
if self.min_tokens_to_keep > 1:
|
|
|
|
# Keep at least min_tokens_to_keep
|
|
|
|
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
|
|
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
|
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
|
|
|
return scores
|
|
|
|
|
|
|
|
|
2024-09-28 07:20:12 +05:30
|
|
|
# Exclude Top Choices (XTC)
|
|
|
|
class XTCLogitsWarper(LogitsWarper):
|
|
|
|
def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")):
|
|
|
|
self.threshold = threshold
|
|
|
|
self.probability = probability
|
|
|
|
self.filter_value = filter_value
|
|
|
|
self.special_token_ids = [
|
|
|
|
shared.tokenizer.encode("\n")[-1],
|
|
|
|
]
|
|
|
|
|
|
|
|
if shared.tokenizer.eos_token_id is not None:
|
|
|
|
self.special_token_ids.append(shared.tokenizer.eos_token_id)
|
|
|
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
|
|
# `random` returns values in the half-open range [0, 1), so setting `probability`
|
|
|
|
# to 0 means the sampler never takes action, while setting it to 1 means the sampler
|
|
|
|
# always takes action.
|
|
|
|
#
|
|
|
|
# Note that while XTC is most intuitively described as "if multiple tokens meet
|
|
|
|
# the threshold, then with probability...", reversing the two conditions is logically
|
|
|
|
# equivalent, and improves performance because processing can immediately be stopped
|
|
|
|
# if the random check fails.
|
|
|
|
if random.random() >= self.probability:
|
|
|
|
return scores
|
|
|
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
|
|
|
probs = sorted_logits.softmax(dim=-1)
|
|
|
|
|
|
|
|
sorted_indices_to_remove = torch.full_like(probs, False, dtype=torch.bool)
|
|
|
|
|
|
|
|
# This operation sets exactly those indices to `True` for which the next index has
|
|
|
|
# probability above the threshold. Since `probs` is sorted, those are the indices
|
|
|
|
# of all tokens that meet the threshold, *except* the least probable one.
|
|
|
|
sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold
|
|
|
|
|
|
|
|
# Convert sorted_indices_to_remove to the original indices
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
|
|
|
|
|
|
# If newline or EOS tokens would be removed, return the original scores
|
|
|
|
if indices_to_remove[:, self.special_token_ids].any():
|
|
|
|
return scores
|
|
|
|
|
|
|
|
# Otherwise, remove tokens with the mask
|
|
|
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
|
|
|
return scores
|
|
|
|
|
|
|
|
|
2024-05-20 08:23:47 +05:30
|
|
|
class DRYLogitsProcessor(LogitsProcessor):
|
|
|
|
def __init__(self, multiplier: float, base: float, allowed_length: int, sequence_breakers: set[int], _range: int):
|
|
|
|
self.multiplier = multiplier
|
|
|
|
self.base = base
|
|
|
|
self.allowed_length = allowed_length
|
|
|
|
self.sequence_breakers = sequence_breakers
|
|
|
|
self._range = _range
|
|
|
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
|
|
if self._range > 0:
|
|
|
|
input_ids = input_ids[:, -self._range:]
|
|
|
|
|
|
|
|
for input_ids_row, scores_row in zip(input_ids, scores):
|
2024-06-13 05:39:11 +03:00
|
|
|
# Use normal Python data types for improved performance
|
|
|
|
input_ids = input_ids_row.tolist()
|
2024-05-20 08:23:47 +05:30
|
|
|
|
2024-06-13 05:39:11 +03:00
|
|
|
last_token = input_ids[-1]
|
2024-05-20 08:23:47 +05:30
|
|
|
if last_token in self.sequence_breakers:
|
|
|
|
continue
|
|
|
|
|
|
|
|
# Exclude the last token as it always matches.
|
2024-06-13 05:39:11 +03:00
|
|
|
match_indices = []
|
|
|
|
for idx, val in enumerate(input_ids[:-1]):
|
|
|
|
if val == last_token:
|
|
|
|
match_indices.append(idx)
|
2024-05-20 08:23:47 +05:30
|
|
|
|
|
|
|
# Stores the maximum matching sequence length
|
|
|
|
# for each token immediately following the sequence in the input.
|
|
|
|
match_lengths = {}
|
|
|
|
|
|
|
|
for i in match_indices:
|
2024-06-13 05:39:11 +03:00
|
|
|
next_token = input_ids[i + 1]
|
2024-05-20 08:23:47 +05:30
|
|
|
|
|
|
|
if next_token in self.sequence_breakers:
|
|
|
|
continue
|
|
|
|
|
|
|
|
# We have already found that `last_token` matches at this index,
|
|
|
|
# so the match is at least of length 1.
|
|
|
|
match_length = 1
|
|
|
|
|
2024-06-13 05:39:11 +03:00
|
|
|
# Extend the match backwards (at most to 50 to prevent exponent overflow at penalty calculation) (this cap also improves performance on worst case)
|
|
|
|
while match_length < 50:
|
2024-05-20 08:23:47 +05:30
|
|
|
j = i - match_length
|
|
|
|
if j < 0:
|
|
|
|
# Start of input reached.
|
|
|
|
break
|
|
|
|
|
2024-06-13 05:39:11 +03:00
|
|
|
previous_token = input_ids[-(match_length + 1)]
|
|
|
|
if input_ids[j] != previous_token:
|
2024-05-20 08:23:47 +05:30
|
|
|
# Start of match reached.
|
|
|
|
break
|
|
|
|
|
|
|
|
if previous_token in self.sequence_breakers:
|
|
|
|
# Sequence-breaking token reached.
|
|
|
|
break
|
|
|
|
|
|
|
|
match_length += 1
|
|
|
|
|
|
|
|
if next_token in match_lengths:
|
|
|
|
match_lengths[next_token] = max(match_length, match_lengths[next_token])
|
|
|
|
else:
|
|
|
|
match_lengths[next_token] = match_length
|
|
|
|
|
|
|
|
# Apply penalties.
|
|
|
|
for token, match_length in match_lengths.items():
|
|
|
|
if match_length >= self.allowed_length:
|
|
|
|
penalty = self.multiplier * self.base ** (match_length - self.allowed_length)
|
|
|
|
scores_row[token] -= penalty
|
|
|
|
|
|
|
|
return scores
|
|
|
|
|
|
|
|
|
2023-06-09 19:26:31 -05:00
|
|
|
class MirostatLogitsWarper(LogitsWarper):
|
|
|
|
def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
|
|
|
if mirostat_mode not in [2]:
|
|
|
|
raise ValueError(f"`mirostat` has to be a an integer 2, but is {mirostat_mode}")
|
2024-02-06 11:20:10 -03:00
|
|
|
|
2023-06-09 19:26:31 -05:00
|
|
|
self.mirostat_mode = mirostat_mode
|
|
|
|
self.mirostat_eta = mirostat_eta
|
|
|
|
self.mirostat_tau = mirostat_tau
|
|
|
|
self.filter_value = filter_value
|
|
|
|
self.min_tokens_to_keep = min_tokens_to_keep
|
|
|
|
self.mu = 2 * self.mirostat_tau
|
|
|
|
self.e = 0
|
|
|
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
|
|
logits = scores[0]
|
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
|
prob_original = torch.softmax(sorted_logits, dim=-1).tolist() # candidates
|
|
|
|
|
|
|
|
# Truncate the words with surprise values greater than mu
|
|
|
|
for i, candidate in enumerate(prob_original):
|
|
|
|
if candidate > 0 and -math.log2(candidate) > self.mu:
|
|
|
|
if (i == 0):
|
|
|
|
sorted_logits = sorted_logits[:1]
|
|
|
|
else:
|
|
|
|
sorted_logits = sorted_logits[:i]
|
|
|
|
break
|
|
|
|
|
|
|
|
# Normalize the probabilities of the remaining words
|
2023-10-27 08:09:51 +05:30
|
|
|
if is_torch_xpu_available():
|
|
|
|
prob_topk = torch.softmax(sorted_logits, dim=0).to("xpu")
|
|
|
|
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to("xpu")
|
|
|
|
else:
|
|
|
|
prob_topk = torch.softmax(sorted_logits, dim=0).to('cuda')
|
|
|
|
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to('cuda')
|
2023-06-09 19:26:31 -05:00
|
|
|
|
|
|
|
observed_surprise = -math.log2(prob_topk[prev_i])
|
|
|
|
self.e = observed_surprise - self.mirostat_tau
|
|
|
|
|
|
|
|
# Update mu using the learning rate and error
|
|
|
|
self.mu -= self.mirostat_eta * self.e
|
|
|
|
|
|
|
|
sorted_indices_to_remove = torch.ones_like(scores[0], dtype=torch.bool)
|
|
|
|
sorted_indices_to_remove[prev_i] = False
|
|
|
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.unsqueeze(0).scatter(1, sorted_indices.unsqueeze(0), sorted_indices_to_remove.unsqueeze(0))
|
|
|
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
|
|
|
return scores
|
|
|
|
|
|
|
|
|
2023-08-22 20:18:16 -07:00
|
|
|
class SpyLogitsWarper(LogitsWarper):
|
|
|
|
def __init__(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
|
|
global global_scores
|
|
|
|
global_scores = scores
|
|
|
|
return scores
|
|
|
|
|
|
|
|
|
2023-06-29 13:40:13 -03:00
|
|
|
class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
2024-09-27 19:03:25 -03:00
|
|
|
def __init__(self, penalty: float, _range: int):
|
2023-10-23 00:28:07 -05:00
|
|
|
if not (penalty > 0):
|
|
|
|
raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}")
|
2023-06-29 13:40:13 -03:00
|
|
|
self.penalty = penalty
|
|
|
|
self._range = _range
|
|
|
|
|
2024-09-27 19:03:25 -03:00
|
|
|
def apply_repetition_penalty(self, input_ids_row, scores_row):
|
|
|
|
unique_ids = torch.unique(input_ids_row)
|
|
|
|
score = torch.gather(scores_row, 0, unique_ids)
|
|
|
|
|
|
|
|
# Apply multiplicative repetition penalty
|
|
|
|
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
|
|
|
scores_row.scatter_(0, unique_ids, score)
|
|
|
|
return scores_row
|
|
|
|
|
2023-06-29 13:40:13 -03:00
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
|
|
input_ids = input_ids[:, -self._range:]
|
2024-09-27 19:03:25 -03:00
|
|
|
for input_ids_row, scores_row in zip(input_ids, scores):
|
|
|
|
scores_row = self.apply_repetition_penalty(input_ids_row, scores_row)
|
|
|
|
|
|
|
|
return scores
|
|
|
|
|
2023-06-29 13:40:13 -03:00
|
|
|
|
2024-09-27 19:03:25 -03:00
|
|
|
class PresencePenaltyLogitsProcessor(LogitsProcessor):
|
|
|
|
def __init__(self, presence_penalty: float, _range: int):
|
|
|
|
self.presence_penalty = presence_penalty
|
|
|
|
self._range = _range
|
|
|
|
|
|
|
|
def apply_presence_penalty(self, input_ids_row, scores_row):
|
|
|
|
unique_ids, counts = torch.unique(input_ids_row, return_counts=True)
|
|
|
|
|
|
|
|
# Apply presence penalty
|
|
|
|
raw_presence_penalty = (counts > 0).to(scores_row.dtype)
|
|
|
|
presence_penalty = raw_presence_penalty * self.presence_penalty
|
|
|
|
scores_row.scatter_add_(0, unique_ids, -presence_penalty)
|
|
|
|
return scores_row
|
|
|
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
|
|
input_ids = input_ids[:, -self._range:]
|
2023-10-25 10:10:28 -05:00
|
|
|
for input_ids_row, scores_row in zip(input_ids, scores):
|
2024-09-27 19:03:25 -03:00
|
|
|
scores_row = self.apply_presence_penalty(input_ids_row, scores_row)
|
|
|
|
return scores
|
|
|
|
|
|
|
|
|
|
|
|
class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
|
|
|
def __init__(self, frequency_penalty: float, _range: int):
|
|
|
|
self.frequency_penalty = frequency_penalty
|
|
|
|
self._range = _range
|
2023-10-25 10:10:28 -05:00
|
|
|
|
2024-09-27 19:03:25 -03:00
|
|
|
def apply_frequency_penalty(self, input_ids_row, scores_row):
|
|
|
|
unique_ids, counts = torch.unique(input_ids_row, return_counts=True)
|
2023-10-25 10:10:28 -05:00
|
|
|
|
2024-09-27 19:03:25 -03:00
|
|
|
# Apply frequency penalty
|
|
|
|
raw_frequency_penalty = counts.to(scores_row.dtype)
|
|
|
|
frequency_penalty = raw_frequency_penalty * self.frequency_penalty
|
|
|
|
scores_row.scatter_add_(0, unique_ids, -frequency_penalty)
|
|
|
|
return scores_row
|
2023-06-29 13:40:13 -03:00
|
|
|
|
2024-09-27 19:03:25 -03:00
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
|
|
input_ids = input_ids[:, -self._range:]
|
|
|
|
for input_ids_row, scores_row in zip(input_ids, scores):
|
|
|
|
scores_row = self.apply_frequency_penalty(input_ids_row, scores_row)
|
2023-06-29 13:40:13 -03:00
|
|
|
return scores
|
|
|
|
|
|
|
|
|
2024-09-27 19:03:25 -03:00
|
|
|
def get_logits_processor_patch(self, **kwargs):
|
|
|
|
generation_config = kwargs['generation_config']
|
2024-01-07 07:36:26 -06:00
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
# Parameter sanitization
|
|
|
|
if isinstance(generation_config.temperature, int):
|
|
|
|
generation_config.temperature = float(generation_config.temperature) # Must be float
|
2024-01-07 07:36:26 -06:00
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
# Get the original warpers
|
2024-09-27 19:03:25 -03:00
|
|
|
warpers = self._get_logits_processor_old(**kwargs)
|
2024-02-06 11:20:10 -03:00
|
|
|
|
2024-09-27 19:03:25 -03:00
|
|
|
for i in range(len(warpers) - 1, -1, -1):
|
|
|
|
# Replace temperature with our modified class.
|
2024-01-07 07:36:26 -06:00
|
|
|
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
|
2024-02-06 11:20:10 -03:00
|
|
|
warpers[i] = TemperatureLogitsWarperCustom(
|
|
|
|
generation_config.temperature,
|
2024-01-08 23:28:35 -03:00
|
|
|
)
|
2024-01-07 07:36:26 -06:00
|
|
|
|
2024-09-27 19:03:25 -03:00
|
|
|
# Stuff we don't need
|
|
|
|
elif warpers[i].__class__.__name__ in ['SuppressTokensLogitsProcessor', 'RepetitionPenaltyLogitsProcessor']:
|
|
|
|
del warpers[i]
|
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
# Add custom warpers
|
2023-05-30 08:40:01 +08:00
|
|
|
warpers_to_add = LogitsProcessorList()
|
|
|
|
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
|
2024-09-27 19:03:25 -03:00
|
|
|
|
|
|
|
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
|
|
|
|
warpers_to_add.append(
|
|
|
|
RepetitionPenaltyLogitsProcessorWithRange(
|
|
|
|
penalty=generation_config.repetition_penalty,
|
|
|
|
_range=generation_config.repetition_penalty_range
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
if generation_config.presence_penalty is not None and generation_config.presence_penalty != 0.0:
|
|
|
|
warpers_to_add.append(
|
|
|
|
PresencePenaltyLogitsProcessor(
|
|
|
|
presence_penalty=generation_config.presence_penalty,
|
|
|
|
_range=generation_config.repetition_penalty_range
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
if generation_config.frequency_penalty is not None and generation_config.frequency_penalty != 0.0:
|
|
|
|
warpers_to_add.append(
|
|
|
|
FrequencyPenaltyLogitsProcessor(
|
|
|
|
frequency_penalty=generation_config.frequency_penalty,
|
|
|
|
_range=generation_config.repetition_penalty_range
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0:
|
|
|
|
dry_sequence_breakers = generation_config.dry_sequence_breakers
|
|
|
|
|
|
|
|
# Support both JSON array notation and comma-separated strings.
|
|
|
|
if not dry_sequence_breakers.startswith("["):
|
|
|
|
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
|
|
|
|
|
|
|
|
sequence_breaker_strings = json.loads(dry_sequence_breakers)
|
|
|
|
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
|
|
|
|
sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings}
|
|
|
|
|
|
|
|
warpers.append(
|
|
|
|
DRYLogitsProcessor(
|
|
|
|
multiplier=generation_config.dry_multiplier,
|
|
|
|
base=generation_config.dry_base,
|
|
|
|
allowed_length=generation_config.dry_allowed_length,
|
|
|
|
sequence_breakers=sequence_breakers,
|
|
|
|
_range=generation_config.repetition_penalty_range,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0:
|
|
|
|
warpers_to_add.append(
|
|
|
|
TailFreeLogitsWarper(
|
|
|
|
tfs=generation_config.tfs,
|
|
|
|
min_tokens_to_keep=min_tokens_to_keep
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
if generation_config.top_a is not None and 0.0 < generation_config.top_a <= 1.0:
|
|
|
|
warpers_to_add.append(
|
|
|
|
TopALogitsWarper(
|
|
|
|
top_a=generation_config.top_a,
|
|
|
|
min_tokens_to_keep=min_tokens_to_keep
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2024-09-28 07:20:12 +05:30
|
|
|
if generation_config.xtc_probability is not None and generation_config.xtc_probability > 0:
|
|
|
|
warpers_to_add.append(
|
|
|
|
XTCLogitsWarper(
|
|
|
|
threshold=generation_config.xtc_threshold,
|
|
|
|
probability=generation_config.xtc_probability,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
if generation_config.dynamic_temperature:
|
|
|
|
warpers_to_add.append(
|
|
|
|
DynamicTemperatureLogitsWarper(
|
|
|
|
dynatemp_low=generation_config.dynatemp_low,
|
|
|
|
dynatemp_high=generation_config.dynatemp_high,
|
|
|
|
dynatemp_exponent=generation_config.dynatemp_exponent,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
if generation_config.smoothing_factor > 0:
|
|
|
|
warpers_to_add.append(
|
|
|
|
QuadraticSamplingLogitsWarper(
|
2024-03-03 10:22:21 -06:00
|
|
|
smoothing_factor=generation_config.smoothing_factor,
|
|
|
|
smoothing_curve=generation_config.smoothing_curve
|
2024-02-06 11:20:10 -03:00
|
|
|
)
|
|
|
|
)
|
2023-05-30 08:40:01 +08:00
|
|
|
|
2023-06-09 19:26:31 -05:00
|
|
|
if generation_config.mirostat_mode is not None and generation_config.mirostat_mode == 2:
|
2024-02-06 11:20:10 -03:00
|
|
|
warpers_to_add.append(
|
|
|
|
MirostatLogitsWarper(
|
|
|
|
mirostat_mode=generation_config.mirostat_mode,
|
|
|
|
mirostat_eta=generation_config.mirostat_eta,
|
|
|
|
mirostat_tau=generation_config.mirostat_tau,
|
|
|
|
min_tokens_to_keep=min_tokens_to_keep
|
|
|
|
)
|
|
|
|
)
|
2023-05-30 08:40:01 +08:00
|
|
|
|
2023-11-04 13:09:07 -03:00
|
|
|
if len(warpers) > 0 and isinstance(warpers[-1], LogitNormalization):
|
|
|
|
normalize = warpers.pop(-1)
|
2023-05-30 08:40:01 +08:00
|
|
|
else:
|
2023-11-04 13:09:07 -03:00
|
|
|
normalize = None
|
|
|
|
|
|
|
|
warpers += warpers_to_add
|
|
|
|
|
2024-02-06 11:20:10 -03:00
|
|
|
# Sort the samplers.
|
|
|
|
sampler_priority = generation_config.sampler_priority
|
|
|
|
|
|
|
|
# Handle temperature_last
|
|
|
|
if generation_config.temperature_last:
|
|
|
|
for param_name in ['temperature', 'dynamic_temperature', 'quadratic_sampling']:
|
|
|
|
if param_name in sampler_priority:
|
|
|
|
if param_name in sampler_priority:
|
|
|
|
index = sampler_priority.index(param_name)
|
|
|
|
sampler_priority.append(sampler_priority.pop(index))
|
|
|
|
else:
|
|
|
|
sampler_priority.append(param_name)
|
|
|
|
|
|
|
|
class_name_to_nickname = {
|
|
|
|
'DynamicTemperatureLogitsWarper': 'dynamic_temperature',
|
|
|
|
'EpsilonLogitsWarper': 'epsilon_cutoff',
|
|
|
|
'EtaLogitsWarper': 'eta_cutoff',
|
|
|
|
'MinPLogitsWarper': 'min_p',
|
|
|
|
'MirostatLogitsWarper': 'mirostat',
|
|
|
|
'QuadraticSamplingLogitsWarper': 'quadratic_sampling',
|
|
|
|
'TailFreeLogitsWarper': 'tfs',
|
|
|
|
'TemperatureLogitsWarperCustom': 'temperature',
|
|
|
|
'TopALogitsWarper': 'top_a',
|
|
|
|
'TopKLogitsWarper': 'top_k',
|
|
|
|
'TopPLogitsWarper': 'top_p',
|
2024-09-27 19:03:25 -03:00
|
|
|
'TypicalLogitsWarper': 'typical_p',
|
2024-09-28 07:20:12 +05:30
|
|
|
'XTCLogitsWarper': 'xtc',
|
2024-09-27 19:03:25 -03:00
|
|
|
'RepetitionPenaltyLogitsProcessorWithRange': 'repetition_penalty',
|
|
|
|
'PresencePenaltyLogitsProcessor': 'presence_penalty',
|
|
|
|
'FrequencyPenaltyLogitsProcessor': 'frequency_penalty',
|
2024-09-27 15:40:48 -07:00
|
|
|
'DRYLogitsProcessor': 'dry',
|
2024-09-27 19:03:25 -03:00
|
|
|
'EncoderRepetitionPenaltyLogitsProcessor': 'encoder_repetition_penalty',
|
|
|
|
'NoRepeatNGramLogitsProcessor': 'no_repeat_ngram',
|
2024-02-06 11:20:10 -03:00
|
|
|
}
|
|
|
|
|
|
|
|
def custom_sort_key(obj):
|
|
|
|
class_name = obj.__class__.__name__
|
|
|
|
|
2024-09-29 20:45:27 -07:00
|
|
|
# Return -1 if class_name is not mapped
|
2024-02-06 11:20:10 -03:00
|
|
|
if class_name not in class_name_to_nickname or class_name_to_nickname[class_name] not in sampler_priority:
|
2024-09-29 20:45:27 -07:00
|
|
|
return -1
|
2024-02-06 11:20:10 -03:00
|
|
|
|
|
|
|
return sampler_priority.index(class_name_to_nickname[class_name])
|
|
|
|
|
|
|
|
# Sort the list using the custom key function
|
|
|
|
warpers = sorted(warpers, key=custom_sort_key)
|
2024-02-06 06:49:22 -08:00
|
|
|
if shared.args.verbose:
|
|
|
|
logger.info("WARPERS=")
|
|
|
|
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint([x.__class__.__name__ for x in warpers])
|
2024-02-06 07:09:21 -08:00
|
|
|
print()
|
2023-11-04 13:09:07 -03:00
|
|
|
|
|
|
|
if normalize is not None:
|
|
|
|
warpers.append(normalize)
|
2023-05-30 08:40:01 +08:00
|
|
|
|
2023-08-22 20:18:16 -07:00
|
|
|
warpers.append(SpyLogitsWarper())
|
2024-01-07 07:36:26 -06:00
|
|
|
warpers = LogitsProcessorList(warpers)
|
2023-05-30 08:40:01 +08:00
|
|
|
return warpers
|
|
|
|
|
|
|
|
|
|
|
|
def generation_config_init_patch(self, **kwargs):
|
|
|
|
self.__init___old(**kwargs)
|
2023-11-03 08:25:22 -07:00
|
|
|
self.min_p = kwargs.pop("min_p", 0.0)
|
2024-01-07 17:03:47 -03:00
|
|
|
self.dynamic_temperature = kwargs.pop("dynamic_temperature", False)
|
2024-01-08 23:28:35 -03:00
|
|
|
self.dynatemp_low = kwargs.pop("dynatemp_low", 1)
|
|
|
|
self.dynatemp_high = kwargs.pop("dynatemp_high", 1)
|
|
|
|
self.dynatemp_exponent = kwargs.pop("dynatemp_exponent", 1)
|
2024-02-03 21:20:02 -06:00
|
|
|
self.smoothing_factor = kwargs.pop("smoothing_factor", 0.0)
|
2024-03-03 10:22:21 -06:00
|
|
|
self.smoothing_curve = kwargs.pop("smoothing_curve", 1.0)
|
2023-05-30 08:40:01 +08:00
|
|
|
self.tfs = kwargs.pop("tfs", 1.0)
|
|
|
|
self.top_a = kwargs.pop("top_a", 0.0)
|
2023-06-09 19:26:31 -05:00
|
|
|
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
|
|
|
|
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
|
|
|
|
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
|
2023-06-29 13:40:13 -03:00
|
|
|
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
|
2023-10-25 10:10:28 -05:00
|
|
|
self.presence_penalty = kwargs.pop("presence_penalty", 0)
|
|
|
|
self.frequency_penalty = kwargs.pop("frequency_penalty", 0)
|
2024-05-20 08:23:47 +05:30
|
|
|
self.dry_multiplier = kwargs.pop("dry_multiplier", 0.0)
|
|
|
|
self.dry_base = kwargs.pop("dry_base", 1.75)
|
|
|
|
self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2)
|
|
|
|
self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '"\\n", ":", "\\"", "*"')
|
2024-09-28 07:20:12 +05:30
|
|
|
self.xtc_threshold = kwargs.pop("xtc_threshold", 0.1)
|
|
|
|
self.xtc_probability = kwargs.pop("xtc_probability", 0)
|
2023-11-04 13:09:07 -03:00
|
|
|
self.temperature_last = kwargs.pop("temperature_last", False)
|
2024-09-28 07:20:12 +05:30
|
|
|
self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram'])
|
2023-05-30 08:40:01 +08:00
|
|
|
|
|
|
|
|
|
|
|
def hijack_samplers():
|
2023-06-29 13:40:13 -03:00
|
|
|
transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor
|
|
|
|
transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch
|
|
|
|
|
2023-05-30 08:40:01 +08:00
|
|
|
transformers.GenerationConfig.__init___old = transformers.GenerationConfig.__init__
|
|
|
|
transformers.GenerationConfig.__init__ = generation_config_init_patch
|