mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-24 17:06:53 +01:00
Add ability to load all text files from a subdirectory for training (#1997)
* Update utils.py returns individual txt files and subdirectories to getdatasets to allow for training from a directory of text files * Update training.py minor tweak to training on raw datasets to detect if a directory is selected, and if so, to load in all the txt files in that directory for training * Update put-trainer-datasets-here.txt document * Minor change * Use pathlib, sort by natural keys * Space --------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
73a0def4af
commit
5d513eea22
@ -32,6 +32,7 @@ from modules.evaluate import (
|
||||
save_past_evaluations
|
||||
)
|
||||
from modules.logging_colors import logger
|
||||
from modules.utils import natural_keys
|
||||
|
||||
# This mapping is from a very recent commit, not yet released.
|
||||
# If not available, default to a backup map for some common model types.
|
||||
@ -354,12 +355,23 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
|
||||
# == Prep the dataset, format, etc ==
|
||||
if raw_text_file not in ['None', '']:
|
||||
logger.info("Loading raw text file dataset...")
|
||||
|
||||
train_template["template_type"] = "raw_text"
|
||||
logger.info("Loading raw text file dataset...")
|
||||
fullpath = clean_path('training/datasets', f'{raw_text_file}')
|
||||
fullpath = Path(fullpath)
|
||||
if fullpath.is_dir():
|
||||
logger.info('Training path directory {}'.format(raw_text_file))
|
||||
raw_text = ""
|
||||
file_paths = sorted(fullpath.glob('*.txt'), key=lambda path: natural_keys(path.name))
|
||||
for file_path in file_paths:
|
||||
if file_path.is_file():
|
||||
with file_path.open('r', encoding='utf-8') as file:
|
||||
raw_text += file.read()
|
||||
|
||||
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
||||
raw_text = file.read().replace('\r', '')
|
||||
logger.info(f"Loaded training file: {file_path.name}")
|
||||
else:
|
||||
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
||||
raw_text = file.read()
|
||||
|
||||
cut_string = hard_cut_string.replace('\\n', '\n')
|
||||
out_tokens = []
|
||||
@ -579,7 +591,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
if WANT_INTERRUPT:
|
||||
yield "Interrupted before start."
|
||||
return
|
||||
|
||||
|
||||
def log_train_dataset(trainer):
|
||||
decoded_entries = []
|
||||
# Try to decode the entries and write the log file
|
||||
|
@ -114,6 +114,10 @@ def get_available_loras():
|
||||
|
||||
|
||||
def get_datasets(path: str, ext: str):
|
||||
# include subdirectories for raw txt files to allow training from a subdirectory of txt files
|
||||
if ext == "txt":
|
||||
return ['None'] + sorted(set([k.stem for k in list(Path(path).glob('txt'))+list(Path(path).glob('*/')) if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
|
||||
|
||||
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
|
||||
|
||||
|
||||
|
@ -0,0 +1 @@
|
||||
to load multiple raw text files create a subdirectory and put them all there
|
Loading…
Reference in New Issue
Block a user