Merge branch 'oobabooga:main' into ds

This commit is contained in:
81300 2023-02-01 20:50:51 +02:00 committed by GitHub
commit 248ec4fa21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -200,7 +200,7 @@ def load_model(model_name):
# Custom # Custom
else: else:
command = "AutoModelForCausalLM.from_pretrained" command = "AutoModelForCausalLM.from_pretrained"
settings = [] settings = ["low_cpu_mem_usage=True"]
if args.cpu: if args.cpu:
settings.append("low_cpu_mem_usage=True") settings.append("low_cpu_mem_usage=True")
@ -211,7 +211,7 @@ def load_model(model_name):
if args.gpu_memory: if args.gpu_memory:
settings.append(f"max_memory={{0: '{args.gpu_memory or '99'}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}") settings.append(f"max_memory={{0: '{args.gpu_memory or '99'}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}")
elif not args.load_in_8bit: elif (args.gpu_memory or args.cpu_memory) and not args.load_in_8bit:
total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024)) total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024))
suggestion = round((total_mem-1000)/1000)*1000 suggestion = round((total_mem-1000)/1000)*1000
if total_mem-suggestion < 800: if total_mem-suggestion < 800: