mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-10 04:20:33 +01:00
Training_PRO fix: add if 'quantization_config' in shared.model.config.to_dict()
This commit is contained in:
parent
c0f600c887
commit
03a0f236a4
@ -789,7 +789,11 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||||
logger.info("Getting model ready...")
|
logger.info("Getting model ready...")
|
||||||
# here we can disable gradient checkpoint, by default = true, use_gradient_checkpointing=True
|
# here we can disable gradient checkpoint, by default = true, use_gradient_checkpointing=True
|
||||||
|
if 'quantization_config' in shared.model.config.to_dict():
|
||||||
|
print(f"Method: {RED}QLORA{RESET}")
|
||||||
prepare_model_for_kbit_training(shared.model)
|
prepare_model_for_kbit_training(shared.model)
|
||||||
|
else:
|
||||||
|
print(f"Method: {RED}LoRA{RESET}")
|
||||||
|
|
||||||
# base model is now frozen and should not be reused for any other LoRA training than this one
|
# base model is now frozen and should not be reused for any other LoRA training than this one
|
||||||
shared.model_dirty_from_training = True
|
shared.model_dirty_from_training = True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user