Fix safetensors kwarg usage in AutoAWQ

This commit is contained in:
oobabooga 2023-10-10 19:03:09 -07:00
parent 39f16ff83d
commit f63361568c

View File

@ -277,24 +277,24 @@ def ctransformers_loader(model_name):
model, tokenizer = ctrans.from_pretrained(model_file) model, tokenizer = ctrans.from_pretrained(model_file)
return model, tokenizer return model, tokenizer
def AutoAWQ_loader(model_name): def AutoAWQ_loader(model_name):
from awq import AutoAWQForCausalLM from awq import AutoAWQForCausalLM
model_dir = Path(f'{shared.args.model_dir}/{model_name}') model_dir = Path(f'{shared.args.model_dir}/{model_name}')
if shared.args.deepspeed: model = AutoAWQForCausalLM.from_quantized(
logger.warn("AutoAWQ is incompatible with deepspeed") quant_path=model_dir,
max_new_tokens=shared.args.max_seq_len,
trust_remote_code=shared.args.trust_remote_code,
fuse_layers=not shared.args.no_inject_fused_attention,
max_memory=get_max_memory_dict(),
batch_size=shared.args.n_batch,
safetensors=any(model_dir.glob('*.safetensors')),
)
model = AutoAWQForCausalLM.from_quantized( return model
quant_path=model_dir,
max_new_tokens=shared.args.max_seq_len,
trust_remote_code=shared.args.trust_remote_code,
fuse_layers=not shared.args.no_inject_fused_attention,
max_memory=get_max_memory_dict(),
batch_size=shared.args.n_batch,
safetensors=not shared.args.trust_remote_code)
return model
def GPTQ_loader(model_name): def GPTQ_loader(model_name):