From 2f0571bfa4a17300113b3e91f422cc8aa5471b4d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 27 Mar 2023 21:24:39 -0300 Subject: [PATCH] Small style changes --- css/main.css | 2 +- modules/training.py | 23 ++++++++++++++++++----- server.py | 2 +- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/css/main.css b/css/main.css index 3f044094..6aa3bc1a 100644 --- a/css/main.css +++ b/css/main.css @@ -41,7 +41,7 @@ ol li p, ul li p { display: inline-block; } -#main, #parameters, #chat-settings, #interface-mode, #lora { +#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab { border: 0; } diff --git a/modules/training.py b/modules/training.py index f8846049..bc5b3878 100644 --- a/modules/training.py +++ b/modules/training.py @@ -1,10 +1,17 @@ -import sys, torch, json, threading, time +import json +import sys +import threading +import time from pathlib import Path + import gradio as gr -from datasets import load_dataset +import torch import transformers -from modules import ui, shared -from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict +from datasets import load_dataset +from peft import (LoraConfig, get_peft_model, get_peft_model_state_dict, + prepare_model_for_int8_training) + +from modules import shared, ui WANT_INTERRUPT = False CURRENT_STEPS = 0 @@ -44,7 +51,7 @@ def create_train_interface(): with gr.Row(): startButton = gr.Button("Start LoRA Training") stopButton = gr.Button("Interrupt") - output = gr.Markdown(value="(...)") + output = gr.Markdown(value="Ready") startEvent = startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output]) stopButton.click(doInterrupt, [], [], cancels=[], queue=False) @@ -169,16 +176,20 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le ).__get__(loraModel, type(loraModel)) if torch.__version__ >= "2" and sys.platform != "win32": loraModel = torch.compile(loraModel) + # == Main run and monitor loop == # TODO: save/load checkpoints to resume from? print("Starting training...") yield "Starting..." + def threadedRun(): trainer.train() + thread = threading.Thread(target=threadedRun) thread.start() lastStep = 0 startTime = time.perf_counter() + while thread.is_alive(): time.sleep(0.5) if WANT_INTERRUPT: @@ -197,8 +208,10 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le timerInfo = f"`{1.0/its:.2f}` s/it" totalTimeEstimate = (1.0/its) * (MAX_STEPS) yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds" + print("Training complete, saving...") loraModel.save_pretrained(loraName) + if WANT_INTERRUPT: print("Training interrupted.") yield f"Interrupted. Incomplete LoRA saved to `{loraName}`" diff --git a/server.py b/server.py index cf37dc50..c3c8d2c8 100644 --- a/server.py +++ b/server.py @@ -9,8 +9,8 @@ from pathlib import Path import gradio as gr -from modules import chat, shared, ui, training import modules.extensions as extensions_module +from modules import chat, shared, training, ui from modules.html_generator import generate_chat_html from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt