mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Option to select/target additional linear modules/layers in LORA training (#4178)
This commit is contained in:
parent
7a3f885ea8
commit
4405513ca5
@ -41,7 +41,7 @@ from modules.models import reload_model
|
||||
from modules.utils import natural_keys
|
||||
|
||||
MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()}
|
||||
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to"]
|
||||
PARAMETERS = ["lora_name", "always_override", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to"]
|
||||
WANT_INTERRUPT = False
|
||||
|
||||
train_log = {}
|
||||
@ -67,13 +67,31 @@ def create_ui():
|
||||
with gr.Column():
|
||||
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background'])
|
||||
|
||||
with gr.Accordion(label='Target Modules', open=False):
|
||||
gr.Markdown("Selects which modules to target in training. Targeting more modules is closer to a full fine-tune at the cost of increased VRAM requirements and adapter size.\nNOTE: Only works for model_id='llama', other types will retain default training behavior and not use these settings.")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
q_proj_en = gr.Checkbox(label='Enable q_proj', value=True)
|
||||
with gr.Column():
|
||||
v_proj_en = gr.Checkbox(label='Enable v_proj', value=True)
|
||||
with gr.Column():
|
||||
k_proj_en = gr.Checkbox(label='Enable k_proj', value=False)
|
||||
with gr.Column():
|
||||
o_proj_en = gr.Checkbox(label='Enable o_proj', value=False)
|
||||
with gr.Column():
|
||||
gate_proj_en = gr.Checkbox(label='Enable gate_proj', value=False)
|
||||
with gr.Column():
|
||||
down_proj_en = gr.Checkbox(label='Enable down_proj', value=False)
|
||||
with gr.Column():
|
||||
up_proj_en = gr.Checkbox(label='Enable up_proj', value=False)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='Also called dimension count. Higher values = larger file, more content control. Smaller values = smaller file, less control. Use 4 or 8 for style, 128 or 256 to teach, 1024+ for fine-detail on big data. More VRAM is needed for higher ranks.')
|
||||
lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
|
||||
batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
|
||||
micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
|
||||
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
|
||||
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=4096, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
|
||||
|
||||
with gr.Column():
|
||||
save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.')
|
||||
@ -162,7 +180,7 @@ def create_ui():
|
||||
refresh_table = gr.Button('Refresh the table', elem_classes="small-button", interactive=not mu)
|
||||
|
||||
# Training events
|
||||
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to]
|
||||
all_params = [lora_name, always_override, q_proj_en, v_proj_en, k_proj_en, o_proj_en, gate_proj_en, down_proj_en, up_proj_en, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to]
|
||||
|
||||
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
||||
start_button.click(do_train, all_params, output)
|
||||
@ -269,7 +287,7 @@ def calc_trainable_parameters(model):
|
||||
return trainable_params, all_param
|
||||
|
||||
|
||||
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str):
|
||||
def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str):
|
||||
|
||||
if shared.args.monkey_patch:
|
||||
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
||||
@ -320,6 +338,23 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
shared.tokenizer.pad_token_id = 0
|
||||
shared.tokenizer.padding_side = "left"
|
||||
|
||||
# Populate target_modules list with chosen X_proj modules. Llama-based models only atm, non-llama will revert to default behavior.
|
||||
def list_target_modules(model_id):
|
||||
if model_id != "llama":
|
||||
return model_to_lora_modules[model_id]
|
||||
|
||||
available_modules = {
|
||||
"gate": gate_proj_en,
|
||||
"down": down_proj_en,
|
||||
"up": up_proj_en,
|
||||
"q": q_proj_en,
|
||||
"v": v_proj_en,
|
||||
"k": k_proj_en,
|
||||
"o": o_proj_en,
|
||||
}
|
||||
target_mods = [f"{name}_proj" for name, enabled in available_modules.items() if enabled]
|
||||
return target_mods
|
||||
|
||||
def encode(text, add_bos_token):
|
||||
result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len)
|
||||
# Check if the first two tokens are BOS
|
||||
@ -490,7 +525,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=model_to_lora_modules[model_id],
|
||||
target_modules=list_target_modules(model_id),
|
||||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
@ -616,7 +651,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
|
||||
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
|
||||
|
||||
projections_string = ", ".join([projection.replace("_proj", "") for projection in model_to_lora_modules[model_id]])
|
||||
projections_string = ", ".join([projection.replace("_proj", "") for projection in list_target_modules(model_id)])
|
||||
|
||||
print(f"Training '{model_id}' model using ({projections_string}) projections")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user