From c8207d474f9c5365ab5a1c269eb71bff05a31988 Mon Sep 17 00:00:00 2001 From: Maya Eary Date: Tue, 28 Mar 2023 20:38:55 +0300 Subject: [PATCH 1/5] Generalized load_quantized --- modules/GPTQ_loader.py | 54 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 10 deletions(-) diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index afb5695f..351d658d 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -4,13 +4,48 @@ from pathlib import Path import accelerate import torch +import transformers +from transformers import AutoConfig, AutoModelForCausalLM import modules.shared as shared sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa"))) -import llama import llama_inference_offload -import opt +from quant import make_quant +from modelutils import find_layers + +def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head']): + config = AutoConfig.from_pretrained(model) + def noop(*args, **kwargs): + pass + torch.nn.init.kaiming_uniform_ = noop + torch.nn.init.uniform_ = noop + torch.nn.init.normal_ = noop + + torch.set_default_dtype(torch.half) + transformers.modeling_utils._init_weights = False + torch.set_default_dtype(torch.half) + model = AutoModelForCausalLM.from_config(config) + torch.set_default_dtype(torch.float) + model = model.eval() + layers = find_layers(model) + for name in exclude_layers: + if name in layers: + del layers[name] + make_quant(model, layers, wbits, groupsize, faster=faster_kernel) + + del layers + + print('Loading model ...') + if checkpoint.endswith('.safetensors'): + from safetensors.torch import load_file as safe_load + model.load_state_dict(safe_load(checkpoint)) + else: + model.load_state_dict(torch.load(checkpoint)) + model.seqlen = 2048 + print('Done.') + + return model def load_quantized(model_name): @@ -20,6 +55,8 @@ def load_quantized(model_name): model_type = 'llama' elif model_name.lower().startswith(('opt', 'galactica')): model_type = 'opt' + elif model_name.lower().startswith(('gpt-j', 'pygmalion-6b')): + model_type = 'gptj' else: print("Can't determine model type from model name. Please specify it manually using --model_type " "argument") @@ -27,15 +64,12 @@ def load_quantized(model_name): else: model_type = shared.args.model_type.lower() - if model_type == 'llama': - if not shared.args.pre_layer: - load_quant = llama.load_quant - else: - load_quant = llama_inference_offload.load_quant - elif model_type == 'opt': - load_quant = opt.load_quant + if model_type == 'llama' and shared.args.pre_layer: + oad_quant = llama_inference_offload.load_quant + elif model_type in ('llama', 'opt', 'gptj'): + load_quant = _load_quant else: - print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported") + print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported") exit() # Now we are going to try to locate the quantized model file. From 1c075d8d219b5fd2bfeba1b4bad8f912b22a26da Mon Sep 17 00:00:00 2001 From: Maya Eary Date: Tue, 28 Mar 2023 20:43:50 +0300 Subject: [PATCH 2/5] Fix typo --- modules/GPTQ_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 351d658d..1fdd23c0 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -65,7 +65,7 @@ def load_quantized(model_name): model_type = shared.args.model_type.lower() if model_type == 'llama' and shared.args.pre_layer: - oad_quant = llama_inference_offload.load_quant + load_quant = llama_inference_offload.load_quant elif model_type in ('llama', 'opt', 'gptj'): load_quant = _load_quant else: From 41ec682834de3e7b79cd8e27aeec98690bc209ac Mon Sep 17 00:00:00 2001 From: Maya Eary Date: Tue, 28 Mar 2023 22:45:38 +0300 Subject: [PATCH 3/5] Disable kernel threshold for gpt-j --- modules/GPTQ_loader.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 1fdd23c0..2a9039a3 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -14,7 +14,7 @@ import llama_inference_offload from quant import make_quant from modelutils import find_layers -def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head']): +def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128): config = AutoConfig.from_pretrained(model) def noop(*args, **kwargs): pass @@ -32,7 +32,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc for name in exclude_layers: if name in layers: del layers[name] - make_quant(model, layers, wbits, groupsize, faster=faster_kernel) + make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold) del layers @@ -109,7 +109,8 @@ def load_quantized(model_name): if shared.args.pre_layer: model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer) else: - model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize) + threshold = False if model_type == 'gptj' else 128 + model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold) # accelerate offload (doesn't work properly) if shared.args.gpu_memory: From 0bec15ebcd1571155a54e87b371dc40534864f2e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 28 Mar 2023 17:34:15 -0300 Subject: [PATCH 4/5] Reorder imports --- modules/GPTQ_loader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 2a9039a3..c99a63f3 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -5,14 +5,15 @@ from pathlib import Path import accelerate import torch import transformers -from transformers import AutoConfig, AutoModelForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM import modules.shared as shared sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa"))) import llama_inference_offload -from quant import make_quant from modelutils import find_layers +from quant import make_quant + def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128): config = AutoConfig.from_pretrained(model) From 010b259dde859b5703a6ea4cf2ea6c0aa4f25343 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 28 Mar 2023 17:46:00 -0300 Subject: [PATCH 5/5] Update documentation --- README.md | 2 +- modules/GPTQ_loader.py | 1 - modules/shared.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f6b1d4f5..ba386852 100644 --- a/README.md +++ b/README.md @@ -177,7 +177,7 @@ Optionally, you can use the following command-line flags: | `--cpu` | Use the CPU to generate text.| | `--load-in-8bit` | Load the model with 8-bit precision.| | `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. | -| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently only LLaMA and OPT are supported. | +| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. | | `--groupsize GROUPSIZE` | GPTQ: Group size. | | `--pre_layer PRE_LAYER` | GPTQ: The number of layers to preload. | | `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. | diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index c99a63f3..7926d0ab 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -48,7 +48,6 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc return model - def load_quantized(model_name): if not shared.args.model_type: # Try to determine model type from model name diff --git a/modules/shared.py b/modules/shared.py index ac9d750c..5d1b42d4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -84,7 +84,7 @@ parser.add_argument('--gptq-bits', type=int, default=0, help='DEPRECATED: use -- parser.add_argument('--gptq-model-type', type=str, help='DEPRECATED: use --model_type instead.') parser.add_argument('--gptq-pre-layer', type=int, default=0, help='DEPRECATED: use --pre_layer instead.') parser.add_argument('--wbits', type=int, default=0, help='GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.') -parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently only LLaMA and OPT are supported.') +parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.') parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.') parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to preload.')