Make training.py more readable

This commit is contained in:
oobabooga 2023-04-16 02:46:27 -03:00
parent a3eec62b50
commit 5c513a5f5c

View File

@ -10,23 +10,27 @@ import gradio as gr
import torch import torch
import transformers import transformers
from datasets import Dataset, load_dataset from datasets import Dataset, load_dataset
from peft import (LoraConfig, get_peft_model, get_peft_model_state_dict, from peft import (LoraConfig, PeftModel, get_peft_model,
PeftModel, prepare_model_for_int8_training) get_peft_model_state_dict, prepare_model_for_int8_training)
try: # This mapping is from a very recent commit, not yet released.
from peft.utils.other import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as model_to_lora_modules
except: # So good backup for the 3 safe model types if not yet available.
standard_modules = ["q_proj", "v_proj"]
model_to_lora_modules = { "llama": standard_modules, "opt": standard_modules, "gptj": standard_modules }
from modules import shared, ui from modules import shared, ui
# This mapping is from a very recent commit, not yet released.
# If not available, default to a backup map for the 3 safe model types.
try:
from peft.utils.other import \
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
model_to_lora_modules
except:
standard_modules = ["q_proj", "v_proj"]
model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules}
WANT_INTERRUPT = False WANT_INTERRUPT = False
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "do_shuffle", "higher_rank_limit"]
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lora_rank", "lora_alpha", # Mapping of Python class names to peft IDs
"lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "do_shuffle", "higher_rank_limit"] MODEL_CLASSES = {
MODEL_CLASSES = { # Mapping of Python class names to peft IDs
"LlamaForCausalLM": "llama", "LlamaForCausalLM": "llama",
"OPTForCausalLM": "opt", "OPTForCausalLM": "opt",
"GPTJForCausalLM": "gptj" "GPTJForCausalLM": "gptj"
@ -79,12 +83,14 @@ def create_train_interface():
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'json')}, 'refresh-button') ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'json')}, 'refresh-button')
format = gr.Dropdown(choices=get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.') format = gr.Dropdown(choices=get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_datasets('training/formats', 'json')}, 'refresh-button') ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_datasets('training/formats', 'json')}, 'refresh-button')
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.') eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
with gr.Tab(label="Raw Text File"): with gr.Tab(label="Raw Text File"):
with gr.Row(): with gr.Row():
raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.') raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'txt')}, 'refresh-button') ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'txt')}, 'refresh-button')
with gr.Row(): with gr.Row():
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.') overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.') newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
@ -94,14 +100,14 @@ def create_train_interface():
stop_button = gr.Button("Interrupt") stop_button = gr.Button("Interrupt")
output = gr.Markdown(value="Ready") output = gr.Markdown(value="Ready")
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit]
cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit]
start_button.click(do_train, all_params, [output]) start_button.click(do_train, all_params, [output])
stop_button.click(do_interrupt, [], [], cancels=[], queue=False) stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
def do_copy_params(lora_name: str): def do_copy_params(lora_name: str):
with open(f"{shared.args.lora_dir}/{clean_path(None, lora_name)}/training_parameters.json", 'r', encoding='utf-8') as formatFile: with open(f"{shared.args.lora_dir}/{clean_path(None, lora_name)}/training_parameters.json", 'r', encoding='utf-8') as formatFile:
params: dict[str, str] = json.load(formatFile) params: dict[str, str] = json.load(formatFile)
return [params[x] for x in PARAMETERS] return [params[x] for x in PARAMETERS]
copy_from.change(do_copy_params, [copy_from], all_params) copy_from.change(do_copy_params, [copy_from], all_params)
@ -118,7 +124,6 @@ def do_interrupt():
WANT_INTERRUPT = True WANT_INTERRUPT = True
def clean_path(base_path: str, path: str): def clean_path(base_path: str, path: str):
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory.""" """"Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path. # TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
@ -126,6 +131,7 @@ def clean_path(base_path: str, path: str):
path = path.replace('\\', '/').replace('..', '_') path = path.replace('\\', '/').replace('..', '_')
if base_path is None: if base_path is None:
return path return path
return f'{Path(base_path).absolute()}/{path}' return f'{Path(base_path).absolute()}/{path}'
@ -182,15 +188,15 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
print("Loading raw text file dataset...") print("Loading raw text file dataset...")
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file: with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
raw_text = file.read() raw_text = file.read()
tokens = shared.tokenizer.encode(raw_text) tokens = shared.tokenizer.encode(raw_text)
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
tokens = list(split_chunks(tokens, cutoff_len - overlap_len)) tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
for i in range(1, len(tokens)): for i in range(1, len(tokens)):
tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i] tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
text_chunks = [shared.tokenizer.decode(x) for x in tokens] text_chunks = [shared.tokenizer.decode(x) for x in tokens]
del tokens del tokens
if newline_favor_len > 0: if newline_favor_len > 0:
text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks] text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
@ -212,9 +218,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
def generate_prompt(data_point: dict[str, str]): def generate_prompt(data_point: dict[str, str]):
for options, data in format_data.items(): for options, data in format_data.items():
if set(options.split(',')) == set(x[0] for x in data_point.items() if (x[1] != None and len(x[1].strip()) > 0)): if set(options.split(',')) == set(x[0] for x in data_point.items() if (x[1] is not None and len(x[1].strip()) > 0)):
for key, val in data_point.items(): for key, val in data_point.items():
if val != None: if val is not None:
data = data.replace(f'%{key}%', val) data = data.replace(f'%{key}%', val)
return data return data
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"') raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
@ -357,7 +363,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
timer_info = f"`{its:.2f}` it/s" timer_info = f"`{its:.2f}` it/s"
else: else:
timer_info = f"`{1.0/its:.2f}` s/it" timer_info = f"`{1.0/its:.2f}` s/it"
total_time_estimate = (1.0 / its) * (tracked.max_steps) total_time_estimate = (1.0 / its) * (tracked.max_steps)
yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining" yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
print("Training complete, saving...") print("Training complete, saving...")
@ -379,22 +387,28 @@ def split_chunks(arr, step):
def cut_chunk_for_newline(chunk: str, max_length: int): def cut_chunk_for_newline(chunk: str, max_length: int):
if '\n' not in chunk: if '\n' not in chunk:
return chunk return chunk
first_newline = chunk.index('\n') first_newline = chunk.index('\n')
if first_newline < max_length: if first_newline < max_length:
chunk = chunk[first_newline + 1:] chunk = chunk[first_newline + 1:]
if '\n' not in chunk: if '\n' not in chunk:
return chunk return chunk
last_newline = chunk.rindex('\n') last_newline = chunk.rindex('\n')
if len(chunk) - last_newline < max_length: if len(chunk) - last_newline < max_length:
chunk = chunk[:last_newline] chunk = chunk[:last_newline]
return chunk return chunk
def format_time(seconds: float): def format_time(seconds: float):
if seconds < 120: if seconds < 120:
return f"`{seconds:.0f}` seconds" return f"`{seconds:.0f}` seconds"
minutes = seconds / 60 minutes = seconds / 60
if minutes < 120: if minutes < 120:
return f"`{minutes:.0f}` minutes" return f"`{minutes:.0f}` minutes"
hours = minutes / 60 hours = minutes / 60
return f"`{hours:.0f}` hours" return f"`{hours:.0f}` hours"