Merge branch 'raw_string_processing' of https://github.com/dandm1/text-generation-webui into raw_string_processing

# Conflicts:
#	modules/training.py
This commit is contained in:
MB7979 2023-08-10 22:10:32 +01:00
commit 7bd293e79e

View File

@ -327,12 +327,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
def encode(text, add_bos_token): def encode(text, add_bos_token):
result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len) result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len)
# Check if the first two tokens are BOS # Check if the first two tokens are BOS
if len(result) >= 2 and result[:2] == [shared.tokenizer.bos_token_id, shared.tokenizer.bos_token_id]: if len(result) >= 2 and result[:2] == [shared.tokenizer.bos_token_id, shared.tokenizer.bos_token_id]:
result = result[1:] result = result[1:]
if not add_bos_token and result[0] == shared.tokenizer.bos_token_id: if not add_bos_token and result[0] == shared.tokenizer.bos_token_id:
result = result[1:] result = result[1:]
return result return result
def tokenize(prompt, append_eos_token=False, prepend_bos_token=True): def tokenize(prompt, append_eos_token=False, prepend_bos_token=True):
@ -398,7 +400,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
eos_added = 0 eos_added = 0
out_tokens = [] out_tokens = []
if add_eos_token and shared.tokenizer.eos_token_id == shared.tokenizer.bos_token_id: if add_eos_token and shared.tokenizer.eos_token_id == shared.tokenizer.bos_token_id:
logger.warn("EOS and BOS tokens are identical when adding EOS tokens. Check model config.") logger.warning("EOS and BOS tokens are identical when adding EOS tokens. Check model config.")
for text_part in raw_text.split(cut_string): for text_part in raw_text.split(cut_string):
@ -721,6 +723,7 @@ def split_chunks(arr, size: int, step: int, max_newline_length: int, newline_tok
split_end = num_tokens - size + step # Don't split in the last overlap split_end = num_tokens - size + step # Don't split in the last overlap
if split_end < 0: if split_end < 0:
split_end = num_tokens split_end = num_tokens
split_starts = list(range(0, split_end, step)) split_starts = list(range(0, split_end, step))
for index in range(1, len(split_starts)): # First split always starts at 0 for index in range(1, len(split_starts)): # First split always starts at 0
if split_starts[index] + size > num_tokens: if split_starts[index] + size > num_tokens: