2023-03-28 02:24:39 +02:00
import json
import sys
import threading
import time
2023-03-25 20:08:26 +01:00
from pathlib import Path
2023-03-28 02:24:39 +02:00
2023-03-25 20:08:26 +01:00
import gradio as gr
2023-03-28 02:24:39 +02:00
import torch
2023-03-25 20:08:26 +01:00
import transformers
2023-03-28 07:15:32 +02:00
from datasets import Dataset , load_dataset
2023-03-28 02:24:39 +02:00
from peft import ( LoraConfig , get_peft_model , get_peft_model_state_dict ,
prepare_model_for_int8_training )
from modules import shared , ui
2023-03-25 20:08:26 +01:00
2023-03-27 19:43:01 +02:00
WANT_INTERRUPT = False
2023-03-27 19:25:08 +02:00
CURRENT_STEPS = 0
MAX_STEPS = 0
2023-03-27 19:43:01 +02:00
CURRENT_GRADIENT_ACCUM = 1
2023-03-27 19:25:08 +02:00
2023-03-28 07:15:32 +02:00
def get_dataset ( path : str , ext : str ) :
return [ ' None ' ] + sorted ( set ( map ( lambda x : ' . ' . join ( str ( x . name ) . split ( ' . ' ) [ : - 1 ] ) , Path ( path ) . glob ( f ' *. { ext } ' ) ) ) , key = str . lower )
2023-03-25 20:08:26 +01:00
def create_train_interface ( ) :
with gr . Tab ( ' Train LoRA ' , elem_id = ' lora-train-tab ' ) :
2023-03-28 03:17:42 +02:00
lora_name = gr . Textbox ( label = " Name " , info = " The name of your new LoRA file " )
2023-03-27 17:31:49 +02:00
with gr . Row ( ) :
# TODO: Implement multi-device support.
2023-03-28 03:17:42 +02:00
micro_batch_size = gr . Slider ( label = ' Micro Batch Size ' , value = 4 , minimum = 1 , maximum = 128 , step = 1 , info = ' Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage. ' )
2023-03-28 06:22:43 +02:00
batch_size = gr . Slider ( label = ' Batch Size ' , value = 128 , minimum = 0 , maximum = 1024 , step = 4 , info = ' Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training. ' )
2023-03-28 03:17:42 +02:00
2023-03-27 17:31:49 +02:00
with gr . Row ( ) :
2023-03-28 03:39:06 +02:00
epochs = gr . Number ( label = ' Epochs ' , value = 3 , info = ' Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc. ' )
2023-03-28 03:17:42 +02:00
learning_rate = gr . Textbox ( label = ' Learning Rate ' , value = ' 3e-4 ' , info = ' Learning rate, in scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low. ' )
2023-03-25 20:08:26 +01:00
# TODO: What is the actual maximum rank? Likely distinct per model. This might be better to somehow be on a log scale.
2023-03-28 07:29:23 +02:00
lora_rank = gr . Slider ( label = ' LoRA Rank ' , value = 32 , minimum = 0 , maximum = 1024 , step = 4 , info = ' LoRA Rank, or dimension count. Higher values produce a larger file with better control over the model \' s content. Smaller values produce a smaller file with less overall control. Small values like 4 or 8 are great for stylistic guidance, high values like 128 or 256 are good for teaching content upgrades. Higher ranks also require higher VRAM. ' )
lora_alpha = gr . Slider ( label = ' LoRA Alpha ' , value = 64 , minimum = 0 , maximum = 2048 , step = 4 , info = ' LoRA Alpha. This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank. ' )
2023-03-25 20:48:35 +01:00
# TODO: Better explain what this does, in terms of real world effect especially.
2023-03-28 03:17:42 +02:00
lora_dropout = gr . Slider ( label = ' LoRA Dropout ' , minimum = 0.0 , maximum = 1.0 , step = 0.025 , value = 0.05 , info = ' Percentage probability for dropout of LoRA layers. ' )
2023-03-28 06:22:43 +02:00
cutoff_len = gr . Slider ( label = ' Cutoff Length ' , minimum = 0 , maximum = 2048 , value = 256 , step = 32 , info = ' Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM. ' )
2023-03-28 03:17:42 +02:00
2023-03-28 07:15:32 +02:00
with gr . Tab ( label = " Formatted Dataset " ) :
with gr . Row ( ) :
dataset = gr . Dropdown ( choices = get_dataset ( ' training/datasets ' , ' json ' ) , value = ' None ' , label = ' Dataset ' , info = ' The dataset file to use for training. ' )
ui . create_refresh_button ( dataset , lambda : None , lambda : { ' choices ' : get_dataset ( ' training/datasets ' , ' json ' ) } , ' refresh-button ' )
eval_dataset = gr . Dropdown ( choices = get_dataset ( ' training/datasets ' , ' json ' ) , value = ' None ' , label = ' Evaluation Dataset ' , info = ' The dataset file used to evaluate the model after training. ' )
ui . create_refresh_button ( eval_dataset , lambda : None , lambda : { ' choices ' : get_dataset ( ' training/datasets ' , ' json ' ) } , ' refresh-button ' )
format = gr . Dropdown ( choices = get_dataset ( ' training/formats ' , ' json ' ) , value = ' None ' , label = ' Data Format ' , info = ' The format file used to decide how to format the dataset input. ' )
ui . create_refresh_button ( format , lambda : None , lambda : { ' choices ' : get_dataset ( ' training/formats ' , ' json ' ) } , ' refresh-button ' )
with gr . Tab ( label = " Raw Text File " ) :
with gr . Row ( ) :
raw_text_file = gr . Dropdown ( choices = get_dataset ( ' training/datasets ' , ' txt ' ) , value = ' None ' , label = ' Text File ' , info = ' The raw text file to use for training. ' )
ui . create_refresh_button ( raw_text_file , lambda : None , lambda : { ' choices ' : get_dataset ( ' training/datasets ' , ' txt ' ) } , ' refresh-button ' )
2023-03-28 07:29:23 +02:00
overlap_len = gr . Slider ( label = ' Overlap Length ' , minimum = 0 , maximum = 512 , value = 128 , step = 16 , info = ' Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length above). Setting overlap to exactly half the cutoff length may be ideal. ' )
2023-03-28 03:17:42 +02:00
2023-03-27 19:43:01 +02:00
with gr . Row ( ) :
2023-03-28 03:17:42 +02:00
start_button = gr . Button ( " Start LoRA Training " )
stop_button = gr . Button ( " Interrupt " )
2023-03-28 02:24:39 +02:00
output = gr . Markdown ( value = " Ready " )
2023-03-28 07:15:32 +02:00
start_button . click ( do_train , [ lora_name , micro_batch_size , batch_size , epochs , learning_rate , lora_rank , lora_alpha , lora_dropout , cutoff_len , dataset , eval_dataset , format , raw_text_file , overlap_len ] , [ output ] )
2023-03-28 03:19:06 +02:00
stop_button . click ( do_interrupt , [ ] , [ ] , cancels = [ ] , queue = False )
2023-03-27 19:43:01 +02:00
2023-03-28 03:19:06 +02:00
def do_interrupt ( ) :
2023-03-27 19:43:01 +02:00
global WANT_INTERRUPT
WANT_INTERRUPT = True
2023-03-25 20:08:26 +01:00
2023-03-27 19:25:08 +02:00
class Callbacks ( transformers . TrainerCallback ) :
def on_step_begin ( self , args : transformers . TrainingArguments , state : transformers . TrainerState , control : transformers . TrainerControl , * * kwargs ) :
global CURRENT_STEPS , MAX_STEPS
2023-03-27 19:43:01 +02:00
CURRENT_STEPS = state . global_step * CURRENT_GRADIENT_ACCUM
MAX_STEPS = state . max_steps * CURRENT_GRADIENT_ACCUM
if WANT_INTERRUPT :
control . should_epoch_stop = True
control . should_training_stop = True
def on_substep_end ( self , args : transformers . TrainingArguments , state : transformers . TrainerState , control : transformers . TrainerControl , * * kwargs ) :
global CURRENT_STEPS
CURRENT_STEPS + = 1
if WANT_INTERRUPT :
control . should_epoch_stop = True
control . should_training_stop = True
2023-03-27 19:25:08 +02:00
2023-03-28 03:19:06 +02:00
def clean_path ( base_path : str , path : str ) :
2023-03-25 20:08:26 +01:00
""" " Strips unusual symbols and forcibly builds a path as relative to the intended directory. """
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
# Or swap it to a strict whitelist of [a-zA-Z_0-9]
path = path . replace ( ' \\ ' , ' / ' ) . replace ( ' .. ' , ' _ ' )
2023-03-28 03:17:42 +02:00
if base_path is None :
2023-03-25 20:08:26 +01:00
return path
2023-03-28 03:17:42 +02:00
return f ' { Path ( base_path ) . absolute ( ) } / { path } '
2023-03-25 20:08:26 +01:00
2023-03-28 07:15:32 +02:00
def do_train ( lora_name : str , micro_batch_size : int , batch_size : int , epochs : int , learning_rate : str , lora_rank : int ,
lora_alpha : int , lora_dropout : float , cutoff_len : int , dataset : str , eval_dataset : str , format : str , raw_text_file : str , overlap_len : int ) :
2023-03-27 19:43:01 +02:00
global WANT_INTERRUPT , CURRENT_STEPS , MAX_STEPS , CURRENT_GRADIENT_ACCUM
WANT_INTERRUPT = False
2023-03-27 19:25:08 +02:00
CURRENT_STEPS = 0
MAX_STEPS = 0
2023-03-28 03:17:42 +02:00
2023-03-27 19:25:08 +02:00
# == Input validation / processing ==
2023-03-28 03:17:42 +02:00
yield " Prepping... "
2023-03-28 05:04:16 +02:00
lora_name = f " { shared . args . lora_dir } / { clean_path ( None , lora_name ) } "
2023-03-28 07:15:32 +02:00
actual_lr = float ( learning_rate )
if cutoff_len < = 0 or micro_batch_size < = 0 or batch_size < = 0 or actual_lr < = 0 or lora_rank < = 0 or lora_alpha < = 0 :
yield f " Cannot input zeroes. "
return
2023-03-28 03:17:42 +02:00
gradient_accumulation_steps = batch_size / / micro_batch_size
CURRENT_GRADIENT_ACCUM = gradient_accumulation_steps
2023-03-25 20:57:36 +01:00
shared . tokenizer . pad_token = 0
shared . tokenizer . padding_side = " left "
2023-03-28 03:17:42 +02:00
2023-03-25 20:08:26 +01:00
def tokenize ( prompt ) :
2023-03-28 03:17:42 +02:00
result = shared . tokenizer ( prompt , truncation = True , max_length = cutoff_len + 1 , padding = " max_length " )
2023-03-25 20:08:26 +01:00
return {
" input_ids " : result [ " input_ids " ] [ : - 1 ] ,
" attention_mask " : result [ " attention_mask " ] [ : - 1 ] ,
}
2023-03-28 03:17:42 +02:00
2023-03-28 07:15:32 +02:00
# == Prep the dataset, format, etc ==
2023-03-29 16:48:17 +02:00
if raw_text_file not in [ ' None ' , ' ' ] :
2023-03-28 07:15:32 +02:00
print ( " Loading raw text file dataset... " )
with open ( clean_path ( ' training/datasets ' , f ' { raw_text_file } .txt ' ) , ' r ' ) as file :
raw_text = file . read ( )
tokens = shared . tokenizer . encode ( raw_text )
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
tokens = list ( split_chunks ( tokens , cutoff_len - overlap_len ) )
for i in range ( 1 , len ( tokens ) ) :
tokens [ i ] = tokens [ i - 1 ] [ - overlap_len : ] + tokens [ i ]
text_chunks = [ shared . tokenizer . decode ( x ) for x in tokens ]
del tokens
data = Dataset . from_list ( [ tokenize ( x ) for x in text_chunks ] )
train_data = data . shuffle ( )
eval_data = None
del text_chunks
else :
2023-03-29 16:48:17 +02:00
if dataset in [ ' None ' , ' ' ] :
2023-03-28 07:15:32 +02:00
yield " **Missing dataset choice input, cannot continue.** "
return
2023-03-29 16:48:17 +02:00
if format in [ ' None ' , ' ' ] :
2023-03-28 07:15:32 +02:00
yield " **Missing format choice input, cannot continue.** "
return
2023-03-28 03:17:42 +02:00
2023-03-29 16:48:17 +02:00
with open ( clean_path ( ' training/formats ' , f ' { format } .json ' ) , ' r ' ) as formatFile :
format_data : dict [ str , str ] = json . load ( formatFile )
2023-03-28 07:15:32 +02:00
def generate_prompt ( data_point : dict [ str , str ] ) :
for options , data in format_data . items ( ) :
if set ( options . split ( ' , ' ) ) == set ( x [ 0 ] for x in data_point . items ( ) if len ( x [ 1 ] . strip ( ) ) > 0 ) :
for key , val in data_point . items ( ) :
data = data . replace ( f ' % { key } % ' , val )
return data
raise RuntimeError ( f ' Data-point " { data_point } " has no keyset match within format " { list ( format_data . keys ( ) ) } " ' )
2023-03-28 03:17:42 +02:00
2023-03-28 07:15:32 +02:00
def generate_and_tokenize_prompt ( data_point ) :
prompt = generate_prompt ( data_point )
return tokenize ( prompt )
print ( " Loading JSON datasets... " )
data = load_dataset ( " json " , data_files = clean_path ( ' training/datasets ' , f ' { dataset } .json ' ) )
train_data = data [ ' train ' ] . shuffle ( ) . map ( generate_and_tokenize_prompt )
if eval_dataset == ' None ' :
eval_data = None
else :
eval_data = load_dataset ( " json " , data_files = clean_path ( ' training/datasets ' , f ' { eval_dataset } .json ' ) )
eval_data = eval_data [ ' train ' ] . shuffle ( ) . map ( generate_and_tokenize_prompt )
2023-03-28 03:17:42 +02:00
2023-03-27 19:25:08 +02:00
# == Start prepping the model itself ==
2023-03-25 20:57:36 +01:00
if not hasattr ( shared . model , ' lm_head ' ) or hasattr ( shared . model . lm_head , ' weight ' ) :
2023-03-27 18:41:06 +02:00
print ( " Getting model ready... " )
2023-03-25 20:57:36 +01:00
prepare_model_for_int8_training ( shared . model )
2023-03-28 03:17:42 +02:00
2023-03-27 18:41:06 +02:00
print ( " Prepping for training... " )
2023-03-25 20:08:26 +01:00
config = LoraConfig (
2023-03-28 03:17:42 +02:00
r = lora_rank ,
lora_alpha = lora_alpha ,
2023-03-25 20:08:26 +01:00
# TODO: Should target_modules be configurable?
target_modules = [ " q_proj " , " v_proj " ] ,
2023-03-28 03:17:42 +02:00
lora_dropout = lora_dropout ,
2023-03-25 20:08:26 +01:00
bias = " none " ,
task_type = " CAUSAL_LM "
)
2023-03-28 03:17:42 +02:00
lora_model = get_peft_model ( shared . model , config )
2023-03-25 20:08:26 +01:00
trainer = transformers . Trainer (
2023-03-28 03:17:42 +02:00
model = lora_model ,
2023-03-25 20:08:26 +01:00
train_dataset = train_data ,
2023-03-28 03:17:42 +02:00
eval_dataset = eval_data ,
2023-03-25 20:08:26 +01:00
args = transformers . TrainingArguments (
2023-03-28 03:17:42 +02:00
per_device_train_batch_size = micro_batch_size ,
gradient_accumulation_steps = gradient_accumulation_steps ,
2023-03-25 20:08:26 +01:00
# TODO: Should more of these be configurable? Probably.
warmup_steps = 100 ,
num_train_epochs = epochs ,
2023-03-28 03:17:42 +02:00
learning_rate = actual_lr ,
2023-03-25 20:08:26 +01:00
fp16 = True ,
logging_steps = 20 ,
2023-03-28 03:17:42 +02:00
evaluation_strategy = " steps " if eval_data is not None else " no " ,
2023-03-25 20:08:26 +01:00
save_strategy = " steps " ,
2023-03-28 03:17:42 +02:00
eval_steps = 200 if eval_data is not None else None ,
2023-03-25 20:08:26 +01:00
save_steps = 200 ,
2023-03-28 03:17:42 +02:00
output_dir = lora_name ,
2023-03-25 20:08:26 +01:00
save_total_limit = 3 ,
2023-03-28 03:17:42 +02:00
load_best_model_at_end = True if eval_data is not None else False ,
2023-03-25 20:08:26 +01:00
# TODO: Enable multi-device support
2023-03-27 18:41:06 +02:00
ddp_find_unused_parameters = None
2023-03-25 20:08:26 +01:00
) ,
2023-03-25 20:57:36 +01:00
data_collator = transformers . DataCollatorForLanguageModeling ( shared . tokenizer , mlm = False ) ,
2023-03-27 19:25:08 +02:00
callbacks = list ( [ Callbacks ( ) ] )
2023-03-25 20:08:26 +01:00
)
2023-03-28 03:17:42 +02:00
lora_model . config . use_cache = False
old_state_dict = lora_model . state_dict
lora_model . state_dict = (
2023-03-25 20:08:26 +01:00
lambda self , * _ , * * __ : get_peft_model_state_dict ( self , old_state_dict ( ) )
2023-03-28 03:17:42 +02:00
) . __get__ ( lora_model , type ( lora_model ) )
2023-03-25 20:08:26 +01:00
if torch . __version__ > = " 2 " and sys . platform != " win32 " :
2023-03-28 03:17:42 +02:00
lora_model = torch . compile ( lora_model )
2023-03-28 02:24:39 +02:00
2023-03-27 19:25:08 +02:00
# == Main run and monitor loop ==
2023-03-27 18:41:06 +02:00
# TODO: save/load checkpoints to resume from?
print ( " Starting training... " )
2023-03-27 19:25:08 +02:00
yield " Starting... "
2023-03-28 02:24:39 +02:00
2023-03-27 19:25:08 +02:00
def threadedRun ( ) :
trainer . train ( )
2023-03-28 02:24:39 +02:00
2023-03-27 19:25:08 +02:00
thread = threading . Thread ( target = threadedRun )
thread . start ( )
lastStep = 0
startTime = time . perf_counter ( )
2023-03-28 02:24:39 +02:00
2023-03-27 19:25:08 +02:00
while thread . is_alive ( ) :
time . sleep ( 0.5 )
2023-03-27 19:43:01 +02:00
if WANT_INTERRUPT :
yield " Interrupting, please wait... *(Run will stop after the current training step completes.)* "
elif CURRENT_STEPS != lastStep :
2023-03-27 19:25:08 +02:00
lastStep = CURRENT_STEPS
timeElapsed = time . perf_counter ( ) - startTime
if timeElapsed < = 0 :
timerInfo = " "
2023-03-27 19:57:27 +02:00
totalTimeEstimate = 999
2023-03-27 19:25:08 +02:00
else :
its = CURRENT_STEPS / timeElapsed
if its > 1 :
timerInfo = f " ` { its : .2f } ` it/s "
else :
timerInfo = f " ` { 1.0 / its : .2f } ` s/it "
2023-03-27 19:57:27 +02:00
totalTimeEstimate = ( 1.0 / its ) * ( MAX_STEPS )
yield f " Running... ** { CURRENT_STEPS } ** / ** { MAX_STEPS } ** ... { timerInfo } , ` { timeElapsed : .0f } `/` { totalTimeEstimate : .0f } ` seconds "
2023-03-28 02:24:39 +02:00
2023-03-27 18:41:06 +02:00
print ( " Training complete, saving... " )
2023-03-28 03:17:42 +02:00
lora_model . save_pretrained ( lora_name )
2023-03-28 02:24:39 +02:00
2023-03-27 19:43:01 +02:00
if WANT_INTERRUPT :
print ( " Training interrupted. " )
2023-03-28 03:17:42 +02:00
yield f " Interrupted. Incomplete LoRA saved to ` { lora_name } ` "
2023-03-27 19:43:01 +02:00
else :
print ( " Training complete! " )
2023-03-28 03:17:42 +02:00
yield f " Done! LoRA saved to ` { lora_name } ` "
2023-03-28 07:15:32 +02:00
def split_chunks ( arr , step ) :
for i in range ( 0 , len ( arr ) , step ) :
yield arr [ i : i + step ]