From 4405513ca5cec6c3733fcc8c66f71702659e7c43 Mon Sep 17 00:00:00 2001 From: omo <140050869+computerman00@users.noreply.github.com> Date: Sun, 22 Oct 2023 11:57:19 -0700 Subject: [PATCH] Option to select/target additional linear modules/layers in LORA training (#4178) --- modules/training.py | 47 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/modules/training.py b/modules/training.py index 107bab72..b880152a 100644 --- a/modules/training.py +++ b/modules/training.py @@ -41,7 +41,7 @@ from modules.models import reload_model from modules.utils import natural_keys MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()} -PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to"] +PARAMETERS = ["lora_name", "always_override", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to"] WANT_INTERRUPT = False train_log = {} @@ -67,13 +67,31 @@ def create_ui(): with gr.Column(): always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background']) + with gr.Accordion(label='Target Modules', open=False): + gr.Markdown("Selects which modules to target in training. Targeting more modules is closer to a full fine-tune at the cost of increased VRAM requirements and adapter size.\nNOTE: Only works for model_id='llama', other types will retain default training behavior and not use these settings.") + with gr.Row(): + with gr.Column(): + q_proj_en = gr.Checkbox(label='Enable q_proj', value=True) + with gr.Column(): + v_proj_en = gr.Checkbox(label='Enable v_proj', value=True) + with gr.Column(): + k_proj_en = gr.Checkbox(label='Enable k_proj', value=False) + with gr.Column(): + o_proj_en = gr.Checkbox(label='Enable o_proj', value=False) + with gr.Column(): + gate_proj_en = gr.Checkbox(label='Enable gate_proj', value=False) + with gr.Column(): + down_proj_en = gr.Checkbox(label='Enable down_proj', value=False) + with gr.Column(): + up_proj_en = gr.Checkbox(label='Enable up_proj', value=False) + with gr.Row(): with gr.Column(): lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='Also called dimension count. Higher values = larger file, more content control. Smaller values = smaller file, less control. Use 4 or 8 for style, 128 or 256 to teach, 1024+ for fine-detail on big data. More VRAM is needed for higher ranks.') lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.') batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.') micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.') - cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.') + cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=4096, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.') with gr.Column(): save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.') @@ -162,7 +180,7 @@ def create_ui(): refresh_table = gr.Button('Refresh the table', elem_classes="small-button", interactive=not mu) # Training events - all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to] + all_params = [lora_name, always_override, q_proj_en, v_proj_en, k_proj_en, o_proj_en, gate_proj_en, down_proj_en, up_proj_en, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to] copy_from.change(do_copy_params, [copy_from] + all_params, all_params) start_button.click(do_train, all_params, output) @@ -269,7 +287,7 @@ def calc_trainable_parameters(model): return trainable_params, all_param -def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str): +def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str): if shared.args.monkey_patch: from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import ( @@ -320,6 +338,23 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch shared.tokenizer.pad_token_id = 0 shared.tokenizer.padding_side = "left" + # Populate target_modules list with chosen X_proj modules. Llama-based models only atm, non-llama will revert to default behavior. + def list_target_modules(model_id): + if model_id != "llama": + return model_to_lora_modules[model_id] + + available_modules = { + "gate": gate_proj_en, + "down": down_proj_en, + "up": up_proj_en, + "q": q_proj_en, + "v": v_proj_en, + "k": k_proj_en, + "o": o_proj_en, + } + target_mods = [f"{name}_proj" for name, enabled in available_modules.items() if enabled] + return target_mods + def encode(text, add_bos_token): result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len) # Check if the first two tokens are BOS @@ -490,7 +525,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch config = LoraConfig( r=lora_rank, lora_alpha=lora_alpha, - target_modules=model_to_lora_modules[model_id], + target_modules=list_target_modules(model_id), lora_dropout=lora_dropout, bias="none", task_type="CAUSAL_LM" @@ -616,7 +651,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model) - projections_string = ", ".join([projection.replace("_proj", "") for projection in model_to_lora_modules[model_id]]) + projections_string = ", ".join([projection.replace("_proj", "") for projection in list_target_modules(model_id)]) print(f"Training '{model_id}' model using ({projections_string}) projections")