From 50c70e28f03dc084c180b1ab884457d157aad5c3 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" <4000772+mcmonkey4eva@users.noreply.github.com> Date: Fri, 19 May 2023 08:58:54 -0700 Subject: [PATCH] Lora Trainer improvements, part 6 - slightly better raw text inputs (#2108) --- docs/Training-LoRAs.md | 7 +++++++ modules/training.py | 41 ++++++++++++++++++++++++++++------------- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/docs/Training-LoRAs.md b/docs/Training-LoRAs.md index 406ec1e4..83e6d5a7 100644 --- a/docs/Training-LoRAs.md +++ b/docs/Training-LoRAs.md @@ -75,6 +75,13 @@ So for example if a dataset has `"instruction": "answer my question"`, then the If you have different sets of key inputs, you can make your own format file to match it. This format-file is designed to be as simple as possible to enable easy editing to match your needs. +## Raw Text File Settings + +When using raw text files as your dataset, the text is automatically split into chunks based on your `Cutoff Length` you get a few basic options to configure them. +- `Overlap Length` is how much to overlap chunks by. Overlapping chunks helps prevent the model from learning strange mid-sentence cuts, and instead learn continual sentences that flow from earlier text. +- `Prefer Newline Cut Length` sets a maximum distance in characters to shift the chunk cut towards newlines. Doing this helps prevent lines from starting or ending mid-sentence, preventing the model from learning to cut off sentences randomly. +- `Hard Cut String` sets a string that indicates there must be a hard cut without overlap. This defaults to `\n\n\n`, meaning 3 newlines. No trained chunk will ever contain this string. This allows you to insert unrelated sections of text in the same text file, but still ensure the model won't be taught to randomly change the subject. + ## Parameters The basic purpose and function of each parameter is documented on-page in the WebUI, so read through them in the UI to understand your options. diff --git a/modules/training.py b/modules/training.py index e2410edc..0228b1c6 100644 --- a/modules/training.py +++ b/modules/training.py @@ -36,9 +36,9 @@ except: "GPTNeoXForCausalLM": "gpt_neox" } -WANT_INTERRUPT = False -PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer"] +WANT_INTERRUPT = False +PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string"] def create_train_interface(): @@ -85,6 +85,7 @@ def create_train_interface(): with gr.Row(): raw_text_file = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.') ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'txt')}, 'refresh-button') + hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a hard cut between text parts. Helps prevent unwanted overlap.') with gr.Row(): overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.') @@ -125,7 +126,7 @@ def create_train_interface(): save_comments = gr.Button('Save comments') # Training events - all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer] + all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string] copy_from.change(do_copy_params, [copy_from] + all_params, all_params) start_button.click(do_train, all_params, output) stop_button.click(do_interrupt, None, None, queue=False) @@ -178,7 +179,7 @@ def change_rank_limit(use_higher_ranks: bool): def clean_path(base_path: str, path: str): - """"Strips unusual symbols and forcibly builds a path as relative to the intended directory.""" + """Strips unusual symbols and forcibly builds a path as relative to the intended directory.""" # TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path. # Or swap it to a strict whitelist of [a-zA-Z_0-9] path = path.replace('\\', '/').replace('..', '_') @@ -188,7 +189,7 @@ def clean_path(base_path: str, path: str): return f'{Path(base_path).absolute()}/{path}' -def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str): +def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str): if shared.args.monkey_patch: from monkeypatch.peft_tuners_lora_monkey_patch import \ @@ -254,16 +255,30 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch if raw_text_file not in ['None', '']: logging.info("Loading raw text file dataset...") with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file: - raw_text = file.read() + raw_text = file.read().replace('\r', '') + + cut_string = hard_cut_string.replace('\\n', '\n') + out_tokens = [] + for text_part in raw_text.split(cut_string): + if text_part.strip() == '': + continue + + tokens = shared.tokenizer.encode(text_part) + step = cutoff_len - overlap_len + if step <= 0: + yield f"Error: overlap_len ({overlap_len}) cannot be greater than or equal to cutoff_len ({cutoff_len})" + return + + tokens = list(split_chunks(tokens, step)) + for i in range(1, len(tokens)): + tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i] + + out_tokens.extend(tokens) + del tokens - tokens = shared.tokenizer.encode(raw_text) del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM - tokens = list(split_chunks(tokens, cutoff_len - overlap_len)) - for i in range(1, len(tokens)): - tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i] - - text_chunks = [shared.tokenizer.decode(x) for x in tokens] - del tokens + text_chunks = [shared.tokenizer.decode(x) for x in out_tokens] + del out_tokens if newline_favor_len > 0: text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]