diff --git a/modules/training.py b/modules/training.py index f6033d60..9dfe0a23 100644 --- a/modules/training.py +++ b/modules/training.py @@ -124,7 +124,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int print(f"Warning: LoRA training has only currently been validated for LLaMA models. (Found model type: {model_type})") time.sleep(5) - if shared.args.wbits > 0 or shared.args.gptq_bits > 0: + if shared.args.wbits > 0: yield "LoRA training does not yet support 4bit. Please use `--load-in-8bit` for now." return