From 3967520e71ba0ab386893d7c7e946fd621e25b06 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 4 Jan 2025 16:22:59 -0800 Subject: [PATCH] Connect XTC, DRY, smoothing_factor, and dynatemp to ExLlamaV2 loader (non-HF) --- modules/exllamav2.py | 30 +++++++++++++++++++++++++++++- modules/loaders.py | 12 +++++++++++- modules/sampler_hijack.py | 4 +++- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/modules/exllamav2.py b/modules/exllamav2.py index 9b6da83c..0289bb21 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -1,8 +1,8 @@ +import json import traceback from pathlib import Path import torch - from exllamav2 import ( ExLlamaV2, ExLlamaV2Cache, @@ -15,6 +15,7 @@ from exllamav2 import ( ExLlamaV2Tokenizer ) from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator + from modules import shared from modules.logging_colors import logger from modules.text_generation import get_max_prompt_length @@ -122,6 +123,10 @@ class Exllamav2Model: settings.token_presence_penalty = state['presence_penalty'] settings.temperature = state['temperature'] + settings.smoothing_factor = state['smoothing_factor'] + settings.min_temp = state['dynatemp_low'] if state['dynamic_temperature'] else 0 + settings.max_temp = state['dynatemp_high'] if state['dynamic_temperature'] else 0 + settings.temp_exponent = state['dynatemp_exponent'] settings.top_k = state['top_k'] settings.top_p = state['top_p'] settings.top_a = state['top_a'] @@ -143,6 +148,29 @@ class Exllamav2Model: if len(to_ban) > 0: settings.disallow_tokens(self.tokenizer, to_ban) + settings.dry_allowed_length = state['dry_allowed_length'] + settings.dry_base = state['dry_base'] + settings.dry_multiplier = state['dry_multiplier'] + + # Dry sequence breakers processing + if state['dry_multiplier'] > 0 and state['dry_sequence_breakers']: + dry_sequence_breakers = state['dry_sequence_breakers'] + + # Support both JSON array notation and comma-separated strings. + if not dry_sequence_breakers.startswith("["): + dry_sequence_breakers = "[" + dry_sequence_breakers + "]" + + sequence_breaker_strings = json.loads(dry_sequence_breakers) + # Prefix with 'a' to get the correct encoding of the token at the end of a text. + sequence_breakers = { + self.encode(f"a{s}")[0, -1].item() for s in sequence_breaker_strings + } + + settings.dry_sequence_breakers = sequence_breakers + + settings.xtc_probability = state['xtc_probability'] + settings.xtc_threshold = state['xtc_threshold'] + ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True) ids = ids[:, -get_max_prompt_length(state):] diff --git a/modules/loaders.py b/modules/loaders.py index 1cfdb31b..a4edf822 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -194,6 +194,10 @@ loaders_samplers = { 'ExLlamav2': { 'temperature', 'temperature_last', + 'smoothing_factor', + 'dynatemp_low', + 'dynatemp_high', + 'dynatemp_exponent', 'top_p', 'min_p', 'top_k', @@ -204,10 +208,16 @@ loaders_samplers = { 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', - 'seed', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'dry_multiplier', + 'dry_base', + 'dry_allowed_length', + 'dry_sequence_breakers', + 'xtc_threshold', + 'xtc_probability', + 'seed', 'ban_eos_token', 'add_bos_token', 'custom_token_bans', diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index 62ceca8d..d202af1f 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -495,7 +495,9 @@ def get_logits_processor_patch(self, **kwargs): sequence_breaker_strings = json.loads(dry_sequence_breakers) # Prefix with 'a' to get the correct encoding of the token at the end of a text. - sequence_breakers = {shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings} + sequence_breakers = { + shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings + } warpers.append( DRYLogitsProcessor(