mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Changes to handle multiple line break characters
This commit is contained in:
parent
08e6dfde35
commit
94ea28c24d
@ -393,7 +393,8 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
raw_text = file.read().replace('\r', '')
|
raw_text = file.read().replace('\r', '')
|
||||||
|
|
||||||
cut_string = hard_cut_string.replace('\\n', '\n')
|
cut_string = hard_cut_string.replace('\\n', '\n')
|
||||||
newline_token = shared.tokenizer.encode('\n')[0]
|
newline_token = set(shared.tokenizer.encode('\n')[1:])
|
||||||
|
|
||||||
eos_added = 0
|
eos_added = 0
|
||||||
out_tokens = []
|
out_tokens = []
|
||||||
if add_eos_token and shared.tokenizer.eos_token_id == shared.tokenizer.bos_token_id:
|
if add_eos_token and shared.tokenizer.eos_token_id == shared.tokenizer.bos_token_id:
|
||||||
@ -711,7 +712,11 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
yield f"Done! LoRA saved to `{lora_file_path}`"
|
yield f"Done! LoRA saved to `{lora_file_path}`"
|
||||||
|
|
||||||
|
|
||||||
def split_chunks(arr, size: int, step: int, max_newline_length: int, newline_token: int):
|
def split_chunks(arr, size: int, step: int, max_newline_length: int, newline_tokens: set):
|
||||||
|
while arr and arr[1] in newline_tokens: # Skip the first token, which will be <BOS>
|
||||||
|
del arr[1]
|
||||||
|
while arr and arr[-1] in newline_tokens:
|
||||||
|
del arr[-1]
|
||||||
num_tokens = len(arr)
|
num_tokens = len(arr)
|
||||||
split_end = num_tokens - size + step # Don't split in the last overlap
|
split_end = num_tokens - size + step # Don't split in the last overlap
|
||||||
if split_end < 0:
|
if split_end < 0:
|
||||||
@ -721,8 +726,8 @@ def split_chunks(arr, size: int, step: int, max_newline_length: int, newline_tok
|
|||||||
if split_starts[index] + size > num_tokens:
|
if split_starts[index] + size > num_tokens:
|
||||||
split_starts[index] = num_tokens - size + 1
|
split_starts[index] = num_tokens - size + 1
|
||||||
|
|
||||||
if max_newline_length > 0 and newline_token in arr[split_starts[index]:split_starts[index] + max_newline_length]:
|
if max_newline_length > 0 and ( newline_tokens.intersection(arr[split_starts[index]:split_starts[index] + max_newline_length])):
|
||||||
first_newline = arr[split_starts[index]: split_starts[index] + max_newline_length].index(newline_token)
|
first_newline = end_first_block(arr[split_starts[index]:],newline_tokens)
|
||||||
split_starts[index] += first_newline
|
split_starts[index] += first_newline
|
||||||
|
|
||||||
labels = [1] * size
|
labels = [1] * size
|
||||||
@ -737,6 +742,14 @@ def split_chunks(arr, size: int, step: int, max_newline_length: int, newline_tok
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def end_first_block(arr:list, tokens: set):
|
||||||
|
offset = 0
|
||||||
|
while arr[offset] not in tokens:
|
||||||
|
offset += 1
|
||||||
|
while arr[offset] in tokens:
|
||||||
|
offset += 1
|
||||||
|
return offset
|
||||||
|
|
||||||
def format_time(seconds: float):
|
def format_time(seconds: float):
|
||||||
if seconds < 120:
|
if seconds < 120:
|
||||||
return f"`{seconds:.0f}` seconds"
|
return f"`{seconds:.0f}` seconds"
|
||||||
|
Loading…
Reference in New Issue
Block a user