mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
More robust and error prone training (#3058)
This commit is contained in:
parent
30f37530d5
commit
9b55d3a9f9
@ -478,11 +478,16 @@ def load_character(character, name1, name2, instruct=False):
|
|||||||
if character not in ['None', '', None]:
|
if character not in ['None', '', None]:
|
||||||
folder = 'characters' if not instruct else 'characters/instruction-following'
|
folder = 'characters' if not instruct else 'characters/instruction-following'
|
||||||
picture = generate_pfp_cache(character)
|
picture = generate_pfp_cache(character)
|
||||||
|
filepath = None
|
||||||
for extension in ["yml", "yaml", "json"]:
|
for extension in ["yml", "yaml", "json"]:
|
||||||
filepath = Path(f'{folder}/{character}.{extension}')
|
filepath = Path(f'{folder}/{character}.{extension}')
|
||||||
if filepath.exists():
|
if filepath.exists():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if filepath is None:
|
||||||
|
logger.error(f"Could not find character file for {character} in {folder} folder. Please check your spelling.")
|
||||||
|
return name1, name2, picture, greeting, context, turn_template.replace("\n", r"\n")
|
||||||
|
|
||||||
file_contents = open(filepath, 'r', encoding='utf-8').read()
|
file_contents = open(filepath, 'r', encoding='utf-8').read()
|
||||||
data = json.loads(file_contents) if extension == "json" else yaml.safe_load(file_contents)
|
data = json.loads(file_contents) if extension == "json" else yaml.safe_load(file_contents)
|
||||||
|
|
||||||
|
@ -339,6 +339,7 @@ def clear_torch_cache():
|
|||||||
def unload_model():
|
def unload_model():
|
||||||
shared.model = shared.tokenizer = None
|
shared.model = shared.tokenizer = None
|
||||||
shared.lora_names = []
|
shared.lora_names = []
|
||||||
|
shared.model_dirty_from_training = False
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ tokenizer = None
|
|||||||
is_seq2seq = False
|
is_seq2seq = False
|
||||||
model_name = "None"
|
model_name = "None"
|
||||||
lora_names = []
|
lora_names = []
|
||||||
|
model_dirty_from_training = False
|
||||||
|
|
||||||
# Chat variables
|
# Chat variables
|
||||||
stop_everything = False
|
stop_everything = False
|
||||||
|
@ -17,6 +17,8 @@ from pathlib import Path
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
from modules.models import load_model, unload_model
|
||||||
|
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
from peft import (
|
from peft import (
|
||||||
LoraConfig,
|
LoraConfig,
|
||||||
@ -60,7 +62,7 @@ train_log = {}
|
|||||||
train_template = {}
|
train_template = {}
|
||||||
|
|
||||||
WANT_INTERRUPT = False
|
WANT_INTERRUPT = False
|
||||||
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "report_to"]
|
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to"]
|
||||||
|
|
||||||
|
|
||||||
def create_train_interface():
|
def create_train_interface():
|
||||||
@ -108,6 +110,7 @@ def create_train_interface():
|
|||||||
raw_text_file = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.')
|
raw_text_file = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.')
|
||||||
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'txt')}, 'refresh-button')
|
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'txt')}, 'refresh-button')
|
||||||
hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a hard cut between text parts. Helps prevent unwanted overlap.')
|
hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a hard cut between text parts. Helps prevent unwanted overlap.')
|
||||||
|
min_chars = gr.Number(label='Ignore small blocks', value=0, info='Ignore Hard Cut blocks that have less or equal characters than this number')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
|
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
|
||||||
@ -119,6 +122,7 @@ def create_train_interface():
|
|||||||
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.')
|
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.')
|
||||||
train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.')
|
train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.')
|
||||||
stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
|
stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
|
||||||
|
add_eos_token = gr.Checkbox(label='Add EOS token', value=False, info="Adds EOS token for each dataset item. In case of raw text, the EOS will be added at the Hard Cut")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
|
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
|
||||||
@ -154,7 +158,8 @@ def create_train_interface():
|
|||||||
refresh_table = gr.Button('Refresh the table', elem_classes="small-button")
|
refresh_table = gr.Button('Refresh the table', elem_classes="small-button")
|
||||||
|
|
||||||
# Training events
|
# Training events
|
||||||
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, report_to]
|
|
||||||
|
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to]
|
||||||
|
|
||||||
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
||||||
start_button.click(do_train, all_params, output)
|
start_button.click(do_train, all_params, output)
|
||||||
@ -264,7 +269,7 @@ def calc_trainable_parameters(model):
|
|||||||
return trainable_params, all_param
|
return trainable_params, all_param
|
||||||
|
|
||||||
|
|
||||||
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, report_to: str):
|
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str):
|
||||||
|
|
||||||
if shared.args.monkey_patch:
|
if shared.args.monkey_patch:
|
||||||
from monkeypatch.peft_tuners_lora_monkey_patch import (
|
from monkeypatch.peft_tuners_lora_monkey_patch import (
|
||||||
@ -322,14 +327,22 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
def encode(text, add_bos_token):
|
def encode(text, add_bos_token):
|
||||||
result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len)
|
result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len)
|
||||||
|
# Check if the first two tokens are BOS
|
||||||
|
if len(result) >= 2 and result[:2] == [shared.tokenizer.bos_token_id, shared.tokenizer.bos_token_id]:
|
||||||
|
result = result[1:]
|
||||||
|
|
||||||
if not add_bos_token and result[0] == shared.tokenizer.bos_token_id:
|
if not add_bos_token and result[0] == shared.tokenizer.bos_token_id:
|
||||||
result = result[1:]
|
result = result[1:]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def tokenize(prompt):
|
def tokenize(prompt, append_eos_token=False):
|
||||||
|
|
||||||
if train_only_after == '' or train_only_after not in prompt:
|
if train_only_after == '' or train_only_after not in prompt:
|
||||||
input_ids = encode(prompt, True)
|
input_ids = encode(prompt, True)
|
||||||
|
|
||||||
|
if append_eos_token and input_ids[-1] != shared.tokenizer.eos_token_id and len(input_ids) < cutoff_len:
|
||||||
|
input_ids.append(shared.tokenizer.eos_token_id)
|
||||||
|
|
||||||
input_ids = [shared.tokenizer.pad_token_id] * (cutoff_len - len(input_ids)) + input_ids
|
input_ids = [shared.tokenizer.pad_token_id] * (cutoff_len - len(input_ids)) + input_ids
|
||||||
labels = [1] * len(input_ids)
|
labels = [1] * len(input_ids)
|
||||||
|
|
||||||
@ -338,6 +351,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
before_tokens = encode(prompt[:ind], True)
|
before_tokens = encode(prompt[:ind], True)
|
||||||
after_tokens = encode(prompt[ind:], False)
|
after_tokens = encode(prompt[ind:], False)
|
||||||
|
|
||||||
|
if append_eos_token and after_tokens[-1] != shared.tokenizer.eos_token_id:
|
||||||
|
after_tokens.append(shared.tokenizer.eos_token_id)
|
||||||
|
|
||||||
full_length = len(after_tokens) + len(before_tokens)
|
full_length = len(after_tokens) + len(before_tokens)
|
||||||
if full_length > cutoff_len:
|
if full_length > cutoff_len:
|
||||||
after_tokens = after_tokens[:cutoff_len - len(before_tokens)]
|
after_tokens = after_tokens[:cutoff_len - len(before_tokens)]
|
||||||
@ -377,12 +393,18 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
raw_text = file.read().replace('\r', '')
|
raw_text = file.read().replace('\r', '')
|
||||||
|
|
||||||
cut_string = hard_cut_string.replace('\\n', '\n')
|
cut_string = hard_cut_string.replace('\\n', '\n')
|
||||||
|
eos_added = 0
|
||||||
out_tokens = []
|
out_tokens = []
|
||||||
for text_part in raw_text.split(cut_string):
|
for text_part in raw_text.split(cut_string):
|
||||||
if text_part.strip() == '':
|
|
||||||
|
if len(text_part.strip()) <= min_chars:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tokens = shared.tokenizer.encode(text_part)
|
tokens = shared.tokenizer.encode(text_part)
|
||||||
|
if add_eos_token:
|
||||||
|
tokens.append(shared.tokenizer.eos_token_id)
|
||||||
|
eos_added += 1
|
||||||
|
|
||||||
step = cutoff_len - overlap_len
|
step = cutoff_len - overlap_len
|
||||||
if step <= 0:
|
if step <= 0:
|
||||||
yield f"Error: overlap_len ({overlap_len}) cannot be greater than or equal to cutoff_len ({cutoff_len})"
|
yield f"Error: overlap_len ({overlap_len}) cannot be greater than or equal to cutoff_len ({cutoff_len})"
|
||||||
@ -390,6 +412,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
out_tokens.extend(split_chunks(tokens, cutoff_len, step))
|
out_tokens.extend(split_chunks(tokens, cutoff_len, step))
|
||||||
|
|
||||||
|
if eos_added > 0:
|
||||||
|
print(f"EOS added to {eos_added} text blocks")
|
||||||
|
|
||||||
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
||||||
text_chunks = [shared.tokenizer.decode(x) for x in out_tokens]
|
text_chunks = [shared.tokenizer.decode(x) for x in out_tokens]
|
||||||
del out_tokens
|
del out_tokens
|
||||||
@ -429,7 +454,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
def generate_and_tokenize_prompt(data_point):
|
def generate_and_tokenize_prompt(data_point):
|
||||||
prompt = generate_prompt(data_point)
|
prompt = generate_prompt(data_point)
|
||||||
return tokenize(prompt)
|
return tokenize(prompt, add_eos_token)
|
||||||
|
|
||||||
logger.info("Loading JSON datasets...")
|
logger.info("Loading JSON datasets...")
|
||||||
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
|
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
|
||||||
@ -441,11 +466,33 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
|
eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
|
||||||
eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
||||||
|
|
||||||
|
# == We MUST reload model if it went through any previous training, even failed one ==
|
||||||
|
if shared.model_dirty_from_training:
|
||||||
|
selected_model = shared.model_name
|
||||||
|
if selected_model:
|
||||||
|
print("\033[1;31;1m(Model has been modified by previous training, it needs to be reloaded...)\033[0;37;0m")
|
||||||
|
try:
|
||||||
|
yield f"Reloading {selected_model}..."
|
||||||
|
unload_model()
|
||||||
|
shared.model, shared.tokenizer = load_model(shared.model_name, None)
|
||||||
|
if shared.model is not None:
|
||||||
|
print("Model reloaded OK, continue with training.")
|
||||||
|
else:
|
||||||
|
return f"Failed to load {selected_model}."
|
||||||
|
except:
|
||||||
|
exc = traceback.format_exc()
|
||||||
|
logger.error('Failed to reload the model.')
|
||||||
|
print(exc)
|
||||||
|
return exc
|
||||||
|
|
||||||
# == Start prepping the model itself ==
|
# == Start prepping the model itself ==
|
||||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||||
logger.info("Getting model ready...")
|
logger.info("Getting model ready...")
|
||||||
prepare_model_for_int8_training(shared.model)
|
prepare_model_for_int8_training(shared.model)
|
||||||
|
|
||||||
|
# base model is now frozen and should not be reused for any other LoRA training than this one
|
||||||
|
shared.model_dirty_from_training = True
|
||||||
|
|
||||||
logger.info("Prepping for training...")
|
logger.info("Prepping for training...")
|
||||||
config = LoraConfig(
|
config = LoraConfig(
|
||||||
r=lora_rank,
|
r=lora_rank,
|
||||||
@ -575,6 +622,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
|
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
|
||||||
|
|
||||||
|
projections_string = ", ".join([projection.replace("_proj", "") for projection in model_to_lora_modules[model_id]])
|
||||||
|
|
||||||
|
print(f"Training '{model_id}' model using ({projections_string}) projections")
|
||||||
|
|
||||||
if lora_all_param > 0:
|
if lora_all_param > 0:
|
||||||
print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})")
|
print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})")
|
||||||
|
|
||||||
@ -582,6 +633,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
train_log.update({"base_model_class": shared.model.__class__.__name__})
|
train_log.update({"base_model_class": shared.model.__class__.__name__})
|
||||||
train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)})
|
train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)})
|
||||||
train_log.update({"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False)})
|
train_log.update({"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False)})
|
||||||
|
train_log.update({"projections": projections_string})
|
||||||
|
|
||||||
if stop_at_loss > 0:
|
if stop_at_loss > 0:
|
||||||
print(f"Monitoring loss \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m")
|
print(f"Monitoring loss \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m")
|
||||||
|
Loading…
Reference in New Issue
Block a user