Add various checks to model loading functions

This commit is contained in:
oobabooga 2023-05-17 15:52:23 -03:00
parent abd361b3a0
commit ef10ffc6b4
2 changed files with 28 additions and 19 deletions

View File

@ -13,19 +13,18 @@ def load_quantized(model_name):
use_safetensors = False
# Find the model checkpoint
found_pts = list(path_to_model.glob("*.pt"))
found_safetensors = list(path_to_model.glob("*.safetensors"))
if len(found_safetensors) > 0:
if len(found_safetensors) > 1:
logging.warning('More than one .safetensors model has been found. The last one will be selected. It could be wrong.')
for ext in ['.safetensors', '.pt', '.bin']:
found = list(path_to_model.glob(f"*{ext}"))
if len(found) > 0:
if len(found) > 1:
logging.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
use_safetensors = True
pt_path = found_safetensors[-1]
elif len(found_pts) > 0:
if len(found_pts) > 1:
logging.warning('More than one .pt model has been found. The last one will be selected. It could be wrong.')
pt_path = found[-1]
break
pt_path = found_pts[-1]
if pt_path is None:
logging.error("The model could not be loaded because its checkpoint file in .bin/.pt/.safetensors format could not be located.")
return
# Define the params for AutoGPTQForCausalLM.from_quantized
params = {

View File

@ -40,10 +40,14 @@ if shared.args.deepspeed:
# Some models require special treatment in various parts of the code.
# This function detects those models
def find_model_type(model_name):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
if not path_to_model.exists():
return 'None'
model_name_lower = model_name.lower()
if 'rwkv-' in model_name_lower:
return 'rwkv'
elif len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))) > 0:
elif len(list(path_to_model.glob('*ggml*.bin'))) > 0:
return 'llamacpp'
elif re.match('.*ggml.*\.bin', model_name_lower):
return 'llamacpp'
@ -58,7 +62,7 @@ def find_model_type(model_name):
elif any((k in model_name_lower for k in ['gpt4chan', 'gpt-4chan'])):
return 'gpt4chan'
else:
config = AutoConfig.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), trust_remote_code=shared.args.trust_remote_code)
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
# Not a "catch all", but fairly accurate
if config.to_dict().get("is_encoder_decoder", False):
return 'HF_seq2seq'
@ -71,11 +75,14 @@ def load_model(model_name):
t0 = time.time()
shared.model_type = find_model_type(model_name)
if shared.args.wbits > 0 or shared.args.autogptq:
if shared.args.autogptq:
load_func = AutoGPTQ_loader
else:
load_func = GPTQ_loader
if shared.model_type == 'None':
logging.error('The path to the model does not exist. Exiting.')
return None, None
if shared.args.autogptq:
load_func = AutoGPTQ_loader
elif shared.args.wbits > 0:
load_func = GPTQ_loader
elif shared.model_type == 'llamacpp':
load_func = llamacpp_loader
elif shared.model_type == 'rwkv':
@ -101,6 +108,7 @@ def load_model(model_name):
def load_tokenizer(model_name, model):
tokenizer = None
if shared.model_type == 'gpt4chan' and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
elif type(model) is transformers.LlamaForCausalLM:
@ -122,7 +130,9 @@ def load_tokenizer(model_name, model):
except:
pass
else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=shared.args.trust_remote_code)
path_to_model = Path(f"{shared.args.model_dir}/{model_name}/")
if path_to_model.exists():
tokenizer = AutoTokenizer.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
return tokenizer