diff --git a/modules/training.py b/modules/training.py index 913866d9..62ba181c 100644 --- a/modules/training.py +++ b/modules/training.py @@ -2,6 +2,7 @@ import json import sys import threading import time +import traceback from pathlib import Path import gradio as gr @@ -184,7 +185,13 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int bias="none", task_type="CAUSAL_LM" ) - lora_model = get_peft_model(shared.model, config) + + try: + lora_model = get_peft_model(shared.model, config) + except: + yield traceback.format_exc() + return + trainer = transformers.Trainer( model=lora_model, train_dataset=train_data,