diff --git a/README.md b/README.md index f6b1d4f5..ba386852 100644 --- a/README.md +++ b/README.md @@ -177,7 +177,7 @@ Optionally, you can use the following command-line flags: | `--cpu` | Use the CPU to generate text.| | `--load-in-8bit` | Load the model with 8-bit precision.| | `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. | -| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently only LLaMA and OPT are supported. | +| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. | | `--groupsize GROUPSIZE` | GPTQ: Group size. | | `--pre_layer PRE_LAYER` | GPTQ: The number of layers to preload. | | `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. | diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index afb5695f..7926d0ab 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -4,15 +4,50 @@ from pathlib import Path import accelerate import torch +import transformers +from transformers import AutoConfig, AutoModelForCausalLM import modules.shared as shared sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa"))) -import llama import llama_inference_offload -import opt +from modelutils import find_layers +from quant import make_quant +def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128): + config = AutoConfig.from_pretrained(model) + def noop(*args, **kwargs): + pass + torch.nn.init.kaiming_uniform_ = noop + torch.nn.init.uniform_ = noop + torch.nn.init.normal_ = noop + + torch.set_default_dtype(torch.half) + transformers.modeling_utils._init_weights = False + torch.set_default_dtype(torch.half) + model = AutoModelForCausalLM.from_config(config) + torch.set_default_dtype(torch.float) + model = model.eval() + layers = find_layers(model) + for name in exclude_layers: + if name in layers: + del layers[name] + make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold) + + del layers + + print('Loading model ...') + if checkpoint.endswith('.safetensors'): + from safetensors.torch import load_file as safe_load + model.load_state_dict(safe_load(checkpoint)) + else: + model.load_state_dict(torch.load(checkpoint)) + model.seqlen = 2048 + print('Done.') + + return model + def load_quantized(model_name): if not shared.args.model_type: # Try to determine model type from model name @@ -20,6 +55,8 @@ def load_quantized(model_name): model_type = 'llama' elif model_name.lower().startswith(('opt', 'galactica')): model_type = 'opt' + elif model_name.lower().startswith(('gpt-j', 'pygmalion-6b')): + model_type = 'gptj' else: print("Can't determine model type from model name. Please specify it manually using --model_type " "argument") @@ -27,15 +64,12 @@ def load_quantized(model_name): else: model_type = shared.args.model_type.lower() - if model_type == 'llama': - if not shared.args.pre_layer: - load_quant = llama.load_quant - else: - load_quant = llama_inference_offload.load_quant - elif model_type == 'opt': - load_quant = opt.load_quant + if model_type == 'llama' and shared.args.pre_layer: + load_quant = llama_inference_offload.load_quant + elif model_type in ('llama', 'opt', 'gptj'): + load_quant = _load_quant else: - print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported") + print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported") exit() # Now we are going to try to locate the quantized model file. @@ -75,7 +109,8 @@ def load_quantized(model_name): if shared.args.pre_layer: model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer) else: - model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize) + threshold = False if model_type == 'gptj' else 128 + model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold) # accelerate offload (doesn't work properly) if shared.args.gpu_memory: diff --git a/modules/shared.py b/modules/shared.py index ac9d750c..5d1b42d4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -84,7 +84,7 @@ parser.add_argument('--gptq-bits', type=int, default=0, help='DEPRECATED: use -- parser.add_argument('--gptq-model-type', type=str, help='DEPRECATED: use --model_type instead.') parser.add_argument('--gptq-pre-layer', type=int, default=0, help='DEPRECATED: use --pre_layer instead.') parser.add_argument('--wbits', type=int, default=0, help='GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.') -parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently only LLaMA and OPT are supported.') +parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.') parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.') parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to preload.')