From 00a12889e9240849db1e20cb5c1d956e75a4e809 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 9 Jan 2023 16:28:04 -0300 Subject: [PATCH] Refactor model loading function --- server.py | 21 ++++++++++----------- torch-dumps/place-your-pt-models-here.txt | 0 2 files changed, 10 insertions(+), 11 deletions(-) delete mode 100644 torch-dumps/place-your-pt-models-here.txt diff --git a/server.py b/server.py index 4d349249..9c041bcb 100644 --- a/server.py +++ b/server.py @@ -36,15 +36,18 @@ def load_model(model_name): if not args.cpu and Path(f"torch-dumps/{model_name}.pt").exists(): print("Loading in .pt format...") model = torch.load(Path(f"torch-dumps/{model_name}.pt")) - elif model_name.lower().startswith(('gpt-neo', 'opt-', 'galactica')): - if any(size in model_name.lower() for size in ('13b', '20b', '30b')): - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True) - else: - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=dtype) + elif model_name.lower().startswith(('gpt-neo', 'opt-', 'galactica')) and any(size in model_name.lower() for size in ('13b', '20b', '30b')): + model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True) elif model_name in ['flan-t5', 't5-large']: - model = T5ForConditionalGeneration.from_pretrained(Path(f"models/{model_name}")) + if args.cpu: + model = T5ForConditionalGeneration.from_pretrained(Path(f"models/{model_name}")) + else: + model = T5ForConditionalGeneration.from_pretrained(Path(f"models/{model_name}")).cuda() else: - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=dtype) + if args.cpu: + model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=dtype) + else: + model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=dtype).cuda() # Loading the tokenizer if model_name.lower().startswith('gpt4chan') and Path(f"models/gpt-j-6B/").exists(): @@ -54,10 +57,6 @@ def load_model(model_name): else: tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{model_name}/")) - # Sending to the GPU - if not (args.cpu or any(size in model_name.lower() for size in ('13b', '20b', '30b'))): - model = model.cuda() - print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") return model, tokenizer diff --git a/torch-dumps/place-your-pt-models-here.txt b/torch-dumps/place-your-pt-models-here.txt deleted file mode 100644 index e69de29b..00000000