diff --git a/server.py b/server.py index 73aedf31..4154fe44 100644 --- a/server.py +++ b/server.py @@ -27,7 +27,7 @@ def load_model(model_name): if os.path.exists(f"torch-dumps/{model_name}.pt"): print("Loading in .pt format...") model = torch.load(f"torch-dumps/{model_name}.pt").cuda() - elif model_name in ['gpt-neox-20b', 'opt-13b', 'OPT-13B-Erebus']: + elif model_name.lower().startswith(('gpt-neo', 'opt-')): model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", device_map='auto', load_in_8bit=True) elif model_name in ['gpt-j-6B']: model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()