diff --git a/modules/models.py b/modules/models.py index 4303b937..5bd9db74 100644 --- a/modules/models.py +++ b/modules/models.py @@ -277,24 +277,24 @@ def ctransformers_loader(model_name): model, tokenizer = ctrans.from_pretrained(model_file) return model, tokenizer + def AutoAWQ_loader(model_name): - from awq import AutoAWQForCausalLM + from awq import AutoAWQForCausalLM - model_dir = Path(f'{shared.args.model_dir}/{model_name}') + model_dir = Path(f'{shared.args.model_dir}/{model_name}') - if shared.args.deepspeed: - logger.warn("AutoAWQ is incompatible with deepspeed") + model = AutoAWQForCausalLM.from_quantized( + quant_path=model_dir, + max_new_tokens=shared.args.max_seq_len, + trust_remote_code=shared.args.trust_remote_code, + fuse_layers=not shared.args.no_inject_fused_attention, + max_memory=get_max_memory_dict(), + batch_size=shared.args.n_batch, + safetensors=any(model_dir.glob('*.safetensors')), + ) - model = AutoAWQForCausalLM.from_quantized( - quant_path=model_dir, - max_new_tokens=shared.args.max_seq_len, - trust_remote_code=shared.args.trust_remote_code, - fuse_layers=not shared.args.no_inject_fused_attention, - max_memory=get_max_memory_dict(), - batch_size=shared.args.n_batch, - safetensors=not shared.args.trust_remote_code) + return model - return model def GPTQ_loader(model_name):