mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 09:19:23 +01:00
Generalize GPTQ_loader, support any model (#615 from mayaeary/feature/gpt-j-4bit-v2)
This includes Pygmalion 4bit
This commit is contained in:
commit
b2f356a9ae
@ -177,7 +177,7 @@ Optionally, you can use the following command-line flags:
|
|||||||
| `--cpu` | Use the CPU to generate text.|
|
| `--cpu` | Use the CPU to generate text.|
|
||||||
| `--load-in-8bit` | Load the model with 8-bit precision.|
|
| `--load-in-8bit` | Load the model with 8-bit precision.|
|
||||||
| `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. |
|
| `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. |
|
||||||
| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently only LLaMA and OPT are supported. |
|
| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. |
|
||||||
| `--groupsize GROUPSIZE` | GPTQ: Group size. |
|
| `--groupsize GROUPSIZE` | GPTQ: Group size. |
|
||||||
| `--pre_layer PRE_LAYER` | GPTQ: The number of layers to preload. |
|
| `--pre_layer PRE_LAYER` | GPTQ: The number of layers to preload. |
|
||||||
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
|
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
|
||||||
|
@ -4,15 +4,50 @@ from pathlib import Path
|
|||||||
|
|
||||||
import accelerate
|
import accelerate
|
||||||
import torch
|
import torch
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
|
||||||
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
|
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
|
||||||
import llama
|
|
||||||
import llama_inference_offload
|
import llama_inference_offload
|
||||||
import opt
|
from modelutils import find_layers
|
||||||
|
from quant import make_quant
|
||||||
|
|
||||||
|
|
||||||
|
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
|
||||||
|
config = AutoConfig.from_pretrained(model)
|
||||||
|
def noop(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
torch.nn.init.kaiming_uniform_ = noop
|
||||||
|
torch.nn.init.uniform_ = noop
|
||||||
|
torch.nn.init.normal_ = noop
|
||||||
|
|
||||||
|
torch.set_default_dtype(torch.half)
|
||||||
|
transformers.modeling_utils._init_weights = False
|
||||||
|
torch.set_default_dtype(torch.half)
|
||||||
|
model = AutoModelForCausalLM.from_config(config)
|
||||||
|
torch.set_default_dtype(torch.float)
|
||||||
|
model = model.eval()
|
||||||
|
layers = find_layers(model)
|
||||||
|
for name in exclude_layers:
|
||||||
|
if name in layers:
|
||||||
|
del layers[name]
|
||||||
|
make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
|
||||||
|
|
||||||
|
del layers
|
||||||
|
|
||||||
|
print('Loading model ...')
|
||||||
|
if checkpoint.endswith('.safetensors'):
|
||||||
|
from safetensors.torch import load_file as safe_load
|
||||||
|
model.load_state_dict(safe_load(checkpoint))
|
||||||
|
else:
|
||||||
|
model.load_state_dict(torch.load(checkpoint))
|
||||||
|
model.seqlen = 2048
|
||||||
|
print('Done.')
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
def load_quantized(model_name):
|
def load_quantized(model_name):
|
||||||
if not shared.args.model_type:
|
if not shared.args.model_type:
|
||||||
# Try to determine model type from model name
|
# Try to determine model type from model name
|
||||||
@ -20,6 +55,8 @@ def load_quantized(model_name):
|
|||||||
model_type = 'llama'
|
model_type = 'llama'
|
||||||
elif model_name.lower().startswith(('opt', 'galactica')):
|
elif model_name.lower().startswith(('opt', 'galactica')):
|
||||||
model_type = 'opt'
|
model_type = 'opt'
|
||||||
|
elif model_name.lower().startswith(('gpt-j', 'pygmalion-6b')):
|
||||||
|
model_type = 'gptj'
|
||||||
else:
|
else:
|
||||||
print("Can't determine model type from model name. Please specify it manually using --model_type "
|
print("Can't determine model type from model name. Please specify it manually using --model_type "
|
||||||
"argument")
|
"argument")
|
||||||
@ -27,15 +64,12 @@ def load_quantized(model_name):
|
|||||||
else:
|
else:
|
||||||
model_type = shared.args.model_type.lower()
|
model_type = shared.args.model_type.lower()
|
||||||
|
|
||||||
if model_type == 'llama':
|
if model_type == 'llama' and shared.args.pre_layer:
|
||||||
if not shared.args.pre_layer:
|
|
||||||
load_quant = llama.load_quant
|
|
||||||
else:
|
|
||||||
load_quant = llama_inference_offload.load_quant
|
load_quant = llama_inference_offload.load_quant
|
||||||
elif model_type == 'opt':
|
elif model_type in ('llama', 'opt', 'gptj'):
|
||||||
load_quant = opt.load_quant
|
load_quant = _load_quant
|
||||||
else:
|
else:
|
||||||
print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported")
|
print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
# Now we are going to try to locate the quantized model file.
|
# Now we are going to try to locate the quantized model file.
|
||||||
@ -75,7 +109,8 @@ def load_quantized(model_name):
|
|||||||
if shared.args.pre_layer:
|
if shared.args.pre_layer:
|
||||||
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
|
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
|
||||||
else:
|
else:
|
||||||
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize)
|
threshold = False if model_type == 'gptj' else 128
|
||||||
|
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
|
||||||
|
|
||||||
# accelerate offload (doesn't work properly)
|
# accelerate offload (doesn't work properly)
|
||||||
if shared.args.gpu_memory:
|
if shared.args.gpu_memory:
|
||||||
|
@ -84,7 +84,7 @@ parser.add_argument('--gptq-bits', type=int, default=0, help='DEPRECATED: use --
|
|||||||
parser.add_argument('--gptq-model-type', type=str, help='DEPRECATED: use --model_type instead.')
|
parser.add_argument('--gptq-model-type', type=str, help='DEPRECATED: use --model_type instead.')
|
||||||
parser.add_argument('--gptq-pre-layer', type=int, default=0, help='DEPRECATED: use --pre_layer instead.')
|
parser.add_argument('--gptq-pre-layer', type=int, default=0, help='DEPRECATED: use --pre_layer instead.')
|
||||||
parser.add_argument('--wbits', type=int, default=0, help='GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.')
|
parser.add_argument('--wbits', type=int, default=0, help='GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.')
|
||||||
parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently only LLaMA and OPT are supported.')
|
parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.')
|
||||||
parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.')
|
parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.')
|
||||||
parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to preload.')
|
parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to preload.')
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user