mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-12 05:17:40 +01:00
Allow more granular KV cache settings (#6561)
This commit is contained in:
parent
c43ee5db11
commit
addad3c63e
@ -2,17 +2,19 @@ import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Cache,
|
||||
ExLlamaV2Cache_8bit,
|
||||
ExLlamaV2Cache_Q4,
|
||||
ExLlamaV2Cache_Q6,
|
||||
ExLlamaV2Cache_Q8,
|
||||
ExLlamaV2Cache_TP,
|
||||
ExLlamaV2Config,
|
||||
ExLlamaV2Tokenizer
|
||||
)
|
||||
from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.text_generation import get_max_prompt_length
|
||||
@ -57,12 +59,22 @@ class Exllamav2Model:
|
||||
model.load(split)
|
||||
|
||||
# Determine the correct cache type
|
||||
if shared.args.cache_8bit:
|
||||
kv_cache_type = 'fp16'
|
||||
if shared.args.cache_type:
|
||||
kv_cache_type = shared.args.cache_type.lower()
|
||||
|
||||
if kv_cache_type == 'fp16':
|
||||
cache_type = ExLlamaV2Cache
|
||||
elif kv_cache_type == 'fp8':
|
||||
cache_type = ExLlamaV2Cache_8bit
|
||||
elif shared.args.cache_4bit:
|
||||
elif kv_cache_type == 'q8':
|
||||
cache_type = ExLlamaV2Cache_Q8
|
||||
elif kv_cache_type == 'q6':
|
||||
cache_type = ExLlamaV2Cache_Q6
|
||||
elif kv_cache_type == 'q4':
|
||||
cache_type = ExLlamaV2Cache_Q4
|
||||
else:
|
||||
cache_type = ExLlamaV2Cache
|
||||
raise ValueError(f"Invalid cache type for ExLlamaV2: {cache_type}. Valid options are: fp16, fp8, q8, q6, q4.")
|
||||
|
||||
# Use TP if specified
|
||||
if shared.args.enable_tp:
|
||||
|
@ -4,18 +4,20 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Cache,
|
||||
ExLlamaV2Cache_8bit,
|
||||
ExLlamaV2Cache_Q4,
|
||||
ExLlamaV2Cache_Q6,
|
||||
ExLlamaV2Cache_Q8,
|
||||
ExLlamaV2Cache_TP,
|
||||
ExLlamaV2Config
|
||||
)
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
@ -45,12 +47,22 @@ class Exllamav2HF(PreTrainedModel):
|
||||
self.ex_model.load(split)
|
||||
|
||||
# Determine the correct cache type
|
||||
if shared.args.cache_8bit:
|
||||
kv_cache_type = 'fp16'
|
||||
if shared.args.cache_type:
|
||||
kv_cache_type = shared.args.cache_type.lower()
|
||||
|
||||
if kv_cache_type == 'fp16':
|
||||
cache_type = ExLlamaV2Cache
|
||||
elif kv_cache_type == 'fp8':
|
||||
cache_type = ExLlamaV2Cache_8bit
|
||||
elif shared.args.cache_4bit:
|
||||
elif kv_cache_type == 'q8':
|
||||
cache_type = ExLlamaV2Cache_Q8
|
||||
elif kv_cache_type == 'q6':
|
||||
cache_type = ExLlamaV2Cache_Q6
|
||||
elif kv_cache_type == 'q4':
|
||||
cache_type = ExLlamaV2Cache_Q4
|
||||
else:
|
||||
cache_type = ExLlamaV2Cache
|
||||
raise ValueError(f"Invalid cache type for ExLlamaV2: {cache_type}. Valid options are: fp16, fp8, q8, q6, q4.")
|
||||
|
||||
# Use TP if specified
|
||||
if shared.args.enable_tp:
|
||||
|
@ -9,6 +9,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from modules import shared
|
||||
from modules.llama_cpp_python_hijack import llama_cpp_lib
|
||||
from modules.llamacpp_model import get_llamacpp_cache_type_for_string
|
||||
from modules.logging_colors import logger
|
||||
|
||||
|
||||
@ -196,12 +197,9 @@ class LlamacppHF(PreTrainedModel):
|
||||
'flash_attn': shared.args.flash_attn
|
||||
}
|
||||
|
||||
if shared.args.cache_4bit:
|
||||
params["type_k"] = 2
|
||||
params["type_v"] = 2
|
||||
elif shared.args.cache_8bit:
|
||||
params["type_k"] = 8
|
||||
params["type_v"] = 8
|
||||
if shared.args.cache_type:
|
||||
params["type_k"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)
|
||||
params["type_v"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)
|
||||
|
||||
Llama = llama_cpp_lib().Llama
|
||||
model = Llama(**params)
|
||||
|
@ -10,6 +10,35 @@ from modules.llama_cpp_python_hijack import llama_cpp_lib
|
||||
from modules.logging_colors import logger
|
||||
from modules.text_generation import get_max_prompt_length
|
||||
|
||||
llamacpp_quant_mapping = {
|
||||
'f32': 0,
|
||||
'fp16': 1,
|
||||
'q4_0': 2,
|
||||
'q4_1': 3,
|
||||
'q5_0': 6,
|
||||
'q5_1': 7,
|
||||
'q8_0': 8,
|
||||
'q8_1': 9,
|
||||
'q2_k': 10,
|
||||
'q3_k': 11,
|
||||
'q4_k': 12,
|
||||
'q5_k': 13,
|
||||
'q6_k': 14,
|
||||
'q8_k': 15,
|
||||
'iq4_nl': 20,
|
||||
'bf16': 30,
|
||||
}
|
||||
|
||||
llamacpp_valid_cache_types = {'fp16', 'q8_0', 'q4_0'}
|
||||
|
||||
|
||||
def get_llamacpp_cache_type_for_string(quant_type: str):
|
||||
quant_type = quant_type.lower()
|
||||
if quant_type in llamacpp_valid_cache_types:
|
||||
return llamacpp_quant_mapping[quant_type]
|
||||
else:
|
||||
raise ValueError(f"Invalid cache type for llama.cpp: {quant_type}. Valid options are: fp16, q8_0, q4_0.")
|
||||
|
||||
|
||||
def ban_eos_logits_processor(eos_token, input_ids, logits):
|
||||
logits[eos_token] = -float('inf')
|
||||
@ -75,12 +104,9 @@ class LlamaCppModel:
|
||||
'flash_attn': shared.args.flash_attn
|
||||
}
|
||||
|
||||
if shared.args.cache_4bit:
|
||||
params["type_k"] = 2
|
||||
params["type_v"] = 2
|
||||
elif shared.args.cache_8bit:
|
||||
params["type_k"] = 8
|
||||
params["type_v"] = 8
|
||||
if shared.args.cache_type:
|
||||
params["type_k"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)
|
||||
params["type_v"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)
|
||||
|
||||
result.model = Llama(**params)
|
||||
if cache_capacity > 0:
|
||||
|
@ -31,8 +31,7 @@ loaders_and_params = OrderedDict({
|
||||
'llama.cpp': [
|
||||
'n_ctx',
|
||||
'n_gpu_layers',
|
||||
'cache_8bit',
|
||||
'cache_4bit',
|
||||
'cache_type',
|
||||
'tensor_split',
|
||||
'n_batch',
|
||||
'threads',
|
||||
@ -54,8 +53,7 @@ loaders_and_params = OrderedDict({
|
||||
'llamacpp_HF': [
|
||||
'n_ctx',
|
||||
'n_gpu_layers',
|
||||
'cache_8bit',
|
||||
'cache_4bit',
|
||||
'cache_type',
|
||||
'tensor_split',
|
||||
'n_batch',
|
||||
'threads',
|
||||
@ -87,8 +85,7 @@ loaders_and_params = OrderedDict({
|
||||
'no_xformers',
|
||||
'no_sdpa',
|
||||
'num_experts_per_token',
|
||||
'cache_8bit',
|
||||
'cache_4bit',
|
||||
'cache_type',
|
||||
'autosplit',
|
||||
'enable_tp',
|
||||
'alpha_value',
|
||||
@ -103,8 +100,7 @@ loaders_and_params = OrderedDict({
|
||||
'no_xformers',
|
||||
'no_sdpa',
|
||||
'num_experts_per_token',
|
||||
'cache_8bit',
|
||||
'cache_4bit',
|
||||
'cache_type',
|
||||
'autosplit',
|
||||
'enable_tp',
|
||||
'alpha_value',
|
||||
|
@ -142,8 +142,6 @@ group.add_argument('--cfg-cache', action='store_true', help='ExLlamav2_HF: Creat
|
||||
group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
|
||||
group.add_argument('--no_xformers', action='store_true', help='Force xformers to not be used.')
|
||||
group.add_argument('--no_sdpa', action='store_true', help='Force Torch SDPA to not be used.')
|
||||
group.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.')
|
||||
group.add_argument('--cache_4bit', action='store_true', help='Use Q4 cache to save VRAM.')
|
||||
group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.')
|
||||
group.add_argument('--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) in ExLlamaV2.')
|
||||
|
||||
@ -166,6 +164,10 @@ group.add_argument('--hqq-backend', type=str, default='PYTORCH_COMPILE', help='B
|
||||
group = parser.add_argument_group('TensorRT-LLM')
|
||||
group.add_argument('--cpp-runner', action='store_true', help='Use the ModelRunnerCpp runner, which is faster than the default ModelRunner but doesn\'t support streaming yet.')
|
||||
|
||||
# Cache
|
||||
group = parser.add_argument_group('Cache')
|
||||
group.add_argument('--cache_type', type=str, default=None, help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.')
|
||||
|
||||
# DeepSpeed
|
||||
group = parser.add_argument_group('DeepSpeed')
|
||||
group.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
|
||||
@ -213,6 +215,8 @@ group.add_argument('--pre_layer', type=int, nargs='+', help='DEPRECATED')
|
||||
group.add_argument('--checkpoint', type=str, help='DEPRECATED')
|
||||
group.add_argument('--monkey-patch', action='store_true', help='DEPRECATED')
|
||||
group.add_argument('--no_inject_fused_attention', action='store_true', help='DEPRECATED')
|
||||
group.add_argument('--cache_4bit', action='store_true', help='DEPRECATED')
|
||||
group.add_argument('--cache_8bit', action='store_true', help='DEPRECATED')
|
||||
group.add_argument('--chat-buttons', action='store_true', help='DEPRECATED')
|
||||
|
||||
args = parser.parse_args()
|
||||
@ -270,6 +274,59 @@ def fix_loader_name(name):
|
||||
return 'TensorRT-LLM'
|
||||
|
||||
|
||||
def transform_legacy_kv_cache_options(opts):
|
||||
# Handle both argparse.Namespace and dict here
|
||||
def get(key):
|
||||
return opts.get(key) if isinstance(opts, dict) else getattr(opts, key, None)
|
||||
|
||||
def set(key, value):
|
||||
if isinstance(opts, dict):
|
||||
opts[key] = value
|
||||
else:
|
||||
setattr(opts, key, value)
|
||||
|
||||
def del_key(key, fallback_set):
|
||||
# only remove from user dict, can't delete from argparse.Namespace
|
||||
if type(opts) is dict:
|
||||
if key in opts:
|
||||
del opts[key]
|
||||
else:
|
||||
setattr(opts, key, fallback_set)
|
||||
|
||||
# Retrieve values
|
||||
loader = get('loader')
|
||||
cache_type = get('cache_type')
|
||||
cache_8bit = get('cache_8bit')
|
||||
cache_4bit = get('cache_4bit')
|
||||
|
||||
# Determine cache type based on loader or legacy flags
|
||||
if not cache_type:
|
||||
if not loader:
|
||||
# Legacy behavior: prefer 8-bit over 4-bit to minimize breakage
|
||||
if cache_8bit:
|
||||
set('cache_type', 'fp8')
|
||||
elif cache_4bit:
|
||||
set('cache_type', 'q4')
|
||||
elif loader.lower() in ['exllamav2', 'exllamav2_hf']:
|
||||
# ExLlamaV2 loader-specific cache type
|
||||
if cache_8bit:
|
||||
set('cache_type', 'fp8')
|
||||
elif cache_4bit:
|
||||
set('cache_type', 'q4')
|
||||
elif loader.lower() in ['llama.cpp', 'llamacpp_hf']:
|
||||
# Llama.cpp loader-specific cache type
|
||||
if cache_4bit:
|
||||
set('cache_type', 'q4_0')
|
||||
elif cache_8bit:
|
||||
set('cache_type', 'q8_0')
|
||||
|
||||
# Clean up legacy keys
|
||||
del_key('cache_4bit', False)
|
||||
del_key('cache_8bit', False)
|
||||
|
||||
return opts
|
||||
|
||||
|
||||
def add_extension(name, last=False):
|
||||
if args.extensions is None:
|
||||
args.extensions = [name]
|
||||
@ -298,10 +355,14 @@ def load_user_config():
|
||||
else:
|
||||
user_config = {}
|
||||
|
||||
for model_name in user_config:
|
||||
user_config[model_name] = transform_legacy_kv_cache_options(user_config[model_name])
|
||||
|
||||
return user_config
|
||||
|
||||
|
||||
args.loader = fix_loader_name(args.loader)
|
||||
args = transform_legacy_kv_cache_options(args)
|
||||
|
||||
# Activate the multimodal extension
|
||||
if args.multimodal_pipeline is not None:
|
||||
|
@ -130,8 +130,7 @@ def list_model_elements():
|
||||
'no_xformers',
|
||||
'no_sdpa',
|
||||
'num_experts_per_token',
|
||||
'cache_8bit',
|
||||
'cache_4bit',
|
||||
'cache_type',
|
||||
'autosplit',
|
||||
'enable_tp',
|
||||
'threads',
|
||||
|
@ -118,8 +118,7 @@ def create_ui():
|
||||
shared.gradio['flash_attn'] = gr.Checkbox(label="flash_attn", value=shared.args.flash_attn, info='Use flash-attention.')
|
||||
shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices)
|
||||
shared.gradio['tensorcores'] = gr.Checkbox(label="tensorcores", value=shared.args.tensorcores, info='NVIDIA only: use llama-cpp-python compiled with tensor cores support. This may increase performance on newer cards.')
|
||||
shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.')
|
||||
shared.gradio['cache_4bit'] = gr.Checkbox(label="cache_4bit", value=shared.args.cache_4bit, info='Use Q4 cache to save VRAM.')
|
||||
shared.gradio['cache_type'] = gr.Dropdown(label="cache_type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q6', 'q4'], value=shared.args.cache_type, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.')
|
||||
shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming_llm", value=shared.args.streaming_llm, info='(experimental) Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
|
||||
shared.gradio['attention_sink_size'] = gr.Number(label="attention_sink_size", value=shared.args.attention_sink_size, precision=0, info='StreamingLLM: number of sink tokens. Only used if the trimmed prompt doesn\'t share a prefix with the old prompt.')
|
||||
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='llama.cpp: Use llama-cpp-python compiled without GPU acceleration. Transformers: use PyTorch in CPU mode.')
|
||||
|
Loading…
x
Reference in New Issue
Block a user