Update comments

This commit is contained in:
oobabooga 2023-03-20 16:40:08 -03:00 committed by GitHub
parent 7618f3fe8c
commit db4219a340
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -57,13 +57,13 @@ def load_quantized(model_name):
print(f"Could not find {pt_model}, exiting...")
exit()
# Using qwopqwop200's offload
# qwopqwop200's offload
if shared.args.gptq_pre_layer:
model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits, shared.args.gptq_pre_layer)
else:
model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits)
# Using accelerate offload (doesn't work properly)
# accelerate offload (doesn't work properly)
if shared.args.gpu_memory:
memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory))
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
@ -76,6 +76,8 @@ def load_quantized(model_name):
print("Using the following device map for the 4-bit model:", device_map)
# https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)
# No offload
elif not shared.args.cpu:
model = model.to(torch.device('cuda:0'))