diff --git a/server.py b/server.py index 9c041bcb..1aa59913 100644 --- a/server.py +++ b/server.py @@ -27,11 +27,6 @@ def load_model(model_name): print(f"Loading {model_name}...") t0 = time.time() - if args.cpu: - dtype = torch.float32 - else: - dtype = torch.float16 - # Loading the model if not args.cpu and Path(f"torch-dumps/{model_name}.pt").exists(): print("Loading in .pt format...") @@ -45,9 +40,9 @@ def load_model(model_name): model = T5ForConditionalGeneration.from_pretrained(Path(f"models/{model_name}")).cuda() else: if args.cpu: - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=dtype) + model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.float32) else: - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=dtype).cuda() + model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda() # Loading the tokenizer if model_name.lower().startswith('gpt4chan') and Path(f"models/gpt-j-6B/").exists():