diff --git a/modules/models.py b/modules/models.py index 23eab64f..63060d43 100644 --- a/modules/models.py +++ b/modules/models.py @@ -105,8 +105,10 @@ def load_model(model_name): params["torch_dtype"] = torch.float32 else: params["device_map"] = 'auto' - if shared.args.load_in_8bit: + if shared.args.load_in_8bit and any((shared.args.auto_devices, shared.args.gpu_memory)): params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True) + elif shared.args.load_in_8bit: + params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True) elif shared.args.bf16: params["torch_dtype"] = torch.bfloat16 else: @@ -119,7 +121,7 @@ def load_model(model_name): max_memory[i] = f'{memory_map[i]}GiB' max_memory['cpu'] = f'{shared.args.cpu_memory or 99}GiB' params['max_memory'] = max_memory - else: + elif shared.args.auto_devices: total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024*1024)) suggestion = round((total_mem-1000) / 1000) * 1000 if total_mem - suggestion < 800: