mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-13 13:52:49 +01:00
Add custom sampler order support (#5443)
This commit is contained in:
parent
7301c7618f
commit
8c35fefb3b
@ -55,8 +55,8 @@ For more information about the parameters, the [transformers documentation](http
|
||||
* **mirostat_tau**: No idea, see the paper for details. According to the Preset Arena, 8 is a good value.
|
||||
* **mirostat_eta**: No idea, see the paper for details. According to the Preset Arena, 0.1 is a good value.
|
||||
* **dynamic_temperature**: Activates Dynamic Temperature. This modifies temperature to range between "dynatemp_low" (minimum) and "dynatemp_high" (maximum), with an entropy-based scaling. The steepness of the curve is controlled by "dynatemp_exponent".
|
||||
* **smoothing_factor**: Activates Quadratic Sampling. This takes precedence over regular temperature and dynamic temperature, and replaces those samplers. When `0 < smoothing_factor < 1`, the logits distribution becomes flatter. When `smoothing_factor > 1`, it becomes more peaked.
|
||||
* **temperature_last**: Makes temperature the last sampler instead of the first. With this, you can remove low probability tokens with a sampler like min_p and then use a high temperature to make the model creative without losing coherency.
|
||||
* **smoothing_factor**: Activates Quadratic Sampling. When `0 < smoothing_factor < 1`, the logits distribution becomes flatter. When `smoothing_factor > 1`, it becomes more peaked.
|
||||
* **temperature_last**: Makes temperature the last sampler instead of the first. With this, you can remove low probability tokens with a sampler like min_p and then use a high temperature to make the model creative without losing coherency. Note: this parameter takes precedence over "Sampler priority". That means that `temperature`/`dynamic_temperature`/`quadratic_sampling` will be removed from wherever they are and moved to the end of the stack.
|
||||
* **do_sample**: When unchecked, sampling is entirely disabled, and greedy decoding is used instead (the most likely token is always picked).
|
||||
* **Seed**: Set the Pytorch seed to this number. Note that some loaders do not use Pytorch (notably llama.cpp), and others are not deterministic (notably ExLlama v1 and v2). For these loaders, the seed has no effect.
|
||||
* **encoder_repetition_penalty**: Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.
|
||||
@ -77,6 +77,7 @@ To the right (or below if you are on mobile), the following parameters are prese
|
||||
* **Add the bos_token to the beginning of prompts**: By default, the tokenizer will add a BOS (Beginning of Sequence) token to your prompt. During training, BOS tokens are used to separate different documents. If unchecked, no BOS token will be added, and the model will interpret your prompt as being in the middle of a document instead of at the start of one. This significantly changes the output and can make it more creative.
|
||||
* **Skip special tokens**: When decoding the generated tokens, skip special tokens from being converted to their text representation. Otherwise, BOS appears as `<s>`, EOS as `</s>`, etc.
|
||||
* **Activate text streaming**: When unchecked, the full response is outputted at once, without streaming the words one at a time. I recommend unchecking this parameter on high latency networks like running the webui on Google Colab or using `--share`.
|
||||
* **Sampler priority**: Allows you to customize the order in which the different samplers are applied. The first sampler on the list gets applied first. With this, custom orders like `top_p -> temperature -> top_k` can be defined.
|
||||
* **Load grammar from file**: Loads a GBNF grammar from a file under `text-generation-webui/grammars`. The output is written to the "Grammar" box below. You can also save and delete custom grammars using this menu.
|
||||
* **Grammar**: Allows you to constrain the model output to a particular format. For instance, you can make the model generate lists, JSON, specific words, etc. Grammar is extremely powerful and I highly recommend it. The syntax looks a bit daunting at first sight, but it gets very easy once you understand it. See the [GBNF Guide](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md) for details.
|
||||
|
||||
|
@ -40,6 +40,7 @@ class GenerationOptions(BaseModel):
|
||||
max_tokens_second: int = 0
|
||||
prompt_lookup_num_tokens: int = 0
|
||||
custom_token_bans: str = ""
|
||||
sampler_priority: List[str] | str | None = Field(default=None, description="List of samplers where the first items will appear first in the stack. Example: [\"top_k\", \"temperature\", \"top_p\"].")
|
||||
auto_max_new_tokens: bool = False
|
||||
ban_eos_token: bool = False
|
||||
add_bos_token: bool = True
|
||||
|
@ -182,6 +182,7 @@ def transformers_samplers():
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
'custom_token_bans',
|
||||
'sampler_priority',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
@ -230,6 +231,7 @@ loaders_samplers = {
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
'custom_token_bans',
|
||||
'sampler_priority',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
@ -287,6 +289,7 @@ loaders_samplers = {
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
'custom_token_bans',
|
||||
'sampler_priority',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
|
@ -42,6 +42,7 @@ def default_preset():
|
||||
'num_beams': 1,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False,
|
||||
'sampler_priority': 'temperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat'
|
||||
}
|
||||
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import math
|
||||
import pprint
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@ -6,21 +7,21 @@ from transformers import LogitsWarper, is_torch_xpu_available
|
||||
from transformers.generation.logits_process import (
|
||||
LogitNormalization,
|
||||
LogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
TemperatureLogitsWarper
|
||||
LogitsProcessorList
|
||||
)
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
global_scores = None
|
||||
|
||||
|
||||
class ModifiedTemperatureLogitsWarper(LogitsWarper):
|
||||
class TemperatureLogitsWarperCustom(LogitsWarper):
|
||||
'''
|
||||
Based on the original Transformers temperature logits warper, this
|
||||
adds support for dynamic temperature and quadratic sampling.
|
||||
A copy of the original Transformers temperature logits warper.
|
||||
'''
|
||||
def __init__(self, temperature: float, dynamic_temperature: bool, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float, smoothing_factor: float):
|
||||
|
||||
def __init__(self, temperature: float):
|
||||
if not isinstance(temperature, float) or not (temperature > 0):
|
||||
except_msg = (
|
||||
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
|
||||
@ -32,81 +33,90 @@ class ModifiedTemperatureLogitsWarper(LogitsWarper):
|
||||
raise ValueError(except_msg)
|
||||
|
||||
self.temperature = temperature
|
||||
self.dynamic_temperature = dynamic_temperature
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
scores = scores / self.temperature
|
||||
return scores
|
||||
|
||||
|
||||
class DynamicTemperatureLogitsWarper(LogitsWarper):
|
||||
'''
|
||||
Dynamic temperature.
|
||||
'''
|
||||
|
||||
def __init__(self, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float):
|
||||
self.dynatemp_low = dynatemp_low
|
||||
self.dynatemp_high = dynatemp_high
|
||||
self.dynatemp_exponent = dynatemp_exponent
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
min_temp = self.dynatemp_low
|
||||
max_temp = self.dynatemp_high
|
||||
exponent_val = self.dynatemp_exponent
|
||||
|
||||
# Convert logits to probabilities
|
||||
probs = torch.softmax(scores, dim=-1)
|
||||
|
||||
# Calculate entropy of the softmax probabilities
|
||||
entropy = -1.0 * torch.where(probs > 0, probs * torch.log(probs), torch.zeros_like(probs)).sum()
|
||||
|
||||
# Guard against future possible division by zero
|
||||
entropy = max(entropy, torch.tensor(1e-10)) # Ensures entropy is slightly greater than 0
|
||||
|
||||
# Any logits which are not -Infinity will be considered for calculating max entropy.
|
||||
num_valid_tokens = torch.sum(scores > -float('inf')).item()
|
||||
|
||||
# Now, calculate the max entropy by using only the valid tokens' count
|
||||
max_entropy = math.log(num_valid_tokens)
|
||||
|
||||
# Guard against future possible division by zero
|
||||
max_entropy = max_entropy if max_entropy > 0.0 else 1e-10
|
||||
|
||||
# Normalize the entropy
|
||||
normalized_entropy = entropy / max_entropy
|
||||
|
||||
# Map the normalized entropy to the desired temperature range using the power function
|
||||
dyn_temp = min_temp + (max_temp - min_temp) * (normalized_entropy.pow(exponent_val))
|
||||
|
||||
# Apply the dynamically calculated temperature scaling
|
||||
scores = scores / dyn_temp
|
||||
|
||||
# print("----------------------\nTemperature from generation_config:", self.temperature)
|
||||
# print("min_temp:", min_temp)
|
||||
# print("max_temp:", max_temp)
|
||||
# print("Entropy:", entropy.item())
|
||||
# print("Max Possible Entropy considering valid tokens only:", max_entropy)
|
||||
# print("Normalized Entropy:", normalized_entropy.item())
|
||||
# print("Dynamic Temperature (dyn_temp):", dyn_temp.item())
|
||||
# print("----------------------")
|
||||
|
||||
# max_prob_token_id = torch.argmax(scores, dim=-1) # Get the token ID with the highest probability
|
||||
# max_prob_token = shared.tokenizer.convert_ids_to_tokens(int(max_prob_token_id)) # Convert ID to token
|
||||
# print("--- T=", float(dyn_temp), "token=", max_prob_token, "min=", min_temp, "max=", max_temp, "exponent=", exponent_val)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class QuadraticSamplingLogitsWarper(LogitsWarper):
|
||||
'''
|
||||
Quadratic sampling.
|
||||
'''
|
||||
|
||||
def __init__(self, smoothing_factor: float):
|
||||
self.smoothing_factor = smoothing_factor
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Compute the maximum logit value
|
||||
max_logit = scores.max()
|
||||
|
||||
# Quadratic sampling
|
||||
if self.smoothing_factor > 0:
|
||||
# Apply the quadratic transformation
|
||||
transformed_logits = -(self.smoothing_factor * (scores - max_logit)**2) + max_logit
|
||||
|
||||
# Compute the maximum logit value
|
||||
max_logit = scores.max()
|
||||
# No need to print the top 5 logits since this is not required
|
||||
# print("Original top 5 logits: ", torch.topk(scores, 5))
|
||||
# print("New top 5 logits: ", torch.topk(transformed_logits, 5))
|
||||
|
||||
# Apply the quadratic transformation
|
||||
transformed_logits = -(self.smoothing_factor * (scores - max_logit)**2) + max_logit
|
||||
|
||||
# No need to print the top 5 logits since this is not required
|
||||
# print("Original top 5 logits: ", torch.topk(scores, 5))
|
||||
# print("New top 5 logits: ", torch.topk(transformed_logits, 5))
|
||||
|
||||
return transformed_logits
|
||||
|
||||
# Dynamic temperature
|
||||
elif self.dynamic_temperature:
|
||||
min_temp = self.dynatemp_low
|
||||
max_temp = self.dynatemp_high
|
||||
exponent_val = self.dynatemp_exponent
|
||||
|
||||
# Convert logits to probabilities
|
||||
probs = torch.softmax(scores, dim=-1)
|
||||
|
||||
# Calculate entropy of the softmax probabilities
|
||||
entropy = -1.0 * torch.where(probs > 0, probs * torch.log(probs), torch.zeros_like(probs)).sum()
|
||||
|
||||
# Guard against future possible division by zero
|
||||
entropy = max(entropy, torch.tensor(1e-10)) # Ensures entropy is slightly greater than 0
|
||||
|
||||
# Any logits which are not -Infinity will be considered for calculating max entropy.
|
||||
num_valid_tokens = torch.sum(scores > -float('inf')).item()
|
||||
|
||||
# Now, calculate the max entropy by using only the valid tokens' count
|
||||
max_entropy = math.log(num_valid_tokens)
|
||||
|
||||
# Guard against future possible division by zero
|
||||
max_entropy = max_entropy if max_entropy > 0.0 else 1e-10
|
||||
|
||||
# Normalize the entropy
|
||||
normalized_entropy = entropy / max_entropy
|
||||
|
||||
# Map the normalized entropy to the desired temperature range using the power function
|
||||
dyn_temp = min_temp + (max_temp - min_temp) * (normalized_entropy.pow(exponent_val))
|
||||
|
||||
# Apply the dynamically calculated temperature scaling
|
||||
scores = scores / dyn_temp
|
||||
|
||||
# print("----------------------\nTemperature from generation_config:", self.temperature)
|
||||
# print("min_temp:", min_temp)
|
||||
# print("max_temp:", max_temp)
|
||||
# print("Entropy:", entropy.item())
|
||||
# print("Max Possible Entropy considering valid tokens only:", max_entropy)
|
||||
# print("Normalized Entropy:", normalized_entropy.item())
|
||||
# print("Dynamic Temperature (dyn_temp):", dyn_temp.item())
|
||||
# print("----------------------")
|
||||
|
||||
# max_prob_token_id = torch.argmax(scores, dim=-1) # Get the token ID with the highest probability
|
||||
# max_prob_token = shared.tokenizer.convert_ids_to_tokens(int(max_prob_token_id)) # Convert ID to token
|
||||
# print("--- T=", float(dyn_temp), "token=", max_prob_token, "min=", min_temp, "max=", max_temp, "exponent=", exponent_val)
|
||||
|
||||
return scores
|
||||
|
||||
# Regular temperature
|
||||
else:
|
||||
scores = scores / self.temperature
|
||||
return scores
|
||||
return transformed_logits
|
||||
|
||||
|
||||
class MinPLogitsWarper(LogitsWarper):
|
||||
@ -209,6 +219,7 @@ class MirostatLogitsWarper(LogitsWarper):
|
||||
def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
if mirostat_mode not in [2]:
|
||||
raise ValueError(f"`mirostat` has to be a an integer 2, but is {mirostat_mode}")
|
||||
|
||||
self.mirostat_mode = mirostat_mode
|
||||
self.mirostat_eta = mirostat_eta
|
||||
self.mirostat_tau = mirostat_tau
|
||||
@ -301,44 +312,74 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
||||
|
||||
|
||||
def get_logits_warper_patch(self, generation_config):
|
||||
# Make sure that temperature is float and not int
|
||||
|
||||
# Parameter sanitization
|
||||
if isinstance(generation_config.temperature, int):
|
||||
generation_config.temperature = float(generation_config.temperature)
|
||||
|
||||
temperature = generation_config.temperature
|
||||
if generation_config.dynamic_temperature or generation_config.smoothing_factor > 0:
|
||||
# Make sure TemperatureLogitsWarper will be created by temporarily
|
||||
# setting temperature to a value != 1.
|
||||
generation_config.temperature = 1.1
|
||||
generation_config.temperature = float(generation_config.temperature) # Must be float
|
||||
|
||||
# Get the original warpers
|
||||
warpers = self._get_logits_warper_old(generation_config)
|
||||
|
||||
# Replace temperature with our modified class.
|
||||
# Currently, it behaves identically to the original.
|
||||
for i in range(len(warpers)):
|
||||
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
|
||||
warpers[i] = ModifiedTemperatureLogitsWarper(
|
||||
temperature,
|
||||
generation_config.dynamic_temperature,
|
||||
generation_config.dynatemp_low,
|
||||
generation_config.dynatemp_high,
|
||||
generation_config.dynatemp_exponent,
|
||||
generation_config.smoothing_factor
|
||||
warpers[i] = TemperatureLogitsWarperCustom(
|
||||
generation_config.temperature,
|
||||
)
|
||||
|
||||
# Add custom warpers
|
||||
warpers_to_add = LogitsProcessorList()
|
||||
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
|
||||
if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0:
|
||||
warpers_to_add.append(
|
||||
TailFreeLogitsWarper(
|
||||
tfs=generation_config.tfs,
|
||||
min_tokens_to_keep=min_tokens_to_keep
|
||||
)
|
||||
)
|
||||
|
||||
if generation_config.top_a is not None and 0.0 < generation_config.top_a <= 1.0:
|
||||
warpers_to_add.append(
|
||||
TopALogitsWarper(
|
||||
top_a=generation_config.top_a,
|
||||
min_tokens_to_keep=min_tokens_to_keep
|
||||
)
|
||||
)
|
||||
|
||||
if generation_config.min_p is not None and 0.0 < generation_config.min_p <= 1.0:
|
||||
warpers_to_add.append(
|
||||
MinPLogitsWarper(
|
||||
min_p=generation_config.min_p,
|
||||
min_tokens_to_keep=min_tokens_to_keep
|
||||
)
|
||||
)
|
||||
|
||||
if generation_config.dynamic_temperature:
|
||||
warpers_to_add.append(
|
||||
DynamicTemperatureLogitsWarper(
|
||||
dynatemp_low=generation_config.dynatemp_low,
|
||||
dynatemp_high=generation_config.dynatemp_high,
|
||||
dynatemp_exponent=generation_config.dynatemp_exponent,
|
||||
)
|
||||
)
|
||||
|
||||
if generation_config.smoothing_factor > 0:
|
||||
warpers_to_add.append(
|
||||
QuadraticSamplingLogitsWarper(
|
||||
smoothing_factor=generation_config.smoothing_factor
|
||||
)
|
||||
)
|
||||
|
||||
if generation_config.mirostat_mode is not None and generation_config.mirostat_mode == 2:
|
||||
warpers_to_add.append(MirostatLogitsWarper(mirostat_mode=generation_config.mirostat_mode, mirostat_eta=generation_config.mirostat_eta, mirostat_tau=generation_config.mirostat_tau, min_tokens_to_keep=min_tokens_to_keep))
|
||||
# We need to disable samplers other than temperature
|
||||
for warper in warpers:
|
||||
if not isinstance(warper, TemperatureLogitsWarper):
|
||||
warpers.remove(warper)
|
||||
else:
|
||||
if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0:
|
||||
warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep))
|
||||
if generation_config.top_a is not None and 0.0 < generation_config.top_a <= 1.0:
|
||||
warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep))
|
||||
if generation_config.min_p is not None and 0.0 < generation_config.min_p <= 1.0:
|
||||
warpers_to_add.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
|
||||
warpers_to_add.append(
|
||||
MirostatLogitsWarper(
|
||||
mirostat_mode=generation_config.mirostat_mode,
|
||||
mirostat_eta=generation_config.mirostat_eta,
|
||||
mirostat_tau=generation_config.mirostat_tau,
|
||||
min_tokens_to_keep=min_tokens_to_keep
|
||||
)
|
||||
)
|
||||
|
||||
if len(warpers) > 0 and isinstance(warpers[-1], LogitNormalization):
|
||||
normalize = warpers.pop(-1)
|
||||
@ -346,23 +387,57 @@ def get_logits_warper_patch(self, generation_config):
|
||||
normalize = None
|
||||
|
||||
warpers += warpers_to_add
|
||||
if generation_config.temperature_last:
|
||||
temperature_idx = None
|
||||
for i in range(len(warpers)):
|
||||
if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'ModifiedTemperatureLogitsWarper']:
|
||||
temperature_idx = i
|
||||
break
|
||||
|
||||
if temperature_idx is not None:
|
||||
warpers.append(warpers.pop(temperature_idx))
|
||||
# Sort the samplers.
|
||||
sampler_priority = generation_config.sampler_priority
|
||||
|
||||
# Handle temperature_last
|
||||
if generation_config.temperature_last:
|
||||
for param_name in ['temperature', 'dynamic_temperature', 'quadratic_sampling']:
|
||||
if param_name in sampler_priority:
|
||||
if param_name in sampler_priority:
|
||||
index = sampler_priority.index(param_name)
|
||||
sampler_priority.append(sampler_priority.pop(index))
|
||||
else:
|
||||
sampler_priority.append(param_name)
|
||||
|
||||
class_name_to_nickname = {
|
||||
'DynamicTemperatureLogitsWarper': 'dynamic_temperature',
|
||||
'EpsilonLogitsWarper': 'epsilon_cutoff',
|
||||
'EtaLogitsWarper': 'eta_cutoff',
|
||||
'MinPLogitsWarper': 'min_p',
|
||||
'MirostatLogitsWarper': 'mirostat',
|
||||
'QuadraticSamplingLogitsWarper': 'quadratic_sampling',
|
||||
'TailFreeLogitsWarper': 'tfs',
|
||||
'TemperatureLogitsWarperCustom': 'temperature',
|
||||
'TopALogitsWarper': 'top_a',
|
||||
'TopKLogitsWarper': 'top_k',
|
||||
'TopPLogitsWarper': 'top_p',
|
||||
'TypicalLogitsWarper': 'typical_p'
|
||||
}
|
||||
|
||||
def custom_sort_key(obj):
|
||||
class_name = obj.__class__.__name__
|
||||
|
||||
# Return a large value if class name is not mapped or if the mapped nickname is not in priority
|
||||
if class_name not in class_name_to_nickname or class_name_to_nickname[class_name] not in sampler_priority:
|
||||
return float('inf')
|
||||
|
||||
# Return the index of the nickname in the priority list for sorting
|
||||
return sampler_priority.index(class_name_to_nickname[class_name])
|
||||
|
||||
# Sort the list using the custom key function
|
||||
warpers = sorted(warpers, key=custom_sort_key)
|
||||
|
||||
if normalize is not None:
|
||||
warpers.append(normalize)
|
||||
|
||||
warpers.append(SpyLogitsWarper())
|
||||
warpers = LogitsProcessorList(warpers)
|
||||
# for i in range(len(warpers)):
|
||||
# print(warpers[i].__class__.__name__)
|
||||
if shared.args.verbose:
|
||||
logger.info("WARPERS=")
|
||||
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint([x.__class__.__name__ for x in warpers])
|
||||
|
||||
return warpers
|
||||
|
||||
|
||||
@ -402,6 +477,7 @@ def generation_config_init_patch(self, **kwargs):
|
||||
self.presence_penalty = kwargs.pop("presence_penalty", 0)
|
||||
self.frequency_penalty = kwargs.pop("frequency_penalty", 0)
|
||||
self.temperature_last = kwargs.pop("temperature_last", False)
|
||||
self.sampler_priority = kwargs.pop("sampler_priority", ['temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'mirostat'])
|
||||
|
||||
|
||||
def hijack_samplers():
|
||||
|
@ -50,6 +50,7 @@ settings = {
|
||||
'prompt_lookup_num_tokens': 0,
|
||||
'custom_stopping_strings': '',
|
||||
'custom_token_bans': '',
|
||||
'sampler_priority': 'temperature,top_k,top_p,typical_p,epsilon_cutoff,eta_cutoff,tfs,top_a,min_p,dynamic_temperature,quadratic_sampling,mirostat',
|
||||
'auto_max_new_tokens': False,
|
||||
'ban_eos_token': False,
|
||||
'add_bos_token': True,
|
||||
|
@ -291,6 +291,11 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
||||
if k in state:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
if isinstance(state['sampler_priority'], list):
|
||||
generate_params['sampler_priority'] = state['sampler_priority']
|
||||
elif isinstance(state['sampler_priority'], str):
|
||||
generate_params['sampler_priority'] = [x.strip() for x in state['sampler_priority'].replace('\n', ',').split(',') if x.strip()]
|
||||
|
||||
if state['negative_prompt'] != '':
|
||||
generate_params['negative_prompt_ids'] = encode(state['negative_prompt'])
|
||||
|
||||
|
@ -149,6 +149,7 @@ def list_interface_input_elements():
|
||||
'add_bos_token',
|
||||
'ban_eos_token',
|
||||
'custom_token_bans',
|
||||
'sampler_priority',
|
||||
'truncation_length',
|
||||
'custom_stopping_strings',
|
||||
'skip_special_tokens',
|
||||
|
@ -49,12 +49,12 @@ def create_ui(default_preset):
|
||||
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
|
||||
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
|
||||
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
|
||||
shared.gradio['smoothing_factor'] = gr.Slider(0.0, 10.0, value=generate_params['smoothing_factor'], step=0.01, label='smoothing_factor', info='Replaces temperature with Quadratic Sampling.')
|
||||
shared.gradio['smoothing_factor'] = gr.Slider(0.0, 10.0, value=generate_params['smoothing_factor'], step=0.01, label='smoothing_factor', info='Activates Quadratic Sampling.')
|
||||
shared.gradio['dynamic_temperature'] = gr.Checkbox(value=generate_params['dynamic_temperature'], label='dynamic_temperature')
|
||||
shared.gradio['dynatemp_low'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_low'], step=0.01, label='dynatemp_low', visible=generate_params['dynamic_temperature'])
|
||||
shared.gradio['dynatemp_high'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_high'], step=0.01, label='dynatemp_high', visible=generate_params['dynamic_temperature'])
|
||||
shared.gradio['dynatemp_exponent'] = gr.Slider(0.01, 5, value=generate_params['dynatemp_exponent'], step=0.01, label='dynatemp_exponent', visible=generate_params['dynamic_temperature'])
|
||||
shared.gradio['temperature_last'] = gr.Checkbox(value=generate_params['temperature_last'], label='temperature_last', info='Makes temperature the last sampler instead of the first.')
|
||||
shared.gradio['temperature_last'] = gr.Checkbox(value=generate_params['temperature_last'], label='temperature_last', info='Moves temperature/dynamic temperature/quadratic sampling to the end of the sampler stack, ignoring their positions in "Sampler priority".')
|
||||
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
||||
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
|
||||
with gr.Accordion('Other parameters', open=False):
|
||||
@ -85,6 +85,9 @@ def create_ui(default_preset):
|
||||
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
|
||||
shared.gradio['stream'] = gr.Checkbox(value=shared.settings['stream'], label='Activate text streaming')
|
||||
|
||||
with gr.Blocks():
|
||||
shared.gradio['sampler_priority'] = gr.Textbox(value=generate_params['sampler_priority'], lines=12, label='Sampler priority', info='Parameter names separated by new lines or commas.')
|
||||
|
||||
with gr.Row() as shared.gradio['grammar_file_row']:
|
||||
shared.gradio['grammar_file'] = gr.Dropdown(value='None', choices=utils.get_available_grammars(), label='Load grammar from file (.gbnf)', elem_classes='slim-dropdown')
|
||||
ui.create_refresh_button(shared.gradio['grammar_file'], lambda: None, lambda: {'choices': utils.get_available_grammars()}, 'refresh-button', interactive=not mu)
|
||||
|
Loading…
x
Reference in New Issue
Block a user