mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
initial progress tracker in UI
This commit is contained in:
parent
c07bcd0850
commit
8fc723fc95
@ -1,4 +1,4 @@
|
|||||||
import sys, torch, json
|
import sys, torch, json, threading, time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
@ -6,6 +6,9 @@ import transformers
|
|||||||
from modules import ui, shared
|
from modules import ui, shared
|
||||||
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict
|
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict
|
||||||
|
|
||||||
|
CURRENT_STEPS = 0
|
||||||
|
MAX_STEPS = 0
|
||||||
|
|
||||||
def get_json_dataset(path: str):
|
def get_json_dataset(path: str):
|
||||||
def get_set():
|
def get_set():
|
||||||
return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path(path).glob('*.json'))), key=str.lower)
|
return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path(path).glob('*.json'))), key=str.lower)
|
||||||
@ -40,6 +43,12 @@ def create_train_interface():
|
|||||||
output = gr.Markdown(value="(...)")
|
output = gr.Markdown(value="(...)")
|
||||||
startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output])
|
startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output])
|
||||||
|
|
||||||
|
class Callbacks(transformers.TrainerCallback):
|
||||||
|
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||||
|
global CURRENT_STEPS, MAX_STEPS
|
||||||
|
CURRENT_STEPS = state.global_step
|
||||||
|
MAX_STEPS = state.max_steps
|
||||||
|
|
||||||
def cleanPath(basePath: str, path: str):
|
def cleanPath(basePath: str, path: str):
|
||||||
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
||||||
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
|
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
|
||||||
@ -50,8 +59,11 @@ def cleanPath(basePath: str, path: str):
|
|||||||
return f'{Path(basePath).absolute()}/{path}'
|
return f'{Path(basePath).absolute()}/{path}'
|
||||||
|
|
||||||
def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, learningRate: float, loraRank: int, loraAlpha: int, loraDropout: float, cutoffLen: int, dataset: str, evalDataset: str, format: str):
|
def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, learningRate: float, loraRank: int, loraAlpha: int, loraDropout: float, cutoffLen: int, dataset: str, evalDataset: str, format: str):
|
||||||
|
global CURRENT_STEPS, MAX_STEPS
|
||||||
|
CURRENT_STEPS = 0
|
||||||
|
MAX_STEPS = 0
|
||||||
yield "Prepping..."
|
yield "Prepping..."
|
||||||
# Input validation / processing
|
# == Input validation / processing ==
|
||||||
# TODO: --lora-dir PR once pulled will need to be applied here
|
# TODO: --lora-dir PR once pulled will need to be applied here
|
||||||
loraName = f"loras/{cleanPath(None, loraName)}"
|
loraName = f"loras/{cleanPath(None, loraName)}"
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
@ -62,7 +74,7 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
|||||||
actualLR = float(learningRate)
|
actualLR = float(learningRate)
|
||||||
shared.tokenizer.pad_token = 0
|
shared.tokenizer.pad_token = 0
|
||||||
shared.tokenizer.padding_side = "left"
|
shared.tokenizer.padding_side = "left"
|
||||||
# Prep the dataset, format, etc
|
# == Prep the dataset, format, etc ==
|
||||||
with open(cleanPath('training/formats', f'{format}.json'), 'r') as formatFile:
|
with open(cleanPath('training/formats', f'{format}.json'), 'r') as formatFile:
|
||||||
formatData: dict[str, str] = json.load(formatFile)
|
formatData: dict[str, str] = json.load(formatFile)
|
||||||
def tokenize(prompt):
|
def tokenize(prompt):
|
||||||
@ -89,7 +101,7 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
|||||||
else:
|
else:
|
||||||
evalData = load_dataset("json", data_files=cleanPath('training/datasets', f'{evalDataset}.json'))
|
evalData = load_dataset("json", data_files=cleanPath('training/datasets', f'{evalDataset}.json'))
|
||||||
evalData = evalData['train'].shuffle().map(generate_and_tokenize_prompt)
|
evalData = evalData['train'].shuffle().map(generate_and_tokenize_prompt)
|
||||||
# Start prepping the model itself
|
# == Start prepping the model itself ==
|
||||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||||
print("Getting model ready...")
|
print("Getting model ready...")
|
||||||
prepare_model_for_int8_training(shared.model)
|
prepare_model_for_int8_training(shared.model)
|
||||||
@ -128,6 +140,7 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
|||||||
ddp_find_unused_parameters=None
|
ddp_find_unused_parameters=None
|
||||||
),
|
),
|
||||||
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
||||||
|
callbacks=list([Callbacks()])
|
||||||
)
|
)
|
||||||
loraModel.config.use_cache = False
|
loraModel.config.use_cache = False
|
||||||
old_state_dict = loraModel.state_dict
|
old_state_dict = loraModel.state_dict
|
||||||
@ -136,12 +149,31 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
|||||||
).__get__(loraModel, type(loraModel))
|
).__get__(loraModel, type(loraModel))
|
||||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||||
loraModel = torch.compile(loraModel)
|
loraModel = torch.compile(loraModel)
|
||||||
# Actually start and run and save at the end
|
# == Main run and monitor loop ==
|
||||||
# TODO: save/load checkpoints to resume from?
|
# TODO: save/load checkpoints to resume from?
|
||||||
print("Starting training...")
|
print("Starting training...")
|
||||||
yield "Running..."
|
yield "Starting..."
|
||||||
|
def threadedRun():
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
thread = threading.Thread(target=threadedRun)
|
||||||
|
thread.start()
|
||||||
|
lastStep = 0
|
||||||
|
startTime = time.perf_counter()
|
||||||
|
while thread.is_alive():
|
||||||
|
time.sleep(0.5)
|
||||||
|
if CURRENT_STEPS != lastStep:
|
||||||
|
lastStep = CURRENT_STEPS
|
||||||
|
timeElapsed = time.perf_counter() - startTime
|
||||||
|
if timeElapsed <= 0:
|
||||||
|
timerInfo = ""
|
||||||
|
else:
|
||||||
|
its = CURRENT_STEPS / timeElapsed
|
||||||
|
if its > 1:
|
||||||
|
timerInfo = f"`{its:.2f}` it/s"
|
||||||
|
else:
|
||||||
|
timerInfo = f"`{1.0/its:.2f}` s/it"
|
||||||
|
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.1f}` seconds"
|
||||||
print("Training complete, saving...")
|
print("Training complete, saving...")
|
||||||
loraModel.save_pretrained(loraName)
|
loraModel.save_pretrained(loraName)
|
||||||
print("Training complete!")
|
print("Training complete!")
|
||||||
yield f"Done! Lora saved to `{loraName}`"
|
yield f"Done! LoRA saved to `{loraName}`"
|
||||||
|
Loading…
Reference in New Issue
Block a user