From 826e297b0ec40299318f1002f9165e7ac9c9c257 Mon Sep 17 00:00:00 2001 From: rohvani <3782201+rohvani@users.noreply.github.com> Date: Thu, 9 Mar 2023 18:31:32 -0800 Subject: [PATCH] add llama-65b-4bit support & multiple pt paths --- modules/models.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/modules/models.py b/modules/models.py index 3e6cea18..062ccb1f 100644 --- a/modules/models.py +++ b/modules/models.py @@ -97,19 +97,27 @@ def load_model(model_name): pt_model = '' if path_to_model.name.lower().startswith('llama-7b'): pt_model = 'llama-7b-4bit.pt' - if path_to_model.name.lower().startswith('llama-13b'): + elif path_to_model.name.lower().startswith('llama-13b'): pt_model = 'llama-13b-4bit.pt' - if path_to_model.name.lower().startswith('llama-30b'): + elif path_to_model.name.lower().startswith('llama-30b'): pt_model = 'llama-30b-4bit.pt' - - if not Path(f"models/{pt_model}").exists(): - print(f"Could not find models/{pt_model}, exiting...") - exit() - elif pt_model == '': + elif path_to_model.name.lower().startswith('llama-65b'): + pt_model = 'llama-65b-4bit.pt' + else: print(f"Could not find the .pt model for {model_name}, exiting...") exit() - model = load_quant(path_to_model, Path(f"models/{pt_model}"), 4) + # check root of models folder, and model path root + paths = [ f"{path_to_model}/{pt_model}", f"models/{pt_model}" ] + for path in [ Path(p) for p in paths ]: + if path.exists(): + pt_path = path + + if not pt_path: + print(f"Could not find {pt_model}, exiting...") + exit() + + model = load_quant(path_to_model, pt_path, 4) model = model.to(torch.device('cuda:0')) # Custom