Small style changes

This commit is contained in:
oobabooga 2023-03-27 21:24:39 -03:00
parent c2cad30772
commit 2f0571bfa4
3 changed files with 20 additions and 7 deletions

View File

@ -41,7 +41,7 @@ ol li p, ul li p {
display: inline-block; display: inline-block;
} }
#main, #parameters, #chat-settings, #interface-mode, #lora { #main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab {
border: 0; border: 0;
} }

View File

@ -1,10 +1,17 @@
import sys, torch, json, threading, time import json
import sys
import threading
import time
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
from datasets import load_dataset import torch
import transformers import transformers
from modules import ui, shared from datasets import load_dataset
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict from peft import (LoraConfig, get_peft_model, get_peft_model_state_dict,
prepare_model_for_int8_training)
from modules import shared, ui
WANT_INTERRUPT = False WANT_INTERRUPT = False
CURRENT_STEPS = 0 CURRENT_STEPS = 0
@ -44,7 +51,7 @@ def create_train_interface():
with gr.Row(): with gr.Row():
startButton = gr.Button("Start LoRA Training") startButton = gr.Button("Start LoRA Training")
stopButton = gr.Button("Interrupt") stopButton = gr.Button("Interrupt")
output = gr.Markdown(value="(...)") output = gr.Markdown(value="Ready")
startEvent = startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output]) startEvent = startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output])
stopButton.click(doInterrupt, [], [], cancels=[], queue=False) stopButton.click(doInterrupt, [], [], cancels=[], queue=False)
@ -169,16 +176,20 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
).__get__(loraModel, type(loraModel)) ).__get__(loraModel, type(loraModel))
if torch.__version__ >= "2" and sys.platform != "win32": if torch.__version__ >= "2" and sys.platform != "win32":
loraModel = torch.compile(loraModel) loraModel = torch.compile(loraModel)
# == Main run and monitor loop == # == Main run and monitor loop ==
# TODO: save/load checkpoints to resume from? # TODO: save/load checkpoints to resume from?
print("Starting training...") print("Starting training...")
yield "Starting..." yield "Starting..."
def threadedRun(): def threadedRun():
trainer.train() trainer.train()
thread = threading.Thread(target=threadedRun) thread = threading.Thread(target=threadedRun)
thread.start() thread.start()
lastStep = 0 lastStep = 0
startTime = time.perf_counter() startTime = time.perf_counter()
while thread.is_alive(): while thread.is_alive():
time.sleep(0.5) time.sleep(0.5)
if WANT_INTERRUPT: if WANT_INTERRUPT:
@ -197,8 +208,10 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
timerInfo = f"`{1.0/its:.2f}` s/it" timerInfo = f"`{1.0/its:.2f}` s/it"
totalTimeEstimate = (1.0/its) * (MAX_STEPS) totalTimeEstimate = (1.0/its) * (MAX_STEPS)
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds" yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds"
print("Training complete, saving...") print("Training complete, saving...")
loraModel.save_pretrained(loraName) loraModel.save_pretrained(loraName)
if WANT_INTERRUPT: if WANT_INTERRUPT:
print("Training interrupted.") print("Training interrupted.")
yield f"Interrupted. Incomplete LoRA saved to `{loraName}`" yield f"Interrupted. Incomplete LoRA saved to `{loraName}`"

View File

@ -9,8 +9,8 @@ from pathlib import Path
import gradio as gr import gradio as gr
from modules import chat, shared, ui, training
import modules.extensions as extensions_module import modules.extensions as extensions_module
from modules import chat, shared, training, ui
from modules.html_generator import generate_chat_html from modules.html_generator import generate_chat_html
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt from modules.models import load_model, load_soft_prompt