Make the huggingface loader more readable

This commit is contained in:
oobabooga 2025-01-09 12:23:38 -08:00
parent 03b4067f31
commit c08d87b78d

View File

@ -160,8 +160,22 @@ def huggingface_loader(model_name):
else: else:
LoaderClass = AutoModelForCausalLM LoaderClass = AutoModelForCausalLM
# Determine if we should use default loading
should_use_default_loading = not any([
shared.args.cpu,
shared.args.load_in_8bit,
shared.args.load_in_4bit,
shared.args.auto_devices,
shared.args.disk,
shared.args.deepspeed,
shared.args.gpu_memory is not None,
shared.args.cpu_memory is not None,
shared.args.compress_pos_emb > 1,
shared.args.alpha_value > 1,
])
# Load the model without any special settings # Load the model without any special settings
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1]): if should_use_default_loading:
logger.info("TRANSFORMERS_PARAMS=") logger.info("TRANSFORMERS_PARAMS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params) pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params)
print() print()
@ -174,8 +188,20 @@ def huggingface_loader(model_name):
# DeepSpeed ZeRO-3 # DeepSpeed ZeRO-3
elif shared.args.deepspeed: elif shared.args.deepspeed:
model = LoaderClass.from_pretrained(path_to_model, torch_dtype=params['torch_dtype'], trust_remote_code=params.get('trust_remote_code')) model = LoaderClass.from_pretrained(
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] path_to_model,
torch_dtype=params['torch_dtype'],
trust_remote_code=params.get('trust_remote_code')
)
model = deepspeed.initialize(
model=model,
config_params=ds_config,
model_parameters=None,
optimizer=None,
lr_scheduler=None
)[0]
model.module.eval() # Inference model.module.eval() # Inference
logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}') logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}')
@ -197,16 +223,15 @@ def huggingface_loader(model_name):
# and https://huggingface.co/blog/4bit-transformers-bitsandbytes # and https://huggingface.co/blog/4bit-transformers-bitsandbytes
quantization_config_params = { quantization_config_params = {
'load_in_4bit': True, 'load_in_4bit': True,
'bnb_4bit_compute_dtype': eval("torch.{}".format(shared.args.compute_dtype)) if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None, 'bnb_4bit_compute_dtype': eval(f"torch.{shared.args.compute_dtype}") if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None,
'bnb_4bit_quant_type': shared.args.quant_type, 'bnb_4bit_quant_type': shared.args.quant_type,
'bnb_4bit_use_double_quant': shared.args.use_double_quant, 'bnb_4bit_use_double_quant': shared.args.use_double_quant,
'llm_int8_enable_fp32_cpu_offload': True 'llm_int8_enable_fp32_cpu_offload': True
} }
params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params) params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params)
elif shared.args.load_in_8bit: elif shared.args.load_in_8bit:
if any((shared.args.auto_devices, shared.args.gpu_memory)): if shared.args.auto_devices or shared.args.gpu_memory:
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True) params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
else: else:
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True) params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)