Make the huggingface loader more readable

This commit is contained in:
oobabooga 2025-01-09 12:23:38 -08:00
parent 03b4067f31
commit c08d87b78d

View File

@ -160,8 +160,22 @@ def huggingface_loader(model_name):
else:
LoaderClass = AutoModelForCausalLM
# Determine if we should use default loading
should_use_default_loading = not any([
shared.args.cpu,
shared.args.load_in_8bit,
shared.args.load_in_4bit,
shared.args.auto_devices,
shared.args.disk,
shared.args.deepspeed,
shared.args.gpu_memory is not None,
shared.args.cpu_memory is not None,
shared.args.compress_pos_emb > 1,
shared.args.alpha_value > 1,
])
# Load the model without any special settings
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1]):
if should_use_default_loading:
logger.info("TRANSFORMERS_PARAMS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params)
print()
@ -174,8 +188,20 @@ def huggingface_loader(model_name):
# DeepSpeed ZeRO-3
elif shared.args.deepspeed:
model = LoaderClass.from_pretrained(path_to_model, torch_dtype=params['torch_dtype'], trust_remote_code=params.get('trust_remote_code'))
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
model = LoaderClass.from_pretrained(
path_to_model,
torch_dtype=params['torch_dtype'],
trust_remote_code=params.get('trust_remote_code')
)
model = deepspeed.initialize(
model=model,
config_params=ds_config,
model_parameters=None,
optimizer=None,
lr_scheduler=None
)[0]
model.module.eval() # Inference
logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}')
@ -197,16 +223,15 @@ def huggingface_loader(model_name):
# and https://huggingface.co/blog/4bit-transformers-bitsandbytes
quantization_config_params = {
'load_in_4bit': True,
'bnb_4bit_compute_dtype': eval("torch.{}".format(shared.args.compute_dtype)) if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None,
'bnb_4bit_compute_dtype': eval(f"torch.{shared.args.compute_dtype}") if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None,
'bnb_4bit_quant_type': shared.args.quant_type,
'bnb_4bit_use_double_quant': shared.args.use_double_quant,
'llm_int8_enable_fp32_cpu_offload': True
}
params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params)
elif shared.args.load_in_8bit:
if any((shared.args.auto_devices, shared.args.gpu_memory)):
if shared.args.auto_devices or shared.args.gpu_memory:
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
else:
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)