Connect XTC, DRY, smoothing_factor, and dynatemp to ExLlamaV2 loader (non-HF)

This commit is contained in:
oobabooga 2025-01-04 16:22:59 -08:00
parent d56b500568
commit 3967520e71
3 changed files with 43 additions and 3 deletions

View File

@ -1,8 +1,8 @@
import json
import traceback import traceback
from pathlib import Path from pathlib import Path
import torch import torch
from exllamav2 import ( from exllamav2 import (
ExLlamaV2, ExLlamaV2,
ExLlamaV2Cache, ExLlamaV2Cache,
@ -15,6 +15,7 @@ from exllamav2 import (
ExLlamaV2Tokenizer ExLlamaV2Tokenizer
) )
from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator
from modules import shared from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.text_generation import get_max_prompt_length from modules.text_generation import get_max_prompt_length
@ -122,6 +123,10 @@ class Exllamav2Model:
settings.token_presence_penalty = state['presence_penalty'] settings.token_presence_penalty = state['presence_penalty']
settings.temperature = state['temperature'] settings.temperature = state['temperature']
settings.smoothing_factor = state['smoothing_factor']
settings.min_temp = state['dynatemp_low'] if state['dynamic_temperature'] else 0
settings.max_temp = state['dynatemp_high'] if state['dynamic_temperature'] else 0
settings.temp_exponent = state['dynatemp_exponent']
settings.top_k = state['top_k'] settings.top_k = state['top_k']
settings.top_p = state['top_p'] settings.top_p = state['top_p']
settings.top_a = state['top_a'] settings.top_a = state['top_a']
@ -143,6 +148,29 @@ class Exllamav2Model:
if len(to_ban) > 0: if len(to_ban) > 0:
settings.disallow_tokens(self.tokenizer, to_ban) settings.disallow_tokens(self.tokenizer, to_ban)
settings.dry_allowed_length = state['dry_allowed_length']
settings.dry_base = state['dry_base']
settings.dry_multiplier = state['dry_multiplier']
# Dry sequence breakers processing
if state['dry_multiplier'] > 0 and state['dry_sequence_breakers']:
dry_sequence_breakers = state['dry_sequence_breakers']
# Support both JSON array notation and comma-separated strings.
if not dry_sequence_breakers.startswith("["):
dry_sequence_breakers = "[" + dry_sequence_breakers + "]"
sequence_breaker_strings = json.loads(dry_sequence_breakers)
# Prefix with 'a' to get the correct encoding of the token at the end of a text.
sequence_breakers = {
self.encode(f"a{s}")[0, -1].item() for s in sequence_breaker_strings
}
settings.dry_sequence_breakers = sequence_breakers
settings.xtc_probability = state['xtc_probability']
settings.xtc_threshold = state['xtc_threshold']
ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True) ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True)
ids = ids[:, -get_max_prompt_length(state):] ids = ids[:, -get_max_prompt_length(state):]

View File

@ -194,6 +194,10 @@ loaders_samplers = {
'ExLlamav2': { 'ExLlamav2': {
'temperature', 'temperature',
'temperature_last', 'temperature_last',
'smoothing_factor',
'dynatemp_low',
'dynatemp_high',
'dynatemp_exponent',
'top_p', 'top_p',
'min_p', 'min_p',
'top_k', 'top_k',
@ -204,10 +208,16 @@ loaders_samplers = {
'presence_penalty', 'presence_penalty',
'frequency_penalty', 'frequency_penalty',
'repetition_penalty_range', 'repetition_penalty_range',
'seed',
'mirostat_mode', 'mirostat_mode',
'mirostat_tau', 'mirostat_tau',
'mirostat_eta', 'mirostat_eta',
'dry_multiplier',
'dry_base',
'dry_allowed_length',
'dry_sequence_breakers',
'xtc_threshold',
'xtc_probability',
'seed',
'ban_eos_token', 'ban_eos_token',
'add_bos_token', 'add_bos_token',
'custom_token_bans', 'custom_token_bans',

View File

@ -495,7 +495,9 @@ def get_logits_processor_patch(self, **kwargs):
sequence_breaker_strings = json.loads(dry_sequence_breakers) sequence_breaker_strings = json.loads(dry_sequence_breakers)
# Prefix with 'a' to get the correct encoding of the token at the end of a text. # Prefix with 'a' to get the correct encoding of the token at the end of a text.
sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings} sequence_breakers = {
shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings
}
warpers.append( warpers.append(
DRYLogitsProcessor( DRYLogitsProcessor(