mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-09 03:59:05 +01:00
Connect XTC, DRY, smoothing_factor, and dynatemp to ExLlamaV2 loader (non-HF)
This commit is contained in:
parent
d56b500568
commit
3967520e71
@ -1,8 +1,8 @@
|
||||
import json
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Cache,
|
||||
@ -15,6 +15,7 @@ from exllamav2 import (
|
||||
ExLlamaV2Tokenizer
|
||||
)
|
||||
from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.text_generation import get_max_prompt_length
|
||||
@ -122,6 +123,10 @@ class Exllamav2Model:
|
||||
settings.token_presence_penalty = state['presence_penalty']
|
||||
|
||||
settings.temperature = state['temperature']
|
||||
settings.smoothing_factor = state['smoothing_factor']
|
||||
settings.min_temp = state['dynatemp_low'] if state['dynamic_temperature'] else 0
|
||||
settings.max_temp = state['dynatemp_high'] if state['dynamic_temperature'] else 0
|
||||
settings.temp_exponent = state['dynatemp_exponent']
|
||||
settings.top_k = state['top_k']
|
||||
settings.top_p = state['top_p']
|
||||
settings.top_a = state['top_a']
|
||||
@ -143,6 +148,29 @@ class Exllamav2Model:
|
||||
if len(to_ban) > 0:
|
||||
settings.disallow_tokens(self.tokenizer, to_ban)
|
||||
|
||||
settings.dry_allowed_length = state['dry_allowed_length']
|
||||
settings.dry_base = state['dry_base']
|
||||
settings.dry_multiplier = state['dry_multiplier']
|
||||
|
||||
# Dry sequence breakers processing
|
||||
if state['dry_multiplier'] > 0 and state['dry_sequence_breakers']:
|
||||
dry_sequence_breakers = state['dry_sequence_breakers']
|
||||
|
||||
# Support both JSON array notation and comma-separated strings.
|
||||
if not dry_sequence_breakers.startswith("["):
|
||||
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
|
||||
|
||||
sequence_breaker_strings = json.loads(dry_sequence_breakers)
|
||||
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
|
||||
sequence_breakers = {
|
||||
self.encode(f"a{s}")[0, -1].item() for s in sequence_breaker_strings
|
||||
}
|
||||
|
||||
settings.dry_sequence_breakers = sequence_breakers
|
||||
|
||||
settings.xtc_probability = state['xtc_probability']
|
||||
settings.xtc_threshold = state['xtc_threshold']
|
||||
|
||||
ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True)
|
||||
ids = ids[:, -get_max_prompt_length(state):]
|
||||
|
||||
|
@ -194,6 +194,10 @@ loaders_samplers = {
|
||||
'ExLlamav2': {
|
||||
'temperature',
|
||||
'temperature_last',
|
||||
'smoothing_factor',
|
||||
'dynatemp_low',
|
||||
'dynatemp_high',
|
||||
'dynatemp_exponent',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'top_k',
|
||||
@ -204,10 +208,16 @@ loaders_samplers = {
|
||||
'presence_penalty',
|
||||
'frequency_penalty',
|
||||
'repetition_penalty_range',
|
||||
'seed',
|
||||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'dry_multiplier',
|
||||
'dry_base',
|
||||
'dry_allowed_length',
|
||||
'dry_sequence_breakers',
|
||||
'xtc_threshold',
|
||||
'xtc_probability',
|
||||
'seed',
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'custom_token_bans',
|
||||
|
@ -495,7 +495,9 @@ def get_logits_processor_patch(self, **kwargs):
|
||||
|
||||
sequence_breaker_strings = json.loads(dry_sequence_breakers)
|
||||
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
|
||||
sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings}
|
||||
sequence_breakers = {
|
||||
shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings
|
||||
}
|
||||
|
||||
warpers.append(
|
||||
DRYLogitsProcessor(
|
||||
|
Loading…
Reference in New Issue
Block a user