lora : improve compat with mergekit-extract-lora (#11131)

* (wip) support mergekit-extracted lora

* support mergekit-extract-lora

* use lora->get_scale

* correct comment

* correct norm name & condition

* add some hints
This commit is contained in:
Xuan Son Nguyen 2025-01-08 15:59:53 +01:00 committed by GitHub
parent c07d437bbd
commit 4d2b3d8804
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 74 additions and 12 deletions

View File

@ -226,6 +226,9 @@ def get_base_tensor_name(lora_tensor_name: str) -> str:
base_name = lora_tensor_name.replace("base_model.model.", "") base_name = lora_tensor_name.replace("base_model.model.", "")
base_name = base_name.replace(".lora_A.weight", ".weight") base_name = base_name.replace(".lora_A.weight", ".weight")
base_name = base_name.replace(".lora_B.weight", ".weight") base_name = base_name.replace(".lora_B.weight", ".weight")
# models produced by mergekit-extract-lora have token embeddings in the adapter
base_name = base_name.replace(".lora_embedding_A", ".weight")
base_name = base_name.replace(".lora_embedding_B", ".weight")
return base_name return base_name
@ -260,6 +263,10 @@ def parse_args() -> argparse.Namespace:
"--base", type=Path, "--base", type=Path,
help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config", help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config",
) )
parser.add_argument(
"--base-model-id", type=str,
help="the model ID of the base model, if it is not available locally or in the adapter config. If specified, it will ignore --base and load the base model config from the Hugging Face hub (Example: 'meta-llama/Llama-3.2-1B-Instruct')",
)
parser.add_argument( parser.add_argument(
"lora_path", type=Path, "lora_path", type=Path,
help="directory containing Hugging Face PEFT LoRA config (adapter_model.json) and weights (adapter_model.safetensors or adapter_model.bin)", help="directory containing Hugging Face PEFT LoRA config (adapter_model.json) and weights (adapter_model.safetensors or adapter_model.bin)",
@ -290,6 +297,7 @@ if __name__ == '__main__':
dir_base_model: Path | None = args.base dir_base_model: Path | None = args.base
dir_lora: Path = args.lora_path dir_lora: Path = args.lora_path
base_model_id: str | None = args.base_model_id
lora_config = dir_lora / "adapter_config.json" lora_config = dir_lora / "adapter_config.json"
input_model = dir_lora / "adapter_model.safetensors" input_model = dir_lora / "adapter_model.safetensors"
@ -313,7 +321,10 @@ if __name__ == '__main__':
lparams: dict[str, Any] = json.load(f) lparams: dict[str, Any] = json.load(f)
# load base model # load base model
if dir_base_model is None: if base_model_id is not None:
logger.info(f"Loading base model from Hugging Face: {base_model_id}")
hparams = load_hparams_from_hf(base_model_id)
elif dir_base_model is None:
if "base_model_name_or_path" in lparams: if "base_model_name_or_path" in lparams:
model_id = lparams["base_model_name_or_path"] model_id = lparams["base_model_name_or_path"]
logger.info(f"Loading base model from Hugging Face: {model_id}") logger.info(f"Loading base model from Hugging Face: {model_id}")
@ -371,11 +382,16 @@ if __name__ == '__main__':
if self.lazy: if self.lazy:
tensor = LazyTorchTensor.from_eager(tensor) tensor = LazyTorchTensor.from_eager(tensor)
base_name = get_base_tensor_name(name) base_name = get_base_tensor_name(name)
is_lora_a = ".lora_A.weight" in name # note: mergekit-extract-lora also adds token embeddings to the adapter
is_lora_b = ".lora_B.weight" in name is_lora_a = ".lora_A.weight" in name or ".lora_embedding_A" in name
is_lora_b = ".lora_B.weight" in name or ".lora_embedding_B" in name
if not is_lora_a and not is_lora_b: if not is_lora_a and not is_lora_b:
if ".base_layer.weight" in name: if ".base_layer.weight" in name:
continue continue
# mergekit-extract-lora add these layernorm to the adapter, we need to keep them
if "_layernorm" in name or ".norm" in name:
yield (base_name, tensor)
continue
logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor") logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
if ".embed_tokens.weight" in name or ".lm_head.weight" in name: if ".embed_tokens.weight" in name or ".lm_head.weight" in name:
logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning") logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning")
@ -407,9 +423,21 @@ if __name__ == '__main__':
if name == "lm_head.weight" and len(dest) == 0: if name == "lm_head.weight" and len(dest) == 0:
raise ValueError("lm_head is present in adapter, but is ignored in base model") raise ValueError("lm_head is present in adapter, but is ignored in base model")
for dest_name, dest_data in dest: for dest_name, dest_data in dest:
# mergekit-extract-lora add these layernorm to the adapter
if "_norm" in dest_name:
assert dest_data.dim() == 1
yield (dest_name, dest_data)
continue
# otherwise, we must get the lora_A and lora_B tensors
assert isinstance(dest_data, LoraTorchTensor) assert isinstance(dest_data, LoraTorchTensor)
lora_a, lora_b = dest_data.get_lora_A_B() lora_a, lora_b = dest_data.get_lora_A_B()
# note: mergekit-extract-lora flip and transpose A and B
# here we only need to transpose token_embd.lora_a, see llm_build_inp_embd()
if "token_embd.weight" in dest_name:
lora_a = lora_a.T
yield (dest_name + ".lora_a", lora_a) yield (dest_name + ".lora_a", lora_a)
yield (dest_name + ".lora_b", lora_b) yield (dest_name + ".lora_b", lora_b)

View File

@ -242,6 +242,10 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char
} else { } else {
ab_map[name].b = cur; ab_map[name].b = cur;
} }
} else if (str_endswith(name, "_norm.weight")) {
// TODO: add support for norm vector
// for now, we don't really care because most adapters still work fine without it
continue;
} else { } else {
throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix"); throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix");
} }
@ -251,6 +255,7 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char
for (auto & it : ab_map) { for (auto & it : ab_map) {
const std::string & name = it.first; const std::string & name = it.first;
llama_lora_weight & w = it.second; llama_lora_weight & w = it.second;
bool is_token_embd = str_endswith(name, "token_embd.weight");
if (!w.a || !w.b) { if (!w.a || !w.b) {
throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component"); throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component");
@ -259,17 +264,24 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char
// device buft and device ctx // device buft and device ctx
auto * model_tensor = llama_model_get_tensor(model, name.c_str()); auto * model_tensor = llama_model_get_tensor(model, name.c_str());
if (!model_tensor) { if (!model_tensor) {
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model"); throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
} }
struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer)); struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
// validate tensor shape // validate tensor shape
if (is_token_embd) {
// expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()
if (model_tensor->ne[0] != w.b->ne[1] || model_tensor->ne[1] != w.a->ne[1]) {
throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)");
}
} else {
if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) { if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) {
throw std::runtime_error("tensor '" + name + "' has incorrect shape"); throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)");
} }
if (w.a->ne[1] != w.b->ne[0]) { if (w.a->ne[1] != w.b->ne[0]) {
throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)"); throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)");
} }
}
// save tensor to adapter // save tensor to adapter
struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a); struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);

View File

@ -45,6 +45,13 @@ struct llama_lora_weight {
struct ggml_tensor * a = nullptr; struct ggml_tensor * a = nullptr;
struct ggml_tensor * b = nullptr; struct ggml_tensor * b = nullptr;
// get actual scale based on rank and alpha
float get_scale(float alpha, float adapter_scale) {
const float rank = (float) b->ne[0];
const float scale = alpha ? adapter_scale * alpha / rank : adapter_scale;
return scale;
}
llama_lora_weight() = default; llama_lora_weight() = default;
llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {} llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
}; };

View File

@ -2545,6 +2545,21 @@ static struct ggml_tensor * llm_build_inp_embd(
ggml_set_input(lctx.inp_tokens); ggml_set_input(lctx.inp_tokens);
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens); inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
// apply lora for embedding tokens if needed
for (auto & it : lctx.lora_adapters) {
struct llama_lora_weight * lora = it.first->get_weight(tok_embd);
if (lora == nullptr) {
continue;
}
const float adapter_scale = it.second;
const float scale = lora->get_scale(it.first->alpha, adapter_scale);
struct ggml_tensor * inpL_delta = ggml_scale(ctx, ggml_mul_mat(
ctx, lora->b, // non-transposed lora_b
ggml_get_rows(ctx, lora->a, lctx.inp_tokens)
), scale);
inpL = ggml_add(ctx, inpL, inpL_delta);
}
} else { } else {
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens); lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
inpL = lctx.inp_embd; inpL = lctx.inp_embd;
@ -2617,9 +2632,8 @@ static struct ggml_tensor * llm_build_lora_mm(
if (lora == nullptr) { if (lora == nullptr) {
continue; continue;
} }
const float alpha = it.first->alpha; const float adapter_scale = it.second;
const float rank = (float) lora->b->ne[0]; const float scale = lora->get_scale(it.first->alpha, adapter_scale);
const float scale = alpha ? it.second * alpha / rank : it.second;
struct ggml_tensor * ab_cur = ggml_mul_mat( struct ggml_tensor * ab_cur = ggml_mul_mat(
ctx0, lora->b, ctx0, lora->b,
ggml_mul_mat(ctx0, lora->a, cur) ggml_mul_mat(ctx0, lora->a, cur)
@ -3967,6 +3981,7 @@ struct llm_build_context {
// feed-forward network // feed-forward network
if (model.layers[il].ffn_gate_inp == nullptr) { if (model.layers[il].ffn_gate_inp == nullptr) {
cur = llm_build_norm(ctx0, ffn_inp, hparams, cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL, model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il); LLM_NORM_RMS, cb, il);