From c3a654c0fbad4c7eeeaf669fc708d40aef6f341c Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 23 Jan 2025 15:51:30 +0100 Subject: [PATCH] add SmolVLM --- convert_hf_to_gguf.py | 37 ++++++++++++++++++++++-------- gguf-py/gguf/constants.py | 19 +++++++++++++++ gguf-py/gguf/gguf_writer.py | 3 +++ gguf-py/gguf/tensor_mapping.py | 15 ++++++++++++ src/llama-arch.cpp | 21 +++++++++++++++++ src/llama-arch.h | 3 +++ src/llama-model.cpp | 38 ++++++++++++++++++++++++++++++ src/llama-vision.cpp | 42 +++++++++++++++++++++++++++++++++- src/llama-vision.h | 3 +++ 9 files changed, 171 insertions(+), 10 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e703cd33d..27bf2c1f2 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -292,7 +292,10 @@ class Model: self.gguf_writer.add_vision_vit_head_count(self.vparams["num_attention_heads"]) self.gguf_writer.add_vision_vit_image_mean(self.preprocessor_config["image_mean"]) self.gguf_writer.add_vision_vit_image_std(self.preprocessor_config["image_std"]) - self.gguf_writer.add_vision_vit_select_layer(self.find_hparam(["vision_feature_layer", "mm_vision_select_layer"])) + try: + self.gguf_writer.add_vision_vit_select_layer(self.find_hparam(["vision_feature_layer", "mm_vision_select_layer"])) + except KeyError: + self.gguf_writer.add_vision_vit_select_layer(0) self.gguf_writer.add_file_type(self.ftype) logger.info(f"gguf: file type = {self.ftype}") @@ -506,8 +509,9 @@ class Model: hparams = json.load(f) if "text_config" in hparams: text_config = hparams["text_config"] + model_id = text_config.get("_name_or_path", None) # for example, llava-1.5-7b-hf misses the language model config, need to retrieve it via model ID - if "_name_or_path" in text_config: + if model_id is not None and model_id != "None" and model_id != "": text_config = AutoConfig.from_pretrained(text_config["_name_or_path"]).to_dict() hparams = {**text_config, **hparams} return hparams @@ -1616,7 +1620,7 @@ class StableLMModel(Model): raise ValueError(f"Unprocessed norms: {norms}") -@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration", "MobileLlamaForCausalLM") +@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration", "MobileLlamaForCausalLM", "Idefics3ForConditionalGeneration") class LlamaModel(Model): model_arch = gguf.MODEL_ARCH.LLAMA @@ -1640,6 +1644,11 @@ class LlamaModel(Model): self.preprocessor_config = AutoImageProcessor.from_pretrained(vision_model_id).to_dict() self.vision_arch = gguf.MODEL_ARCH.VISION_MOBILEVLM + if "vision_config" in self.hparams and model_type == "idefics3": + self.vparams = self.hparams["vision_config"] + self.preprocessor_config = self.load_preprocessor_config(self.dir_model) + self.vision_arch = gguf.MODEL_ARCH.VISION_IDEFICS3 + if self.vparams is not None and self.vision_arch is not None: self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"]) @@ -1694,14 +1703,20 @@ class LlamaModel(Model): # For vision model if self.vparams is not None: + max_pos_embd = -1 self.gguf_writer.add_vision_vit_patch_merge_type(gguf.CLIPPatchMergeType.FLAT) # TODO: should not hardcode these, but they are currently missing from config.json if self.vision_arch == gguf.MODEL_ARCH.VISION_LLAVA: self.gguf_writer.add_vision_vit_projector_type(gguf.constants.CLIPProjectorType.MLP) + max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1 if self.vision_arch == gguf.MODEL_ARCH.VISION_MOBILEVLM: self.gguf_writer.add_vision_vit_projector_type(gguf.constants.CLIPProjectorType.LDPV2) + max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1 + if self.vision_arch == gguf.MODEL_ARCH.VISION_IDEFICS3: + self.gguf_writer.add_vision_vit_projector_type(gguf.constants.CLIPProjectorType.MLP) + self.gguf_writer.add_vision_vit_scale_factor(self.hparams["scale_factor"]) + max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 self.gguf_writer.add_vision_vit_layer_norm_epsilon(1e-05) - max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1 self.gguf_writer.add_vision_vit_max_position_embeddings(max_pos_embd) @staticmethod @@ -1717,19 +1732,23 @@ class LlamaModel(Model): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") + is_vision_tensor = "vision_tower" in name or "vision_model" in name # For vision model if name.startswith("language_model"): name = name.replace("language_model.", "") + if name.startswith("model.text_model"): + name = name.replace("text_model.", "") # for SmolVLM else: name = name.replace("model.vision_tower.", "") - if "post_layernorm" in name: + if "post_layernorm" in name and self.vision_arch != gguf.MODEL_ARCH.VISION_IDEFICS3: return [] # skip post_layernorm - if name.endswith(("q_proj.weight", "q_proj.bias")): - data_torch = LlamaModel.permute(data_torch, n_head, n_head) - if name.endswith(("k_proj.weight", "k_proj.bias")): - data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + if not is_vision_tensor: + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) # process the experts separately if name.find("block_sparse_moe.experts") != -1: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index f4da3e234..cc11aa56d 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -238,6 +238,7 @@ class Keys: PATCH_MERGE_TYPE = "vision.vit.patch_merge_type" HEAD_COUNT = "vision.vit.attention.head_count" LAYERNORM_EPS = "vision.vit.attention.layer_norm_epsilon" + SCALE_FACTOR = "vision.vit.scale_factor" # only used by idefics3 for now # # recommended mapping of model tensor names for storage in gguf @@ -311,6 +312,7 @@ class MODEL_ARCH(IntEnum): VISION_LLAVA = auto() VISION_MOBILEVLM = auto() VISION_MINICPMV = auto() + VISION_IDEFICS3 = auto() class MODEL_TENSOR(IntEnum): @@ -441,6 +443,7 @@ class MODEL_TENSOR(IntEnum): POSNET_ATTN_OUT = auto() # vision V_MMPROJ = auto() + V_MMPROJ_FC = auto() V_MMPROJ_MLP = auto() V_MMPROJ_PEG = auto() V_ENC_EMBD_CLS = auto() @@ -535,6 +538,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.VISION_LLAVA: "llava", MODEL_ARCH.VISION_MOBILEVLM: "mobilevlm", MODEL_ARCH.VISION_MINICPMV: "minicpmv", + MODEL_ARCH.VISION_IDEFICS3: "idefics3", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -664,6 +668,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output", # vision MODEL_TENSOR.V_MMPROJ: "v.mmproj_{bid}", + MODEL_TENSOR.V_MMPROJ_FC: "v.mmproj.fc", MODEL_TENSOR.V_MMPROJ_MLP: "v.mmproj.mlp.{bid}", MODEL_TENSOR.V_MMPROJ_PEG: "v.mmproj.peg.{bid}", MODEL_TENSOR.V_ENC_EMBD_CLS: "v.enc.embd.cls", @@ -1695,6 +1700,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.V_TOK_EMBD_SLICE, MODEL_TENSOR.V_TOK_EMBD_END_SLICE, ], + MODEL_ARCH.VISION_IDEFICS3: [ + MODEL_TENSOR.V_MMPROJ_FC, + MODEL_TENSOR.V_ENC_EMBD_PATCH, + MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_ATTN_Q, + MODEL_TENSOR.V_ENC_ATTN_K, + MODEL_TENSOR.V_ENC_ATTN_V, + MODEL_TENSOR.V_ENC_INPUT_NORM, + MODEL_TENSOR.V_ENC_OUTPUT, + MODEL_TENSOR.V_ENC_OUTPUT_NORM, + MODEL_TENSOR.V_ENC_FFN_UP, + MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_POST_NORM, + ], # TODO } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 65d0e8f30..a31ab736b 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -928,6 +928,9 @@ class GGUFWriter: def add_vision_vit_image_std(self, value: Sequence[float]) -> None: self.add_array(Keys.Vision.IMAGE_STD, value) + def add_vision_vit_scale_factor(self, value: int) -> None: + self.add_int32(Keys.Vision.Vit.SCALE_FACTOR, value) + def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: if not isinstance(value, str): template_default = None diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 0228e8400..3f247d787 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -794,6 +794,10 @@ class TensorNameMap: "multi_modal_projector.linear_{bid}", ), + MODEL_TENSOR.V_MMPROJ_FC: ( + "model.connector.modality_projection.proj", # SmolVLM + ), + MODEL_TENSOR.V_MMPROJ_MLP: ( "model.mm_projector.mlp.mlp.{bid}", ), @@ -809,51 +813,61 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_EMBD_PATCH: ( "vision_tower.vision_model.embeddings.patch_embedding", "vpm.embeddings.patch_embedding", + "model.vision_model.embeddings.patch_embedding", # SmolVLM ), MODEL_TENSOR.V_ENC_EMBD_POS: ( "vision_tower.vision_model.embeddings.position_embedding", "vpm.embeddings.position_embedding", + "model.vision_model.embeddings.position_embedding", # SmolVLM ), MODEL_TENSOR.V_ENC_ATTN_Q: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj", "vpm.encoder.layers.{bid}.self_attn.q_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM ), MODEL_TENSOR.V_ENC_ATTN_K: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj", "vpm.encoder.layers.{bid}.self_attn.k_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM ), MODEL_TENSOR.V_ENC_ATTN_V: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj", "vpm.encoder.layers.{bid}.self_attn.v_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM ), MODEL_TENSOR.V_ENC_INPUT_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.layer_norm1", "vpm.encoder.layers.{bid}.layer_norm1", + "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM ), MODEL_TENSOR.V_ENC_OUTPUT: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj", "vpm.encoder.layers.{bid}.self_attn.out_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM ), MODEL_TENSOR.V_ENC_OUTPUT_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2", "vpm.encoder.layers.{bid}.layer_norm2", + "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM ), MODEL_TENSOR.V_ENC_FFN_UP: ( "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1", "vpm.encoder.layers.{bid}.mlp.fc1", + "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM ), MODEL_TENSOR.V_ENC_FFN_DOWN: ( "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2", "vpm.encoder.layers.{bid}.mlp.fc2", + "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM ), MODEL_TENSOR.V_PRE_NORM: ( @@ -862,6 +876,7 @@ class TensorNameMap: MODEL_TENSOR.V_POST_NORM: ( "vision_tower.vision_model.post_layernorm", + "model.vision_model.post_layernorm", # SmolVLM ), MODEL_TENSOR.V_RESMPL_POS_EMBD_K: ( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 1a6d45331..92e488f57 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -66,6 +66,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_VISION_LLAVA, "llava" }, { LLM_ARCH_VISION_MOBILEVLM, "mobilevlm" }, { LLM_ARCH_VISION_MINICPMV, "minicpmv" }, + { LLM_ARCH_VISION_IDEFICS3, "idefics3" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -214,6 +215,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_VISION_VIT_PATCH_MERGE_TYPE, "vision.vit.patch_merge_type" }, { LLM_KV_VISION_VIT_HEAD_COUNT, "vision.vit.attention.head_count" }, { LLM_KV_VISION_VIT_LAYERNORM_EPS, "vision.vit.attention.layer_norm_epsilon" }, + { LLM_KV_VISION_VIT_SCALE_FACTOR, "vision.vit.scale_factor" }, // deprecated { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, @@ -1388,6 +1390,25 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_V_TOK_EMBD_END_SLICE, "v.tok_embd.end_slice" }, } }, + { + LLM_ARCH_VISION_IDEFICS3, + { + { LLM_TENSOR_V_MMPROJ_FC, "v.mmproj.fc" }, + { LLM_TENSOR_V_ENC_EMBD_CLS, "v.enc.embd.cls" }, + { LLM_TENSOR_V_ENC_EMBD_PATCH, "v.enc.embd.patch" }, + { LLM_TENSOR_V_ENC_EMBD_POS, "v.enc.embd.pos" }, + { LLM_TENSOR_V_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" }, + { LLM_TENSOR_V_ENC_ATTN_K, "v.enc.blk.%d.attn_k" }, + { LLM_TENSOR_V_ENC_ATTN_V, "v.enc.blk.%d.attn_v" }, + { LLM_TENSOR_V_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" }, + { LLM_TENSOR_V_ENC_OUTPUT, "v.enc.blk.%d.output" }, + { LLM_TENSOR_V_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" }, + { LLM_TENSOR_V_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" }, + { LLM_TENSOR_V_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" }, + { LLM_TENSOR_V_PRE_NORM, "v.pre_norm" }, + { LLM_TENSOR_V_POST_NORM, "v.post_norm" }, + } + }, { LLM_ARCH_UNKNOWN, { diff --git a/src/llama-arch.h b/src/llama-arch.h index 3440ded53..c3fc32032 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -70,6 +70,7 @@ enum llm_arch { LLM_ARCH_VISION_LLAVA, LLM_ARCH_VISION_MOBILEVLM, LLM_ARCH_VISION_MINICPMV, + LLM_ARCH_VISION_IDEFICS3, LLM_ARCH_UNKNOWN, }; @@ -218,6 +219,7 @@ enum llm_kv { LLM_KV_VISION_VIT_PATCH_MERGE_TYPE, LLM_KV_VISION_VIT_HEAD_COUNT, LLM_KV_VISION_VIT_LAYERNORM_EPS, + LLM_KV_VISION_VIT_SCALE_FACTOR, // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, @@ -354,6 +356,7 @@ enum llm_tensor { LLM_TENSOR_POS_NET_ATTN_OUT, // vision LLM_TENSOR_V_MMPROJ, + LLM_TENSOR_V_MMPROJ_FC, LLM_TENSOR_V_MMPROJ_MLP, LLM_TENSOR_V_MMPROJ_PEG, LLM_TENSOR_V_ENC_EMBD_CLS, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4aed37d89..6a6b65618 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1265,6 +1265,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_VISION_VIT_LAYERNORM_EPS, vparams.eps, true); ml.get_key(LLM_KV_VISION_VIT_SELECT_LAYER, vparams.select_layer, true); ml.get_key(LLM_KV_VISION_VIT_MAX_POS_EMBD, vparams.max_pos_embd, true); + ml.get_key(LLM_KV_VISION_VIT_SCALE_FACTOR, vparams.scale_factor, false); { std::string name; ml.get_key(LLM_KV_VISION_VIT_PROJECTOR_TYPE, name, true); @@ -3555,6 +3556,42 @@ bool llama_model::load_tensors(llama_model_loader & ml) { vit.mm_tok_embd_slice = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_TOK_EMBD_SLICE, "weight"), {n_embd}); vit.mm_tok_embd_end_slice = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_TOK_EMBD_END_SLICE, "weight"), {n_embd}); + for (int i = 0; i < n_vlayer; ++i) { + auto & layer = vit.layers[i]; + + layer.k_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "weight", i), {n_vembd, n_vembd}, 0); + layer.k_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "bias" , i), {n_vembd}, 0); + layer.v_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "weight", i), {n_vembd, n_vembd}, 0); + layer.v_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "bias" , i), {n_vembd}, 0); + layer.q_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "weight", i), {n_vembd, n_vembd}, 0); + layer.q_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "bias" , i), {n_vembd}, 0); + + layer.ffn_up_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "weight", i), {n_vembd, n_vff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "bias" , i), {n_vff}, 0); + layer.ffn_down_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "weight", i), {n_vff, n_vembd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "bias" , i), {n_vembd}, 0); + + layer.norm_in_w = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "weight", i), {n_vembd}, 0); + layer.norm_in_b = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "bias" , i), {n_vembd}, 0); + layer.norm_out_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "weight", i), {n_vembd}, 0); + layer.norm_out_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "bias" , i), {n_vembd}, 0); + + layer.output_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "weight", i), {n_vembd, n_vembd}, 0); + layer.output_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "bias" , i), {n_vembd}, 0); + } + } break; + case LLM_ARCH_VISION_IDEFICS3: + { + int scale_factor = vit.hparams.scale_factor; + vit.projection = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_FC, "weight"), {n_vembd * scale_factor * scale_factor, n_embd}); + + vit.patch_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd}); + vit.patch_bias = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "bias" ), {n_vembd}); + vit.position_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd}); + + vit.post_norm_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_POST_NORM, "weight"), {n_vembd}); + vit.post_norm_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_POST_NORM, "bias" ), {n_vembd}); + for (int i = 0; i < n_vlayer; ++i) { auto & layer = vit.layers[i]; @@ -4085,6 +4122,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { case LLM_ARCH_VISION_LLAVA: case LLM_ARCH_VISION_MOBILEVLM: case LLM_ARCH_VISION_MINICPMV: + case LLM_ARCH_VISION_IDEFICS3: GGML_ABORT("vision arch does not use RoPE"); // all model arches should be listed explicitly here diff --git a/src/llama-vision.cpp b/src/llama-vision.cpp index 209d0b137..2da15e2eb 100644 --- a/src/llama-vision.cpp +++ b/src/llama-vision.cpp @@ -42,7 +42,9 @@ struct llama_image_u8 { uint32_t llama_vision_n_mmproj_embd(const llama_vision_model & vmodel) { auto & proj_type = vmodel.hparams.proj_type; if (proj_type == VISION_PROJECTOR_TYPE_MLP) { - return vmodel.mm_2_b->ne[0]; + return vmodel.mm_2_b + ? vmodel.mm_2_b->ne[0] + : vmodel.projection->ne[1]; // idefics3 } else if (proj_type == VISION_PROJECTOR_TYPE_LDPV2) { return vmodel.mm_model_peg_0_b->ne[0]; } else if (proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_5) { @@ -903,6 +905,40 @@ struct llama_vision_graph_builder { return gf; } + + struct ggml_cgraph * build_idefics3() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, VISION_GRAPH_MAX_NODE, false); + struct ggml_tensor * cur = build_vit(); + + // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578 + { + const int scale_factor = model.hparams.scale_factor; + const int n_embd = cur->ne[0]; + const int seq = cur->ne[1]; + const int bsz = 1; // batch size, always 1 for now since we don't support batching + const int height = std::sqrt(seq); + const int width = std::sqrt(seq); + cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), + n_embd * scale_factor * scale_factor, + height / scale_factor, + width / scale_factor, + bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur), + n_embd * scale_factor * scale_factor, + seq / (scale_factor * scale_factor), + bsz); + + cur = ggml_mul_mat(ctx0, model.projection, cur); + } + + ggml_set_name(cur, "output"); + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_vision_tokens & inp) { @@ -933,6 +969,9 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_ case LLM_ARCH_VISION_MINICPMV: gf = builder.build_minicpmv(); break; + case LLM_ARCH_VISION_IDEFICS3: + gf = builder.build_idefics3(); + break; default: GGML_ASSERT(false && "unsupported vision arch"); } @@ -1064,6 +1103,7 @@ struct llama_vision_tokens * llama_vision_tokenize( switch (vctx.model->hparams.arch) { case LLM_ARCH_VISION_LLAVA: case LLM_ARCH_VISION_MOBILEVLM: + case LLM_ARCH_VISION_IDEFICS3: return new llama_vision_tokens(llama_vision_processor_llava(vctx).tokenize(*bmp)); case LLM_ARCH_VISION_MINICPMV: //return new llama_vision_tokens(llama_vision_processor_uhd(vctx).tokenize(*bmp)); diff --git a/src/llama-vision.h b/src/llama-vision.h index 948c8d0ed..953ec5795 100644 --- a/src/llama-vision.h +++ b/src/llama-vision.h @@ -48,6 +48,9 @@ struct llama_vision_model { std::array image_grid_pinpoints; // TODO: should this be array of (x, y) pairs? int32_t image_crop_resolution; + + // idefics3 + int scale_factor = 0; }; struct vision_hparams hparams; ggml_backend_buffer_type_t buft;