From 83cb20aad85d3c35f8cc88f86183fa5320d3ec9e Mon Sep 17 00:00:00 2001 From: awoo Date: Thu, 16 Mar 2023 18:42:53 +0300 Subject: [PATCH] Add support for --gpu-memory witn --load-in-8bit --- modules/models.py | 63 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 20 deletions(-) diff --git a/modules/models.py b/modules/models.py index 2a7dca62..ea5fe757 100644 --- a/modules/models.py +++ b/modules/models.py @@ -7,7 +7,8 @@ from pathlib import Path import numpy as np import torch import transformers -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig +from accelerate import infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch import modules.shared as shared @@ -94,39 +95,61 @@ def load_model(model_name): # Custom else: - command = "AutoModelForCausalLM.from_pretrained" - params = ["low_cpu_mem_usage=True"] + params = {"low_cpu_mem_usage": True} if not shared.args.cpu and not torch.cuda.is_available(): print("Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n") shared.args.cpu = True if shared.args.cpu: - params.append("low_cpu_mem_usage=True") - params.append("torch_dtype=torch.float32") + params["torch_dtype"] = torch.float32 else: - params.append("device_map='auto'") - params.append("load_in_8bit=True" if shared.args.load_in_8bit else "torch_dtype=torch.bfloat16" if shared.args.bf16 else "torch_dtype=torch.float16") + params["device_map"] = 'auto' + if shared.args.load_in_8bit: + params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True) + elif shared.args.bf16: + params["torch_dtype"] = torch.bfloat16 + else: + params["torch_dtype"] = torch.float16 if shared.args.gpu_memory: memory_map = shared.args.gpu_memory - max_memory = f"max_memory={{0: '{memory_map[0]}GiB'" + max_memory = { 0: f'{memory_map[0]}GiB' } for i in range(1, len(memory_map)): - max_memory += (f", {i}: '{memory_map[i]}GiB'") - max_memory += (f", 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}") - params.append(max_memory) - elif not shared.args.load_in_8bit: - total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024)) - suggestion = round((total_mem-1000)/1000)*1000 - if total_mem-suggestion < 800: + max_memory[i] = f'{memory_map[i]}GiB' + max_memory['cpu'] = f'{shared.args.cpu_memory or 99}GiB' + params['max_memory'] = max_memory + else: + total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024)) + suggestion = round((total_mem - 1000) / 1000) * 1000 + if total_mem - suggestion < 800: suggestion -= 1000 suggestion = int(round(suggestion/1000)) print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m") - params.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}") - if shared.args.disk: - params.append(f"offload_folder='{shared.args.disk_cache_dir}'") + + max_memory = { + 0: f'{suggestion}GiB', + 'cpu': f'{shared.args.cpu_memory or 99}GiB' + } + params['max_memory'] = max_memory - command = f"{command}(Path(f'models/{shared.model_name}'), {', '.join(set(params))})" - model = eval(command) + if shared.args.disk: + params["offload_folder"] = shared.args.disk_cache_dir + + checkpoint = Path(f'models/{shared.model_name}') + + if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': + config = AutoConfig.from_pretrained(checkpoint) + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + model.tie_weights() + params['device_map'] = infer_auto_device_map( + model, + dtype=torch.int8, + max_memory=params['max_memory'], + no_split_module_classes = model._no_split_modules + ) + + model = AutoModelForCausalLM.from_pretrained(checkpoint, **params) # Loading the tokenizer if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists():