mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 09:40:20 +01:00
Better dispatch.
This commit is contained in:
parent
f3591ccfa1
commit
a405064ceb
@ -165,6 +165,7 @@ def load_quantized(model_name):
|
|||||||
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
|
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
|
||||||
|
|
||||||
# accelerate offload (doesn't work properly)
|
# accelerate offload (doesn't work properly)
|
||||||
|
if shared.args.gpu_memory or torch.cuda.device_count() > 1:
|
||||||
if shared.args.gpu_memory:
|
if shared.args.gpu_memory:
|
||||||
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
|
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
|
||||||
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
|
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
|
||||||
@ -172,9 +173,11 @@ def load_quantized(model_name):
|
|||||||
for i in range(len(memory_map)):
|
for i in range(len(memory_map)):
|
||||||
max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
|
max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
|
||||||
max_memory['cpu'] = max_cpu_memory
|
max_memory['cpu'] = max_cpu_memory
|
||||||
|
else:
|
||||||
|
max_memory = accelerate.utils.get_balanced_memory(model)
|
||||||
|
|
||||||
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
|
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
|
||||||
print("Using the following device map for the 4-bit model:", device_map)
|
print("Using the following device map for the quantized model:", device_map)
|
||||||
# https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
|
# https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
|
||||||
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)
|
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user