Fixes to splitting of raw strings

This commit is contained in:
MB7979 2023-08-06 11:57:33 +01:00
parent 6afc1a193b
commit 334efe987c

View File

@ -335,10 +335,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
result = result[1:]
return result
def tokenize(prompt, append_eos_token=False):
def tokenize(prompt, append_eos_token=False, prepend_bos_token=True):
if train_only_after == '' or train_only_after not in prompt:
input_ids = encode(prompt, True)
input_ids = encode(prompt, prepend_bos_token)
if append_eos_token and input_ids[-1] != shared.tokenizer.eos_token_id and len(input_ids) < cutoff_len:
input_ids.append(shared.tokenizer.eos_token_id)
@ -348,7 +348,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
else:
ind = prompt.index(train_only_after) + len(train_only_after)
before_tokens = encode(prompt[:ind], True)
before_tokens = encode(prompt[:ind], prepend_bos_token)
after_tokens = encode(prompt[ind:], False)
if append_eos_token and after_tokens[-1] != shared.tokenizer.eos_token_id:
@ -393,8 +393,12 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
raw_text = file.read().replace('\r', '')
cut_string = hard_cut_string.replace('\\n', '\n')
newline_token = shared.tokenizer.encode('\n')[0]
eos_added = 0
out_tokens = []
if add_eos_token and shared.tokenizer.eos_token_id == shared.tokenizer.bos_token_id:
logger.warn("EOS and BOS tokens are identical when adding EOS tokens. Check model config.")
for text_part in raw_text.split(cut_string):
if len(text_part.strip()) <= min_chars:
@ -410,19 +414,17 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
yield f"Error: overlap_len ({overlap_len}) cannot be greater than or equal to cutoff_len ({cutoff_len})"
return
out_tokens.extend(split_chunks(tokens, cutoff_len, step))
out_tokens.extend(split_chunks(tokens, cutoff_len, step, newline_favor_len, newline_token))
if eos_added > 0:
print(f"EOS added to {eos_added} text blocks")
logger.info(f"EOS token added to {eos_added} text blocks")
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
text_chunks = [shared.tokenizer.decode(x) for x in out_tokens]
del out_tokens
if newline_favor_len > 0:
text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
del text_chunks
train_data = Dataset.from_list(out_tokens)
del out_tokens
eval_data = None
else:
if dataset in ['None', '']:
@ -710,28 +712,30 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
yield f"Done! LoRA saved to `{lora_file_path}`"
def split_chunks(arr, size, step):
for i in range(0, len(arr), step):
yield arr[i:i + size]
def split_chunks(arr, size: int, step: int, max_newline_length: int, newline_token: int):
num_tokens = len(arr)
split_end = num_tokens - size + step # Don't split in the last overlap
if split_end < 0:
split_end = num_tokens
split_starts = list(range(0, split_end, step))
for index in range(1, len(split_starts)): # First split always starts at 0
if split_starts[index] + size > num_tokens:
split_starts[index] = num_tokens - size + 1
if max_newline_length > 0 and newline_token in arr[split_starts[index] : split_starts[index] + max_newline_length]:
first_newline = arr[split_starts[index]: split_starts[index] + max_newline_length].index(newline_token)
split_starts[index] += first_newline
def cut_chunk_for_newline(chunk: str, max_length: int):
if '\n' not in chunk:
return chunk
first_newline = chunk.index('\n')
if first_newline < max_length:
chunk = chunk[first_newline + 1:]
if '\n' not in chunk:
return chunk
last_newline = chunk.rindex('\n')
if len(chunk) - last_newline < max_length:
chunk = chunk[:last_newline]
return chunk
labels = [1] * size
for i in split_starts:
input_ids = arr[i:i + size]
input_ids = [shared.tokenizer.pad_token_id] * (size - len(input_ids)) + input_ids
input_ids = torch.tensor(input_ids)
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": input_ids.ne(shared.tokenizer.pad_token_id),
}
def format_time(seconds: float):
if seconds < 120: