From 5c513a5f5c369008947a0a39f4910e4c7c120b86 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 16 Apr 2023 02:46:27 -0300 Subject: [PATCH] Make training.py more readable --- modules/training.py | 52 ++++++++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/modules/training.py b/modules/training.py index cd4f7c0e..409d7102 100644 --- a/modules/training.py +++ b/modules/training.py @@ -10,23 +10,27 @@ import gradio as gr import torch import transformers from datasets import Dataset, load_dataset -from peft import (LoraConfig, get_peft_model, get_peft_model_state_dict, - PeftModel, prepare_model_for_int8_training) - -try: # This mapping is from a very recent commit, not yet released. - from peft.utils.other import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as model_to_lora_modules -except: # So good backup for the 3 safe model types if not yet available. - standard_modules = ["q_proj", "v_proj"] - model_to_lora_modules = { "llama": standard_modules, "opt": standard_modules, "gptj": standard_modules } +from peft import (LoraConfig, PeftModel, get_peft_model, + get_peft_model_state_dict, prepare_model_for_int8_training) from modules import shared, ui +# This mapping is from a very recent commit, not yet released. +# If not available, default to a backup map for the 3 safe model types. +try: + from peft.utils.other import \ + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \ + model_to_lora_modules +except: + standard_modules = ["q_proj", "v_proj"] + model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules} + WANT_INTERRUPT = False +PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "do_shuffle", "higher_rank_limit"] -PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lora_rank", "lora_alpha", - "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "do_shuffle", "higher_rank_limit"] -MODEL_CLASSES = { # Mapping of Python class names to peft IDs +# Mapping of Python class names to peft IDs +MODEL_CLASSES = { "LlamaForCausalLM": "llama", "OPTForCausalLM": "opt", "GPTJForCausalLM": "gptj" @@ -79,12 +83,14 @@ def create_train_interface(): ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'json')}, 'refresh-button') format = gr.Dropdown(choices=get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.') ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_datasets('training/formats', 'json')}, 'refresh-button') + eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.') with gr.Tab(label="Raw Text File"): with gr.Row(): raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.') ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'txt')}, 'refresh-button') + with gr.Row(): overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.') newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.') @@ -94,16 +100,16 @@ def create_train_interface(): stop_button = gr.Button("Interrupt") output = gr.Markdown(value="Ready") - all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, - cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit] + all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit] start_button.click(do_train, all_params, [output]) stop_button.click(do_interrupt, [], [], cancels=[], queue=False) def do_copy_params(lora_name: str): with open(f"{shared.args.lora_dir}/{clean_path(None, lora_name)}/training_parameters.json", 'r', encoding='utf-8') as formatFile: params: dict[str, str] = json.load(formatFile) + return [params[x] for x in PARAMETERS] - + copy_from.change(do_copy_params, [copy_from], all_params) def change_rank_limit(use_higher_ranks: bool): @@ -118,7 +124,6 @@ def do_interrupt(): WANT_INTERRUPT = True - def clean_path(base_path: str, path: str): """"Strips unusual symbols and forcibly builds a path as relative to the intended directory.""" # TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path. @@ -126,6 +131,7 @@ def clean_path(base_path: str, path: str): path = path.replace('\\', '/').replace('..', '_') if base_path is None: return path + return f'{Path(base_path).absolute()}/{path}' @@ -182,15 +188,15 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch print("Loading raw text file dataset...") with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file: raw_text = file.read() + tokens = shared.tokenizer.encode(raw_text) del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM - tokens = list(split_chunks(tokens, cutoff_len - overlap_len)) for i in range(1, len(tokens)): tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i] + text_chunks = [shared.tokenizer.decode(x) for x in tokens] del tokens - if newline_favor_len > 0: text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks] @@ -212,9 +218,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch def generate_prompt(data_point: dict[str, str]): for options, data in format_data.items(): - if set(options.split(',')) == set(x[0] for x in data_point.items() if (x[1] != None and len(x[1].strip()) > 0)): + if set(options.split(',')) == set(x[0] for x in data_point.items() if (x[1] is not None and len(x[1].strip()) > 0)): for key, val in data_point.items(): - if val != None: + if val is not None: data = data.replace(f'%{key}%', val) return data raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"') @@ -357,7 +363,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch timer_info = f"`{its:.2f}` it/s" else: timer_info = f"`{1.0/its:.2f}` s/it" + total_time_estimate = (1.0 / its) * (tracked.max_steps) + yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining" print("Training complete, saving...") @@ -379,22 +387,28 @@ def split_chunks(arr, step): def cut_chunk_for_newline(chunk: str, max_length: int): if '\n' not in chunk: return chunk + first_newline = chunk.index('\n') if first_newline < max_length: chunk = chunk[first_newline + 1:] + if '\n' not in chunk: return chunk + last_newline = chunk.rindex('\n') if len(chunk) - last_newline < max_length: chunk = chunk[:last_newline] + return chunk def format_time(seconds: float): if seconds < 120: return f"`{seconds:.0f}` seconds" + minutes = seconds / 60 if minutes < 120: return f"`{minutes:.0f}` minutes" + hours = minutes / 60 return f"`{hours:.0f}` hours"