mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-10 20:40:32 +01:00
Make the huggingface loader more readable
This commit is contained in:
parent
03b4067f31
commit
c08d87b78d
@ -160,8 +160,22 @@ def huggingface_loader(model_name):
|
|||||||
else:
|
else:
|
||||||
LoaderClass = AutoModelForCausalLM
|
LoaderClass = AutoModelForCausalLM
|
||||||
|
|
||||||
|
# Determine if we should use default loading
|
||||||
|
should_use_default_loading = not any([
|
||||||
|
shared.args.cpu,
|
||||||
|
shared.args.load_in_8bit,
|
||||||
|
shared.args.load_in_4bit,
|
||||||
|
shared.args.auto_devices,
|
||||||
|
shared.args.disk,
|
||||||
|
shared.args.deepspeed,
|
||||||
|
shared.args.gpu_memory is not None,
|
||||||
|
shared.args.cpu_memory is not None,
|
||||||
|
shared.args.compress_pos_emb > 1,
|
||||||
|
shared.args.alpha_value > 1,
|
||||||
|
])
|
||||||
|
|
||||||
# Load the model without any special settings
|
# Load the model without any special settings
|
||||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1]):
|
if should_use_default_loading:
|
||||||
logger.info("TRANSFORMERS_PARAMS=")
|
logger.info("TRANSFORMERS_PARAMS=")
|
||||||
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params)
|
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params)
|
||||||
print()
|
print()
|
||||||
@ -174,8 +188,20 @@ def huggingface_loader(model_name):
|
|||||||
|
|
||||||
# DeepSpeed ZeRO-3
|
# DeepSpeed ZeRO-3
|
||||||
elif shared.args.deepspeed:
|
elif shared.args.deepspeed:
|
||||||
model = LoaderClass.from_pretrained(path_to_model, torch_dtype=params['torch_dtype'], trust_remote_code=params.get('trust_remote_code'))
|
model = LoaderClass.from_pretrained(
|
||||||
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
|
path_to_model,
|
||||||
|
torch_dtype=params['torch_dtype'],
|
||||||
|
trust_remote_code=params.get('trust_remote_code')
|
||||||
|
)
|
||||||
|
|
||||||
|
model = deepspeed.initialize(
|
||||||
|
model=model,
|
||||||
|
config_params=ds_config,
|
||||||
|
model_parameters=None,
|
||||||
|
optimizer=None,
|
||||||
|
lr_scheduler=None
|
||||||
|
)[0]
|
||||||
|
|
||||||
model.module.eval() # Inference
|
model.module.eval() # Inference
|
||||||
logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}')
|
logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}')
|
||||||
|
|
||||||
@ -197,16 +223,15 @@ def huggingface_loader(model_name):
|
|||||||
# and https://huggingface.co/blog/4bit-transformers-bitsandbytes
|
# and https://huggingface.co/blog/4bit-transformers-bitsandbytes
|
||||||
quantization_config_params = {
|
quantization_config_params = {
|
||||||
'load_in_4bit': True,
|
'load_in_4bit': True,
|
||||||
'bnb_4bit_compute_dtype': eval("torch.{}".format(shared.args.compute_dtype)) if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None,
|
'bnb_4bit_compute_dtype': eval(f"torch.{shared.args.compute_dtype}") if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None,
|
||||||
'bnb_4bit_quant_type': shared.args.quant_type,
|
'bnb_4bit_quant_type': shared.args.quant_type,
|
||||||
'bnb_4bit_use_double_quant': shared.args.use_double_quant,
|
'bnb_4bit_use_double_quant': shared.args.use_double_quant,
|
||||||
'llm_int8_enable_fp32_cpu_offload': True
|
'llm_int8_enable_fp32_cpu_offload': True
|
||||||
}
|
}
|
||||||
|
|
||||||
params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params)
|
params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params)
|
||||||
|
|
||||||
elif shared.args.load_in_8bit:
|
elif shared.args.load_in_8bit:
|
||||||
if any((shared.args.auto_devices, shared.args.gpu_memory)):
|
if shared.args.auto_devices or shared.args.gpu_memory:
|
||||||
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
|
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
|
||||||
else:
|
else:
|
||||||
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
|
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user