py: a bit cleaner

This commit is contained in:
Xuan Son Nguyen 2025-01-23 23:07:08 +01:00
parent c3a654c0fb
commit b986af80de

View File

@ -1734,17 +1734,18 @@ class LlamaModel(Model):
n_kv_head = self.hparams.get("num_key_value_heads")
is_vision_tensor = "vision_tower" in name or "vision_model" in name
# For vision model
if name.startswith("language_model"):
name = name.replace("language_model.", "")
if name.startswith("model.text_model"):
name = name.replace("text_model.", "") # for SmolVLM
else:
name = name.replace("model.vision_tower.", "")
if "post_layernorm" in name and self.vision_arch != gguf.MODEL_ARCH.VISION_IDEFICS3:
return [] # skip post_layernorm
if is_vision_tensor:
if name.startswith("model.text_model"):
name = name.replace("text_model.", "") # for SmolVLM
else:
name = name.replace("model.vision_tower.", "")
if "post_layernorm" in name and self.vision_arch != gguf.MODEL_ARCH.VISION_IDEFICS3:
return [] # skip post_layernorm
if not is_vision_tensor:
if name.startswith("language_model"):
# language model tensors, remove the prefix
name = name.replace("language_model.", "")
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):