diff --git a/server.py b/server.py index bf9ff219..1407b4c7 100644 --- a/server.py +++ b/server.py @@ -97,10 +97,16 @@ def load_model(model_name): settings.append("device_map='auto'") settings.append("load_in_8bit=True" if args.load_in_8bit else "torch_dtype=torch.float16") - if args.gpu_memory and args.cpu_memory: - settings.append(f"max_memory={{0: '{args.gpu_memory}GiB', 'cpu': '{args.cpu_memory}GiB'}}") - elif args.gpu_memory: - settings.append(f"max_memory={{0: '{args.gpu_memory}GiB', 'cpu': '99GiB'}}") + if args.gpu_memory: + settings.append(f"max_memory={{0: '{args.gpu_memory or '99'}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}") + elif not args.load_in_8bit: + total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024)) + suggestion = round((total_mem-1000)/1000)*1000 + if total_mem-suggestion < 800: + suggestion -= 1000 + suggestion = int(round(suggestion/1000)) + print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m") + settings.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}") if args.disk: settings.append(f"offload_folder='{args.disk_cache_dir or 'cache'}'")