mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 09:19:23 +01:00
Revert walrus operator for params['max_memory'] (#5878)
This commit is contained in:
parent
c725d97368
commit
f3c9103e04
@ -179,7 +179,7 @@ def huggingface_loader(model_name):
|
||||
|
||||
# DeepSpeed ZeRO-3
|
||||
elif shared.args.deepspeed:
|
||||
model = LoaderClass.from_pretrained(path_to_model, torch_dtype=params['torch_dtype'], trust_remote_code=params['trust_remote_code'])
|
||||
model = LoaderClass.from_pretrained(path_to_model, torch_dtype=params['torch_dtype'], trust_remote_code=params.get('trust_remote_code'))
|
||||
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
|
||||
model.module.eval() # Inference
|
||||
logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}')
|
||||
@ -194,7 +194,7 @@ def huggingface_loader(model_name):
|
||||
params['torch_dtype'] = torch.float32
|
||||
else:
|
||||
params['device_map'] = 'auto'
|
||||
if x := get_max_memory_dict():
|
||||
if x:= get_max_memory_dict():
|
||||
params['max_memory'] = x
|
||||
|
||||
if shared.args.load_in_4bit:
|
||||
@ -215,15 +215,15 @@ def huggingface_loader(model_name):
|
||||
else:
|
||||
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
if params['max_memory'] is not None:
|
||||
if params.get('max_memory') is not None:
|
||||
with init_empty_weights():
|
||||
model = LoaderClass.from_config(config, trust_remote_code=params['trust_remote_code'])
|
||||
model = LoaderClass.from_config(config, trust_remote_code=params.get('trust_remote_code'))
|
||||
|
||||
model.tie_weights()
|
||||
params['device_map'] = infer_auto_device_map(
|
||||
model,
|
||||
dtype=torch.int8,
|
||||
max_memory=params['max_memory'],
|
||||
max_memory=params.get('max_memory'),
|
||||
no_split_module_classes=model._no_split_modules
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user