Add HQQ quant loader (#4888)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
Water 2023-12-18 19:23:16 -05:00 committed by GitHub
parent 64a57d9dc2
commit 674be9a09a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 79 additions and 0 deletions

View File

@ -305,6 +305,12 @@ List of command-line flags
|-------------|-------------| |-------------|-------------|
| `--model_type MODEL_TYPE` | Model type of pre-quantized model. Currently gpt2, gptj, gptneox, falcon, llama, mpt, starcoder (gptbigcode), dollyv2, and replit are supported. | | `--model_type MODEL_TYPE` | Model type of pre-quantized model. Currently gpt2, gptj, gptneox, falcon, llama, mpt, starcoder (gptbigcode), dollyv2, and replit are supported. |
#### HQQ
| Flag | Description |
|-------------|-------------|
| `--hqq-backend` | Backend for the HQQ loader. Valid options: PYTORCH, PYTORCH_COMPILE, ATEN. |
#### DeepSpeed #### DeepSpeed
| Flag | Description | | Flag | Description |

View File

@ -155,6 +155,11 @@ loaders_and_params = OrderedDict({
'trust_remote_code', 'trust_remote_code',
'no_use_fast', 'no_use_fast',
'no_flash_attn', 'no_flash_attn',
],
'HQQ': [
'hqq_backend',
'trust_remote_code',
'no_use_fast',
] ]
}) })
@ -503,6 +508,43 @@ loaders_samplers = {
'skip_special_tokens', 'skip_special_tokens',
'auto_max_new_tokens', 'auto_max_new_tokens',
}, },
'HQQ': {
'temperature',
'temperature_last',
'top_p',
'min_p',
'top_k',
'typical_p',
'epsilon_cutoff',
'eta_cutoff',
'tfs',
'top_a',
'repetition_penalty',
'presence_penalty',
'frequency_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'min_length',
'seed',
'do_sample',
'penalty_alpha',
'num_beams',
'length_penalty',
'early_stopping',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'grammar_file_row',
'grammar_string',
'guidance_scale',
'negative_prompt',
'ban_eos_token',
'custom_token_bans',
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',
},
} }
loaders_model_types = { loaders_model_types = {

View File

@ -73,6 +73,7 @@ def load_model(model_name, loader=None):
'ctransformers': ctransformers_loader, 'ctransformers': ctransformers_loader,
'AutoAWQ': AutoAWQ_loader, 'AutoAWQ': AutoAWQ_loader,
'QuIP#': QuipSharp_loader, 'QuIP#': QuipSharp_loader,
'HQQ': HQQ_loader,
} }
metadata = get_model_metadata(model_name) metadata = get_model_metadata(model_name)
@ -411,6 +412,18 @@ def ExLlamav2_HF_loader(model_name):
return Exllamav2HF.from_pretrained(model_name) return Exllamav2HF.from_pretrained(model_name)
def HQQ_loader(model_name):
from hqq.engine.hf import HQQModelForCausalLM
from hqq.core.quantize import HQQLinear, HQQBackend
logger.info(f"Loading HQQ model with backend: {shared.args.hqq_backend}")
model_dir = Path(f'{shared.args.model_dir}/{model_name}')
model = HQQModelForCausalLM.from_quantized(str(model_dir))
HQQLinear.set_backend(getattr(HQQBackend, shared.args.hqq_backend))
return model
def RWKV_loader(model_name): def RWKV_loader(model_name):
''' '''
This loader is not currently maintained as RWKV can now be loaded This loader is not currently maintained as RWKV can now be loaded

View File

@ -163,6 +163,8 @@ def infer_loader(model_name, model_settings):
loader = 'RWKV' loader = 'RWKV'
elif re.match(r'.*exl2', model_name.lower()): elif re.match(r'.*exl2', model_name.lower()):
loader = 'ExLlamav2_HF' loader = 'ExLlamav2_HF'
elif re.match(r'.*-hqq', model_name.lower()):
return 'HQQ'
else: else:
loader = 'Transformers' loader = 'Transformers'

View File

@ -144,6 +144,9 @@ parser.add_argument('--pre_layer', type=int, nargs='+', help='The number of laye
parser.add_argument('--checkpoint', type=str, help='The path to the quantized checkpoint file. If not specified, it will be automatically detected.') parser.add_argument('--checkpoint', type=str, help='The path to the quantized checkpoint file. If not specified, it will be automatically detected.')
parser.add_argument('--monkey-patch', action='store_true', help='Apply the monkey patch for using LoRAs with quantized models.') parser.add_argument('--monkey-patch', action='store_true', help='Apply the monkey patch for using LoRAs with quantized models.')
# HQQ
parser.add_argument('--hqq-backend', type=str, default='PYTORCH_COMPILE', help='Backend for the HQQ loader. Valid options: PYTORCH, PYTORCH_COMPILE, ATEN.')
# DeepSpeed # DeepSpeed
parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.') parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.') parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.')
@ -246,6 +249,8 @@ def fix_loader_name(name):
return 'AutoAWQ' return 'AutoAWQ'
elif name in ['quip#', 'quip-sharp', 'quipsharp', 'quip_sharp']: elif name in ['quip#', 'quip-sharp', 'quipsharp', 'quip_sharp']:
return 'QuIP#' return 'QuIP#'
elif name in ['hqq']:
return 'HQQ'
def add_extension(name, last=False): def add_extension(name, last=False):

View File

@ -91,6 +91,7 @@ def list_model_elements():
'rope_freq_base', 'rope_freq_base',
'numa', 'numa',
'logits_all', 'logits_all',
'hqq_backend',
] ]
if is_torch_xpu_available(): if is_torch_xpu_available():
for i in range(torch.xpu.device_count()): for i in range(torch.xpu.device_count()):

View File

@ -84,6 +84,7 @@ def create_ui():
shared.gradio['transformers_info'] = gr.Markdown('load-in-4bit params:') shared.gradio['transformers_info'] = gr.Markdown('load-in-4bit params:')
shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype) shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype)
shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type) shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type)
shared.gradio['hqq_backend'] = gr.Dropdown(label="hqq_backend", choices=["PYTORCH", "PYTORCH_COMPILE", "ATEN"], value=shared.args.hqq_backend)
shared.gradio['n_gpu_layers'] = gr.Slider(label="n-gpu-layers", minimum=0, maximum=128, value=shared.args.n_gpu_layers) shared.gradio['n_gpu_layers'] = gr.Slider(label="n-gpu-layers", minimum=0, maximum=128, value=shared.args.n_gpu_layers)
shared.gradio['n_ctx'] = gr.Slider(minimum=0, maximum=shared.settings['truncation_length_max'], step=256, label="n_ctx", value=shared.args.n_ctx, info='Context length. Try lowering this if you run out of memory while loading the model.') shared.gradio['n_ctx'] = gr.Slider(minimum=0, maximum=shared.settings['truncation_length_max'], step=256, label="n_ctx", value=shared.args.n_ctx, info='Context length. Try lowering this if you run out of memory while loading the model.')

View File

@ -4,6 +4,7 @@ datasets
einops einops
exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64" exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64"
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.15.* optimum==1.15.*

View File

@ -4,6 +4,7 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.15.* optimum==1.15.*

View File

@ -4,6 +4,7 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.15.* optimum==1.15.*

View File

@ -4,6 +4,7 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.15.* optimum==1.15.*

View File

@ -4,6 +4,7 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.15.* optimum==1.15.*

View File

@ -4,6 +4,7 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.15.* optimum==1.15.*

View File

@ -4,6 +4,7 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.15.* optimum==1.15.*

View File

@ -4,6 +4,7 @@ datasets
einops einops
exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64" exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64"
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.15.* optimum==1.15.*

View File

@ -4,6 +4,7 @@ datasets
einops einops
exllamav2==0.0.11 exllamav2==0.0.11
gradio==3.50.* gradio==3.50.*
hqq==0.1.1
markdown markdown
numpy==1.24.* numpy==1.24.*
optimum==1.15.* optimum==1.15.*