From 376fdb3caaffb521dbbdddae3583e62792d10e81 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 9 Aug 2023 19:24:09 -0700 Subject: [PATCH] Minor changes --- modules/training.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/training.py b/modules/training.py index fff7cde1..01a46209 100644 --- a/modules/training.py +++ b/modules/training.py @@ -327,12 +327,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch def encode(text, add_bos_token): result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len) + # Check if the first two tokens are BOS if len(result) >= 2 and result[:2] == [shared.tokenizer.bos_token_id, shared.tokenizer.bos_token_id]: result = result[1:] if not add_bos_token and result[0] == shared.tokenizer.bos_token_id: result = result[1:] + return result def tokenize(prompt, append_eos_token=False, prepend_bos_token=True): @@ -393,11 +395,11 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch raw_text = file.read().replace('\r', '') cut_string = hard_cut_string.replace('\\n', '\n') - newline_token = shared.tokenizer.encode('\n')[0] + newline_token = shared.tokenizer.encode('\n')[-1] eos_added = 0 out_tokens = [] if add_eos_token and shared.tokenizer.eos_token_id == shared.tokenizer.bos_token_id: - logger.warn("EOS and BOS tokens are identical when adding EOS tokens. Check model config.") + logger.warning("EOS and BOS tokens are identical when adding EOS tokens. Check model config.") for text_part in raw_text.split(cut_string): @@ -716,6 +718,7 @@ def split_chunks(arr, size: int, step: int, max_newline_length: int, newline_tok split_end = num_tokens - size + step # Don't split in the last overlap if split_end < 0: split_end = num_tokens + split_starts = list(range(0, split_end, step)) for index in range(1, len(split_starts)): # First split always starts at 0 if split_starts[index] + size > num_tokens: