mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Small style changes
This commit is contained in:
parent
c2cad30772
commit
2f0571bfa4
@ -41,7 +41,7 @@ ol li p, ul li p {
|
|||||||
display: inline-block;
|
display: inline-block;
|
||||||
}
|
}
|
||||||
|
|
||||||
#main, #parameters, #chat-settings, #interface-mode, #lora {
|
#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab {
|
||||||
border: 0;
|
border: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,10 +1,17 @@
|
|||||||
import sys, torch, json, threading, time
|
import json
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from datasets import load_dataset
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from modules import ui, shared
|
from datasets import load_dataset
|
||||||
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict
|
from peft import (LoraConfig, get_peft_model, get_peft_model_state_dict,
|
||||||
|
prepare_model_for_int8_training)
|
||||||
|
|
||||||
|
from modules import shared, ui
|
||||||
|
|
||||||
WANT_INTERRUPT = False
|
WANT_INTERRUPT = False
|
||||||
CURRENT_STEPS = 0
|
CURRENT_STEPS = 0
|
||||||
@ -44,7 +51,7 @@ def create_train_interface():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
startButton = gr.Button("Start LoRA Training")
|
startButton = gr.Button("Start LoRA Training")
|
||||||
stopButton = gr.Button("Interrupt")
|
stopButton = gr.Button("Interrupt")
|
||||||
output = gr.Markdown(value="(...)")
|
output = gr.Markdown(value="Ready")
|
||||||
startEvent = startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output])
|
startEvent = startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output])
|
||||||
stopButton.click(doInterrupt, [], [], cancels=[], queue=False)
|
stopButton.click(doInterrupt, [], [], cancels=[], queue=False)
|
||||||
|
|
||||||
@ -169,16 +176,20 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
|||||||
).__get__(loraModel, type(loraModel))
|
).__get__(loraModel, type(loraModel))
|
||||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||||
loraModel = torch.compile(loraModel)
|
loraModel = torch.compile(loraModel)
|
||||||
|
|
||||||
# == Main run and monitor loop ==
|
# == Main run and monitor loop ==
|
||||||
# TODO: save/load checkpoints to resume from?
|
# TODO: save/load checkpoints to resume from?
|
||||||
print("Starting training...")
|
print("Starting training...")
|
||||||
yield "Starting..."
|
yield "Starting..."
|
||||||
|
|
||||||
def threadedRun():
|
def threadedRun():
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
thread = threading.Thread(target=threadedRun)
|
thread = threading.Thread(target=threadedRun)
|
||||||
thread.start()
|
thread.start()
|
||||||
lastStep = 0
|
lastStep = 0
|
||||||
startTime = time.perf_counter()
|
startTime = time.perf_counter()
|
||||||
|
|
||||||
while thread.is_alive():
|
while thread.is_alive():
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
if WANT_INTERRUPT:
|
if WANT_INTERRUPT:
|
||||||
@ -197,8 +208,10 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
|||||||
timerInfo = f"`{1.0/its:.2f}` s/it"
|
timerInfo = f"`{1.0/its:.2f}` s/it"
|
||||||
totalTimeEstimate = (1.0/its) * (MAX_STEPS)
|
totalTimeEstimate = (1.0/its) * (MAX_STEPS)
|
||||||
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds"
|
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds"
|
||||||
|
|
||||||
print("Training complete, saving...")
|
print("Training complete, saving...")
|
||||||
loraModel.save_pretrained(loraName)
|
loraModel.save_pretrained(loraName)
|
||||||
|
|
||||||
if WANT_INTERRUPT:
|
if WANT_INTERRUPT:
|
||||||
print("Training interrupted.")
|
print("Training interrupted.")
|
||||||
yield f"Interrupted. Incomplete LoRA saved to `{loraName}`"
|
yield f"Interrupted. Incomplete LoRA saved to `{loraName}`"
|
||||||
|
@ -9,8 +9,8 @@ from pathlib import Path
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import chat, shared, ui, training
|
|
||||||
import modules.extensions as extensions_module
|
import modules.extensions as extensions_module
|
||||||
|
from modules import chat, shared, training, ui
|
||||||
from modules.html_generator import generate_chat_html
|
from modules.html_generator import generate_chat_html
|
||||||
from modules.LoRA import add_lora_to_model
|
from modules.LoRA import add_lora_to_model
|
||||||
from modules.models import load_model, load_soft_prompt
|
from modules.models import load_model, load_soft_prompt
|
||||||
|
Loading…
Reference in New Issue
Block a user