diff --git a/modules/training.py b/modules/training.py index c98fded2..db4c59c4 100644 --- a/modules/training.py +++ b/modules/training.py @@ -335,10 +335,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch result = result[1:] return result - def tokenize(prompt, append_eos_token=False): + def tokenize(prompt, append_eos_token=False, prepend_bos_token=True): if train_only_after == '' or train_only_after not in prompt: - input_ids = encode(prompt, True) + input_ids = encode(prompt, prepend_bos_token) if append_eos_token and input_ids[-1] != shared.tokenizer.eos_token_id and len(input_ids) < cutoff_len: input_ids.append(shared.tokenizer.eos_token_id) @@ -348,7 +348,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch else: ind = prompt.index(train_only_after) + len(train_only_after) - before_tokens = encode(prompt[:ind], True) + before_tokens = encode(prompt[:ind], prepend_bos_token) after_tokens = encode(prompt[ind:], False) if append_eos_token and after_tokens[-1] != shared.tokenizer.eos_token_id: @@ -393,8 +393,12 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch raw_text = file.read().replace('\r', '') cut_string = hard_cut_string.replace('\\n', '\n') + newline_token = shared.tokenizer.encode('\n')[0] eos_added = 0 out_tokens = [] + if add_eos_token and shared.tokenizer.eos_token_id == shared.tokenizer.bos_token_id: + logger.warn("EOS and BOS tokens are identical when adding EOS tokens. Check model config.") + for text_part in raw_text.split(cut_string): if len(text_part.strip()) <= min_chars: @@ -410,19 +414,17 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch yield f"Error: overlap_len ({overlap_len}) cannot be greater than or equal to cutoff_len ({cutoff_len})" return - out_tokens.extend(split_chunks(tokens, cutoff_len, step)) + out_tokens.extend(split_chunks(tokens, cutoff_len, step, newline_favor_len, newline_token)) if eos_added > 0: - print(f"EOS added to {eos_added} text blocks") + logger.info(f"EOS token added to {eos_added} text blocks") del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM - text_chunks = [shared.tokenizer.decode(x) for x in out_tokens] - del out_tokens - if newline_favor_len > 0: - text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks] - train_data = Dataset.from_list([tokenize(x) for x in text_chunks]) - del text_chunks + + train_data = Dataset.from_list(out_tokens) + del out_tokens + eval_data = None else: if dataset in ['None', '']: @@ -710,28 +712,30 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch yield f"Done! LoRA saved to `{lora_file_path}`" -def split_chunks(arr, size, step): - for i in range(0, len(arr), step): - yield arr[i:i + size] +def split_chunks(arr, size: int, step: int, max_newline_length: int, newline_token: int): + num_tokens = len(arr) + split_end = num_tokens - size + step # Don't split in the last overlap + if split_end < 0: + split_end = num_tokens + split_starts = list(range(0, split_end, step)) + for index in range(1, len(split_starts)): # First split always starts at 0 + if split_starts[index] + size > num_tokens: + split_starts[index] = num_tokens - size + 1 + if max_newline_length > 0 and newline_token in arr[split_starts[index] : split_starts[index] + max_newline_length]: + first_newline = arr[split_starts[index]: split_starts[index] + max_newline_length].index(newline_token) + split_starts[index] += first_newline -def cut_chunk_for_newline(chunk: str, max_length: int): - if '\n' not in chunk: - return chunk - - first_newline = chunk.index('\n') - if first_newline < max_length: - chunk = chunk[first_newline + 1:] - - if '\n' not in chunk: - return chunk - - last_newline = chunk.rindex('\n') - if len(chunk) - last_newline < max_length: - chunk = chunk[:last_newline] - - return chunk - + labels = [1] * size + for i in split_starts: + input_ids = arr[i:i + size] + input_ids = [shared.tokenizer.pad_token_id] * (size - len(input_ids)) + input_ids + input_ids = torch.tensor(input_ids) + yield { + "input_ids": input_ids, + "labels": labels, + "attention_mask": input_ids.ne(shared.tokenizer.pad_token_id), + } def format_time(seconds: float): if seconds < 120: