mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-29 19:09:32 +01:00
Merge branch 'oobabooga:main' into ds
This commit is contained in:
commit
248ec4fa21
@ -200,7 +200,7 @@ def load_model(model_name):
|
|||||||
# Custom
|
# Custom
|
||||||
else:
|
else:
|
||||||
command = "AutoModelForCausalLM.from_pretrained"
|
command = "AutoModelForCausalLM.from_pretrained"
|
||||||
settings = []
|
settings = ["low_cpu_mem_usage=True"]
|
||||||
|
|
||||||
if args.cpu:
|
if args.cpu:
|
||||||
settings.append("low_cpu_mem_usage=True")
|
settings.append("low_cpu_mem_usage=True")
|
||||||
@ -211,7 +211,7 @@ def load_model(model_name):
|
|||||||
|
|
||||||
if args.gpu_memory:
|
if args.gpu_memory:
|
||||||
settings.append(f"max_memory={{0: '{args.gpu_memory or '99'}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}")
|
settings.append(f"max_memory={{0: '{args.gpu_memory or '99'}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}")
|
||||||
elif not args.load_in_8bit:
|
elif (args.gpu_memory or args.cpu_memory) and not args.load_in_8bit:
|
||||||
total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024))
|
total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024))
|
||||||
suggestion = round((total_mem-1000)/1000)*1000
|
suggestion = round((total_mem-1000)/1000)*1000
|
||||||
if total_mem-suggestion < 800:
|
if total_mem-suggestion < 800:
|
||||||
|
Loading…
Reference in New Issue
Block a user