From f3c9103e04f2fca128b7ec5b78113dcfa9429807 Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 24 Apr 2024 00:09:14 -0400 Subject: [PATCH] Revert walrus operator for params['max_memory'] (#5878) --- modules/models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/models.py b/modules/models.py index 20a65764..cf4e4634 100644 --- a/modules/models.py +++ b/modules/models.py @@ -179,7 +179,7 @@ def huggingface_loader(model_name): # DeepSpeed ZeRO-3 elif shared.args.deepspeed: - model = LoaderClass.from_pretrained(path_to_model, torch_dtype=params['torch_dtype'], trust_remote_code=params['trust_remote_code']) + model = LoaderClass.from_pretrained(path_to_model, torch_dtype=params['torch_dtype'], trust_remote_code=params.get('trust_remote_code')) model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] model.module.eval() # Inference logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}') @@ -194,7 +194,7 @@ def huggingface_loader(model_name): params['torch_dtype'] = torch.float32 else: params['device_map'] = 'auto' - if x := get_max_memory_dict(): + if x:= get_max_memory_dict(): params['max_memory'] = x if shared.args.load_in_4bit: @@ -215,15 +215,15 @@ def huggingface_loader(model_name): else: params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True) - if params['max_memory'] is not None: + if params.get('max_memory') is not None: with init_empty_weights(): - model = LoaderClass.from_config(config, trust_remote_code=params['trust_remote_code']) + model = LoaderClass.from_config(config, trust_remote_code=params.get('trust_remote_code')) model.tie_weights() params['device_map'] = infer_auto_device_map( model, dtype=torch.int8, - max_memory=params['max_memory'], + max_memory=params.get('max_memory'), no_split_module_classes=model._no_split_modules )