diff --git a/modules/training.py b/modules/training.py index 7bcecb38..913866d9 100644 --- a/modules/training.py +++ b/modules/training.py @@ -119,7 +119,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int } # == Prep the dataset, format, etc == - if raw_text_file is not None: + if raw_text_file not in ['None', '']: print("Loading raw text file dataset...") with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file: raw_text = file.read() @@ -136,16 +136,17 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int del text_chunks else: - with open(clean_path('training/formats', f'{format}.json'), 'r') as formatFile: - format_data: dict[str, str] = json.load(formatFile) - - if dataset is None: + if dataset in ['None', '']: yield "**Missing dataset choice input, cannot continue.**" return - if format is None: + + if format in ['None', '']: yield "**Missing format choice input, cannot continue.**" return + with open(clean_path('training/formats', f'{format}.json'), 'r') as formatFile: + format_data: dict[str, str] = json.load(formatFile) + def generate_prompt(data_point: dict[str, str]): for options, data in format_data.items(): if set(options.split(',')) == set(x[0] for x in data_point.items() if len(x[1].strip()) > 0):