From 11af199aff41d0863a06d3407a22c0a13fb2709b Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 4 Jan 2025 17:52:57 -0800 Subject: [PATCH] Add a "Static KV cache" option for transformers --- extensions/openai/typing.py | 1 + modules/loaders.py | 3 ++- modules/shared.py | 1 + modules/text_generation.py | 3 +++ modules/ui.py | 1 + modules/ui_parameters.py | 1 + settings-template.yaml | 1 + 7 files changed, 10 insertions(+), 1 deletion(-) diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index f63c1f39..dfac8e03 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -42,6 +42,7 @@ class GenerationOptions(BaseModel): truncation_length: int = 0 max_tokens_second: int = 0 prompt_lookup_num_tokens: int = 0 + static_cache: bool = False custom_token_bans: str = "" sampler_priority: List[str] | str | None = Field(default=None, description="List of samplers where the first items will appear first in the stack. Example: [\"top_k\", \"temperature\", \"top_p\"].") auto_max_new_tokens: bool = False diff --git a/modules/loaders.py b/modules/loaders.py index a4edf822..e1a41bb1 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -183,7 +183,8 @@ def transformers_samplers(): 'add_bos_token', 'skip_special_tokens', 'auto_max_new_tokens', - 'prompt_lookup_num_tokens' + 'prompt_lookup_num_tokens', + 'static_cache', } diff --git a/modules/shared.py b/modules/shared.py index cab61226..f2ae05a6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -46,6 +46,7 @@ settings = { 'max_tokens_second': 0, 'max_updates_second': 0, 'prompt_lookup_num_tokens': 0, + 'static_cache': False, 'custom_stopping_strings': '', 'custom_token_bans': '', 'auto_max_new_tokens': False, diff --git a/modules/text_generation.py b/modules/text_generation.py index db415dce..3e9788b8 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -302,6 +302,9 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings if state['prompt_lookup_num_tokens'] > 0: generate_params['prompt_lookup_num_tokens'] = state['prompt_lookup_num_tokens'] + if state['static_cache']: + generate_params['cache_implementation'] = 'static' + for k in ['epsilon_cutoff', 'eta_cutoff']: if state[k] > 0: generate_params[k] = state[k] * 1e-4 diff --git a/modules/ui.py b/modules/ui.py index a3bf520f..3c75f6ca 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -220,6 +220,7 @@ def list_interface_input_elements(): 'custom_stopping_strings', 'skip_special_tokens', 'stream', + 'static_cache', 'tfs', 'top_a', ] diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index 727a1528..f22f6233 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -83,6 +83,7 @@ def create_ui(default_preset): shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.') shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') shared.gradio['stream'] = gr.Checkbox(value=shared.settings['stream'], label='Activate text streaming') + shared.gradio['static_cache'] = gr.Checkbox(value=shared.settings['static_cache'], label='Static KV cache') with gr.Column(): shared.gradio['truncation_length'] = gr.Number(precision=0, step=256, value=get_truncation_length(), label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.') diff --git a/settings-template.yaml b/settings-template.yaml index 59c76c35..d5ed47c3 100644 --- a/settings-template.yaml +++ b/settings-template.yaml @@ -22,6 +22,7 @@ ban_eos_token: false add_bos_token: true skip_special_tokens: true stream: true +static_cache: false character: Assistant name1: You custom_system_message: ''