2023-03-27 19:25:08 +02:00
import sys , torch , json , threading , time
2023-03-25 20:08:26 +01:00
from pathlib import Path
import gradio as gr
from datasets import load_dataset
import transformers
from modules import ui , shared
from peft import prepare_model_for_int8_training , LoraConfig , get_peft_model , get_peft_model_state_dict
2023-03-27 19:25:08 +02:00
CURRENT_STEPS = 0
MAX_STEPS = 0
2023-03-25 20:08:26 +01:00
def get_json_dataset ( path : str ) :
def get_set ( ) :
return [ ' None ' ] + sorted ( set ( map ( lambda x : ' . ' . join ( str ( x . name ) . split ( ' . ' ) [ : - 1 ] ) , Path ( path ) . glob ( ' *.json ' ) ) ) , key = str . lower )
return get_set
def create_train_interface ( ) :
with gr . Tab ( ' Train LoRA ' , elem_id = ' lora-train-tab ' ) :
loraName = gr . Textbox ( label = " Name " , info = " The name of your new LoRA file " )
2023-03-27 17:31:49 +02:00
with gr . Row ( ) :
# TODO: Implement multi-device support.
microBatchSize = gr . Slider ( label = ' Micro Batch Size ' , value = 4 , minimum = 1 , maximum = 128 , step = 1 , info = ' Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage. ' )
batchSize = gr . Slider ( label = ' Batch Size ' , value = 128 , minimum = 1 , maximum = 1024 , step = 4 , info = ' Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training. ' )
with gr . Row ( ) :
2023-03-27 18:41:06 +02:00
epochs = gr . Number ( label = ' Epochs ' , value = 1 , info = ' Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc. ' )
2023-03-27 17:31:49 +02:00
learningRate = gr . Textbox ( label = ' Learning Rate ' , value = ' 3e-4 ' , info = ' Learning rate, in scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low. ' )
2023-03-25 20:08:26 +01:00
# TODO: What is the actual maximum rank? Likely distinct per model. This might be better to somehow be on a log scale.
loraRank = gr . Slider ( label = ' LoRA Rank ' , value = 8 , minimum = 1 , maximum = 1024 , step = 4 , info = ' LoRA Rank, or dimension count. Higher values produce a larger file with better control over the model \' s content. Smaller values produce a smaller file with less overall control. Small values like 4 or 8 are great for stylistic guidance, high values like 128 or 256 are good for teaching content upgrades. Higher ranks also require higher VRAM. ' )
loraAlpha = gr . Slider ( label = ' LoRA Alpha ' , value = 16 , minimum = 1 , maximum = 2048 , step = 4 , info = ' LoRA Alpha. This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank. ' )
2023-03-25 20:48:35 +01:00
# TODO: Better explain what this does, in terms of real world effect especially.
2023-03-25 20:08:26 +01:00
loraDropout = gr . Slider ( label = ' LoRA Dropout ' , minimum = 0.0 , maximum = 1.0 , step = 0.025 , value = 0.05 , info = ' Percentage probability for dropout of LoRA layers. ' )
cutoffLen = gr . Slider ( label = ' Cutoff Length ' , minimum = 1 , maximum = 2048 , value = 256 , step = 32 , info = ' Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM. ' )
with gr . Row ( ) :
datasetFunction = get_json_dataset ( ' training/datasets ' )
2023-03-25 20:48:35 +01:00
dataset = gr . Dropdown ( choices = datasetFunction ( ) , value = ' None ' , label = ' Dataset ' , info = ' The dataset file to use for training. ' )
2023-03-25 20:08:26 +01:00
ui . create_refresh_button ( dataset , lambda : None , lambda : { ' choices ' : datasetFunction ( ) } , ' refresh-button ' )
2023-03-25 20:48:35 +01:00
evalDataset = gr . Dropdown ( choices = datasetFunction ( ) , value = ' None ' , label = ' Evaluation Dataset ' , info = ' The dataset file used to evaluate the model after training. ' )
2023-03-25 20:08:26 +01:00
ui . create_refresh_button ( evalDataset , lambda : None , lambda : { ' choices ' : datasetFunction ( ) } , ' refresh-button ' )
formatsFunction = get_json_dataset ( ' training/formats ' )
2023-03-25 20:48:35 +01:00
format = gr . Dropdown ( choices = formatsFunction ( ) , value = ' None ' , label = ' Data Format ' , info = ' The format file used to decide how to format the dataset input. ' )
2023-03-25 20:08:26 +01:00
ui . create_refresh_button ( format , lambda : None , lambda : { ' choices ' : formatsFunction ( ) } , ' refresh-button ' )
startButton = gr . Button ( " Start LoRA Training " )
output = gr . Markdown ( value = " (...) " )
startButton . click ( do_train , [ loraName , microBatchSize , batchSize , epochs , learningRate , loraRank , loraAlpha , loraDropout , cutoffLen , dataset , evalDataset , format ] , [ output ] )
2023-03-27 19:25:08 +02:00
class Callbacks ( transformers . TrainerCallback ) :
def on_step_begin ( self , args : transformers . TrainingArguments , state : transformers . TrainerState , control : transformers . TrainerControl , * * kwargs ) :
global CURRENT_STEPS , MAX_STEPS
CURRENT_STEPS = state . global_step
MAX_STEPS = state . max_steps
2023-03-25 20:08:26 +01:00
def cleanPath ( basePath : str , path : str ) :
""" " Strips unusual symbols and forcibly builds a path as relative to the intended directory. """
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
# Or swap it to a strict whitelist of [a-zA-Z_0-9]
path = path . replace ( ' \\ ' , ' / ' ) . replace ( ' .. ' , ' _ ' )
if basePath is None :
return path
return f ' { Path ( basePath ) . absolute ( ) } / { path } '
def do_train ( loraName : str , microBatchSize : int , batchSize : int , epochs : int , learningRate : float , loraRank : int , loraAlpha : int , loraDropout : float , cutoffLen : int , dataset : str , evalDataset : str , format : str ) :
2023-03-27 19:25:08 +02:00
global CURRENT_STEPS , MAX_STEPS
CURRENT_STEPS = 0
MAX_STEPS = 0
2023-03-27 18:41:06 +02:00
yield " Prepping... "
2023-03-27 19:25:08 +02:00
# == Input validation / processing ==
2023-03-25 20:08:26 +01:00
# TODO: --lora-dir PR once pulled will need to be applied here
loraName = f " loras/ { cleanPath ( None , loraName ) } "
if dataset is None :
return " **Missing dataset choice input, cannot continue.** "
if format is None :
return " **Missing format choice input, cannot continue.** "
gradientAccumulationSteps = batchSize / / microBatchSize
actualLR = float ( learningRate )
2023-03-25 20:57:36 +01:00
shared . tokenizer . pad_token = 0
shared . tokenizer . padding_side = " left "
2023-03-27 19:25:08 +02:00
# == Prep the dataset, format, etc ==
2023-03-25 20:08:26 +01:00
with open ( cleanPath ( ' training/formats ' , f ' { format } .json ' ) , ' r ' ) as formatFile :
formatData : dict [ str , str ] = json . load ( formatFile )
def tokenize ( prompt ) :
2023-03-25 20:57:36 +01:00
result = shared . tokenizer ( prompt , truncation = True , max_length = cutoffLen + 1 , padding = " max_length " )
2023-03-25 20:08:26 +01:00
return {
" input_ids " : result [ " input_ids " ] [ : - 1 ] ,
" attention_mask " : result [ " attention_mask " ] [ : - 1 ] ,
}
def generate_prompt ( data_point : dict [ str , str ] ) :
for options , data in formatData . items ( ) :
2023-03-25 20:28:46 +01:00
if set ( options . split ( ' , ' ) ) == set ( x [ 0 ] for x in data_point . items ( ) if len ( x [ 1 ] . strip ( ) ) > 0 ) :
2023-03-25 20:08:26 +01:00
for key , val in data_point . items ( ) :
data = data . replace ( f ' % { key } % ' , val )
return data
raise RuntimeError ( f ' Data-point " { data_point } " has no keyset match within format " { list ( formatData . keys ( ) ) } " ' )
def generate_and_tokenize_prompt ( data_point ) :
prompt = generate_prompt ( data_point )
return tokenize ( prompt )
2023-03-27 18:41:06 +02:00
print ( " Loading datasets... " )
2023-03-25 20:08:26 +01:00
data = load_dataset ( " json " , data_files = cleanPath ( ' training/datasets ' , f ' { dataset } .json ' ) )
train_data = data [ ' train ' ] . shuffle ( ) . map ( generate_and_tokenize_prompt )
if evalDataset == ' None ' :
evalData = None
else :
evalData = load_dataset ( " json " , data_files = cleanPath ( ' training/datasets ' , f ' { evalDataset } .json ' ) )
evalData = evalData [ ' train ' ] . shuffle ( ) . map ( generate_and_tokenize_prompt )
2023-03-27 19:25:08 +02:00
# == Start prepping the model itself ==
2023-03-25 20:57:36 +01:00
if not hasattr ( shared . model , ' lm_head ' ) or hasattr ( shared . model . lm_head , ' weight ' ) :
2023-03-27 18:41:06 +02:00
print ( " Getting model ready... " )
2023-03-25 20:57:36 +01:00
prepare_model_for_int8_training ( shared . model )
2023-03-27 18:41:06 +02:00
print ( " Prepping for training... " )
2023-03-25 20:08:26 +01:00
config = LoraConfig (
r = loraRank ,
lora_alpha = loraAlpha ,
# TODO: Should target_modules be configurable?
target_modules = [ " q_proj " , " v_proj " ] ,
lora_dropout = loraDropout ,
bias = " none " ,
task_type = " CAUSAL_LM "
)
2023-03-25 20:57:36 +01:00
loraModel = get_peft_model ( shared . model , config )
2023-03-25 20:08:26 +01:00
trainer = transformers . Trainer (
2023-03-25 20:57:36 +01:00
model = loraModel ,
2023-03-25 20:08:26 +01:00
train_dataset = train_data ,
eval_dataset = evalData ,
args = transformers . TrainingArguments (
per_device_train_batch_size = microBatchSize ,
gradient_accumulation_steps = gradientAccumulationSteps ,
# TODO: Should more of these be configurable? Probably.
warmup_steps = 100 ,
num_train_epochs = epochs ,
learning_rate = actualLR ,
fp16 = True ,
logging_steps = 20 ,
evaluation_strategy = " steps " if evalData is not None else " no " ,
save_strategy = " steps " ,
eval_steps = 200 if evalData is not None else None ,
save_steps = 200 ,
output_dir = loraName ,
save_total_limit = 3 ,
load_best_model_at_end = True if evalData is not None else False ,
# TODO: Enable multi-device support
2023-03-27 18:41:06 +02:00
ddp_find_unused_parameters = None
2023-03-25 20:08:26 +01:00
) ,
2023-03-25 20:57:36 +01:00
data_collator = transformers . DataCollatorForLanguageModeling ( shared . tokenizer , mlm = False ) ,
2023-03-27 19:25:08 +02:00
callbacks = list ( [ Callbacks ( ) ] )
2023-03-25 20:08:26 +01:00
)
2023-03-25 20:57:36 +01:00
loraModel . config . use_cache = False
old_state_dict = loraModel . state_dict
loraModel . state_dict = (
2023-03-25 20:08:26 +01:00
lambda self , * _ , * * __ : get_peft_model_state_dict ( self , old_state_dict ( ) )
2023-03-25 20:57:36 +01:00
) . __get__ ( loraModel , type ( loraModel ) )
2023-03-25 20:08:26 +01:00
if torch . __version__ > = " 2 " and sys . platform != " win32 " :
2023-03-25 20:57:36 +01:00
loraModel = torch . compile ( loraModel )
2023-03-27 19:25:08 +02:00
# == Main run and monitor loop ==
2023-03-27 18:41:06 +02:00
# TODO: save/load checkpoints to resume from?
print ( " Starting training... " )
2023-03-27 19:25:08 +02:00
yield " Starting... "
def threadedRun ( ) :
trainer . train ( )
thread = threading . Thread ( target = threadedRun )
thread . start ( )
lastStep = 0
startTime = time . perf_counter ( )
while thread . is_alive ( ) :
time . sleep ( 0.5 )
if CURRENT_STEPS != lastStep :
lastStep = CURRENT_STEPS
timeElapsed = time . perf_counter ( ) - startTime
if timeElapsed < = 0 :
timerInfo = " "
else :
its = CURRENT_STEPS / timeElapsed
if its > 1 :
timerInfo = f " ` { its : .2f } ` it/s "
else :
timerInfo = f " ` { 1.0 / its : .2f } ` s/it "
yield f " Running... ** { CURRENT_STEPS } ** / ** { MAX_STEPS } ** ... { timerInfo } , ` { timeElapsed : .1f } ` seconds "
2023-03-27 18:41:06 +02:00
print ( " Training complete, saving... " )
2023-03-25 20:57:36 +01:00
loraModel . save_pretrained ( loraName )
2023-03-27 18:41:06 +02:00
print ( " Training complete! " )
2023-03-27 19:25:08 +02:00
yield f " Done! LoRA saved to ` { loraName } ` "