diff --git a/modules/models.py b/modules/models.py index 64cbffe1..c89bcae4 100644 --- a/modules/models.py +++ b/modules/models.py @@ -205,6 +205,7 @@ def huggingface_loader(model_name): 'bnb_4bit_compute_dtype': eval("torch.{}".format(shared.args.compute_dtype)) if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None, 'bnb_4bit_quant_type': shared.args.quant_type, 'bnb_4bit_use_double_quant': shared.args.use_double_quant, + 'llm_int8_enable_fp32_cpu_offload': True } params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params)