mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Fix various bugs for LoRA training (#5161)
This commit is contained in:
parent
f6a204d7c9
commit
b80e6365d0
@ -341,7 +341,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||
|
||||
# Populate target_modules list with chosen X_proj modules. Llama-based models only atm, non-llama will revert to default behavior.
|
||||
def list_target_modules(model_id):
|
||||
if model_id != "llama":
|
||||
if model_id != "llama" and model_id != "mistral":
|
||||
return model_to_lora_modules[model_id]
|
||||
|
||||
available_modules = {
|
||||
@ -517,7 +517,8 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||
# == Start prepping the model itself ==
|
||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||
logger.info("Getting model ready")
|
||||
prepare_model_for_kbit_training(shared.model)
|
||||
if 'quantization_config' in shared.model.config.to_dict():
|
||||
prepare_model_for_kbit_training(shared.model)
|
||||
|
||||
# base model is now frozen and should not be reused for any other LoRA training than this one
|
||||
shared.model_dirty_from_training = True
|
||||
@ -615,7 +616,8 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||
warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps),
|
||||
num_train_epochs=epochs,
|
||||
learning_rate=actual_lr,
|
||||
fp16=False if shared.args.cpu else True,
|
||||
fp16=False if shared.args.cpu or shared.args.bf16 else True,
|
||||
bf16=shared.args.bf16,
|
||||
optim=optimizer,
|
||||
logging_steps=2 if stop_at_loss > 0 else 5,
|
||||
evaluation_strategy="steps" if eval_data is not None else "no",
|
||||
@ -627,7 +629,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
|
||||
# TODO: Enable multi-device support
|
||||
ddp_find_unused_parameters=None,
|
||||
no_cuda=shared.args.cpu,
|
||||
use_ipex=True if is_torch_xpu_available and not shared.args.cpu else False
|
||||
use_ipex=True if is_torch_xpu_available() and not shared.args.cpu else False
|
||||
),
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
||||
callbacks=list([Callbacks()])
|
||||
|
Loading…
Reference in New Issue
Block a user