Make training.py more readable

This commit is contained in:
oobabooga 2023-04-16 02:46:27 -03:00
parent a3eec62b50
commit 5c513a5f5c

View File

@ -10,23 +10,27 @@ import gradio as gr
import torch
import transformers
from datasets import Dataset, load_dataset
from peft import (LoraConfig, get_peft_model, get_peft_model_state_dict,
PeftModel, prepare_model_for_int8_training)
try: # This mapping is from a very recent commit, not yet released.
from peft.utils.other import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as model_to_lora_modules
except: # So good backup for the 3 safe model types if not yet available.
standard_modules = ["q_proj", "v_proj"]
model_to_lora_modules = { "llama": standard_modules, "opt": standard_modules, "gptj": standard_modules }
from peft import (LoraConfig, PeftModel, get_peft_model,
get_peft_model_state_dict, prepare_model_for_int8_training)
from modules import shared, ui
# This mapping is from a very recent commit, not yet released.
# If not available, default to a backup map for the 3 safe model types.
try:
from peft.utils.other import \
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
model_to_lora_modules
except:
standard_modules = ["q_proj", "v_proj"]
model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules}
WANT_INTERRUPT = False
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "do_shuffle", "higher_rank_limit"]
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lora_rank", "lora_alpha",
"lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "do_shuffle", "higher_rank_limit"]
MODEL_CLASSES = { # Mapping of Python class names to peft IDs
# Mapping of Python class names to peft IDs
MODEL_CLASSES = {
"LlamaForCausalLM": "llama",
"OPTForCausalLM": "opt",
"GPTJForCausalLM": "gptj"
@ -79,12 +83,14 @@ def create_train_interface():
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'json')}, 'refresh-button')
format = gr.Dropdown(choices=get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_datasets('training/formats', 'json')}, 'refresh-button')
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
with gr.Tab(label="Raw Text File"):
with gr.Row():
raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'txt')}, 'refresh-button')
with gr.Row():
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
@ -94,14 +100,14 @@ def create_train_interface():
stop_button = gr.Button("Interrupt")
output = gr.Markdown(value="Ready")
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout,
cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit]
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit]
start_button.click(do_train, all_params, [output])
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
def do_copy_params(lora_name: str):
with open(f"{shared.args.lora_dir}/{clean_path(None, lora_name)}/training_parameters.json", 'r', encoding='utf-8') as formatFile:
params: dict[str, str] = json.load(formatFile)
return [params[x] for x in PARAMETERS]
copy_from.change(do_copy_params, [copy_from], all_params)
@ -118,7 +124,6 @@ def do_interrupt():
WANT_INTERRUPT = True
def clean_path(base_path: str, path: str):
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
@ -126,6 +131,7 @@ def clean_path(base_path: str, path: str):
path = path.replace('\\', '/').replace('..', '_')
if base_path is None:
return path
return f'{Path(base_path).absolute()}/{path}'
@ -182,15 +188,15 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
print("Loading raw text file dataset...")
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
raw_text = file.read()
tokens = shared.tokenizer.encode(raw_text)
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
for i in range(1, len(tokens)):
tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
text_chunks = [shared.tokenizer.decode(x) for x in tokens]
del tokens
if newline_favor_len > 0:
text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
@ -212,9 +218,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
def generate_prompt(data_point: dict[str, str]):
for options, data in format_data.items():
if set(options.split(',')) == set(x[0] for x in data_point.items() if (x[1] != None and len(x[1].strip()) > 0)):
if set(options.split(',')) == set(x[0] for x in data_point.items() if (x[1] is not None and len(x[1].strip()) > 0)):
for key, val in data_point.items():
if val != None:
if val is not None:
data = data.replace(f'%{key}%', val)
return data
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
@ -357,7 +363,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
timer_info = f"`{its:.2f}` it/s"
else:
timer_info = f"`{1.0/its:.2f}` s/it"
total_time_estimate = (1.0 / its) * (tracked.max_steps)
yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
print("Training complete, saving...")
@ -379,22 +387,28 @@ def split_chunks(arr, step):
def cut_chunk_for_newline(chunk: str, max_length: int):
if '\n' not in chunk:
return chunk
first_newline = chunk.index('\n')
if first_newline < max_length:
chunk = chunk[first_newline + 1:]
if '\n' not in chunk:
return chunk
last_newline = chunk.rindex('\n')
if len(chunk) - last_newline < max_length:
chunk = chunk[:last_newline]
return chunk
def format_time(seconds: float):
if seconds < 120:
return f"`{seconds:.0f}` seconds"
minutes = seconds / 60
if minutes < 120:
return f"`{minutes:.0f}` minutes"
hours = minutes / 60
return f"`{hours:.0f}` hours"