mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 09:19:23 +01:00
Revert walrus operator for params['max_memory'] (#5878)
This commit is contained in:
parent
c725d97368
commit
f3c9103e04
@ -179,7 +179,7 @@ def huggingface_loader(model_name):
|
|||||||
|
|
||||||
# DeepSpeed ZeRO-3
|
# DeepSpeed ZeRO-3
|
||||||
elif shared.args.deepspeed:
|
elif shared.args.deepspeed:
|
||||||
model = LoaderClass.from_pretrained(path_to_model, torch_dtype=params['torch_dtype'], trust_remote_code=params['trust_remote_code'])
|
model = LoaderClass.from_pretrained(path_to_model, torch_dtype=params['torch_dtype'], trust_remote_code=params.get('trust_remote_code'))
|
||||||
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
|
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
|
||||||
model.module.eval() # Inference
|
model.module.eval() # Inference
|
||||||
logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}')
|
logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}')
|
||||||
@ -194,7 +194,7 @@ def huggingface_loader(model_name):
|
|||||||
params['torch_dtype'] = torch.float32
|
params['torch_dtype'] = torch.float32
|
||||||
else:
|
else:
|
||||||
params['device_map'] = 'auto'
|
params['device_map'] = 'auto'
|
||||||
if x := get_max_memory_dict():
|
if x:= get_max_memory_dict():
|
||||||
params['max_memory'] = x
|
params['max_memory'] = x
|
||||||
|
|
||||||
if shared.args.load_in_4bit:
|
if shared.args.load_in_4bit:
|
||||||
@ -215,15 +215,15 @@ def huggingface_loader(model_name):
|
|||||||
else:
|
else:
|
||||||
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
|
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
|
||||||
|
|
||||||
if params['max_memory'] is not None:
|
if params.get('max_memory') is not None:
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = LoaderClass.from_config(config, trust_remote_code=params['trust_remote_code'])
|
model = LoaderClass.from_config(config, trust_remote_code=params.get('trust_remote_code'))
|
||||||
|
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
params['device_map'] = infer_auto_device_map(
|
params['device_map'] = infer_auto_device_map(
|
||||||
model,
|
model,
|
||||||
dtype=torch.int8,
|
dtype=torch.int8,
|
||||||
max_memory=params['max_memory'],
|
max_memory=params.get('max_memory'),
|
||||||
no_split_module_classes=model._no_split_modules
|
no_split_module_classes=model._no_split_modules
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user