mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-29 19:09:32 +01:00
make 'model' variables less ambiguous
This commit is contained in:
parent
8da237223e
commit
f1ba2196b1
@ -59,15 +59,13 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
|||||||
return "**Missing format choice input, cannot continue.**"
|
return "**Missing format choice input, cannot continue.**"
|
||||||
gradientAccumulationSteps = batchSize // microBatchSize
|
gradientAccumulationSteps = batchSize // microBatchSize
|
||||||
actualLR = float(learningRate)
|
actualLR = float(learningRate)
|
||||||
model = shared.model
|
shared.tokenizer.pad_token = 0
|
||||||
tokenizer = shared.tokenizer
|
shared.tokenizer.padding_side = "left"
|
||||||
tokenizer.pad_token = 0
|
|
||||||
tokenizer.padding_side = "left"
|
|
||||||
# Prep the dataset, format, etc
|
# Prep the dataset, format, etc
|
||||||
with open(cleanPath('training/formats', f'{format}.json'), 'r') as formatFile:
|
with open(cleanPath('training/formats', f'{format}.json'), 'r') as formatFile:
|
||||||
formatData: dict[str, str] = json.load(formatFile)
|
formatData: dict[str, str] = json.load(formatFile)
|
||||||
def tokenize(prompt):
|
def tokenize(prompt):
|
||||||
result = tokenizer(prompt, truncation=True, max_length=cutoffLen + 1, padding="max_length")
|
result = shared.tokenizer(prompt, truncation=True, max_length=cutoffLen + 1, padding="max_length")
|
||||||
return {
|
return {
|
||||||
"input_ids": result["input_ids"][:-1],
|
"input_ids": result["input_ids"][:-1],
|
||||||
"attention_mask": result["attention_mask"][:-1],
|
"attention_mask": result["attention_mask"][:-1],
|
||||||
@ -90,8 +88,8 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
|||||||
evalData = load_dataset("json", data_files=cleanPath('training/datasets', f'{evalDataset}.json'))
|
evalData = load_dataset("json", data_files=cleanPath('training/datasets', f'{evalDataset}.json'))
|
||||||
evalData = evalData['train'].shuffle().map(generate_and_tokenize_prompt)
|
evalData = evalData['train'].shuffle().map(generate_and_tokenize_prompt)
|
||||||
# Start prepping the model itself
|
# Start prepping the model itself
|
||||||
if not hasattr(model, 'lm_head') or hasattr(model.lm_head, 'weight'):
|
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||||
model = prepare_model_for_int8_training(model)
|
prepare_model_for_int8_training(shared.model)
|
||||||
config = LoraConfig(
|
config = LoraConfig(
|
||||||
r=loraRank,
|
r=loraRank,
|
||||||
lora_alpha=loraAlpha,
|
lora_alpha=loraAlpha,
|
||||||
@ -101,9 +99,9 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
|||||||
bias="none",
|
bias="none",
|
||||||
task_type="CAUSAL_LM"
|
task_type="CAUSAL_LM"
|
||||||
)
|
)
|
||||||
model = get_peft_model(model, config)
|
loraModel = get_peft_model(shared.model, config)
|
||||||
trainer = transformers.Trainer(
|
trainer = transformers.Trainer(
|
||||||
model=model,
|
model=loraModel,
|
||||||
train_dataset=train_data,
|
train_dataset=train_data,
|
||||||
eval_dataset=evalData,
|
eval_dataset=evalData,
|
||||||
args=transformers.TrainingArguments(
|
args=transformers.TrainingArguments(
|
||||||
@ -125,16 +123,16 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
|||||||
# TODO: Enable multi-device support
|
# TODO: Enable multi-device support
|
||||||
ddp_find_unused_parameters=None,
|
ddp_find_unused_parameters=None,
|
||||||
),
|
),
|
||||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
||||||
)
|
)
|
||||||
model.config.use_cache = False
|
loraModel.config.use_cache = False
|
||||||
old_state_dict = model.state_dict
|
old_state_dict = loraModel.state_dict
|
||||||
model.state_dict = (
|
loraModel.state_dict = (
|
||||||
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
||||||
).__get__(model, type(model))
|
).__get__(loraModel, type(loraModel))
|
||||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||||
model = torch.compile(model)
|
loraModel = torch.compile(loraModel)
|
||||||
# Actually start and run and save at the end
|
# Actually start and run and save at the end
|
||||||
trainer.train()
|
trainer.train()
|
||||||
model.save_pretrained(loraName)
|
loraModel.save_pretrained(loraName)
|
||||||
return "Done!"
|
return "Done!"
|
||||||
|
Loading…
Reference in New Issue
Block a user