From 16ea4fc36df9ec0cde796eaecf22db64c4d91fd8 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Mon, 27 Mar 2023 10:43:01 -0700 Subject: [PATCH] interrupt button --- modules/training.py | 42 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/modules/training.py b/modules/training.py index c83427d6..19f33220 100644 --- a/modules/training.py +++ b/modules/training.py @@ -6,8 +6,10 @@ import transformers from modules import ui, shared from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict +WANT_INTERRUPT = False CURRENT_STEPS = 0 MAX_STEPS = 0 +CURRENT_GRADIENT_ACCUM = 1 def get_json_dataset(path: str): def get_set(): @@ -39,15 +41,31 @@ def create_train_interface(): formatsFunction = get_json_dataset('training/formats') format = gr.Dropdown(choices=formatsFunction(), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.') ui.create_refresh_button(format, lambda : None, lambda : {'choices': formatsFunction()}, 'refresh-button') - startButton = gr.Button("Start LoRA Training") + with gr.Row(): + startButton = gr.Button("Start LoRA Training") + stopButton = gr.Button("Interrupt") output = gr.Markdown(value="(...)") - startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output]) + startEvent = startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output]) + stopButton.click(doInterrupt, [], [], cancels=[], queue=False) + +def doInterrupt(): + global WANT_INTERRUPT + WANT_INTERRUPT = True class Callbacks(transformers.TrainerCallback): def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): global CURRENT_STEPS, MAX_STEPS - CURRENT_STEPS = state.global_step - MAX_STEPS = state.max_steps + CURRENT_STEPS = state.global_step * CURRENT_GRADIENT_ACCUM + MAX_STEPS = state.max_steps * CURRENT_GRADIENT_ACCUM + if WANT_INTERRUPT: + control.should_epoch_stop = True + control.should_training_stop = True + def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): + global CURRENT_STEPS + CURRENT_STEPS += 1 + if WANT_INTERRUPT: + control.should_epoch_stop = True + control.should_training_stop = True def cleanPath(basePath: str, path: str): """"Strips unusual symbols and forcibly builds a path as relative to the intended directory.""" @@ -59,7 +77,8 @@ def cleanPath(basePath: str, path: str): return f'{Path(basePath).absolute()}/{path}' def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, learningRate: float, loraRank: int, loraAlpha: int, loraDropout: float, cutoffLen: int, dataset: str, evalDataset: str, format: str): - global CURRENT_STEPS, MAX_STEPS + global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM + WANT_INTERRUPT = False CURRENT_STEPS = 0 MAX_STEPS = 0 yield "Prepping..." @@ -71,6 +90,7 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le if format is None: return "**Missing format choice input, cannot continue.**" gradientAccumulationSteps = batchSize // microBatchSize + CURRENT_GRADIENT_ACCUM = gradientAccumulationSteps actualLR = float(learningRate) shared.tokenizer.pad_token = 0 shared.tokenizer.padding_side = "left" @@ -161,7 +181,9 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le startTime = time.perf_counter() while thread.is_alive(): time.sleep(0.5) - if CURRENT_STEPS != lastStep: + if WANT_INTERRUPT: + yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*" + elif CURRENT_STEPS != lastStep: lastStep = CURRENT_STEPS timeElapsed = time.perf_counter() - startTime if timeElapsed <= 0: @@ -175,5 +197,9 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.1f}` seconds" print("Training complete, saving...") loraModel.save_pretrained(loraName) - print("Training complete!") - yield f"Done! LoRA saved to `{loraName}`" + if WANT_INTERRUPT: + print("Training interrupted.") + yield f"Interrupted. Incomplete LoRA saved to `{loraName}`" + else: + print("Training complete!") + yield f"Done! LoRA saved to `{loraName}`"