mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 09:40:20 +01:00
interrupt button
This commit is contained in:
parent
8fc723fc95
commit
16ea4fc36d
@ -6,8 +6,10 @@ import transformers
|
|||||||
from modules import ui, shared
|
from modules import ui, shared
|
||||||
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict
|
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict
|
||||||
|
|
||||||
|
WANT_INTERRUPT = False
|
||||||
CURRENT_STEPS = 0
|
CURRENT_STEPS = 0
|
||||||
MAX_STEPS = 0
|
MAX_STEPS = 0
|
||||||
|
CURRENT_GRADIENT_ACCUM = 1
|
||||||
|
|
||||||
def get_json_dataset(path: str):
|
def get_json_dataset(path: str):
|
||||||
def get_set():
|
def get_set():
|
||||||
@ -39,15 +41,31 @@ def create_train_interface():
|
|||||||
formatsFunction = get_json_dataset('training/formats')
|
formatsFunction = get_json_dataset('training/formats')
|
||||||
format = gr.Dropdown(choices=formatsFunction(), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
|
format = gr.Dropdown(choices=formatsFunction(), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
|
||||||
ui.create_refresh_button(format, lambda : None, lambda : {'choices': formatsFunction()}, 'refresh-button')
|
ui.create_refresh_button(format, lambda : None, lambda : {'choices': formatsFunction()}, 'refresh-button')
|
||||||
startButton = gr.Button("Start LoRA Training")
|
with gr.Row():
|
||||||
|
startButton = gr.Button("Start LoRA Training")
|
||||||
|
stopButton = gr.Button("Interrupt")
|
||||||
output = gr.Markdown(value="(...)")
|
output = gr.Markdown(value="(...)")
|
||||||
startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output])
|
startEvent = startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output])
|
||||||
|
stopButton.click(doInterrupt, [], [], cancels=[], queue=False)
|
||||||
|
|
||||||
|
def doInterrupt():
|
||||||
|
global WANT_INTERRUPT
|
||||||
|
WANT_INTERRUPT = True
|
||||||
|
|
||||||
class Callbacks(transformers.TrainerCallback):
|
class Callbacks(transformers.TrainerCallback):
|
||||||
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||||
global CURRENT_STEPS, MAX_STEPS
|
global CURRENT_STEPS, MAX_STEPS
|
||||||
CURRENT_STEPS = state.global_step
|
CURRENT_STEPS = state.global_step * CURRENT_GRADIENT_ACCUM
|
||||||
MAX_STEPS = state.max_steps
|
MAX_STEPS = state.max_steps * CURRENT_GRADIENT_ACCUM
|
||||||
|
if WANT_INTERRUPT:
|
||||||
|
control.should_epoch_stop = True
|
||||||
|
control.should_training_stop = True
|
||||||
|
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||||
|
global CURRENT_STEPS
|
||||||
|
CURRENT_STEPS += 1
|
||||||
|
if WANT_INTERRUPT:
|
||||||
|
control.should_epoch_stop = True
|
||||||
|
control.should_training_stop = True
|
||||||
|
|
||||||
def cleanPath(basePath: str, path: str):
|
def cleanPath(basePath: str, path: str):
|
||||||
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
||||||
@ -59,7 +77,8 @@ def cleanPath(basePath: str, path: str):
|
|||||||
return f'{Path(basePath).absolute()}/{path}'
|
return f'{Path(basePath).absolute()}/{path}'
|
||||||
|
|
||||||
def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, learningRate: float, loraRank: int, loraAlpha: int, loraDropout: float, cutoffLen: int, dataset: str, evalDataset: str, format: str):
|
def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, learningRate: float, loraRank: int, loraAlpha: int, loraDropout: float, cutoffLen: int, dataset: str, evalDataset: str, format: str):
|
||||||
global CURRENT_STEPS, MAX_STEPS
|
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
|
||||||
|
WANT_INTERRUPT = False
|
||||||
CURRENT_STEPS = 0
|
CURRENT_STEPS = 0
|
||||||
MAX_STEPS = 0
|
MAX_STEPS = 0
|
||||||
yield "Prepping..."
|
yield "Prepping..."
|
||||||
@ -71,6 +90,7 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
|||||||
if format is None:
|
if format is None:
|
||||||
return "**Missing format choice input, cannot continue.**"
|
return "**Missing format choice input, cannot continue.**"
|
||||||
gradientAccumulationSteps = batchSize // microBatchSize
|
gradientAccumulationSteps = batchSize // microBatchSize
|
||||||
|
CURRENT_GRADIENT_ACCUM = gradientAccumulationSteps
|
||||||
actualLR = float(learningRate)
|
actualLR = float(learningRate)
|
||||||
shared.tokenizer.pad_token = 0
|
shared.tokenizer.pad_token = 0
|
||||||
shared.tokenizer.padding_side = "left"
|
shared.tokenizer.padding_side = "left"
|
||||||
@ -161,7 +181,9 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
|||||||
startTime = time.perf_counter()
|
startTime = time.perf_counter()
|
||||||
while thread.is_alive():
|
while thread.is_alive():
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
if CURRENT_STEPS != lastStep:
|
if WANT_INTERRUPT:
|
||||||
|
yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
|
||||||
|
elif CURRENT_STEPS != lastStep:
|
||||||
lastStep = CURRENT_STEPS
|
lastStep = CURRENT_STEPS
|
||||||
timeElapsed = time.perf_counter() - startTime
|
timeElapsed = time.perf_counter() - startTime
|
||||||
if timeElapsed <= 0:
|
if timeElapsed <= 0:
|
||||||
@ -175,5 +197,9 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
|||||||
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.1f}` seconds"
|
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.1f}` seconds"
|
||||||
print("Training complete, saving...")
|
print("Training complete, saving...")
|
||||||
loraModel.save_pretrained(loraName)
|
loraModel.save_pretrained(loraName)
|
||||||
print("Training complete!")
|
if WANT_INTERRUPT:
|
||||||
yield f"Done! LoRA saved to `{loraName}`"
|
print("Training interrupted.")
|
||||||
|
yield f"Interrupted. Incomplete LoRA saved to `{loraName}`"
|
||||||
|
else:
|
||||||
|
print("Training complete!")
|
||||||
|
yield f"Done! LoRA saved to `{loraName}`"
|
||||||
|
Loading…
Reference in New Issue
Block a user