mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-09 12:09:04 +01:00
Connect XTC, DRY, smoothing_factor, and dynatemp to ExLlamaV2 loader (non-HF)
This commit is contained in:
parent
d56b500568
commit
3967520e71
@ -1,8 +1,8 @@
|
|||||||
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from exllamav2 import (
|
from exllamav2 import (
|
||||||
ExLlamaV2,
|
ExLlamaV2,
|
||||||
ExLlamaV2Cache,
|
ExLlamaV2Cache,
|
||||||
@ -15,6 +15,7 @@ from exllamav2 import (
|
|||||||
ExLlamaV2Tokenizer
|
ExLlamaV2Tokenizer
|
||||||
)
|
)
|
||||||
from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator
|
from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.text_generation import get_max_prompt_length
|
from modules.text_generation import get_max_prompt_length
|
||||||
@ -122,6 +123,10 @@ class Exllamav2Model:
|
|||||||
settings.token_presence_penalty = state['presence_penalty']
|
settings.token_presence_penalty = state['presence_penalty']
|
||||||
|
|
||||||
settings.temperature = state['temperature']
|
settings.temperature = state['temperature']
|
||||||
|
settings.smoothing_factor = state['smoothing_factor']
|
||||||
|
settings.min_temp = state['dynatemp_low'] if state['dynamic_temperature'] else 0
|
||||||
|
settings.max_temp = state['dynatemp_high'] if state['dynamic_temperature'] else 0
|
||||||
|
settings.temp_exponent = state['dynatemp_exponent']
|
||||||
settings.top_k = state['top_k']
|
settings.top_k = state['top_k']
|
||||||
settings.top_p = state['top_p']
|
settings.top_p = state['top_p']
|
||||||
settings.top_a = state['top_a']
|
settings.top_a = state['top_a']
|
||||||
@ -143,6 +148,29 @@ class Exllamav2Model:
|
|||||||
if len(to_ban) > 0:
|
if len(to_ban) > 0:
|
||||||
settings.disallow_tokens(self.tokenizer, to_ban)
|
settings.disallow_tokens(self.tokenizer, to_ban)
|
||||||
|
|
||||||
|
settings.dry_allowed_length = state['dry_allowed_length']
|
||||||
|
settings.dry_base = state['dry_base']
|
||||||
|
settings.dry_multiplier = state['dry_multiplier']
|
||||||
|
|
||||||
|
# Dry sequence breakers processing
|
||||||
|
if state['dry_multiplier'] > 0 and state['dry_sequence_breakers']:
|
||||||
|
dry_sequence_breakers = state['dry_sequence_breakers']
|
||||||
|
|
||||||
|
# Support both JSON array notation and comma-separated strings.
|
||||||
|
if not dry_sequence_breakers.startswith("["):
|
||||||
|
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
|
||||||
|
|
||||||
|
sequence_breaker_strings = json.loads(dry_sequence_breakers)
|
||||||
|
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
|
||||||
|
sequence_breakers = {
|
||||||
|
self.encode(f"a{s}")[0, -1].item() for s in sequence_breaker_strings
|
||||||
|
}
|
||||||
|
|
||||||
|
settings.dry_sequence_breakers = sequence_breakers
|
||||||
|
|
||||||
|
settings.xtc_probability = state['xtc_probability']
|
||||||
|
settings.xtc_threshold = state['xtc_threshold']
|
||||||
|
|
||||||
ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True)
|
ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True)
|
||||||
ids = ids[:, -get_max_prompt_length(state):]
|
ids = ids[:, -get_max_prompt_length(state):]
|
||||||
|
|
||||||
|
@ -194,6 +194,10 @@ loaders_samplers = {
|
|||||||
'ExLlamav2': {
|
'ExLlamav2': {
|
||||||
'temperature',
|
'temperature',
|
||||||
'temperature_last',
|
'temperature_last',
|
||||||
|
'smoothing_factor',
|
||||||
|
'dynatemp_low',
|
||||||
|
'dynatemp_high',
|
||||||
|
'dynatemp_exponent',
|
||||||
'top_p',
|
'top_p',
|
||||||
'min_p',
|
'min_p',
|
||||||
'top_k',
|
'top_k',
|
||||||
@ -204,10 +208,16 @@ loaders_samplers = {
|
|||||||
'presence_penalty',
|
'presence_penalty',
|
||||||
'frequency_penalty',
|
'frequency_penalty',
|
||||||
'repetition_penalty_range',
|
'repetition_penalty_range',
|
||||||
'seed',
|
|
||||||
'mirostat_mode',
|
'mirostat_mode',
|
||||||
'mirostat_tau',
|
'mirostat_tau',
|
||||||
'mirostat_eta',
|
'mirostat_eta',
|
||||||
|
'dry_multiplier',
|
||||||
|
'dry_base',
|
||||||
|
'dry_allowed_length',
|
||||||
|
'dry_sequence_breakers',
|
||||||
|
'xtc_threshold',
|
||||||
|
'xtc_probability',
|
||||||
|
'seed',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
'add_bos_token',
|
'add_bos_token',
|
||||||
'custom_token_bans',
|
'custom_token_bans',
|
||||||
|
@ -495,7 +495,9 @@ def get_logits_processor_patch(self, **kwargs):
|
|||||||
|
|
||||||
sequence_breaker_strings = json.loads(dry_sequence_breakers)
|
sequence_breaker_strings = json.loads(dry_sequence_breakers)
|
||||||
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
|
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
|
||||||
sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings}
|
sequence_breakers = {
|
||||||
|
shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings
|
||||||
|
}
|
||||||
|
|
||||||
warpers.append(
|
warpers.append(
|
||||||
DRYLogitsProcessor(
|
DRYLogitsProcessor(
|
||||||
|
Loading…
Reference in New Issue
Block a user