Handle training exception for unsupported models

This commit is contained in:
oobabooga 2023-03-29 11:55:34 -03:00 committed by GitHub
parent a6d0373063
commit 58349f44a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,6 +2,7 @@ import json
import sys import sys
import threading import threading
import time import time
import traceback
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
@ -184,7 +185,13 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
bias="none", bias="none",
task_type="CAUSAL_LM" task_type="CAUSAL_LM"
) )
lora_model = get_peft_model(shared.model, config)
try:
lora_model = get_peft_model(shared.model, config)
except:
yield traceback.format_exc()
return
trainer = transformers.Trainer( trainer = transformers.Trainer(
model=lora_model, model=lora_model,
train_dataset=train_data, train_dataset=train_data,