Add support for --gpu-memory witn --load-in-8bit

This commit is contained in:
awoo 2023-03-16 18:42:53 +03:00
parent 23a5e886e1
commit 83cb20aad8

View File

@ -7,7 +7,8 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
import transformers import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
from accelerate import infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
import modules.shared as shared import modules.shared as shared
@ -94,39 +95,61 @@ def load_model(model_name):
# Custom # Custom
else: else:
command = "AutoModelForCausalLM.from_pretrained" params = {"low_cpu_mem_usage": True}
params = ["low_cpu_mem_usage=True"]
if not shared.args.cpu and not torch.cuda.is_available(): if not shared.args.cpu and not torch.cuda.is_available():
print("Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n") print("Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
shared.args.cpu = True shared.args.cpu = True
if shared.args.cpu: if shared.args.cpu:
params.append("low_cpu_mem_usage=True") params["torch_dtype"] = torch.float32
params.append("torch_dtype=torch.float32")
else: else:
params.append("device_map='auto'") params["device_map"] = 'auto'
params.append("load_in_8bit=True" if shared.args.load_in_8bit else "torch_dtype=torch.bfloat16" if shared.args.bf16 else "torch_dtype=torch.float16") if shared.args.load_in_8bit:
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
elif shared.args.bf16:
params["torch_dtype"] = torch.bfloat16
else:
params["torch_dtype"] = torch.float16
if shared.args.gpu_memory: if shared.args.gpu_memory:
memory_map = shared.args.gpu_memory memory_map = shared.args.gpu_memory
max_memory = f"max_memory={{0: '{memory_map[0]}GiB'" max_memory = { 0: f'{memory_map[0]}GiB' }
for i in range(1, len(memory_map)): for i in range(1, len(memory_map)):
max_memory += (f", {i}: '{memory_map[i]}GiB'") max_memory[i] = f'{memory_map[i]}GiB'
max_memory += (f", 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}") max_memory['cpu'] = f'{shared.args.cpu_memory or 99}GiB'
params.append(max_memory) params['max_memory'] = max_memory
elif not shared.args.load_in_8bit: else:
total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024)) total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
suggestion = round((total_mem-1000)/1000)*1000 suggestion = round((total_mem - 1000) / 1000) * 1000
if total_mem-suggestion < 800: if total_mem - suggestion < 800:
suggestion -= 1000 suggestion -= 1000
suggestion = int(round(suggestion/1000)) suggestion = int(round(suggestion/1000))
print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m") print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
params.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
if shared.args.disk: max_memory = {
params.append(f"offload_folder='{shared.args.disk_cache_dir}'") 0: f'{suggestion}GiB',
'cpu': f'{shared.args.cpu_memory or 99}GiB'
}
params['max_memory'] = max_memory
command = f"{command}(Path(f'models/{shared.model_name}'), {', '.join(set(params))})" if shared.args.disk:
model = eval(command) params["offload_folder"] = shared.args.disk_cache_dir
checkpoint = Path(f'models/{shared.model_name}')
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
config = AutoConfig.from_pretrained(checkpoint)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
model.tie_weights()
params['device_map'] = infer_auto_device_map(
model,
dtype=torch.int8,
max_memory=params['max_memory'],
no_split_module_classes = model._no_split_modules
)
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
# Loading the tokenizer # Loading the tokenizer
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists(): if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists():