Add a "Static KV cache" option for transformers

This commit is contained in:
oobabooga 2025-01-04 17:52:57 -08:00
parent 3967520e71
commit 11af199aff
7 changed files with 10 additions and 1 deletions

View File

@ -42,6 +42,7 @@ class GenerationOptions(BaseModel):
truncation_length: int = 0
max_tokens_second: int = 0
prompt_lookup_num_tokens: int = 0
static_cache: bool = False
custom_token_bans: str = ""
sampler_priority: List[str] | str | None = Field(default=None, description="List of samplers where the first items will appear first in the stack. Example: [\"top_k\", \"temperature\", \"top_p\"].")
auto_max_new_tokens: bool = False

View File

@ -183,7 +183,8 @@ def transformers_samplers():
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',
'prompt_lookup_num_tokens'
'prompt_lookup_num_tokens',
'static_cache',
}

View File

@ -46,6 +46,7 @@ settings = {
'max_tokens_second': 0,
'max_updates_second': 0,
'prompt_lookup_num_tokens': 0,
'static_cache': False,
'custom_stopping_strings': '',
'custom_token_bans': '',
'auto_max_new_tokens': False,

View File

@ -302,6 +302,9 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
if state['prompt_lookup_num_tokens'] > 0:
generate_params['prompt_lookup_num_tokens'] = state['prompt_lookup_num_tokens']
if state['static_cache']:
generate_params['cache_implementation'] = 'static'
for k in ['epsilon_cutoff', 'eta_cutoff']:
if state[k] > 0:
generate_params[k] = state[k] * 1e-4

View File

@ -220,6 +220,7 @@ def list_interface_input_elements():
'custom_stopping_strings',
'skip_special_tokens',
'stream',
'static_cache',
'tfs',
'top_a',
]

View File

@ -83,6 +83,7 @@ def create_ui(default_preset):
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
shared.gradio['stream'] = gr.Checkbox(value=shared.settings['stream'], label='Activate text streaming')
shared.gradio['static_cache'] = gr.Checkbox(value=shared.settings['static_cache'], label='Static KV cache')
with gr.Column():
shared.gradio['truncation_length'] = gr.Number(precision=0, step=256, value=get_truncation_length(), label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')

View File

@ -22,6 +22,7 @@ ban_eos_token: false
add_bos_token: true
skip_special_tokens: true
stream: true
static_cache: false
character: Assistant
name1: You
custom_system_message: ''