Generalize GPTQ_loader, support any model (#615 from mayaeary/feature/gpt-j-4bit-v2)

This includes Pygmalion 4bit
This commit is contained in:
oobabooga 2023-03-28 18:00:09 -03:00 committed by GitHub
commit b2f356a9ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 13 deletions

View File

@ -177,7 +177,7 @@ Optionally, you can use the following command-line flags:
| `--cpu` | Use the CPU to generate text.| | `--cpu` | Use the CPU to generate text.|
| `--load-in-8bit` | Load the model with 8-bit precision.| | `--load-in-8bit` | Load the model with 8-bit precision.|
| `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. | | `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. |
| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently only LLaMA and OPT are supported. | | `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. |
| `--groupsize GROUPSIZE` | GPTQ: Group size. | | `--groupsize GROUPSIZE` | GPTQ: Group size. |
| `--pre_layer PRE_LAYER` | GPTQ: The number of layers to preload. | | `--pre_layer PRE_LAYER` | GPTQ: The number of layers to preload. |
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. | | `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |

View File

@ -4,15 +4,50 @@ from pathlib import Path
import accelerate import accelerate
import torch import torch
import transformers
from transformers import AutoConfig, AutoModelForCausalLM
import modules.shared as shared import modules.shared as shared
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa"))) sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
import llama
import llama_inference_offload import llama_inference_offload
import opt from modelutils import find_layers
from quant import make_quant
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
config = AutoConfig.from_pretrained(model)
def noop(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in exclude_layers:
if name in layers:
del layers[name]
make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
del layers
print('Loading model ...')
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))
model.seqlen = 2048
print('Done.')
return model
def load_quantized(model_name): def load_quantized(model_name):
if not shared.args.model_type: if not shared.args.model_type:
# Try to determine model type from model name # Try to determine model type from model name
@ -20,6 +55,8 @@ def load_quantized(model_name):
model_type = 'llama' model_type = 'llama'
elif model_name.lower().startswith(('opt', 'galactica')): elif model_name.lower().startswith(('opt', 'galactica')):
model_type = 'opt' model_type = 'opt'
elif model_name.lower().startswith(('gpt-j', 'pygmalion-6b')):
model_type = 'gptj'
else: else:
print("Can't determine model type from model name. Please specify it manually using --model_type " print("Can't determine model type from model name. Please specify it manually using --model_type "
"argument") "argument")
@ -27,15 +64,12 @@ def load_quantized(model_name):
else: else:
model_type = shared.args.model_type.lower() model_type = shared.args.model_type.lower()
if model_type == 'llama': if model_type == 'llama' and shared.args.pre_layer:
if not shared.args.pre_layer: load_quant = llama_inference_offload.load_quant
load_quant = llama.load_quant elif model_type in ('llama', 'opt', 'gptj'):
else: load_quant = _load_quant
load_quant = llama_inference_offload.load_quant
elif model_type == 'opt':
load_quant = opt.load_quant
else: else:
print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported") print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
exit() exit()
# Now we are going to try to locate the quantized model file. # Now we are going to try to locate the quantized model file.
@ -75,7 +109,8 @@ def load_quantized(model_name):
if shared.args.pre_layer: if shared.args.pre_layer:
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer) model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
else: else:
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize) threshold = False if model_type == 'gptj' else 128
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
# accelerate offload (doesn't work properly) # accelerate offload (doesn't work properly)
if shared.args.gpu_memory: if shared.args.gpu_memory:

View File

@ -84,7 +84,7 @@ parser.add_argument('--gptq-bits', type=int, default=0, help='DEPRECATED: use --
parser.add_argument('--gptq-model-type', type=str, help='DEPRECATED: use --model_type instead.') parser.add_argument('--gptq-model-type', type=str, help='DEPRECATED: use --model_type instead.')
parser.add_argument('--gptq-pre-layer', type=int, default=0, help='DEPRECATED: use --pre_layer instead.') parser.add_argument('--gptq-pre-layer', type=int, default=0, help='DEPRECATED: use --pre_layer instead.')
parser.add_argument('--wbits', type=int, default=0, help='GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.') parser.add_argument('--wbits', type=int, default=0, help='GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.')
parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently only LLaMA and OPT are supported.') parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.')
parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.') parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.')
parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to preload.') parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to preload.')