Several Training Enhancements (#2868)

This commit is contained in:
FartyPants 2023-06-25 14:34:46 -04:00 committed by GitHub
parent 95212edf1f
commit 21c189112c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -38,20 +38,21 @@ try:
MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES} MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES}
except: except:
standard_modules = ["q_proj", "v_proj"] standard_modules = ["q_proj", "v_proj"]
model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules, "gpt_neox": ["query_key_value"], "rw":["query_key_value"]} model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules, "gpt_neox": ["query_key_value"], "rw": ["query_key_value"]}
MODEL_CLASSES = { MODEL_CLASSES = {
"LlamaForCausalLM": "llama", "LlamaForCausalLM": "llama",
"OPTForCausalLM": "opt", "OPTForCausalLM": "opt",
"GPTJForCausalLM": "gptj", "GPTJForCausalLM": "gptj",
"GPTNeoXForCausalLM": "gpt_neox", "GPTNeoXForCausalLM": "gpt_neox",
"RWForCausalLM": "rw" "RWForCausalLM": "rw"
} }
train_log = {} train_log = {}
train_template = {}
WANT_INTERRUPT = False WANT_INTERRUPT = False
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after"] PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss"]
def create_train_interface(): def create_train_interface():
@ -109,6 +110,7 @@ def create_train_interface():
warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.') warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.')
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.') optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.')
train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.') train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.')
stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
with gr.Row(): with gr.Row():
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.') higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
@ -142,7 +144,7 @@ def create_train_interface():
refresh_table = gr.Button('Refresh the table', elem_classes="small-button") refresh_table = gr.Button('Refresh the table', elem_classes="small-button")
# Training events # Training events
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after] all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss]
copy_from.change(do_copy_params, [copy_from] + all_params, all_params) copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
start_button.click(do_train, all_params, output) start_button.click(do_train, all_params, output)
stop_button.click(do_interrupt, None, None, queue=False) stop_button.click(do_interrupt, None, None, queue=False)
@ -206,7 +208,7 @@ def clean_path(base_path: str, path: str):
return f'{Path(base_path).absolute()}/{path}' return f'{Path(base_path).absolute()}/{path}'
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str): def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float):
if shared.args.monkey_patch: if shared.args.monkey_patch:
from monkeypatch.peft_tuners_lora_monkey_patch import ( from monkeypatch.peft_tuners_lora_monkey_patch import (
@ -296,9 +298,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
"attention_mask": input_ids.ne(shared.tokenizer.pad_token_id), "attention_mask": input_ids.ne(shared.tokenizer.pad_token_id),
} }
train_template.clear()
# == Prep the dataset, format, etc == # == Prep the dataset, format, etc ==
if raw_text_file not in ['None', '']: if raw_text_file not in ['None', '']:
logger.info("Loading raw text file dataset...") logger.info("Loading raw text file dataset...")
train_template["template_type"] = "raw_text"
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file: with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
raw_text = file.read().replace('\r', '') raw_text = file.read().replace('\r', '')
@ -330,7 +337,6 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
train_data = Dataset.from_list([tokenize(x) for x in text_chunks]) train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
del text_chunks del text_chunks
eval_data = None eval_data = None
else: else:
if dataset in ['None', '']: if dataset in ['None', '']:
yield "**Missing dataset choice input, cannot continue.**" yield "**Missing dataset choice input, cannot continue.**"
@ -340,9 +346,16 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
yield "**Missing format choice input, cannot continue.**" yield "**Missing format choice input, cannot continue.**"
return return
with open(clean_path('training/formats', f'{format}.json'), 'r', encoding='utf-8') as formatFile: train_template["template_type"] = "dataset"
with open(clean_path('training/formats', f'{format}.json'), 'r', encoding='utf-8-sig') as formatFile:
format_data: dict[str, str] = json.load(formatFile) format_data: dict[str, str] = json.load(formatFile)
# == store training prompt ==
for _, value in format_data.items():
prompt_key = f"template_{len(train_template)}"
train_template[prompt_key] = value
def generate_prompt(data_point: dict[str, str]): def generate_prompt(data_point: dict[str, str]):
for options, data in format_data.items(): for options, data in format_data.items():
if set(options.split(',')) == set(x[0] for x in data_point.items() if (x[1] is not None and len(x[1].strip()) > 0)): if set(options.split(',')) == set(x[0] for x in data_point.items() if (x[1] is not None and len(x[1].strip()) > 0)):
@ -369,7 +382,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# == Start prepping the model itself == # == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'): if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
logger.info("Getting model ready...") logger.info("Getting model ready...")
prepare_model_for_kbit_training(shared.model) prepare_model_for_int8_training(shared.model)
logger.info("Prepping for training...") logger.info("Prepping for training...")
config = LoraConfig( config = LoraConfig(
@ -418,19 +431,32 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
control.should_training_stop = True control.should_training_stop = True
elif state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0: elif state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0:
lora_model.save_pretrained(f"{lora_file_path}/checkpoint-{tracked.current_steps}/") lora_model.save_pretrained(f"{lora_file_path}/checkpoint-{tracked.current_steps}/")
# Save log # Save log
with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_log.json", 'w', encoding='utf-8') as file: with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_log.json", 'w', encoding='utf-8') as file:
json.dump(train_log, file, indent=2) json.dump(train_log, file, indent=2)
# == Save training prompt ==
with open(f"{lora_file_path}/checkpoint-{tracked.current_steps}/training_prompt.json", 'w', encoding='utf-8') as file:
json.dump(train_template, file, indent=2)
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
tracked.current_steps += 1 tracked.current_steps += 1
if WANT_INTERRUPT: if WANT_INTERRUPT:
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
def on_log(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, logs, **kwargs): def on_log(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, logs, **kwargs):
train_log.update(logs) train_log.update(logs)
train_log.update({"current_steps": tracked.current_steps})
if WANT_INTERRUPT:
print("\033[1;31;1mInterrupted by user\033[0;37;0m")
print(f"\033[1;30;40mStep: {tracked.current_steps} \033[0;37;0m", end='')
if 'loss' in logs:
loss = float(logs['loss'])
if loss <= stop_at_loss:
control.should_epoch_stop = True
control.should_training_stop = True
print(f"\033[1;31;1mStop Loss {stop_at_loss} reached.\033[0;37;0m")
trainer = transformers.Trainer( trainer = transformers.Trainer(
model=lora_model, model=lora_model,
@ -444,7 +470,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
learning_rate=actual_lr, learning_rate=actual_lr,
fp16=False if shared.args.cpu else True, fp16=False if shared.args.cpu else True,
optim=optimizer, optim=optimizer,
logging_steps=5, logging_steps=2 if stop_at_loss > 0 else 5,
evaluation_strategy="steps" if eval_data is not None else "no", evaluation_strategy="steps" if eval_data is not None else "no",
eval_steps=math.ceil(eval_steps / gradient_accumulation_steps) if eval_data is not None else None, eval_steps=math.ceil(eval_steps / gradient_accumulation_steps) if eval_data is not None else None,
save_strategy="steps" if eval_data is not None else "no", save_strategy="steps" if eval_data is not None else "no",
@ -469,9 +495,22 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
vars = locals() vars = locals()
json.dump({x: vars[x] for x in PARAMETERS}, file, indent=2) json.dump({x: vars[x] for x in PARAMETERS}, file, indent=2)
# == Save training prompt ==
with open(f"{lora_file_path}/training_prompt.json", 'w', encoding='utf-8') as file:
json.dump(train_template, file, indent=2)
# == Main run and monitor loop == # == Main run and monitor loop ==
logger.info("Starting training...") logger.info("Starting training...")
yield "Starting..." yield "Starting..."
train_log.update({"base_model_name": shared.model_name})
train_log.update({"base_model_class": shared.model.__class__.__name__})
train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)})
train_log.update({"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False)})
if stop_at_loss > 0:
print(f"Monitoring loss \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m")
if WANT_INTERRUPT: if WANT_INTERRUPT:
yield "Interrupted before start." yield "Interrupted before start."
return return