From a405064cebefff79fe590a73e5bc6a3a24976847 Mon Sep 17 00:00:00 2001 From: Light Date: Thu, 13 Apr 2023 01:48:17 +0800 Subject: [PATCH] Better dispatch. --- modules/GPTQ_loader.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 2d8b9b9e..1cd3e5cd 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -165,16 +165,19 @@ def load_quantized(model_name): model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold) # accelerate offload (doesn't work properly) - if shared.args.gpu_memory: - memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory)) - max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB' - max_memory = {} - for i in range(len(memory_map)): - max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i] - max_memory['cpu'] = max_cpu_memory + if shared.args.gpu_memory or torch.cuda.device_count() > 1: + if shared.args.gpu_memory: + memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory)) + max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB' + max_memory = {} + for i in range(len(memory_map)): + max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i] + max_memory['cpu'] = max_cpu_memory + else: + max_memory = accelerate.utils.get_balanced_memory(model) device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"]) - print("Using the following device map for the 4-bit model:", device_map) + print("Using the following device map for the quantized model:", device_map) # https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)