mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-01 20:04:04 +01:00
pycodestyle cleanup
This commit is contained in:
parent
334efe987c
commit
08e6dfde35
@ -421,7 +421,6 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
||||||
|
|
||||||
|
|
||||||
train_data = Dataset.from_list(out_tokens)
|
train_data = Dataset.from_list(out_tokens)
|
||||||
del out_tokens
|
del out_tokens
|
||||||
|
|
||||||
@ -714,15 +713,15 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
def split_chunks(arr, size: int, step: int, max_newline_length: int, newline_token: int):
|
def split_chunks(arr, size: int, step: int, max_newline_length: int, newline_token: int):
|
||||||
num_tokens = len(arr)
|
num_tokens = len(arr)
|
||||||
split_end = num_tokens - size + step # Don't split in the last overlap
|
split_end = num_tokens - size + step # Don't split in the last overlap
|
||||||
if split_end < 0:
|
if split_end < 0:
|
||||||
split_end = num_tokens
|
split_end = num_tokens
|
||||||
split_starts = list(range(0, split_end, step))
|
split_starts = list(range(0, split_end, step))
|
||||||
for index in range(1, len(split_starts)): # First split always starts at 0
|
for index in range(1, len(split_starts)): # First split always starts at 0
|
||||||
if split_starts[index] + size > num_tokens:
|
if split_starts[index] + size > num_tokens:
|
||||||
split_starts[index] = num_tokens - size + 1
|
split_starts[index] = num_tokens - size + 1
|
||||||
|
|
||||||
if max_newline_length > 0 and newline_token in arr[split_starts[index] : split_starts[index] + max_newline_length]:
|
if max_newline_length > 0 and newline_token in arr[split_starts[index]:split_starts[index] + max_newline_length]:
|
||||||
first_newline = arr[split_starts[index]: split_starts[index] + max_newline_length].index(newline_token)
|
first_newline = arr[split_starts[index]: split_starts[index] + max_newline_length].index(newline_token)
|
||||||
split_starts[index] += first_newline
|
split_starts[index] += first_newline
|
||||||
|
|
||||||
@ -737,6 +736,7 @@ def split_chunks(arr, size: int, step: int, max_newline_length: int, newline_tok
|
|||||||
"attention_mask": input_ids.ne(shared.tokenizer.pad_token_id),
|
"attention_mask": input_ids.ne(shared.tokenizer.pad_token_id),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def format_time(seconds: float):
|
def format_time(seconds: float):
|
||||||
if seconds < 120:
|
if seconds < 120:
|
||||||
return f"`{seconds:.0f}` seconds"
|
return f"`{seconds:.0f}` seconds"
|
||||||
|
Loading…
Reference in New Issue
Block a user