Fix training dataset loading #636

This commit is contained in:
oobabooga 2023-03-29 11:48:17 -03:00 committed by GitHub
parent 41b58bc47e
commit a6d0373063
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -119,7 +119,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
} }
# == Prep the dataset, format, etc == # == Prep the dataset, format, etc ==
if raw_text_file is not None: if raw_text_file not in ['None', '']:
print("Loading raw text file dataset...") print("Loading raw text file dataset...")
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file: with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file:
raw_text = file.read() raw_text = file.read()
@ -136,16 +136,17 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
del text_chunks del text_chunks
else: else:
with open(clean_path('training/formats', f'{format}.json'), 'r') as formatFile: if dataset in ['None', '']:
format_data: dict[str, str] = json.load(formatFile)
if dataset is None:
yield "**Missing dataset choice input, cannot continue.**" yield "**Missing dataset choice input, cannot continue.**"
return return
if format is None:
if format in ['None', '']:
yield "**Missing format choice input, cannot continue.**" yield "**Missing format choice input, cannot continue.**"
return return
with open(clean_path('training/formats', f'{format}.json'), 'r') as formatFile:
format_data: dict[str, str] = json.load(formatFile)
def generate_prompt(data_point: dict[str, str]): def generate_prompt(data_point: dict[str, str]):
for options, data in format_data.items(): for options, data in format_data.items():
if set(options.split(',')) == set(x[0] for x in data_point.items() if len(x[1].strip()) > 0): if set(options.split(',')) == set(x[0] for x in data_point.items() if len(x[1].strip()) > 0):