Revert walrus operator for params['max_memory'] (#5878)

This commit is contained in:
Colin 2024-04-24 00:09:14 -04:00 committed by GitHub
parent c725d97368
commit f3c9103e04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -179,7 +179,7 @@ def huggingface_loader(model_name):
# DeepSpeed ZeRO-3 # DeepSpeed ZeRO-3
elif shared.args.deepspeed: elif shared.args.deepspeed:
model = LoaderClass.from_pretrained(path_to_model, torch_dtype=params['torch_dtype'], trust_remote_code=params['trust_remote_code']) model = LoaderClass.from_pretrained(path_to_model, torch_dtype=params['torch_dtype'], trust_remote_code=params.get('trust_remote_code'))
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
model.module.eval() # Inference model.module.eval() # Inference
logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}') logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}')
@ -194,7 +194,7 @@ def huggingface_loader(model_name):
params['torch_dtype'] = torch.float32 params['torch_dtype'] = torch.float32
else: else:
params['device_map'] = 'auto' params['device_map'] = 'auto'
if x := get_max_memory_dict(): if x:= get_max_memory_dict():
params['max_memory'] = x params['max_memory'] = x
if shared.args.load_in_4bit: if shared.args.load_in_4bit:
@ -215,15 +215,15 @@ def huggingface_loader(model_name):
else: else:
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True) params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
if params['max_memory'] is not None: if params.get('max_memory') is not None:
with init_empty_weights(): with init_empty_weights():
model = LoaderClass.from_config(config, trust_remote_code=params['trust_remote_code']) model = LoaderClass.from_config(config, trust_remote_code=params.get('trust_remote_code'))
model.tie_weights() model.tie_weights()
params['device_map'] = infer_auto_device_map( params['device_map'] = infer_auto_device_map(
model, model,
dtype=torch.int8, dtype=torch.int8,
max_memory=params['max_memory'], max_memory=params.get('max_memory'),
no_split_module_classes=model._no_split_modules no_split_module_classes=model._no_split_modules
) )