Disable kernel threshold for gpt-j

This commit is contained in:
Maya Eary 2023-03-28 22:45:38 +03:00
parent 1ac003d41c
commit 41ec682834

View File

@ -14,7 +14,7 @@ import llama_inference_offload
from quant import make_quant from quant import make_quant
from modelutils import find_layers from modelutils import find_layers
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head']): def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
config = AutoConfig.from_pretrained(model) config = AutoConfig.from_pretrained(model)
def noop(*args, **kwargs): def noop(*args, **kwargs):
pass pass
@ -32,7 +32,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
for name in exclude_layers: for name in exclude_layers:
if name in layers: if name in layers:
del layers[name] del layers[name]
make_quant(model, layers, wbits, groupsize, faster=faster_kernel) make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
del layers del layers
@ -109,7 +109,8 @@ def load_quantized(model_name):
if shared.args.pre_layer: if shared.args.pre_layer:
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer) model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
else: else:
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize) threshold = False if model_type == 'gptj' else 128
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
# accelerate offload (doesn't work properly) # accelerate offload (doesn't work properly)
if shared.args.gpu_memory: if shared.args.gpu_memory: