diff --git a/modules/training.py b/modules/training.py index aaca44c5..f6033d60 100644 --- a/modules/training.py +++ b/modules/training.py @@ -185,10 +185,11 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int def generate_prompt(data_point: dict[str, str]): for options, data in format_data.items(): - if set(options.split(',')) == set(x[0] for x in data_point.items() if len(x[1].strip()) > 0): + if set(options.split(',')) == set(x[0] for x in data_point.items() if (x[1] != None and len(x[1].strip()) > 0)): for key, val in data_point.items(): - data = data.replace(f'%{key}%', val) - return data + if val != None: + data = data.replace(f'%{key}%', val) + return data raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"') def generate_and_tokenize_prompt(data_point):