mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Lora trainer improvements (#763)
This commit is contained in:
parent
5b301d9a02
commit
0c7ef26981
@ -20,7 +20,7 @@ MAX_STEPS = 0
|
||||
CURRENT_GRADIENT_ACCUM = 1
|
||||
|
||||
def get_dataset(path: str, ext: str):
|
||||
return ['None'] + sorted(set((k.stem for k in Path(path).glob(f'*.{ext}'))), key=str.lower)
|
||||
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
|
||||
|
||||
def create_train_interface():
|
||||
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
||||
@ -45,22 +45,26 @@ def create_train_interface():
|
||||
with gr.Row():
|
||||
dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
|
||||
ui.create_refresh_button(dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||
eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The dataset file used to evaluate the model after training.')
|
||||
eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
|
||||
ui.create_refresh_button(eval_dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||
format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
|
||||
ui.create_refresh_button(format, lambda : None, lambda : {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
|
||||
|
||||
with gr.Tab(label="Raw Text File"):
|
||||
with gr.Row():
|
||||
raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
|
||||
ui.create_refresh_button(raw_text_file, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
|
||||
overlap_len = gr.Slider(label='Overlap Length', minimum=0,maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length above). Setting overlap to exactly half the cutoff length may be ideal.')
|
||||
with gr.Row():
|
||||
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
|
||||
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
|
||||
|
||||
with gr.Row():
|
||||
start_button = gr.Button("Start LoRA Training")
|
||||
stop_button = gr.Button("Interrupt")
|
||||
|
||||
output = gr.Markdown(value="Ready")
|
||||
start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len], [output])
|
||||
start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout,
|
||||
cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output])
|
||||
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
|
||||
|
||||
def do_interrupt():
|
||||
@ -91,8 +95,8 @@ def clean_path(base_path: str, path: str):
|
||||
return path
|
||||
return f'{Path(base_path).absolute()}/{path}'
|
||||
|
||||
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int,
|
||||
lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int):
|
||||
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
|
||||
cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int):
|
||||
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
|
||||
WANT_INTERRUPT = False
|
||||
CURRENT_STEPS = 0
|
||||
@ -103,6 +107,25 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
lora_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}"
|
||||
actual_lr = float(learning_rate)
|
||||
|
||||
model_type = type(shared.model).__name__
|
||||
if model_type != "LlamaForCausalLM":
|
||||
if model_type == "PeftModelForCausalLM":
|
||||
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.")
|
||||
else:
|
||||
yield "LoRA training has only currently been validated for LLaMA models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
print(f"Warning: LoRA training has only currently been validated for LLaMA models. (Found model type: {model_type})")
|
||||
time.sleep(5)
|
||||
|
||||
if shared.args.wbits > 0 or shared.args.gptq_bits > 0:
|
||||
yield "LoRA training does not yet support 4bit. Please use `--load-in-8bit` for now."
|
||||
return
|
||||
|
||||
elif not shared.args.load_in_8bit:
|
||||
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
|
||||
print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
|
||||
time.sleep(2) # Give it a moment for the message to show in UI before continuing
|
||||
|
||||
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
|
||||
yield "Cannot input zeroes."
|
||||
return
|
||||
@ -126,15 +149,20 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
raw_text = file.read()
|
||||
tokens = shared.tokenizer.encode(raw_text)
|
||||
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
||||
|
||||
tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
|
||||
for i in range(1, len(tokens)):
|
||||
tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
|
||||
text_chunks = [shared.tokenizer.decode(x) for x in tokens]
|
||||
del tokens
|
||||
data = Dataset.from_list([tokenize(x) for x in text_chunks])
|
||||
train_data = data.shuffle()
|
||||
eval_data = None
|
||||
|
||||
if newline_favor_len > 0:
|
||||
text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
|
||||
|
||||
train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
|
||||
del text_chunks
|
||||
train_data = train_data.shuffle()
|
||||
eval_data = None
|
||||
|
||||
else:
|
||||
if dataset in ['None', '']:
|
||||
@ -232,33 +260,37 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
# TODO: save/load checkpoints to resume from?
|
||||
print("Starting training...")
|
||||
yield "Starting..."
|
||||
if WANT_INTERRUPT:
|
||||
yield "Interrupted before start."
|
||||
return
|
||||
|
||||
def threadedRun():
|
||||
def threaded_run():
|
||||
trainer.train()
|
||||
|
||||
thread = threading.Thread(target=threadedRun)
|
||||
thread = threading.Thread(target=threaded_run)
|
||||
thread.start()
|
||||
lastStep = 0
|
||||
startTime = time.perf_counter()
|
||||
last_step = 0
|
||||
start_time = time.perf_counter()
|
||||
|
||||
while thread.is_alive():
|
||||
time.sleep(0.5)
|
||||
if WANT_INTERRUPT:
|
||||
yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
|
||||
elif CURRENT_STEPS != lastStep:
|
||||
lastStep = CURRENT_STEPS
|
||||
timeElapsed = time.perf_counter() - startTime
|
||||
if timeElapsed <= 0:
|
||||
timerInfo = ""
|
||||
totalTimeEstimate = 999
|
||||
|
||||
elif CURRENT_STEPS != last_step:
|
||||
last_step = CURRENT_STEPS
|
||||
time_elapsed = time.perf_counter() - start_time
|
||||
if time_elapsed <= 0:
|
||||
timer_info = ""
|
||||
total_time_estimate = 999
|
||||
else:
|
||||
its = CURRENT_STEPS / timeElapsed
|
||||
its = CURRENT_STEPS / time_elapsed
|
||||
if its > 1:
|
||||
timerInfo = f"`{its:.2f}` it/s"
|
||||
timer_info = f"`{its:.2f}` it/s"
|
||||
else:
|
||||
timerInfo = f"`{1.0/its:.2f}` s/it"
|
||||
totalTimeEstimate = (1.0/its) * (MAX_STEPS)
|
||||
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds"
|
||||
timer_info = f"`{1.0/its:.2f}` s/it"
|
||||
total_time_estimate = (1.0/its) * (MAX_STEPS)
|
||||
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
|
||||
|
||||
print("Training complete, saving...")
|
||||
lora_model.save_pretrained(lora_name)
|
||||
@ -273,3 +305,25 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
def split_chunks(arr, step):
|
||||
for i in range(0, len(arr), step):
|
||||
yield arr[i:i + step]
|
||||
|
||||
def cut_chunk_for_newline(chunk: str, max_length: int):
|
||||
if '\n' not in chunk:
|
||||
return chunk
|
||||
first_newline = chunk.index('\n')
|
||||
if first_newline < max_length:
|
||||
chunk = chunk[first_newline + 1:]
|
||||
if '\n' not in chunk:
|
||||
return chunk
|
||||
last_newline = chunk.rindex('\n')
|
||||
if len(chunk) - last_newline < max_length:
|
||||
chunk = chunk[:last_newline]
|
||||
return chunk
|
||||
|
||||
def format_time(seconds: float):
|
||||
if seconds < 120:
|
||||
return f"`{seconds:.0f}` seconds"
|
||||
minutes = seconds / 60
|
||||
if minutes < 120:
|
||||
return f"`{minutes:.0f}` minutes"
|
||||
hours = minutes / 60
|
||||
return f"`{hours:.0f}` hours"
|
||||
|
Loading…
Reference in New Issue
Block a user