mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Add support for --gpu-memory witn --load-in-8bit
This commit is contained in:
parent
23a5e886e1
commit
83cb20aad8
@ -7,7 +7,8 @@ from pathlib import Path
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
|
||||||
|
from accelerate import infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
|
||||||
@ -94,39 +95,61 @@ def load_model(model_name):
|
|||||||
|
|
||||||
# Custom
|
# Custom
|
||||||
else:
|
else:
|
||||||
command = "AutoModelForCausalLM.from_pretrained"
|
params = {"low_cpu_mem_usage": True}
|
||||||
params = ["low_cpu_mem_usage=True"]
|
|
||||||
if not shared.args.cpu and not torch.cuda.is_available():
|
if not shared.args.cpu and not torch.cuda.is_available():
|
||||||
print("Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
|
print("Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
|
||||||
shared.args.cpu = True
|
shared.args.cpu = True
|
||||||
|
|
||||||
if shared.args.cpu:
|
if shared.args.cpu:
|
||||||
params.append("low_cpu_mem_usage=True")
|
params["torch_dtype"] = torch.float32
|
||||||
params.append("torch_dtype=torch.float32")
|
|
||||||
else:
|
else:
|
||||||
params.append("device_map='auto'")
|
params["device_map"] = 'auto'
|
||||||
params.append("load_in_8bit=True" if shared.args.load_in_8bit else "torch_dtype=torch.bfloat16" if shared.args.bf16 else "torch_dtype=torch.float16")
|
if shared.args.load_in_8bit:
|
||||||
|
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
|
||||||
|
elif shared.args.bf16:
|
||||||
|
params["torch_dtype"] = torch.bfloat16
|
||||||
|
else:
|
||||||
|
params["torch_dtype"] = torch.float16
|
||||||
|
|
||||||
if shared.args.gpu_memory:
|
if shared.args.gpu_memory:
|
||||||
memory_map = shared.args.gpu_memory
|
memory_map = shared.args.gpu_memory
|
||||||
max_memory = f"max_memory={{0: '{memory_map[0]}GiB'"
|
max_memory = { 0: f'{memory_map[0]}GiB' }
|
||||||
for i in range(1, len(memory_map)):
|
for i in range(1, len(memory_map)):
|
||||||
max_memory += (f", {i}: '{memory_map[i]}GiB'")
|
max_memory[i] = f'{memory_map[i]}GiB'
|
||||||
max_memory += (f", 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
|
max_memory['cpu'] = f'{shared.args.cpu_memory or 99}GiB'
|
||||||
params.append(max_memory)
|
params['max_memory'] = max_memory
|
||||||
elif not shared.args.load_in_8bit:
|
else:
|
||||||
total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024))
|
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
|
||||||
suggestion = round((total_mem-1000)/1000)*1000
|
suggestion = round((total_mem - 1000) / 1000) * 1000
|
||||||
if total_mem-suggestion < 800:
|
if total_mem - suggestion < 800:
|
||||||
suggestion -= 1000
|
suggestion -= 1000
|
||||||
suggestion = int(round(suggestion/1000))
|
suggestion = int(round(suggestion/1000))
|
||||||
print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
|
print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
|
||||||
params.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
|
|
||||||
if shared.args.disk:
|
|
||||||
params.append(f"offload_folder='{shared.args.disk_cache_dir}'")
|
|
||||||
|
|
||||||
command = f"{command}(Path(f'models/{shared.model_name}'), {', '.join(set(params))})"
|
max_memory = {
|
||||||
model = eval(command)
|
0: f'{suggestion}GiB',
|
||||||
|
'cpu': f'{shared.args.cpu_memory or 99}GiB'
|
||||||
|
}
|
||||||
|
params['max_memory'] = max_memory
|
||||||
|
|
||||||
|
if shared.args.disk:
|
||||||
|
params["offload_folder"] = shared.args.disk_cache_dir
|
||||||
|
|
||||||
|
checkpoint = Path(f'models/{shared.model_name}')
|
||||||
|
|
||||||
|
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
|
||||||
|
config = AutoConfig.from_pretrained(checkpoint)
|
||||||
|
with init_empty_weights():
|
||||||
|
model = AutoModelForCausalLM.from_config(config)
|
||||||
|
model.tie_weights()
|
||||||
|
params['device_map'] = infer_auto_device_map(
|
||||||
|
model,
|
||||||
|
dtype=torch.int8,
|
||||||
|
max_memory=params['max_memory'],
|
||||||
|
no_split_module_classes = model._no_split_modules
|
||||||
|
)
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
|
||||||
|
|
||||||
# Loading the tokenizer
|
# Loading the tokenizer
|
||||||
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists():
|
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists():
|
||||||
|
Loading…
Reference in New Issue
Block a user