mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
make 'model' variables less ambiguous
This commit is contained in:
parent
8da237223e
commit
f1ba2196b1
@ -59,15 +59,13 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
||||
return "**Missing format choice input, cannot continue.**"
|
||||
gradientAccumulationSteps = batchSize // microBatchSize
|
||||
actualLR = float(learningRate)
|
||||
model = shared.model
|
||||
tokenizer = shared.tokenizer
|
||||
tokenizer.pad_token = 0
|
||||
tokenizer.padding_side = "left"
|
||||
shared.tokenizer.pad_token = 0
|
||||
shared.tokenizer.padding_side = "left"
|
||||
# Prep the dataset, format, etc
|
||||
with open(cleanPath('training/formats', f'{format}.json'), 'r') as formatFile:
|
||||
formatData: dict[str, str] = json.load(formatFile)
|
||||
def tokenize(prompt):
|
||||
result = tokenizer(prompt, truncation=True, max_length=cutoffLen + 1, padding="max_length")
|
||||
result = shared.tokenizer(prompt, truncation=True, max_length=cutoffLen + 1, padding="max_length")
|
||||
return {
|
||||
"input_ids": result["input_ids"][:-1],
|
||||
"attention_mask": result["attention_mask"][:-1],
|
||||
@ -90,8 +88,8 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
||||
evalData = load_dataset("json", data_files=cleanPath('training/datasets', f'{evalDataset}.json'))
|
||||
evalData = evalData['train'].shuffle().map(generate_and_tokenize_prompt)
|
||||
# Start prepping the model itself
|
||||
if not hasattr(model, 'lm_head') or hasattr(model.lm_head, 'weight'):
|
||||
model = prepare_model_for_int8_training(model)
|
||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||
prepare_model_for_int8_training(shared.model)
|
||||
config = LoraConfig(
|
||||
r=loraRank,
|
||||
lora_alpha=loraAlpha,
|
||||
@ -101,9 +99,9 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
model = get_peft_model(model, config)
|
||||
loraModel = get_peft_model(shared.model, config)
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
model=loraModel,
|
||||
train_dataset=train_data,
|
||||
eval_dataset=evalData,
|
||||
args=transformers.TrainingArguments(
|
||||
@ -125,16 +123,16 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
|
||||
# TODO: Enable multi-device support
|
||||
ddp_find_unused_parameters=None,
|
||||
),
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
||||
)
|
||||
model.config.use_cache = False
|
||||
old_state_dict = model.state_dict
|
||||
model.state_dict = (
|
||||
loraModel.config.use_cache = False
|
||||
old_state_dict = loraModel.state_dict
|
||||
loraModel.state_dict = (
|
||||
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
||||
).__get__(model, type(model))
|
||||
).__get__(loraModel, type(loraModel))
|
||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||
model = torch.compile(model)
|
||||
loraModel = torch.compile(loraModel)
|
||||
# Actually start and run and save at the end
|
||||
trainer.train()
|
||||
model.save_pretrained(loraName)
|
||||
loraModel.save_pretrained(loraName)
|
||||
return "Done!"
|
||||
|
Loading…
Reference in New Issue
Block a user