Option to select/target additional linear modules/layers in LORA training (#4178)

This commit is contained in:
omo 2023-10-22 11:57:19 -07:00 committed by GitHub
parent 7a3f885ea8
commit 4405513ca5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -41,7 +41,7 @@ from modules.models import reload_model
from modules.utils import natural_keys from modules.utils import natural_keys
MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()} MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()}
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to"] PARAMETERS = ["lora_name", "always_override", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to"]
WANT_INTERRUPT = False WANT_INTERRUPT = False
train_log = {} train_log = {}
@ -67,13 +67,31 @@ def create_ui():
with gr.Column(): with gr.Column():
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background']) always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background'])
with gr.Accordion(label='Target Modules', open=False):
gr.Markdown("Selects which modules to target in training. Targeting more modules is closer to a full fine-tune at the cost of increased VRAM requirements and adapter size.\nNOTE: Only works for model_id='llama', other types will retain default training behavior and not use these settings.")
with gr.Row():
with gr.Column():
q_proj_en = gr.Checkbox(label='Enable q_proj', value=True)
with gr.Column():
v_proj_en = gr.Checkbox(label='Enable v_proj', value=True)
with gr.Column():
k_proj_en = gr.Checkbox(label='Enable k_proj', value=False)
with gr.Column():
o_proj_en = gr.Checkbox(label='Enable o_proj', value=False)
with gr.Column():
gate_proj_en = gr.Checkbox(label='Enable gate_proj', value=False)
with gr.Column():
down_proj_en = gr.Checkbox(label='Enable down_proj', value=False)
with gr.Column():
up_proj_en = gr.Checkbox(label='Enable up_proj', value=False)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='Also called dimension count. Higher values = larger file, more content control. Smaller values = smaller file, less control. Use 4 or 8 for style, 128 or 256 to teach, 1024+ for fine-detail on big data. More VRAM is needed for higher ranks.') lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='Also called dimension count. Higher values = larger file, more content control. Smaller values = smaller file, less control. Use 4 or 8 for style, 128 or 256 to teach, 1024+ for fine-detail on big data. More VRAM is needed for higher ranks.')
lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.') lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.') batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.') micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.') cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=4096, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
with gr.Column(): with gr.Column():
save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.') save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.')
@ -162,7 +180,7 @@ def create_ui():
refresh_table = gr.Button('Refresh the table', elem_classes="small-button", interactive=not mu) refresh_table = gr.Button('Refresh the table', elem_classes="small-button", interactive=not mu)
# Training events # Training events
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to] all_params = [lora_name, always_override, q_proj_en, v_proj_en, k_proj_en, o_proj_en, gate_proj_en, down_proj_en, up_proj_en, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to]
copy_from.change(do_copy_params, [copy_from] + all_params, all_params) copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
start_button.click(do_train, all_params, output) start_button.click(do_train, all_params, output)
@ -269,7 +287,7 @@ def calc_trainable_parameters(model):
return trainable_params, all_param return trainable_params, all_param
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str): def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str):
if shared.args.monkey_patch: if shared.args.monkey_patch:
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import ( from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
@ -320,6 +338,23 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
shared.tokenizer.pad_token_id = 0 shared.tokenizer.pad_token_id = 0
shared.tokenizer.padding_side = "left" shared.tokenizer.padding_side = "left"
# Populate target_modules list with chosen X_proj modules. Llama-based models only atm, non-llama will revert to default behavior.
def list_target_modules(model_id):
if model_id != "llama":
return model_to_lora_modules[model_id]
available_modules = {
"gate": gate_proj_en,
"down": down_proj_en,
"up": up_proj_en,
"q": q_proj_en,
"v": v_proj_en,
"k": k_proj_en,
"o": o_proj_en,
}
target_mods = [f"{name}_proj" for name, enabled in available_modules.items() if enabled]
return target_mods
def encode(text, add_bos_token): def encode(text, add_bos_token):
result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len) result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len)
# Check if the first two tokens are BOS # Check if the first two tokens are BOS
@ -490,7 +525,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
config = LoraConfig( config = LoraConfig(
r=lora_rank, r=lora_rank,
lora_alpha=lora_alpha, lora_alpha=lora_alpha,
target_modules=model_to_lora_modules[model_id], target_modules=list_target_modules(model_id),
lora_dropout=lora_dropout, lora_dropout=lora_dropout,
bias="none", bias="none",
task_type="CAUSAL_LM" task_type="CAUSAL_LM"
@ -616,7 +651,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model) lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
projections_string = ", ".join([projection.replace("_proj", "") for projection in model_to_lora_modules[model_id]]) projections_string = ", ".join([projection.replace("_proj", "") for projection in list_target_modules(model_id)])
print(f"Training '{model_id}' model using ({projections_string}) projections") print(f"Training '{model_id}' model using ({projections_string}) projections")