Clean up the transformers loader

This commit is contained in:
oobabooga 2023-09-24 20:23:05 -07:00
parent 36c38d7561
commit 63de9eb24f

View File

@ -2,6 +2,7 @@ import gc
import os import os
import re import re
import time import time
import traceback
from pathlib import Path from pathlib import Path
import torch import torch
@ -117,12 +118,17 @@ def load_tokenizer(model_name, model):
def huggingface_loader(model_name): def huggingface_loader(model_name):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}') path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code) params = {
'low_cpu_mem_usage': True,
'trust_remote_code': shared.args.trust_remote_code,
'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16
}
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=params['trust_remote_code'])
if 'chatglm' in model_name.lower(): if 'chatglm' in model_name.lower():
LoaderClass = AutoModel LoaderClass = AutoModel
else: else:
if config.to_dict().get("is_encoder_decoder", False): if config.to_dict().get('is_encoder_decoder', False):
LoaderClass = AutoModelForSeq2SeqLM LoaderClass = AutoModelForSeq2SeqLM
shared.is_seq2seq = True shared.is_seq2seq = True
else: else:
@ -130,7 +136,7 @@ def huggingface_loader(model_name):
# Load the model in simple 16-bit mode by default # Load the model in simple 16-bit mode by default
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1, shared.args.disable_exllama]): if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1, shared.args.disable_exllama]):
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=shared.args.trust_remote_code) model = LoaderClass.from_pretrained(path_to_model, **params)
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
device = torch.device('mps') device = torch.device('mps')
model = model.to(device) model = model.to(device)
@ -139,28 +145,23 @@ def huggingface_loader(model_name):
# DeepSpeed ZeRO-3 # DeepSpeed ZeRO-3
elif shared.args.deepspeed: elif shared.args.deepspeed:
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) model = LoaderClass.from_pretrained(path_to_model, torch_dtype=params['torch_dtype'])
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
model.module.eval() # Inference model.module.eval() # Inference
logger.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}')
# Custom # Load with quantization and/or offloading
else: else:
params = {
"low_cpu_mem_usage": True,
"trust_remote_code": shared.args.trust_remote_code
}
if not any((shared.args.cpu, torch.cuda.is_available(), torch.backends.mps.is_available())): if not any((shared.args.cpu, torch.cuda.is_available(), torch.backends.mps.is_available())):
logger.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.") logger.warning('torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.')
shared.args.cpu = True shared.args.cpu = True
if shared.args.cpu: if shared.args.cpu:
params["torch_dtype"] = torch.float32 params['torch_dtype'] = torch.float32
else: else:
params["device_map"] = 'auto' params['device_map'] = 'auto'
params['max_memory'] = get_max_memory_dict()
if shared.args.load_in_4bit: if shared.args.load_in_4bit:
# See https://github.com/huggingface/transformers/pull/23479/files # See https://github.com/huggingface/transformers/pull/23479/files
# and https://huggingface.co/blog/4bit-transformers-bitsandbytes # and https://huggingface.co/blog/4bit-transformers-bitsandbytes
quantization_config_params = { quantization_config_params = {
@ -170,7 +171,7 @@ def huggingface_loader(model_name):
'bnb_4bit_use_double_quant': shared.args.use_double_quant, 'bnb_4bit_use_double_quant': shared.args.use_double_quant,
} }
logger.warning("Using the following 4-bit params: " + str(quantization_config_params)) logger.info('Using the following 4-bit params: ' + str(quantization_config_params))
params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params) params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params)
elif shared.args.load_in_8bit: elif shared.args.load_in_8bit:
@ -178,14 +179,21 @@ def huggingface_loader(model_name):
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True) params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
else: else:
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True) params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
elif shared.args.bf16:
params["torch_dtype"] = torch.bfloat16
else:
params["torch_dtype"] = torch.float16
params['max_memory'] = get_max_memory_dict() if params['max_memory'] is not None:
with init_empty_weights():
model = LoaderClass.from_config(config, trust_remote_code=params['trust_remote_code'])
model.tie_weights()
params['device_map'] = infer_auto_device_map(
model,
dtype=torch.int8,
max_memory=params['max_memory'],
no_split_module_classes=model._no_split_modules
)
if shared.args.disk: if shared.args.disk:
params["offload_folder"] = shared.args.disk_cache_dir params['offload_folder'] = shared.args.disk_cache_dir
if shared.args.disable_exllama: if shared.args.disable_exllama:
try: try:
@ -193,20 +201,9 @@ def huggingface_loader(model_name):
params['quantization_config'] = gptq_config params['quantization_config'] = gptq_config
logger.info('Loading with ExLlama kernel disabled.') logger.info('Loading with ExLlama kernel disabled.')
except: except:
exc = traceback.format_exc()
logger.error('Failed to disable exllama. Does the config.json for this model contain the necessary quantization info?') logger.error('Failed to disable exllama. Does the config.json for this model contain the necessary quantization info?')
print(exc)
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
with init_empty_weights():
model = LoaderClass.from_config(config, trust_remote_code=shared.args.trust_remote_code)
model.tie_weights()
params['device_map'] = infer_auto_device_map(
model,
dtype=torch.int8,
max_memory=params['max_memory'],
no_split_module_classes=model._no_split_modules
)
if shared.args.compress_pos_emb > 1: if shared.args.compress_pos_emb > 1:
params['rope_scaling'] = {'type': 'linear', 'factor': shared.args.compress_pos_emb} params['rope_scaling'] = {'type': 'linear', 'factor': shared.args.compress_pos_emb}