diff --git a/modules/training.py b/modules/training.py index 442b92b3..2f9a7768 100644 --- a/modules/training.py +++ b/modules/training.py @@ -32,6 +32,7 @@ from modules.evaluate import ( save_past_evaluations ) from modules.logging_colors import logger +from modules.utils import natural_keys # This mapping is from a very recent commit, not yet released. # If not available, default to a backup map for some common model types. @@ -354,12 +355,23 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch # == Prep the dataset, format, etc == if raw_text_file not in ['None', '']: - logger.info("Loading raw text file dataset...") - train_template["template_type"] = "raw_text" + logger.info("Loading raw text file dataset...") + fullpath = clean_path('training/datasets', f'{raw_text_file}') + fullpath = Path(fullpath) + if fullpath.is_dir(): + logger.info('Training path directory {}'.format(raw_text_file)) + raw_text = "" + file_paths = sorted(fullpath.glob('*.txt'), key=lambda path: natural_keys(path.name)) + for file_path in file_paths: + if file_path.is_file(): + with file_path.open('r', encoding='utf-8') as file: + raw_text += file.read() - with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file: - raw_text = file.read().replace('\r', '') + logger.info(f"Loaded training file: {file_path.name}") + else: + with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file: + raw_text = file.read() cut_string = hard_cut_string.replace('\\n', '\n') out_tokens = [] @@ -579,7 +591,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch if WANT_INTERRUPT: yield "Interrupted before start." return - + def log_train_dataset(trainer): decoded_entries = [] # Try to decode the entries and write the log file diff --git a/modules/utils.py b/modules/utils.py index 72a0dfa1..8b662be1 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -114,6 +114,10 @@ def get_available_loras(): def get_datasets(path: str, ext: str): + # include subdirectories for raw txt files to allow training from a subdirectory of txt files + if ext == "txt": + return ['None'] + sorted(set([k.stem for k in list(Path(path).glob('txt'))+list(Path(path).glob('*/')) if k.stem != 'put-trainer-datasets-here']), key=natural_keys) + return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys) diff --git a/training/datasets/put-trainer-datasets-here.txt b/training/datasets/put-trainer-datasets-here.txt index e69de29b..932eacf8 100644 --- a/training/datasets/put-trainer-datasets-here.txt +++ b/training/datasets/put-trainer-datasets-here.txt @@ -0,0 +1 @@ +to load multiple raw text files create a subdirectory and put them all there