mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Fix various bugs for LoRA training (#5161)
This commit is contained in:
parent
f6a204d7c9
commit
b80e6365d0
@ -341,7 +341,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
|||||||
|
|
||||||
# Populate target_modules list with chosen X_proj modules. Llama-based models only atm, non-llama will revert to default behavior.
|
# Populate target_modules list with chosen X_proj modules. Llama-based models only atm, non-llama will revert to default behavior.
|
||||||
def list_target_modules(model_id):
|
def list_target_modules(model_id):
|
||||||
if model_id != "llama":
|
if model_id != "llama" and model_id != "mistral":
|
||||||
return model_to_lora_modules[model_id]
|
return model_to_lora_modules[model_id]
|
||||||
|
|
||||||
available_modules = {
|
available_modules = {
|
||||||
@ -517,7 +517,8 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
|||||||
# == Start prepping the model itself ==
|
# == Start prepping the model itself ==
|
||||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||||
logger.info("Getting model ready")
|
logger.info("Getting model ready")
|
||||||
prepare_model_for_kbit_training(shared.model)
|
if 'quantization_config' in shared.model.config.to_dict():
|
||||||
|
prepare_model_for_kbit_training(shared.model)
|
||||||
|
|
||||||
# base model is now frozen and should not be reused for any other LoRA training than this one
|
# base model is now frozen and should not be reused for any other LoRA training than this one
|
||||||
shared.model_dirty_from_training = True
|
shared.model_dirty_from_training = True
|
||||||
@ -615,7 +616,8 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
|||||||
warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps),
|
warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps),
|
||||||
num_train_epochs=epochs,
|
num_train_epochs=epochs,
|
||||||
learning_rate=actual_lr,
|
learning_rate=actual_lr,
|
||||||
fp16=False if shared.args.cpu else True,
|
fp16=False if shared.args.cpu or shared.args.bf16 else True,
|
||||||
|
bf16=shared.args.bf16,
|
||||||
optim=optimizer,
|
optim=optimizer,
|
||||||
logging_steps=2 if stop_at_loss > 0 else 5,
|
logging_steps=2 if stop_at_loss > 0 else 5,
|
||||||
evaluation_strategy="steps" if eval_data is not None else "no",
|
evaluation_strategy="steps" if eval_data is not None else "no",
|
||||||
@ -627,7 +629,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
|||||||
# TODO: Enable multi-device support
|
# TODO: Enable multi-device support
|
||||||
ddp_find_unused_parameters=None,
|
ddp_find_unused_parameters=None,
|
||||||
no_cuda=shared.args.cpu,
|
no_cuda=shared.args.cpu,
|
||||||
use_ipex=True if is_torch_xpu_available and not shared.args.cpu else False
|
use_ipex=True if is_torch_xpu_available() and not shared.args.cpu else False
|
||||||
),
|
),
|
||||||
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
||||||
callbacks=list([Callbacks()])
|
callbacks=list([Callbacks()])
|
||||||
|
Loading…
Reference in New Issue
Block a user