mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-23 00:18:20 +01:00
Make training.py more readable
This commit is contained in:
parent
a3eec62b50
commit
5c513a5f5c
@ -10,23 +10,27 @@ import gradio as gr
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
from peft import (LoraConfig, get_peft_model, get_peft_model_state_dict,
|
from peft import (LoraConfig, PeftModel, get_peft_model,
|
||||||
PeftModel, prepare_model_for_int8_training)
|
get_peft_model_state_dict, prepare_model_for_int8_training)
|
||||||
|
|
||||||
try: # This mapping is from a very recent commit, not yet released.
|
|
||||||
from peft.utils.other import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as model_to_lora_modules
|
|
||||||
except: # So good backup for the 3 safe model types if not yet available.
|
|
||||||
standard_modules = ["q_proj", "v_proj"]
|
|
||||||
model_to_lora_modules = { "llama": standard_modules, "opt": standard_modules, "gptj": standard_modules }
|
|
||||||
|
|
||||||
from modules import shared, ui
|
from modules import shared, ui
|
||||||
|
|
||||||
|
# This mapping is from a very recent commit, not yet released.
|
||||||
|
# If not available, default to a backup map for the 3 safe model types.
|
||||||
|
try:
|
||||||
|
from peft.utils.other import \
|
||||||
|
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
|
||||||
|
model_to_lora_modules
|
||||||
|
except:
|
||||||
|
standard_modules = ["q_proj", "v_proj"]
|
||||||
|
model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules}
|
||||||
|
|
||||||
WANT_INTERRUPT = False
|
WANT_INTERRUPT = False
|
||||||
|
|
||||||
|
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "do_shuffle", "higher_rank_limit"]
|
||||||
|
|
||||||
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lora_rank", "lora_alpha",
|
# Mapping of Python class names to peft IDs
|
||||||
"lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "do_shuffle", "higher_rank_limit"]
|
MODEL_CLASSES = {
|
||||||
MODEL_CLASSES = { # Mapping of Python class names to peft IDs
|
|
||||||
"LlamaForCausalLM": "llama",
|
"LlamaForCausalLM": "llama",
|
||||||
"OPTForCausalLM": "opt",
|
"OPTForCausalLM": "opt",
|
||||||
"GPTJForCausalLM": "gptj"
|
"GPTJForCausalLM": "gptj"
|
||||||
@ -79,12 +83,14 @@ def create_train_interface():
|
|||||||
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'json')}, 'refresh-button')
|
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'json')}, 'refresh-button')
|
||||||
format = gr.Dropdown(choices=get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
|
format = gr.Dropdown(choices=get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
|
||||||
ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_datasets('training/formats', 'json')}, 'refresh-button')
|
ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_datasets('training/formats', 'json')}, 'refresh-button')
|
||||||
|
|
||||||
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
|
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
|
||||||
|
|
||||||
with gr.Tab(label="Raw Text File"):
|
with gr.Tab(label="Raw Text File"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
|
raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
|
||||||
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'txt')}, 'refresh-button')
|
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'txt')}, 'refresh-button')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
|
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
|
||||||
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
|
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
|
||||||
@ -94,14 +100,14 @@ def create_train_interface():
|
|||||||
stop_button = gr.Button("Interrupt")
|
stop_button = gr.Button("Interrupt")
|
||||||
|
|
||||||
output = gr.Markdown(value="Ready")
|
output = gr.Markdown(value="Ready")
|
||||||
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout,
|
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit]
|
||||||
cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit]
|
|
||||||
start_button.click(do_train, all_params, [output])
|
start_button.click(do_train, all_params, [output])
|
||||||
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
|
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
|
||||||
|
|
||||||
def do_copy_params(lora_name: str):
|
def do_copy_params(lora_name: str):
|
||||||
with open(f"{shared.args.lora_dir}/{clean_path(None, lora_name)}/training_parameters.json", 'r', encoding='utf-8') as formatFile:
|
with open(f"{shared.args.lora_dir}/{clean_path(None, lora_name)}/training_parameters.json", 'r', encoding='utf-8') as formatFile:
|
||||||
params: dict[str, str] = json.load(formatFile)
|
params: dict[str, str] = json.load(formatFile)
|
||||||
|
|
||||||
return [params[x] for x in PARAMETERS]
|
return [params[x] for x in PARAMETERS]
|
||||||
|
|
||||||
copy_from.change(do_copy_params, [copy_from], all_params)
|
copy_from.change(do_copy_params, [copy_from], all_params)
|
||||||
@ -118,7 +124,6 @@ def do_interrupt():
|
|||||||
WANT_INTERRUPT = True
|
WANT_INTERRUPT = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def clean_path(base_path: str, path: str):
|
def clean_path(base_path: str, path: str):
|
||||||
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
||||||
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
|
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
|
||||||
@ -126,6 +131,7 @@ def clean_path(base_path: str, path: str):
|
|||||||
path = path.replace('\\', '/').replace('..', '_')
|
path = path.replace('\\', '/').replace('..', '_')
|
||||||
if base_path is None:
|
if base_path is None:
|
||||||
return path
|
return path
|
||||||
|
|
||||||
return f'{Path(base_path).absolute()}/{path}'
|
return f'{Path(base_path).absolute()}/{path}'
|
||||||
|
|
||||||
|
|
||||||
@ -182,15 +188,15 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
print("Loading raw text file dataset...")
|
print("Loading raw text file dataset...")
|
||||||
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
||||||
raw_text = file.read()
|
raw_text = file.read()
|
||||||
|
|
||||||
tokens = shared.tokenizer.encode(raw_text)
|
tokens = shared.tokenizer.encode(raw_text)
|
||||||
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
||||||
|
|
||||||
tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
|
tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
|
||||||
for i in range(1, len(tokens)):
|
for i in range(1, len(tokens)):
|
||||||
tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
|
tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
|
||||||
|
|
||||||
text_chunks = [shared.tokenizer.decode(x) for x in tokens]
|
text_chunks = [shared.tokenizer.decode(x) for x in tokens]
|
||||||
del tokens
|
del tokens
|
||||||
|
|
||||||
if newline_favor_len > 0:
|
if newline_favor_len > 0:
|
||||||
text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
|
text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
|
||||||
|
|
||||||
@ -212,9 +218,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
def generate_prompt(data_point: dict[str, str]):
|
def generate_prompt(data_point: dict[str, str]):
|
||||||
for options, data in format_data.items():
|
for options, data in format_data.items():
|
||||||
if set(options.split(',')) == set(x[0] for x in data_point.items() if (x[1] != None and len(x[1].strip()) > 0)):
|
if set(options.split(',')) == set(x[0] for x in data_point.items() if (x[1] is not None and len(x[1].strip()) > 0)):
|
||||||
for key, val in data_point.items():
|
for key, val in data_point.items():
|
||||||
if val != None:
|
if val is not None:
|
||||||
data = data.replace(f'%{key}%', val)
|
data = data.replace(f'%{key}%', val)
|
||||||
return data
|
return data
|
||||||
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
|
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
|
||||||
@ -357,7 +363,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
timer_info = f"`{its:.2f}` it/s"
|
timer_info = f"`{its:.2f}` it/s"
|
||||||
else:
|
else:
|
||||||
timer_info = f"`{1.0/its:.2f}` s/it"
|
timer_info = f"`{1.0/its:.2f}` s/it"
|
||||||
|
|
||||||
total_time_estimate = (1.0 / its) * (tracked.max_steps)
|
total_time_estimate = (1.0 / its) * (tracked.max_steps)
|
||||||
|
|
||||||
yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
|
yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
|
||||||
|
|
||||||
print("Training complete, saving...")
|
print("Training complete, saving...")
|
||||||
@ -379,22 +387,28 @@ def split_chunks(arr, step):
|
|||||||
def cut_chunk_for_newline(chunk: str, max_length: int):
|
def cut_chunk_for_newline(chunk: str, max_length: int):
|
||||||
if '\n' not in chunk:
|
if '\n' not in chunk:
|
||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
first_newline = chunk.index('\n')
|
first_newline = chunk.index('\n')
|
||||||
if first_newline < max_length:
|
if first_newline < max_length:
|
||||||
chunk = chunk[first_newline + 1:]
|
chunk = chunk[first_newline + 1:]
|
||||||
|
|
||||||
if '\n' not in chunk:
|
if '\n' not in chunk:
|
||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
last_newline = chunk.rindex('\n')
|
last_newline = chunk.rindex('\n')
|
||||||
if len(chunk) - last_newline < max_length:
|
if len(chunk) - last_newline < max_length:
|
||||||
chunk = chunk[:last_newline]
|
chunk = chunk[:last_newline]
|
||||||
|
|
||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
def format_time(seconds: float):
|
def format_time(seconds: float):
|
||||||
if seconds < 120:
|
if seconds < 120:
|
||||||
return f"`{seconds:.0f}` seconds"
|
return f"`{seconds:.0f}` seconds"
|
||||||
|
|
||||||
minutes = seconds / 60
|
minutes = seconds / 60
|
||||||
if minutes < 120:
|
if minutes < 120:
|
||||||
return f"`{minutes:.0f}` minutes"
|
return f"`{minutes:.0f}` minutes"
|
||||||
|
|
||||||
hours = minutes / 60
|
hours = minutes / 60
|
||||||
return f"`{hours:.0f}` hours"
|
return f"`{hours:.0f}` hours"
|
||||||
|
Loading…
Reference in New Issue
Block a user