Reorganize things

This commit is contained in:
oobabooga 2023-01-15 23:01:51 -03:00
parent a3bb18002a
commit b55486fa00

View File

@ -74,28 +74,21 @@ def load_model(model_name):
# Custom # Custom
else: else:
settings = ["low_cpu_mem_usage=True"] settings = ["low_cpu_mem_usage=True"]
cuda = ""
command = "AutoModelForCausalLM.from_pretrained" command = "AutoModelForCausalLM.from_pretrained"
if args.cpu: if args.cpu:
settings.append("torch_dtype=torch.float32") settings.append("torch_dtype=torch.float32")
else: else:
settings.append("device_map='auto'")
if args.max_gpu_memory is not None: if args.max_gpu_memory is not None:
settings.append(f"max_memory={{0: '{args.max_gpu_memory}GiB', 'cpu': '99GiB'}}") settings.append(f"max_memory={{0: '{args.max_gpu_memory}GiB', 'cpu': '99GiB'}}")
settings.append("device_map='auto'") if args.load_in_8bit:
settings.append("torch_dtype=torch.float16")
elif args.load_in_8bit:
settings.append("device_map='auto'")
settings.append("load_in_8bit=True") settings.append("load_in_8bit=True")
else: else:
settings.append("torch_dtype=torch.float16") settings.append("torch_dtype=torch.float16")
if args.auto_devices:
settings.append("device_map='auto'")
else:
cuda = ".cuda()"
settings = ', '.join(list(set(settings))) settings = ', '.join(list(set(settings)))
command = f"{command}(Path(f'models/{model_name}'), {settings}){cuda}" command = f"{command}(Path(f'models/{model_name}'), {settings})"
model = eval(command) model = eval(command)
# Loading the tokenizer # Loading the tokenizer