From 94ea28c24dba40ae238d3d73d0dae8f0444419ea Mon Sep 17 00:00:00 2001 From: MB7979 Date: Thu, 10 Aug 2023 22:01:45 +0100 Subject: [PATCH] Changes to handle multiple line break characters --- modules/training.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/modules/training.py b/modules/training.py index fff7cde1..50947735 100644 --- a/modules/training.py +++ b/modules/training.py @@ -393,7 +393,8 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch raw_text = file.read().replace('\r', '') cut_string = hard_cut_string.replace('\\n', '\n') - newline_token = shared.tokenizer.encode('\n')[0] + newline_token = set(shared.tokenizer.encode('\n')[1:]) + eos_added = 0 out_tokens = [] if add_eos_token and shared.tokenizer.eos_token_id == shared.tokenizer.bos_token_id: @@ -711,7 +712,11 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch yield f"Done! LoRA saved to `{lora_file_path}`" -def split_chunks(arr, size: int, step: int, max_newline_length: int, newline_token: int): +def split_chunks(arr, size: int, step: int, max_newline_length: int, newline_tokens: set): + while arr and arr[1] in newline_tokens: # Skip the first token, which will be + del arr[1] + while arr and arr[-1] in newline_tokens: + del arr[-1] num_tokens = len(arr) split_end = num_tokens - size + step # Don't split in the last overlap if split_end < 0: @@ -721,8 +726,8 @@ def split_chunks(arr, size: int, step: int, max_newline_length: int, newline_tok if split_starts[index] + size > num_tokens: split_starts[index] = num_tokens - size + 1 - if max_newline_length > 0 and newline_token in arr[split_starts[index]:split_starts[index] + max_newline_length]: - first_newline = arr[split_starts[index]: split_starts[index] + max_newline_length].index(newline_token) + if max_newline_length > 0 and ( newline_tokens.intersection(arr[split_starts[index]:split_starts[index] + max_newline_length])): + first_newline = end_first_block(arr[split_starts[index]:],newline_tokens) split_starts[index] += first_newline labels = [1] * size @@ -737,6 +742,14 @@ def split_chunks(arr, size: int, step: int, max_newline_length: int, newline_tok } +def end_first_block(arr:list, tokens: set): + offset = 0 + while arr[offset] not in tokens: + offset += 1 + while arr[offset] in tokens: + offset += 1 + return offset + def format_time(seconds: float): if seconds < 120: return f"`{seconds:.0f}` seconds"