mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Fix training dataset loading #636
This commit is contained in:
parent
41b58bc47e
commit
a6d0373063
@ -119,7 +119,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||||||
}
|
}
|
||||||
|
|
||||||
# == Prep the dataset, format, etc ==
|
# == Prep the dataset, format, etc ==
|
||||||
if raw_text_file is not None:
|
if raw_text_file not in ['None', '']:
|
||||||
print("Loading raw text file dataset...")
|
print("Loading raw text file dataset...")
|
||||||
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file:
|
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file:
|
||||||
raw_text = file.read()
|
raw_text = file.read()
|
||||||
@ -136,16 +136,17 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||||||
del text_chunks
|
del text_chunks
|
||||||
|
|
||||||
else:
|
else:
|
||||||
with open(clean_path('training/formats', f'{format}.json'), 'r') as formatFile:
|
if dataset in ['None', '']:
|
||||||
format_data: dict[str, str] = json.load(formatFile)
|
|
||||||
|
|
||||||
if dataset is None:
|
|
||||||
yield "**Missing dataset choice input, cannot continue.**"
|
yield "**Missing dataset choice input, cannot continue.**"
|
||||||
return
|
return
|
||||||
if format is None:
|
|
||||||
|
if format in ['None', '']:
|
||||||
yield "**Missing format choice input, cannot continue.**"
|
yield "**Missing format choice input, cannot continue.**"
|
||||||
return
|
return
|
||||||
|
|
||||||
|
with open(clean_path('training/formats', f'{format}.json'), 'r') as formatFile:
|
||||||
|
format_data: dict[str, str] = json.load(formatFile)
|
||||||
|
|
||||||
def generate_prompt(data_point: dict[str, str]):
|
def generate_prompt(data_point: dict[str, str]):
|
||||||
for options, data in format_data.items():
|
for options, data in format_data.items():
|
||||||
if set(options.split(',')) == set(x[0] for x in data_point.items() if len(x[1].strip()) > 0):
|
if set(options.split(',')) == set(x[0] for x in data_point.items() if len(x[1].strip()) > 0):
|
||||||
|
Loading…
Reference in New Issue
Block a user