From ef10ffc6b4188eea6bef1f59264e5264bdf55781 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 17 May 2023 15:52:23 -0300 Subject: [PATCH] Add various checks to model loading functions --- modules/AutoGPTQ_loader.py | 21 ++++++++++----------- modules/models.py | 26 ++++++++++++++++++-------- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/modules/AutoGPTQ_loader.py b/modules/AutoGPTQ_loader.py index ac0a04dc..adbee7eb 100644 --- a/modules/AutoGPTQ_loader.py +++ b/modules/AutoGPTQ_loader.py @@ -13,19 +13,18 @@ def load_quantized(model_name): use_safetensors = False # Find the model checkpoint - found_pts = list(path_to_model.glob("*.pt")) - found_safetensors = list(path_to_model.glob("*.safetensors")) - if len(found_safetensors) > 0: - if len(found_safetensors) > 1: - logging.warning('More than one .safetensors model has been found. The last one will be selected. It could be wrong.') + for ext in ['.safetensors', '.pt', '.bin']: + found = list(path_to_model.glob(f"*{ext}")) + if len(found) > 0: + if len(found) > 1: + logging.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.') - use_safetensors = True - pt_path = found_safetensors[-1] - elif len(found_pts) > 0: - if len(found_pts) > 1: - logging.warning('More than one .pt model has been found. The last one will be selected. It could be wrong.') + pt_path = found[-1] + break - pt_path = found_pts[-1] + if pt_path is None: + logging.error("The model could not be loaded because its checkpoint file in .bin/.pt/.safetensors format could not be located.") + return # Define the params for AutoGPTQForCausalLM.from_quantized params = { diff --git a/modules/models.py b/modules/models.py index ded61b77..a04a7ec8 100644 --- a/modules/models.py +++ b/modules/models.py @@ -40,10 +40,14 @@ if shared.args.deepspeed: # Some models require special treatment in various parts of the code. # This function detects those models def find_model_type(model_name): + path_to_model = Path(f'{shared.args.model_dir}/{model_name}') + if not path_to_model.exists(): + return 'None' + model_name_lower = model_name.lower() if 'rwkv-' in model_name_lower: return 'rwkv' - elif len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))) > 0: + elif len(list(path_to_model.glob('*ggml*.bin'))) > 0: return 'llamacpp' elif re.match('.*ggml.*\.bin', model_name_lower): return 'llamacpp' @@ -58,7 +62,7 @@ def find_model_type(model_name): elif any((k in model_name_lower for k in ['gpt4chan', 'gpt-4chan'])): return 'gpt4chan' else: - config = AutoConfig.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), trust_remote_code=shared.args.trust_remote_code) + config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code) # Not a "catch all", but fairly accurate if config.to_dict().get("is_encoder_decoder", False): return 'HF_seq2seq' @@ -71,11 +75,14 @@ def load_model(model_name): t0 = time.time() shared.model_type = find_model_type(model_name) - if shared.args.wbits > 0 or shared.args.autogptq: - if shared.args.autogptq: - load_func = AutoGPTQ_loader - else: - load_func = GPTQ_loader + if shared.model_type == 'None': + logging.error('The path to the model does not exist. Exiting.') + return None, None + + if shared.args.autogptq: + load_func = AutoGPTQ_loader + elif shared.args.wbits > 0: + load_func = GPTQ_loader elif shared.model_type == 'llamacpp': load_func = llamacpp_loader elif shared.model_type == 'rwkv': @@ -101,6 +108,7 @@ def load_model(model_name): def load_tokenizer(model_name, model): + tokenizer = None if shared.model_type == 'gpt4chan' and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists(): tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/")) elif type(model) is transformers.LlamaForCausalLM: @@ -122,7 +130,9 @@ def load_tokenizer(model_name, model): except: pass else: - tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=shared.args.trust_remote_code) + path_to_model = Path(f"{shared.args.model_dir}/{model_name}/") + if path_to_model.exists(): + tokenizer = AutoTokenizer.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code) return tokenizer