Update to support GPTQ triton commit c90adef (#1229)

This commit is contained in:
sgsdxzy 2023-04-17 12:11:18 +08:00 committed by GitHub
parent 209fcd21d5
commit b57ffc2ec9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 23 deletions

View File

@ -236,7 +236,9 @@ Optionally, you can use the following command-line flags:
| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. | | `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. |
| `--groupsize GROUPSIZE` | GPTQ: Group size. | | `--groupsize GROUPSIZE` | GPTQ: Group size. |
| `--pre_layer PRE_LAYER` | GPTQ: The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. | | `--pre_layer PRE_LAYER` | GPTQ: The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. |
| `--no-quant_attn` | GPTQ: Disable quant attention for triton. If you encounter incoherent results try disabling this. |
| `--no-warmup_autotune` | GPTQ: Disable warmup autotune for triton. | | `--no-warmup_autotune` | GPTQ: Disable warmup autotune for triton. |
| `--no-fused_mlp` | GPTQ: Disable fused mlp for triton. If you encounter "Unexpected mma -> mma layout conversion" try disabling this. |
| `--monkey-patch` | GPTQ: Apply the monkey patch for using LoRAs with quantized models. | | `--monkey-patch` | GPTQ: Apply the monkey patch for using LoRAs with quantized models. |
#### FlexGen #### FlexGen

View File

@ -13,12 +13,18 @@ import modules.shared as shared
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa"))) sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
import llama_inference_offload import llama_inference_offload
from modelutils import find_layers from modelutils import find_layers
from quant import make_quant
try:
from quant import make_quant
is_triton = False
except ImportError:
import quant
is_triton = True
# This function is a replacement for the load_quant function in the # This function is a replacement for the load_quant function in the
# GPTQ-for_LLaMa repository. It supports more models and branches. # GPTQ-for_LLaMa repository. It supports more models and branches.
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128): def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128, eval=True):
def noop(*args, **kwargs): def noop(*args, **kwargs):
pass pass
@ -33,27 +39,31 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
torch.set_default_dtype(torch.half) torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config) model = AutoModelForCausalLM.from_config(config)
torch.set_default_dtype(torch.float) torch.set_default_dtype(torch.float)
model = model.eval() if eval:
model = model.eval()
layers = find_layers(model) layers = find_layers(model)
for name in exclude_layers: for name in exclude_layers:
if name in layers: if name in layers:
del layers[name] del layers[name]
gptq_args = inspect.getfullargspec(make_quant).args if not is_triton:
gptq_args = inspect.getfullargspec(make_quant).args
make_quant_kwargs = { make_quant_kwargs = {
'module': model, 'module': model,
'names': layers, 'names': layers,
'bits': wbits, 'bits': wbits,
} }
if 'groupsize' in gptq_args: if 'groupsize' in gptq_args:
make_quant_kwargs['groupsize'] = groupsize make_quant_kwargs['groupsize'] = groupsize
if 'faster' in gptq_args: if 'faster' in gptq_args:
make_quant_kwargs['faster'] = faster_kernel make_quant_kwargs['faster'] = faster_kernel
if 'kernel_switch_threshold' in gptq_args: if 'kernel_switch_threshold' in gptq_args:
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
make_quant(**make_quant_kwargs) make_quant(**make_quant_kwargs)
else:
quant.make_quant_linear(model, layers, wbits, groupsize)
del layers del layers
@ -64,15 +74,16 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
else: else:
model.load_state_dict(torch.load(checkpoint), strict=False) model.load_state_dict(torch.load(checkpoint), strict=False)
try: if is_triton:
from quant import autotune_warmup, make_quant_attn if not shared.args.no_quant_attn:
quant.make_quant_attn(model)
if eval and not shared.args.no_fused_mlp:
quant.make_fused_mlp(model)
# triton branch
make_quant_attn(model)
if not shared.args.no_warmup_autotune: if not shared.args.no_warmup_autotune:
autotune_warmup(model) quant.autotune_warmup_linear(model, transpose=not eval)
except ImportError: # not triton branch if eval and not shared.args.no_fused_mlp:
pass quant.autotune_warmup_fused(model)
model.seqlen = 2048 model.seqlen = 2048
print('Done.') print('Done.')

View File

@ -123,7 +123,9 @@ parser.add_argument('--wbits', type=int, default=0, help='GPTQ: Load a pre-quant
parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.') parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.')
parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.') parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.')
parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models.') parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models.')
parser.add_argument('--no-quant_attn', action='store_true', help='GPTQ: Disable quant attention for triton. If you encounter incoherent results try disabling this.')
parser.add_argument('--no-warmup_autotune', action='store_true', help='GPTQ: Disable warmup autotune for triton.') parser.add_argument('--no-warmup_autotune', action='store_true', help='GPTQ: Disable warmup autotune for triton.')
parser.add_argument('--no-fused_mlp', action='store_true', help='GPTQ: Disable fused mlp for triton. If you encounter "Unexpected mma -> mma layout conversion" try disabling this.')
parser.add_argument('--monkey-patch', action='store_true', help='GPTQ: Apply the monkey patch for using LoRAs with quantized models.') parser.add_argument('--monkey-patch', action='store_true', help='GPTQ: Apply the monkey patch for using LoRAs with quantized models.')
# FlexGen # FlexGen