mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Clean up the transformers loader
This commit is contained in:
parent
36c38d7561
commit
63de9eb24f
@ -2,6 +2,7 @@ import gc
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@ -117,12 +118,17 @@ def load_tokenizer(model_name, model):
|
||||
def huggingface_loader(model_name):
|
||||
|
||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
|
||||
params = {
|
||||
'low_cpu_mem_usage': True,
|
||||
'trust_remote_code': shared.args.trust_remote_code,
|
||||
'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16
|
||||
}
|
||||
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=params['trust_remote_code'])
|
||||
|
||||
if 'chatglm' in model_name.lower():
|
||||
LoaderClass = AutoModel
|
||||
else:
|
||||
if config.to_dict().get("is_encoder_decoder", False):
|
||||
if config.to_dict().get('is_encoder_decoder', False):
|
||||
LoaderClass = AutoModelForSeq2SeqLM
|
||||
shared.is_seq2seq = True
|
||||
else:
|
||||
@ -130,7 +136,7 @@ def huggingface_loader(model_name):
|
||||
|
||||
# Load the model in simple 16-bit mode by default
|
||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1, shared.args.disable_exllama]):
|
||||
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=shared.args.trust_remote_code)
|
||||
model = LoaderClass.from_pretrained(path_to_model, **params)
|
||||
if torch.backends.mps.is_available():
|
||||
device = torch.device('mps')
|
||||
model = model.to(device)
|
||||
@ -139,28 +145,23 @@ def huggingface_loader(model_name):
|
||||
|
||||
# DeepSpeed ZeRO-3
|
||||
elif shared.args.deepspeed:
|
||||
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
|
||||
model = LoaderClass.from_pretrained(path_to_model, torch_dtype=params['torch_dtype'])
|
||||
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
|
||||
model.module.eval() # Inference
|
||||
logger.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
|
||||
logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}')
|
||||
|
||||
# Custom
|
||||
# Load with quantization and/or offloading
|
||||
else:
|
||||
params = {
|
||||
"low_cpu_mem_usage": True,
|
||||
"trust_remote_code": shared.args.trust_remote_code
|
||||
}
|
||||
|
||||
if not any((shared.args.cpu, torch.cuda.is_available(), torch.backends.mps.is_available())):
|
||||
logger.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.")
|
||||
logger.warning('torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.')
|
||||
shared.args.cpu = True
|
||||
|
||||
if shared.args.cpu:
|
||||
params["torch_dtype"] = torch.float32
|
||||
params['torch_dtype'] = torch.float32
|
||||
else:
|
||||
params["device_map"] = 'auto'
|
||||
params['device_map'] = 'auto'
|
||||
params['max_memory'] = get_max_memory_dict()
|
||||
if shared.args.load_in_4bit:
|
||||
|
||||
# See https://github.com/huggingface/transformers/pull/23479/files
|
||||
# and https://huggingface.co/blog/4bit-transformers-bitsandbytes
|
||||
quantization_config_params = {
|
||||
@ -170,7 +171,7 @@ def huggingface_loader(model_name):
|
||||
'bnb_4bit_use_double_quant': shared.args.use_double_quant,
|
||||
}
|
||||
|
||||
logger.warning("Using the following 4-bit params: " + str(quantization_config_params))
|
||||
logger.info('Using the following 4-bit params: ' + str(quantization_config_params))
|
||||
params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params)
|
||||
|
||||
elif shared.args.load_in_8bit:
|
||||
@ -178,14 +179,21 @@ def huggingface_loader(model_name):
|
||||
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
|
||||
else:
|
||||
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
elif shared.args.bf16:
|
||||
params["torch_dtype"] = torch.bfloat16
|
||||
else:
|
||||
params["torch_dtype"] = torch.float16
|
||||
|
||||
params['max_memory'] = get_max_memory_dict()
|
||||
if params['max_memory'] is not None:
|
||||
with init_empty_weights():
|
||||
model = LoaderClass.from_config(config, trust_remote_code=params['trust_remote_code'])
|
||||
|
||||
model.tie_weights()
|
||||
params['device_map'] = infer_auto_device_map(
|
||||
model,
|
||||
dtype=torch.int8,
|
||||
max_memory=params['max_memory'],
|
||||
no_split_module_classes=model._no_split_modules
|
||||
)
|
||||
|
||||
if shared.args.disk:
|
||||
params["offload_folder"] = shared.args.disk_cache_dir
|
||||
params['offload_folder'] = shared.args.disk_cache_dir
|
||||
|
||||
if shared.args.disable_exllama:
|
||||
try:
|
||||
@ -193,20 +201,9 @@ def huggingface_loader(model_name):
|
||||
params['quantization_config'] = gptq_config
|
||||
logger.info('Loading with ExLlama kernel disabled.')
|
||||
except:
|
||||
exc = traceback.format_exc()
|
||||
logger.error('Failed to disable exllama. Does the config.json for this model contain the necessary quantization info?')
|
||||
|
||||
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
|
||||
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
|
||||
with init_empty_weights():
|
||||
model = LoaderClass.from_config(config, trust_remote_code=shared.args.trust_remote_code)
|
||||
|
||||
model.tie_weights()
|
||||
params['device_map'] = infer_auto_device_map(
|
||||
model,
|
||||
dtype=torch.int8,
|
||||
max_memory=params['max_memory'],
|
||||
no_split_module_classes=model._no_split_modules
|
||||
)
|
||||
print(exc)
|
||||
|
||||
if shared.args.compress_pos_emb > 1:
|
||||
params['rope_scaling'] = {'type': 'linear', 'factor': shared.args.compress_pos_emb}
|
||||
|
Loading…
Reference in New Issue
Block a user