mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-01 20:04:04 +01:00
Fixes to splitting of raw strings
This commit is contained in:
parent
6afc1a193b
commit
334efe987c
@ -335,10 +335,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
result = result[1:]
|
||||
return result
|
||||
|
||||
def tokenize(prompt, append_eos_token=False):
|
||||
def tokenize(prompt, append_eos_token=False, prepend_bos_token=True):
|
||||
|
||||
if train_only_after == '' or train_only_after not in prompt:
|
||||
input_ids = encode(prompt, True)
|
||||
input_ids = encode(prompt, prepend_bos_token)
|
||||
|
||||
if append_eos_token and input_ids[-1] != shared.tokenizer.eos_token_id and len(input_ids) < cutoff_len:
|
||||
input_ids.append(shared.tokenizer.eos_token_id)
|
||||
@ -348,7 +348,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
|
||||
else:
|
||||
ind = prompt.index(train_only_after) + len(train_only_after)
|
||||
before_tokens = encode(prompt[:ind], True)
|
||||
before_tokens = encode(prompt[:ind], prepend_bos_token)
|
||||
after_tokens = encode(prompt[ind:], False)
|
||||
|
||||
if append_eos_token and after_tokens[-1] != shared.tokenizer.eos_token_id:
|
||||
@ -393,8 +393,12 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
raw_text = file.read().replace('\r', '')
|
||||
|
||||
cut_string = hard_cut_string.replace('\\n', '\n')
|
||||
newline_token = shared.tokenizer.encode('\n')[0]
|
||||
eos_added = 0
|
||||
out_tokens = []
|
||||
if add_eos_token and shared.tokenizer.eos_token_id == shared.tokenizer.bos_token_id:
|
||||
logger.warn("EOS and BOS tokens are identical when adding EOS tokens. Check model config.")
|
||||
|
||||
for text_part in raw_text.split(cut_string):
|
||||
|
||||
if len(text_part.strip()) <= min_chars:
|
||||
@ -410,19 +414,17 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
yield f"Error: overlap_len ({overlap_len}) cannot be greater than or equal to cutoff_len ({cutoff_len})"
|
||||
return
|
||||
|
||||
out_tokens.extend(split_chunks(tokens, cutoff_len, step))
|
||||
out_tokens.extend(split_chunks(tokens, cutoff_len, step, newline_favor_len, newline_token))
|
||||
|
||||
if eos_added > 0:
|
||||
print(f"EOS added to {eos_added} text blocks")
|
||||
logger.info(f"EOS token added to {eos_added} text blocks")
|
||||
|
||||
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
||||
text_chunks = [shared.tokenizer.decode(x) for x in out_tokens]
|
||||
del out_tokens
|
||||
if newline_favor_len > 0:
|
||||
text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
|
||||
|
||||
train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
|
||||
del text_chunks
|
||||
|
||||
train_data = Dataset.from_list(out_tokens)
|
||||
del out_tokens
|
||||
|
||||
eval_data = None
|
||||
else:
|
||||
if dataset in ['None', '']:
|
||||
@ -710,28 +712,30 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
yield f"Done! LoRA saved to `{lora_file_path}`"
|
||||
|
||||
|
||||
def split_chunks(arr, size, step):
|
||||
for i in range(0, len(arr), step):
|
||||
yield arr[i:i + size]
|
||||
def split_chunks(arr, size: int, step: int, max_newline_length: int, newline_token: int):
|
||||
num_tokens = len(arr)
|
||||
split_end = num_tokens - size + step # Don't split in the last overlap
|
||||
if split_end < 0:
|
||||
split_end = num_tokens
|
||||
split_starts = list(range(0, split_end, step))
|
||||
for index in range(1, len(split_starts)): # First split always starts at 0
|
||||
if split_starts[index] + size > num_tokens:
|
||||
split_starts[index] = num_tokens - size + 1
|
||||
|
||||
if max_newline_length > 0 and newline_token in arr[split_starts[index] : split_starts[index] + max_newline_length]:
|
||||
first_newline = arr[split_starts[index]: split_starts[index] + max_newline_length].index(newline_token)
|
||||
split_starts[index] += first_newline
|
||||
|
||||
def cut_chunk_for_newline(chunk: str, max_length: int):
|
||||
if '\n' not in chunk:
|
||||
return chunk
|
||||
|
||||
first_newline = chunk.index('\n')
|
||||
if first_newline < max_length:
|
||||
chunk = chunk[first_newline + 1:]
|
||||
|
||||
if '\n' not in chunk:
|
||||
return chunk
|
||||
|
||||
last_newline = chunk.rindex('\n')
|
||||
if len(chunk) - last_newline < max_length:
|
||||
chunk = chunk[:last_newline]
|
||||
|
||||
return chunk
|
||||
|
||||
labels = [1] * size
|
||||
for i in split_starts:
|
||||
input_ids = arr[i:i + size]
|
||||
input_ids = [shared.tokenizer.pad_token_id] * (size - len(input_ids)) + input_ids
|
||||
input_ids = torch.tensor(input_ids)
|
||||
yield {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": input_ids.ne(shared.tokenizer.pad_token_id),
|
||||
}
|
||||
|
||||
def format_time(seconds: float):
|
||||
if seconds < 120:
|
||||
|
Loading…
Reference in New Issue
Block a user