diff --git a/modules/models.py b/modules/models.py index 9c58b279..f551b828 100644 --- a/modules/models.py +++ b/modules/models.py @@ -160,8 +160,22 @@ def huggingface_loader(model_name): else: LoaderClass = AutoModelForCausalLM + # Determine if we should use default loading + should_use_default_loading = not any([ + shared.args.cpu, + shared.args.load_in_8bit, + shared.args.load_in_4bit, + shared.args.auto_devices, + shared.args.disk, + shared.args.deepspeed, + shared.args.gpu_memory is not None, + shared.args.cpu_memory is not None, + shared.args.compress_pos_emb > 1, + shared.args.alpha_value > 1, + ]) + # Load the model without any special settings - if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1]): + if should_use_default_loading: logger.info("TRANSFORMERS_PARAMS=") pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params) print() @@ -174,8 +188,20 @@ def huggingface_loader(model_name): # DeepSpeed ZeRO-3 elif shared.args.deepspeed: - model = LoaderClass.from_pretrained(path_to_model, torch_dtype=params['torch_dtype'], trust_remote_code=params.get('trust_remote_code')) - model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] + model = LoaderClass.from_pretrained( + path_to_model, + torch_dtype=params['torch_dtype'], + trust_remote_code=params.get('trust_remote_code') + ) + + model = deepspeed.initialize( + model=model, + config_params=ds_config, + model_parameters=None, + optimizer=None, + lr_scheduler=None + )[0] + model.module.eval() # Inference logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}') @@ -197,16 +223,15 @@ def huggingface_loader(model_name): # and https://huggingface.co/blog/4bit-transformers-bitsandbytes quantization_config_params = { 'load_in_4bit': True, - 'bnb_4bit_compute_dtype': eval("torch.{}".format(shared.args.compute_dtype)) if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None, + 'bnb_4bit_compute_dtype': eval(f"torch.{shared.args.compute_dtype}") if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None, 'bnb_4bit_quant_type': shared.args.quant_type, 'bnb_4bit_use_double_quant': shared.args.use_double_quant, 'llm_int8_enable_fp32_cpu_offload': True } - params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params) elif shared.args.load_in_8bit: - if any((shared.args.auto_devices, shared.args.gpu_memory)): + if shared.args.auto_devices or shared.args.gpu_memory: params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True) else: params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)