mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Merge branch 'raw_string_processing' of https://github.com/dandm1/text-generation-webui into raw_string_processing
# Conflicts: # modules/training.py
This commit is contained in:
commit
7bd293e79e
@ -327,12 +327,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
def encode(text, add_bos_token):
|
def encode(text, add_bos_token):
|
||||||
result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len)
|
result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len)
|
||||||
|
|
||||||
# Check if the first two tokens are BOS
|
# Check if the first two tokens are BOS
|
||||||
if len(result) >= 2 and result[:2] == [shared.tokenizer.bos_token_id, shared.tokenizer.bos_token_id]:
|
if len(result) >= 2 and result[:2] == [shared.tokenizer.bos_token_id, shared.tokenizer.bos_token_id]:
|
||||||
result = result[1:]
|
result = result[1:]
|
||||||
|
|
||||||
if not add_bos_token and result[0] == shared.tokenizer.bos_token_id:
|
if not add_bos_token and result[0] == shared.tokenizer.bos_token_id:
|
||||||
result = result[1:]
|
result = result[1:]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def tokenize(prompt, append_eos_token=False, prepend_bos_token=True):
|
def tokenize(prompt, append_eos_token=False, prepend_bos_token=True):
|
||||||
@ -398,7 +400,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
eos_added = 0
|
eos_added = 0
|
||||||
out_tokens = []
|
out_tokens = []
|
||||||
if add_eos_token and shared.tokenizer.eos_token_id == shared.tokenizer.bos_token_id:
|
if add_eos_token and shared.tokenizer.eos_token_id == shared.tokenizer.bos_token_id:
|
||||||
logger.warn("EOS and BOS tokens are identical when adding EOS tokens. Check model config.")
|
logger.warning("EOS and BOS tokens are identical when adding EOS tokens. Check model config.")
|
||||||
|
|
||||||
for text_part in raw_text.split(cut_string):
|
for text_part in raw_text.split(cut_string):
|
||||||
|
|
||||||
@ -721,6 +723,7 @@ def split_chunks(arr, size: int, step: int, max_newline_length: int, newline_tok
|
|||||||
split_end = num_tokens - size + step # Don't split in the last overlap
|
split_end = num_tokens - size + step # Don't split in the last overlap
|
||||||
if split_end < 0:
|
if split_end < 0:
|
||||||
split_end = num_tokens
|
split_end = num_tokens
|
||||||
|
|
||||||
split_starts = list(range(0, split_end, step))
|
split_starts = list(range(0, split_end, step))
|
||||||
for index in range(1, len(split_starts)): # First split always starts at 0
|
for index in range(1, len(split_starts)): # First split always starts at 0
|
||||||
if split_starts[index] + size > num_tokens:
|
if split_starts[index] + size > num_tokens:
|
||||||
|
Loading…
Reference in New Issue
Block a user