mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-10 20:40:32 +01:00
Make the huggingface loader more readable
This commit is contained in:
parent
03b4067f31
commit
c08d87b78d
@ -160,8 +160,22 @@ def huggingface_loader(model_name):
|
||||
else:
|
||||
LoaderClass = AutoModelForCausalLM
|
||||
|
||||
# Determine if we should use default loading
|
||||
should_use_default_loading = not any([
|
||||
shared.args.cpu,
|
||||
shared.args.load_in_8bit,
|
||||
shared.args.load_in_4bit,
|
||||
shared.args.auto_devices,
|
||||
shared.args.disk,
|
||||
shared.args.deepspeed,
|
||||
shared.args.gpu_memory is not None,
|
||||
shared.args.cpu_memory is not None,
|
||||
shared.args.compress_pos_emb > 1,
|
||||
shared.args.alpha_value > 1,
|
||||
])
|
||||
|
||||
# Load the model without any special settings
|
||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1]):
|
||||
if should_use_default_loading:
|
||||
logger.info("TRANSFORMERS_PARAMS=")
|
||||
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params)
|
||||
print()
|
||||
@ -174,8 +188,20 @@ def huggingface_loader(model_name):
|
||||
|
||||
# DeepSpeed ZeRO-3
|
||||
elif shared.args.deepspeed:
|
||||
model = LoaderClass.from_pretrained(path_to_model, torch_dtype=params['torch_dtype'], trust_remote_code=params.get('trust_remote_code'))
|
||||
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
|
||||
model = LoaderClass.from_pretrained(
|
||||
path_to_model,
|
||||
torch_dtype=params['torch_dtype'],
|
||||
trust_remote_code=params.get('trust_remote_code')
|
||||
)
|
||||
|
||||
model = deepspeed.initialize(
|
||||
model=model,
|
||||
config_params=ds_config,
|
||||
model_parameters=None,
|
||||
optimizer=None,
|
||||
lr_scheduler=None
|
||||
)[0]
|
||||
|
||||
model.module.eval() # Inference
|
||||
logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}')
|
||||
|
||||
@ -197,16 +223,15 @@ def huggingface_loader(model_name):
|
||||
# and https://huggingface.co/blog/4bit-transformers-bitsandbytes
|
||||
quantization_config_params = {
|
||||
'load_in_4bit': True,
|
||||
'bnb_4bit_compute_dtype': eval("torch.{}".format(shared.args.compute_dtype)) if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None,
|
||||
'bnb_4bit_compute_dtype': eval(f"torch.{shared.args.compute_dtype}") if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None,
|
||||
'bnb_4bit_quant_type': shared.args.quant_type,
|
||||
'bnb_4bit_use_double_quant': shared.args.use_double_quant,
|
||||
'llm_int8_enable_fp32_cpu_offload': True
|
||||
}
|
||||
|
||||
params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params)
|
||||
|
||||
elif shared.args.load_in_8bit:
|
||||
if any((shared.args.auto_devices, shared.args.gpu_memory)):
|
||||
if shared.args.auto_devices or shared.args.gpu_memory:
|
||||
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
|
||||
else:
|
||||
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user