Better dispatch.

This commit is contained in:
Light 2023-04-13 01:48:17 +08:00
parent f3591ccfa1
commit a405064ceb

View File

@ -165,16 +165,19 @@ def load_quantized(model_name):
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold) model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
# accelerate offload (doesn't work properly) # accelerate offload (doesn't work properly)
if shared.args.gpu_memory: if shared.args.gpu_memory or torch.cuda.device_count() > 1:
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory)) if shared.args.gpu_memory:
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB' memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
max_memory = {} max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
for i in range(len(memory_map)): max_memory = {}
max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i] for i in range(len(memory_map)):
max_memory['cpu'] = max_cpu_memory max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
max_memory['cpu'] = max_cpu_memory
else:
max_memory = accelerate.utils.get_balanced_memory(model)
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"]) device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
print("Using the following device map for the 4-bit model:", device_map) print("Using the following device map for the quantized model:", device_map)
# https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model # https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True) model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)