diff --git a/README.md b/README.md index ad8087ee..d75121ea 100644 --- a/README.md +++ b/README.md @@ -305,6 +305,12 @@ List of command-line flags |-------------|-------------| | `--model_type MODEL_TYPE` | Model type of pre-quantized model. Currently gpt2, gptj, gptneox, falcon, llama, mpt, starcoder (gptbigcode), dollyv2, and replit are supported. | +#### HQQ + +| Flag | Description | +|-------------|-------------| +| `--hqq-backend` | Backend for the HQQ loader. Valid options: PYTORCH, PYTORCH_COMPILE, ATEN. | + #### DeepSpeed | Flag | Description | diff --git a/modules/loaders.py b/modules/loaders.py index 9f1c70d1..a532830b 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -155,6 +155,11 @@ loaders_and_params = OrderedDict({ 'trust_remote_code', 'no_use_fast', 'no_flash_attn', + ], + 'HQQ': [ + 'hqq_backend', + 'trust_remote_code', + 'no_use_fast', ] }) @@ -503,6 +508,43 @@ loaders_samplers = { 'skip_special_tokens', 'auto_max_new_tokens', }, + 'HQQ': { + 'temperature', + 'temperature_last', + 'top_p', + 'min_p', + 'top_k', + 'typical_p', + 'epsilon_cutoff', + 'eta_cutoff', + 'tfs', + 'top_a', + 'repetition_penalty', + 'presence_penalty', + 'frequency_penalty', + 'repetition_penalty_range', + 'encoder_repetition_penalty', + 'no_repeat_ngram_size', + 'min_length', + 'seed', + 'do_sample', + 'penalty_alpha', + 'num_beams', + 'length_penalty', + 'early_stopping', + 'mirostat_mode', + 'mirostat_tau', + 'mirostat_eta', + 'grammar_file_row', + 'grammar_string', + 'guidance_scale', + 'negative_prompt', + 'ban_eos_token', + 'custom_token_bans', + 'add_bos_token', + 'skip_special_tokens', + 'auto_max_new_tokens', + }, } loaders_model_types = { diff --git a/modules/models.py b/modules/models.py index 49e5f818..5a23f743 100644 --- a/modules/models.py +++ b/modules/models.py @@ -73,6 +73,7 @@ def load_model(model_name, loader=None): 'ctransformers': ctransformers_loader, 'AutoAWQ': AutoAWQ_loader, 'QuIP#': QuipSharp_loader, + 'HQQ': HQQ_loader, } metadata = get_model_metadata(model_name) @@ -411,6 +412,18 @@ def ExLlamav2_HF_loader(model_name): return Exllamav2HF.from_pretrained(model_name) +def HQQ_loader(model_name): + from hqq.engine.hf import HQQModelForCausalLM + from hqq.core.quantize import HQQLinear, HQQBackend + + logger.info(f"Loading HQQ model with backend: {shared.args.hqq_backend}") + + model_dir = Path(f'{shared.args.model_dir}/{model_name}') + model = HQQModelForCausalLM.from_quantized(str(model_dir)) + HQQLinear.set_backend(getattr(HQQBackend, shared.args.hqq_backend)) + return model + + def RWKV_loader(model_name): ''' This loader is not currently maintained as RWKV can now be loaded diff --git a/modules/models_settings.py b/modules/models_settings.py index 156c05d9..4e1fb1ad 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -163,6 +163,8 @@ def infer_loader(model_name, model_settings): loader = 'RWKV' elif re.match(r'.*exl2', model_name.lower()): loader = 'ExLlamav2_HF' + elif re.match(r'.*-hqq', model_name.lower()): + return 'HQQ' else: loader = 'Transformers' diff --git a/modules/shared.py b/modules/shared.py index edd74af1..5afcaebf 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -144,6 +144,9 @@ parser.add_argument('--pre_layer', type=int, nargs='+', help='The number of laye parser.add_argument('--checkpoint', type=str, help='The path to the quantized checkpoint file. If not specified, it will be automatically detected.') parser.add_argument('--monkey-patch', action='store_true', help='Apply the monkey patch for using LoRAs with quantized models.') +# HQQ +parser.add_argument('--hqq-backend', type=str, default='PYTORCH_COMPILE', help='Backend for the HQQ loader. Valid options: PYTORCH, PYTORCH_COMPILE, ATEN.') + # DeepSpeed parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.') parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.') @@ -246,6 +249,8 @@ def fix_loader_name(name): return 'AutoAWQ' elif name in ['quip#', 'quip-sharp', 'quipsharp', 'quip_sharp']: return 'QuIP#' + elif name in ['hqq']: + return 'HQQ' def add_extension(name, last=False): diff --git a/modules/ui.py b/modules/ui.py index 285e2fc3..aa735d24 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -91,6 +91,7 @@ def list_model_elements(): 'rope_freq_base', 'numa', 'logits_all', + 'hqq_backend', ] if is_torch_xpu_available(): for i in range(torch.xpu.device_count()): diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 7f81ca2d..ae50f697 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -84,6 +84,7 @@ def create_ui(): shared.gradio['transformers_info'] = gr.Markdown('load-in-4bit params:') shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype) shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type) + shared.gradio['hqq_backend'] = gr.Dropdown(label="hqq_backend", choices=["PYTORCH", "PYTORCH_COMPILE", "ATEN"], value=shared.args.hqq_backend) shared.gradio['n_gpu_layers'] = gr.Slider(label="n-gpu-layers", minimum=0, maximum=128, value=shared.args.n_gpu_layers) shared.gradio['n_ctx'] = gr.Slider(minimum=0, maximum=shared.settings['truncation_length_max'], step=256, label="n_ctx", value=shared.args.n_ctx, info='Context length. Try lowering this if you run out of memory while loading the model.') diff --git a/requirements.txt b/requirements.txt index 827e7654..12753783 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ datasets einops exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64" gradio==3.50.* +hqq==0.1.1 markdown numpy==1.24.* optimum==1.15.* diff --git a/requirements_amd.txt b/requirements_amd.txt index bd8ccbd6..62c9e889 100644 --- a/requirements_amd.txt +++ b/requirements_amd.txt @@ -4,6 +4,7 @@ datasets einops exllamav2==0.0.11 gradio==3.50.* +hqq==0.1.1 markdown numpy==1.24.* optimum==1.15.* diff --git a/requirements_amd_noavx2.txt b/requirements_amd_noavx2.txt index d7e51706..1d17ca68 100644 --- a/requirements_amd_noavx2.txt +++ b/requirements_amd_noavx2.txt @@ -4,6 +4,7 @@ datasets einops exllamav2==0.0.11 gradio==3.50.* +hqq==0.1.1 markdown numpy==1.24.* optimum==1.15.* diff --git a/requirements_apple_intel.txt b/requirements_apple_intel.txt index f0ed2341..55fc0d2c 100644 --- a/requirements_apple_intel.txt +++ b/requirements_apple_intel.txt @@ -4,6 +4,7 @@ datasets einops exllamav2==0.0.11 gradio==3.50.* +hqq==0.1.1 markdown numpy==1.24.* optimum==1.15.* diff --git a/requirements_apple_silicon.txt b/requirements_apple_silicon.txt index 201a55a8..a161eb30 100644 --- a/requirements_apple_silicon.txt +++ b/requirements_apple_silicon.txt @@ -4,6 +4,7 @@ datasets einops exllamav2==0.0.11 gradio==3.50.* +hqq==0.1.1 markdown numpy==1.24.* optimum==1.15.* diff --git a/requirements_cpu_only.txt b/requirements_cpu_only.txt index 7bd9da9e..7e71bc38 100644 --- a/requirements_cpu_only.txt +++ b/requirements_cpu_only.txt @@ -4,6 +4,7 @@ datasets einops exllamav2==0.0.11 gradio==3.50.* +hqq==0.1.1 markdown numpy==1.24.* optimum==1.15.* diff --git a/requirements_cpu_only_noavx2.txt b/requirements_cpu_only_noavx2.txt index d9b73ef9..6f38369f 100644 --- a/requirements_cpu_only_noavx2.txt +++ b/requirements_cpu_only_noavx2.txt @@ -4,6 +4,7 @@ datasets einops exllamav2==0.0.11 gradio==3.50.* +hqq==0.1.1 markdown numpy==1.24.* optimum==1.15.* diff --git a/requirements_noavx2.txt b/requirements_noavx2.txt index a193967d..f705d92c 100644 --- a/requirements_noavx2.txt +++ b/requirements_noavx2.txt @@ -4,6 +4,7 @@ datasets einops exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64" gradio==3.50.* +hqq==0.1.1 markdown numpy==1.24.* optimum==1.15.* diff --git a/requirements_nowheels.txt b/requirements_nowheels.txt index 4c1161f9..1270bf50 100644 --- a/requirements_nowheels.txt +++ b/requirements_nowheels.txt @@ -4,6 +4,7 @@ datasets einops exllamav2==0.0.11 gradio==3.50.* +hqq==0.1.1 markdown numpy==1.24.* optimum==1.15.*