From c0f600c8879fb28c5b84dd74bc6384619bfcf4fa Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 5 Jan 2025 05:45:12 -0800 Subject: [PATCH] Add a --torch-compile flag for transformers --- modules/loaders.py | 3 ++- modules/models.py | 3 +++ modules/shared.py | 1 + modules/ui.py | 3 ++- modules/ui_model_menu.py | 1 + 5 files changed, 9 insertions(+), 2 deletions(-) diff --git a/modules/loaders.py b/modules/loaders.py index e1a41bb1..191126b3 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -9,12 +9,13 @@ loaders_and_params = OrderedDict({ 'Transformers': [ 'cpu_memory', 'gpu_memory', + 'load_in_4bit', 'load_in_8bit', + 'torch_compile', 'bf16', 'cpu', 'disk', 'auto_devices', - 'load_in_4bit', 'use_double_quant', 'quant_type', 'compute_dtype', diff --git a/modules/models.py b/modules/models.py index d906535b..cb1ba218 100644 --- a/modules/models.py +++ b/modules/models.py @@ -254,6 +254,9 @@ def huggingface_loader(model_name): print() model = LoaderClass.from_pretrained(path_to_model, **params) + if shared.args.torch_compile: + model = torch.compile(model) + return model diff --git a/modules/shared.py b/modules/shared.py index f2ae05a6..891c0556 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -104,6 +104,7 @@ group.add_argument('--force-safetensors', action='store_true', help='Set use_saf group.add_argument('--no_use_fast', action='store_true', help='Set use_fast=False while loading the tokenizer (it\'s True by default). Use this if you have any problems related to use_fast.') group.add_argument('--use_flash_attention_2', action='store_true', help='Set use_flash_attention_2=True while loading the model.') group.add_argument('--use_eager_attention', action='store_true', help='Set attn_implementation= eager while loading the model.') +group.add_argument('--torch-compile', action='store_true', help='Compile the model with torch.compile for improved performance.') # bitsandbytes 4-bit group = parser.add_argument_group('bitsandbytes 4-bit') diff --git a/modules/ui.py b/modules/ui.py index 3c75f6ca..30d4163c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -109,12 +109,13 @@ def list_model_elements(): 'disk', 'cpu', 'bf16', + 'load_in_4bit', 'load_in_8bit', + 'torch_compile', 'trust_remote_code', 'no_use_fast', 'use_flash_attention_2', 'use_eager_attention', - 'load_in_4bit', 'compute_dtype', 'quant_type', 'use_double_quant', diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index c4bb8f01..f2814401 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -108,6 +108,7 @@ def create_ui(): shared.gradio['tensorcores'] = gr.Checkbox(label="tensorcores", value=shared.args.tensorcores, info='NVIDIA only: use llama-cpp-python compiled with tensor cores support. This may increase performance on newer cards.') shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit) shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit) + shared.gradio['torch_compile'] = gr.Checkbox(label="torch-compile", value=shared.args.torch_compile, info='Compile the model with torch.compile for improved performance.') shared.gradio['flash_attn'] = gr.Checkbox(label="flash_attn", value=shared.args.flash_attn, info='Use flash-attention.') shared.gradio['use_flash_attention_2'] = gr.Checkbox(label="use_flash_attention_2", value=shared.args.use_flash_attention_2, info='Set use_flash_attention_2=True while loading the model.') shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming_llm", value=shared.args.streaming_llm, info='(experimental) Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')