mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-29 21:50:16 +01:00
Update to support GPTQ triton commit c90adef (#1229)
This commit is contained in:
parent
209fcd21d5
commit
b57ffc2ec9
@ -236,7 +236,9 @@ Optionally, you can use the following command-line flags:
|
|||||||
| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. |
|
| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. |
|
||||||
| `--groupsize GROUPSIZE` | GPTQ: Group size. |
|
| `--groupsize GROUPSIZE` | GPTQ: Group size. |
|
||||||
| `--pre_layer PRE_LAYER` | GPTQ: The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. |
|
| `--pre_layer PRE_LAYER` | GPTQ: The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. |
|
||||||
|
| `--no-quant_attn` | GPTQ: Disable quant attention for triton. If you encounter incoherent results try disabling this. |
|
||||||
| `--no-warmup_autotune` | GPTQ: Disable warmup autotune for triton. |
|
| `--no-warmup_autotune` | GPTQ: Disable warmup autotune for triton. |
|
||||||
|
| `--no-fused_mlp` | GPTQ: Disable fused mlp for triton. If you encounter "Unexpected mma -> mma layout conversion" try disabling this. |
|
||||||
| `--monkey-patch` | GPTQ: Apply the monkey patch for using LoRAs with quantized models. |
|
| `--monkey-patch` | GPTQ: Apply the monkey patch for using LoRAs with quantized models. |
|
||||||
|
|
||||||
#### FlexGen
|
#### FlexGen
|
||||||
|
@ -13,12 +13,18 @@ import modules.shared as shared
|
|||||||
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
|
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
|
||||||
import llama_inference_offload
|
import llama_inference_offload
|
||||||
from modelutils import find_layers
|
from modelutils import find_layers
|
||||||
from quant import make_quant
|
|
||||||
|
try:
|
||||||
|
from quant import make_quant
|
||||||
|
is_triton = False
|
||||||
|
except ImportError:
|
||||||
|
import quant
|
||||||
|
is_triton = True
|
||||||
|
|
||||||
|
|
||||||
# This function is a replacement for the load_quant function in the
|
# This function is a replacement for the load_quant function in the
|
||||||
# GPTQ-for_LLaMa repository. It supports more models and branches.
|
# GPTQ-for_LLaMa repository. It supports more models and branches.
|
||||||
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
|
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128, eval=True):
|
||||||
|
|
||||||
def noop(*args, **kwargs):
|
def noop(*args, **kwargs):
|
||||||
pass
|
pass
|
||||||
@ -33,27 +39,31 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||||||
torch.set_default_dtype(torch.half)
|
torch.set_default_dtype(torch.half)
|
||||||
model = AutoModelForCausalLM.from_config(config)
|
model = AutoModelForCausalLM.from_config(config)
|
||||||
torch.set_default_dtype(torch.float)
|
torch.set_default_dtype(torch.float)
|
||||||
model = model.eval()
|
if eval:
|
||||||
|
model = model.eval()
|
||||||
layers = find_layers(model)
|
layers = find_layers(model)
|
||||||
for name in exclude_layers:
|
for name in exclude_layers:
|
||||||
if name in layers:
|
if name in layers:
|
||||||
del layers[name]
|
del layers[name]
|
||||||
|
|
||||||
gptq_args = inspect.getfullargspec(make_quant).args
|
if not is_triton:
|
||||||
|
gptq_args = inspect.getfullargspec(make_quant).args
|
||||||
|
|
||||||
make_quant_kwargs = {
|
make_quant_kwargs = {
|
||||||
'module': model,
|
'module': model,
|
||||||
'names': layers,
|
'names': layers,
|
||||||
'bits': wbits,
|
'bits': wbits,
|
||||||
}
|
}
|
||||||
if 'groupsize' in gptq_args:
|
if 'groupsize' in gptq_args:
|
||||||
make_quant_kwargs['groupsize'] = groupsize
|
make_quant_kwargs['groupsize'] = groupsize
|
||||||
if 'faster' in gptq_args:
|
if 'faster' in gptq_args:
|
||||||
make_quant_kwargs['faster'] = faster_kernel
|
make_quant_kwargs['faster'] = faster_kernel
|
||||||
if 'kernel_switch_threshold' in gptq_args:
|
if 'kernel_switch_threshold' in gptq_args:
|
||||||
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
|
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
|
||||||
|
|
||||||
make_quant(**make_quant_kwargs)
|
make_quant(**make_quant_kwargs)
|
||||||
|
else:
|
||||||
|
quant.make_quant_linear(model, layers, wbits, groupsize)
|
||||||
|
|
||||||
del layers
|
del layers
|
||||||
|
|
||||||
@ -64,15 +74,16 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||||||
else:
|
else:
|
||||||
model.load_state_dict(torch.load(checkpoint), strict=False)
|
model.load_state_dict(torch.load(checkpoint), strict=False)
|
||||||
|
|
||||||
try:
|
if is_triton:
|
||||||
from quant import autotune_warmup, make_quant_attn
|
if not shared.args.no_quant_attn:
|
||||||
|
quant.make_quant_attn(model)
|
||||||
|
if eval and not shared.args.no_fused_mlp:
|
||||||
|
quant.make_fused_mlp(model)
|
||||||
|
|
||||||
# triton branch
|
|
||||||
make_quant_attn(model)
|
|
||||||
if not shared.args.no_warmup_autotune:
|
if not shared.args.no_warmup_autotune:
|
||||||
autotune_warmup(model)
|
quant.autotune_warmup_linear(model, transpose=not eval)
|
||||||
except ImportError: # not triton branch
|
if eval and not shared.args.no_fused_mlp:
|
||||||
pass
|
quant.autotune_warmup_fused(model)
|
||||||
|
|
||||||
model.seqlen = 2048
|
model.seqlen = 2048
|
||||||
print('Done.')
|
print('Done.')
|
||||||
|
@ -123,7 +123,9 @@ parser.add_argument('--wbits', type=int, default=0, help='GPTQ: Load a pre-quant
|
|||||||
parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.')
|
parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.')
|
||||||
parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.')
|
parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.')
|
||||||
parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models.')
|
parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models.')
|
||||||
|
parser.add_argument('--no-quant_attn', action='store_true', help='GPTQ: Disable quant attention for triton. If you encounter incoherent results try disabling this.')
|
||||||
parser.add_argument('--no-warmup_autotune', action='store_true', help='GPTQ: Disable warmup autotune for triton.')
|
parser.add_argument('--no-warmup_autotune', action='store_true', help='GPTQ: Disable warmup autotune for triton.')
|
||||||
|
parser.add_argument('--no-fused_mlp', action='store_true', help='GPTQ: Disable fused mlp for triton. If you encounter "Unexpected mma -> mma layout conversion" try disabling this.')
|
||||||
parser.add_argument('--monkey-patch', action='store_true', help='GPTQ: Apply the monkey patch for using LoRAs with quantized models.')
|
parser.add_argument('--monkey-patch', action='store_true', help='GPTQ: Apply the monkey patch for using LoRAs with quantized models.')
|
||||||
|
|
||||||
# FlexGen
|
# FlexGen
|
||||||
|
Loading…
Reference in New Issue
Block a user