lora training fixes: (#970)

Fix wrong input format being picked
Fix crash when an entry in the dataset has an attribute of value None
This commit is contained in:
Lukas 2023-04-12 16:38:01 +02:00 committed by GitHub
parent 4f7e88c043
commit 5ad92c940e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -185,10 +185,11 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
def generate_prompt(data_point: dict[str, str]): def generate_prompt(data_point: dict[str, str]):
for options, data in format_data.items(): for options, data in format_data.items():
if set(options.split(',')) == set(x[0] for x in data_point.items() if len(x[1].strip()) > 0): if set(options.split(',')) == set(x[0] for x in data_point.items() if (x[1] != None and len(x[1].strip()) > 0)):
for key, val in data_point.items(): for key, val in data_point.items():
data = data.replace(f'%{key}%', val) if val != None:
return data data = data.replace(f'%{key}%', val)
return data
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"') raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
def generate_and_tokenize_prompt(data_point): def generate_and_tokenize_prompt(data_point):