From d0068ef0eda5f43c65258dd2eefda5bacea412fb Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 19 Jan 2025 16:29:20 +0100 Subject: [PATCH] add mobilevlm --- convert_hf_to_gguf.py | 66 ++++++++++++++++++-------- gguf-py/gguf/constants.py | 34 +++++++++++-- gguf-py/gguf/gguf_writer.py | 2 +- gguf-py/gguf/tensor_mapping.py | 8 ++++ src/llama-arch.cpp | 31 +++++++++++- src/llama-arch.h | 3 ++ src/llama-model.cpp | 87 ++++++++++++++++++++-------------- src/llama-vision.cpp | 29 +++++++++++- src/llama-vision.h | 11 +++++ 9 files changed, 210 insertions(+), 61 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 9e36cad61..89e62f5ce 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -17,7 +17,7 @@ from hashlib import sha256 from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast from itertools import chain -from transformers import AutoConfig +from transformers import AutoConfig, AutoImageProcessor import math import numpy as np import torch @@ -68,9 +68,10 @@ class Model: dir_model_card: Path # for vision model + vision_arch: gguf.MODEL_ARCH | None = None preprocessor_config: dict[str, Any] | None = None vparams: dict[str, Any] | None = None - v_tensor_map: gguf.TensorNameMap + v_tensor_map: gguf.TensorNameMap | None = None v_tensor_names: set[str] | None # subclasses should define this! @@ -102,7 +103,6 @@ class Model: self.metadata_override = metadata_override self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py - self.preprocessor_config = self.load_preprocessor_config(self.dir_model) # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type if self.ftype == gguf.LlamaFileType.GUESSED: @@ -218,7 +218,7 @@ class Model: def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str: new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes) - new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes) + new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes) if self.v_tensor_map is not None else None if new_name is not None: return new_name elif new_name_vision is not None: @@ -488,14 +488,17 @@ class Model: return hparams @staticmethod - def load_preprocessor_config(dir_model: Path): + def load_preprocessor_config(dir_or_model_id: Path | str): # TODO: this varies vastly among models, need to handle more cases in the future - file_path = dir_model / "preprocessor_config.json" - if os.path.exists(file_path): - with open(file_path, "r", encoding="utf-8") as f: - return json.load(f) + if isinstance(dir_or_model_id, Path): + file_path = dir_or_model_id / "preprocessor_config.json" + if os.path.exists(file_path): + with open(file_path, "r", encoding="utf-8") as f: + return json.load(f) + else: + raise Exception(f"Preprocessor config not found at {file_path}") else: - return None + return AutoImageProcessor.from_pretrained(dir_or_model_id).to_dict() @classmethod def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: @@ -1586,16 +1589,31 @@ class StableLMModel(Model): raise ValueError(f"Unprocessed norms: {norms}") -@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration") +@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration", "MobileLlamaForCausalLM") class LlamaModel(Model): model_arch = gguf.MODEL_ARCH.LLAMA def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if "vision_config" in self.hparams: + + model_type = self.hparams.get("model_type", None) + self.vision_arch = None + + # only tested with https://huggingface.co/llava-hf/llava-1.5-7b-hf + if "vision_config" in self.hparams and model_type == "llava": self.vparams = self.hparams["vision_config"] - if self.vparams is not None: - self.v_tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAVA_VISION, self.vparams["num_hidden_layers"]) + self.preprocessor_config = self.load_preprocessor_config(self.dir_model) + self.vision_arch = gguf.MODEL_ARCH.VISION_LLAVA + + # only tested with https://huggingface.co/mtgv/MobileVLM_V2-1.7B + if "mm_vision_tower" in self.hparams and model_type == "mobilevlm": + vision_model_id = self.hparams["mm_vision_tower"] + self.vparams = AutoConfig.from_pretrained(vision_model_id).to_dict()["vision_config"] + self.preprocessor_config = self.load_preprocessor_config(vision_model_id) + self.vision_arch = gguf.MODEL_ARCH.VISION_MOBILEVLM + + if self.vparams is not None and self.vision_arch is not None: + self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"]) def set_vocab(self): try: @@ -1631,23 +1649,31 @@ class LlamaModel(Model): self.gguf_writer.add_add_bos_token(False) # For vision model - if self.vparams is not None and self.preprocessor_config is not None: + if self.vparams is not None and self.preprocessor_config is not None and self.vision_arch is not None: self.gguf_writer.add_vision_type("clip-vit") self.gguf_writer.add_vision_image_size(self.vparams["image_size"]) self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"]) - self.gguf_writer.add_vision_clip_architecture("llava") + self.gguf_writer.add_vision_clip_architecture(gguf.MODEL_ARCH_NAMES[self.vision_arch]) self.gguf_writer.add_vision_clip_block_count(self.vparams["num_hidden_layers"]) self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"]) self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"]) self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"]) self.gguf_writer.add_vision_clip_image_mean(self.preprocessor_config["image_mean"]) self.gguf_writer.add_vision_clip_image_std(self.preprocessor_config["image_std"]) - self.gguf_writer.add_vision_clip_select_layer(self.hparams["vision_feature_layer"]) self.gguf_writer.add_vision_clip_patch_merge_type(gguf.CLIPPatchMergeType.FLAT) max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1 self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd) + if "vision_feature_layer" in self.hparams: + self.gguf_writer.add_vision_clip_select_layer(self.hparams["vision_feature_layer"]) + elif "mm_vision_select_layer" in self.hparams: + self.gguf_writer.add_vision_clip_select_layer(self.hparams["mm_vision_select_layer"]) + else: + raise ValueError("gguf: can not find vision_feature_layer parameter.") # TODO: should not hardcode these, but they are currently missing from config.json - self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP) + if self.vision_arch == gguf.MODEL_ARCH.VISION_LLAVA: + self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP) + if self.vision_arch == gguf.MODEL_ARCH.VISION_MOBILEVLM: + self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.LDPV2) self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05) def set_gguf_parameters(self): @@ -1683,6 +1709,8 @@ class LlamaModel(Model): # For vision model if name.startswith("language_model"): name = name.replace("language_model.", "") + else: + name = name.replace("model.vision_tower.", "") if "post_layernorm" in name: return [] # skip post_layernorm @@ -2101,7 +2129,7 @@ class DbrxModel(Model): return n_dims > 1 -@Model.register("MiniCPMForCausalLM") +@Model.register("MiniCPMForCausalLM", "MiniCPMV") class MiniCPMModel(Model): model_arch = gguf.MODEL_ARCH.MINICPM diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 411c89e7f..7007ecfd8 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -308,7 +308,8 @@ class MODEL_ARCH(IntEnum): CHAMELEON = auto() WAVTOKENIZER_DEC = auto() # vision models - LLAVA_VISION = auto() + VISION_LLAVA = auto() + VISION_MOBILEVLM = auto() class MODEL_TENSOR(IntEnum): @@ -439,6 +440,8 @@ class MODEL_TENSOR(IntEnum): POSNET_ATTN_OUT = auto() # vision V_MMPROJ = auto() + V_MMPROJ_MLP = auto() + V_MMPROJ_PEG = auto() V_ENC_EMBD_CLS = auto() V_ENC_EMBD_PATCH = auto() V_ENC_EMBD_POS = auto() @@ -512,6 +515,9 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.GRANITE_MOE: "granitemoe", MODEL_ARCH.CHAMELEON: "chameleon", MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", + # vision + MODEL_ARCH.VISION_LLAVA: "llava", + MODEL_ARCH.VISION_MOBILEVLM: "mobilevlm", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -641,6 +647,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output", # vision MODEL_TENSOR.V_MMPROJ: "v.mmproj_{bid}", + MODEL_TENSOR.V_MMPROJ_MLP: "v.mmproj.mlp.{bid}", + MODEL_TENSOR.V_MMPROJ_PEG: "v.mmproj.peg.{bid}", MODEL_TENSOR.V_ENC_EMBD_CLS: "v.enc.embd.cls", MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.enc.embd.patch", MODEL_TENSOR.V_ENC_EMBD_POS: "v.enc.embd.pos", @@ -1595,7 +1603,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.POSNET_ATTN_V, MODEL_TENSOR.POSNET_ATTN_OUT, ], - MODEL_ARCH.LLAVA_VISION: [ + MODEL_ARCH.VISION_LLAVA: [ MODEL_TENSOR.V_MMPROJ, MODEL_TENSOR.V_ENC_EMBD_CLS, MODEL_TENSOR.V_ENC_EMBD_PATCH, @@ -1611,6 +1619,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.V_PRE_NORM, MODEL_TENSOR.V_POST_NORM, ], + MODEL_ARCH.VISION_MOBILEVLM: [ + MODEL_TENSOR.V_MMPROJ_MLP, + MODEL_TENSOR.V_MMPROJ_PEG, + MODEL_TENSOR.V_ENC_EMBD_CLS, + MODEL_TENSOR.V_ENC_EMBD_PATCH, + MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_ATTN_Q, + MODEL_TENSOR.V_ENC_ATTN_K, + MODEL_TENSOR.V_ENC_ATTN_V, + MODEL_TENSOR.V_ENC_INPUT_NORM, + MODEL_TENSOR.V_ENC_OUTPUT, + MODEL_TENSOR.V_ENC_OUTPUT_NORM, + MODEL_TENSOR.V_ENC_FFN_UP, + MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_PRE_NORM, + MODEL_TENSOR.V_POST_NORM, + ], # TODO } @@ -1693,11 +1718,12 @@ class PoolingType(IntEnum): class CLIPProjectorType(Enum): - MLP = 'mlp' + MLP = 'mlp' + LDPV2 = 'ldpv2' class CLIPPatchMergeType(Enum): - FLAT = 'flat' + FLAT = 'flat' SPATIAL_UNPAD = 'spatial_unpad' diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 5438acd06..4b9a0c966 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -876,7 +876,7 @@ class GGUFWriter: def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None: self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap) - + def add_vision_type(self, value: str) -> None: self.add_string(Keys.Vision.TYPE, value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 813f8f7e0..f7ff9a032 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -794,6 +794,14 @@ class TensorNameMap: "multi_modal_projector.linear_{bid}", ), + MODEL_TENSOR.V_MMPROJ_MLP: ( + "model.mm_projector.mlp.mlp.{bid}", + ), + + MODEL_TENSOR.V_MMPROJ_PEG: ( + "model.mm_projector.peg.peg.{bid}", + ), + MODEL_TENSOR.V_ENC_EMBD_CLS: ( "vision_tower.vision_model.embeddings.class_embedding", ), diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index dcfbdab3e..b474e0750 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -67,6 +67,7 @@ static const std::map LLM_ARCH_NAMES = { static const std::map VISION_ARCH_NAMES = { { VISION_ARCH_LLAVA, "llava" }, + { VISION_ARCH_MOBILEVLM, "mobilevlm" }, { VISION_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1345,7 +1346,27 @@ static const std::map> VISION { VISION_TENSOR_PRE_NORM, "v.pre_norm" }, { VISION_TENSOR_POST_NORM, "v.post_norm" }, } - } + }, + { + VISION_ARCH_MOBILEVLM, + { + { VISION_TENSOR_MMPROJ_MLP, "v.mmproj.mlp.%d" }, + { VISION_TENSOR_MMPROJ_PEG, "v.mmproj.peg.%d" }, + { VISION_TENSOR_ENC_EMBD_CLS, "v.enc.embd.cls" }, + { VISION_TENSOR_ENC_EMBD_PATCH, "v.enc.embd.patch" }, + { VISION_TENSOR_ENC_EMBD_POS, "v.enc.embd.pos" }, + { VISION_TENSOR_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" }, + { VISION_TENSOR_ENC_ATTN_K, "v.enc.blk.%d.attn_k" }, + { VISION_TENSOR_ENC_ATTN_V, "v.enc.blk.%d.attn_v" }, + { VISION_TENSOR_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" }, + { VISION_TENSOR_ENC_OUTPUT, "v.enc.blk.%d.output" }, + { VISION_TENSOR_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" }, + { VISION_TENSOR_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" }, + { VISION_TENSOR_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" }, + { VISION_TENSOR_PRE_NORM, "v.pre_norm" }, + { VISION_TENSOR_POST_NORM, "v.post_norm" }, + } + }, }; static const std::map LLM_TENSOR_INFOS = { @@ -1499,6 +1520,10 @@ std::string LLM_KV::operator()(llm_kv kv) const { template<> std::string BASE_TN_IMPL::str() const { + if (LLM_TENSOR_NAMES.find(arch) == LLM_TENSOR_NAMES.end()) { + throw std::runtime_error(format("Cannot find tensor name mapping for arch %d", arch)); + } + if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { return "__missing__"; } @@ -1515,6 +1540,10 @@ std::string BASE_TN_IMPL::str() const { template<> std::string BASE_TN_IMPL::str() const { + if (VISION_TENSOR_NAMES.find(arch) == VISION_TENSOR_NAMES.end()) { + throw std::runtime_error(format("Cannot find tensor name mapping for arch %d", arch)); + } + if (VISION_TENSOR_NAMES.at(arch).find(tensor) == VISION_TENSOR_NAMES.at(arch).end()) { return "__missing__"; } diff --git a/src/llama-arch.h b/src/llama-arch.h index ce89b15f5..87966b11f 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -72,6 +72,7 @@ enum llm_arch { enum vision_arch { VISION_ARCH_UNKNOWN, VISION_ARCH_LLAVA, + VISION_ARCH_MOBILEVLM, }; enum llm_kv { @@ -356,6 +357,8 @@ enum llm_tensor { enum vision_tensor { VISION_TENSOR_MMPROJ, + VISION_TENSOR_MMPROJ_MLP, + VISION_TENSOR_MMPROJ_PEG, VISION_TENSOR_ENC_EMBD_CLS, VISION_TENSOR_ENC_EMBD_PATCH, VISION_TENSOR_ENC_EMBD_POS, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 42cc230ce..cd669744f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1280,6 +1280,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { std::string arch; ml.get_key(LLM_KV_VISION_CLIP_ARCHITECTURE, arch, true); vparams.arch = vision_arch_from_string(arch); + if (vparams.arch == VISION_ARCH_UNKNOWN) { + throw std::runtime_error(format("unsupported vision arch: %s", arch.c_str())); + } } } else if (!vision_type.empty()) { throw std::runtime_error(format("unsupported vision type: %s", vision_type.c_str())); @@ -1288,6 +1291,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // arch-specific CLIP hparams switch (vparams.arch) { case VISION_ARCH_LLAVA: + case VISION_ARCH_MOBILEVLM: { ml.get_key(LLM_KV_VISION_CLIP_MAX_POS_EMBD, vparams.max_pos_embd, true); } break; @@ -3410,58 +3414,71 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // load tensors for vision model auto & vparams = clip.hparams; if (has_vision) { - const int64_t n_layer = vparams.n_layer; - const int64_t n_embd = vparams.hidden_size; - const int64_t n_ff = vparams.n_intermediate; - const int64_t max_pos_embd = vparams.max_pos_embd; - const int64_t n_channel = 3; // always RGB - const int64_t patch_size = vparams.patch_size; + // language params + const int64_t n_embd = hparams.n_embd; + // vision params + const int64_t n_vlayer = vparams.n_layer; + const int64_t n_vembd = vparams.hidden_size; + const int64_t n_vff = vparams.n_intermediate; + const int64_t max_pos_embd = vparams.max_pos_embd; + const int64_t n_channel = 3; // always RGB + const int64_t patch_size = vparams.patch_size; const auto tn = VISION_TN(vparams.arch); // clip is CPU-only for now clip.buft = ggml_backend_cpu_buffer_type(); ggml_context * ctx_vision = ctx_map.at(clip.buft); - clip.layers.resize(n_layer); + clip.layers.resize(n_vlayer); switch (vparams.arch) { case VISION_ARCH_LLAVA: + case VISION_ARCH_MOBILEVLM: { - clip.mm_1_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "weight", 1), {n_embd, n_ff}); - clip.mm_1_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "bias" , 1), {n_ff}); - clip.mm_2_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "weight", 2), {n_ff, n_ff}); - clip.mm_2_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "bias" , 2), {n_ff}); + if (vparams.arch == VISION_ARCH_LLAVA) { + clip.mm_1_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "weight", 1), {n_vembd, n_vff}); + clip.mm_1_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "bias" , 1), {n_vff}); + clip.mm_2_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "weight", 2), {n_vff, n_vff}); + clip.mm_2_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ, "bias" , 2), {n_vff}); + } else if (vparams.arch == VISION_ARCH_MOBILEVLM) { + clip.mm_model_mlp_0_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ_MLP, "weight", 0), {n_vembd, n_embd}); + clip.mm_model_mlp_0_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ_MLP, "bias", 0), {n_embd}); + clip.mm_model_mlp_2_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ_MLP, "weight", 2), {n_embd, n_embd}); + clip.mm_model_mlp_2_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ_MLP, "bias", 2), {n_embd}); + clip.mm_model_peg_0_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ_PEG, "weight", 0), {n_channel, n_channel, 1, n_embd}); + clip.mm_model_peg_0_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ_PEG, "bias", 0), {n_embd}); + } - clip.class_embedding = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_CLS ), {n_embd}); - clip.patch_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_embd}); - clip.position_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_POS, "weight"), {n_embd, max_pos_embd}); + clip.class_embedding = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_CLS ), {n_vembd}); + clip.patch_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd}); + clip.position_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd}); - clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "weight"), {n_embd}); - clip.pre_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "bias" ), {n_embd}); - clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "weight"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - clip.post_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "bias" ), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "weight"), {n_vembd}); + clip.pre_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "bias" ), {n_vembd}); + clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "weight"), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED); + clip.post_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "bias" ), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED); - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_vlayer; ++i) { auto & layer = clip.layers[i]; - layer.k_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd}); - layer.k_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_K, "bias" , i), {n_embd}); - layer.v_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd}); - layer.v_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_V, "bias" , i), {n_embd}); - layer.q_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.q_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_Q, "bias" , i), {n_embd}); + layer.k_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_K, "weight", i), {n_vembd, n_vembd}); + layer.k_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_K, "bias" , i), {n_vembd}); + layer.v_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_V, "weight", i), {n_vembd, n_vembd}); + layer.v_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_V, "bias" , i), {n_vembd}); + layer.q_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_Q, "weight", i), {n_vembd, n_vembd}); + layer.q_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_Q, "bias" , i), {n_vembd}); - layer.ffn_up_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_UP, "bias" , i), {n_ff}); - layer.ffn_down_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_DOWN, "bias" , i), {n_embd}); + layer.ffn_up_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_UP, "weight", i), {n_vembd, n_vff}); + layer.ffn_up_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_UP, "bias" , i), {n_vff}); + layer.ffn_down_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_DOWN, "weight", i), {n_vff, n_vembd}); + layer.ffn_down_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_DOWN, "bias" , i), {n_vembd}); - layer.norm_in_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_INPUT_NORM, "weight", i), {n_embd}); - layer.norm_in_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_INPUT_NORM, "bias" , i), {n_embd}); - layer.norm_out_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT_NORM, "weight", i), {n_embd}); - layer.norm_out_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT_NORM, "bias" , i), {n_embd}); + layer.norm_in_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_INPUT_NORM, "weight", i), {n_vembd}); + layer.norm_in_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_INPUT_NORM, "bias" , i), {n_vembd}); + layer.norm_out_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT_NORM, "weight", i), {n_vembd}); + layer.norm_out_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT_NORM, "bias" , i), {n_vembd}); - layer.output_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT, "weight", i), {n_embd, n_embd}); - layer.output_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT, "bias" , i), {n_embd}); + layer.output_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT, "weight", i), {n_vembd, n_vembd}); + layer.output_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT, "bias" , i), {n_vembd}); } } break; default: diff --git a/src/llama-vision.cpp b/src/llama-vision.cpp index b419627e6..9b78ec1a6 100644 --- a/src/llama-vision.cpp +++ b/src/llama-vision.cpp @@ -58,8 +58,11 @@ static int clip_n_patches(const clip_context & ctx) { } uint32_t clip_n_mmproj_embd(const clip_vision_model & clip_model) { - if (clip_model.hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) { + auto & proj_type = clip_model.hparams.proj_type; + if (proj_type == CLIP_PROJECTOR_TYPE_MLP) { return clip_model.mm_2_b->ne[0]; + } else if (proj_type == CLIP_PROJECTOR_TYPE_LDPV2) { + return clip_model.mm_model_peg_0_b->ne[0]; } else { GGML_ASSERT(false && "invalid proj type"); } @@ -559,6 +562,30 @@ static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size, embeddings = ggml_gelu(ctx0, embeddings); embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); + + } else if (hparams.proj_type == CLIP_PROJECTOR_TYPE_LDPV2) { + int n_patch = 24; + struct ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); + mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b); + mlp_0 = ggml_gelu(ctx0, mlp_0); + struct ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0); + mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b); + // mlp_2 ne = [2048, 576, 1, 1] + // // AVG Pool Layer 2*2, strides = 2 + mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 0, 2, 3)); + // mlp_2 ne = [576, 2048, 1, 1] + mlp_2 = ggml_reshape_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]); + // mlp_2 ne [24, 24, 2048, 1] + mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0); + // weight ne = [3, 3, 2048, 1] + struct ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1); + peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3)); + peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b); + mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3)); + peg_0 = ggml_add(ctx0, peg_0, mlp_2); + peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]); + embeddings = peg_0; + } else { GGML_ASSERT(false && "unsupported proj type"); } diff --git a/src/llama-vision.h b/src/llama-vision.h index ced58dd0b..5401cb51a 100644 --- a/src/llama-vision.h +++ b/src/llama-vision.h @@ -10,6 +10,7 @@ enum clip_projector_type { CLIP_PROJECTOR_TYPE_UNKNOWN, CLIP_PROJECTOR_TYPE_MLP, + CLIP_PROJECTOR_TYPE_LDPV2, }; enum mm_patch_merge { @@ -98,6 +99,14 @@ struct clip_vision_model { struct ggml_tensor * mm_2_w = nullptr; struct ggml_tensor * mm_2_b = nullptr; + // MobileVLM_V2 projection + struct ggml_tensor * mm_model_mlp_0_w = nullptr; + struct ggml_tensor * mm_model_mlp_0_b = nullptr; + struct ggml_tensor * mm_model_mlp_2_w = nullptr; + struct ggml_tensor * mm_model_mlp_2_b = nullptr; + struct ggml_tensor * mm_model_peg_0_w = nullptr; + struct ggml_tensor * mm_model_peg_0_b = nullptr; + struct ggml_tensor * image_newline = nullptr; }; @@ -138,6 +147,8 @@ inline mm_patch_merge mm_patch_merge_from_name(std::string & name) { inline clip_projector_type clip_projector_type_from_name(std::string & name) { if (name == "mlp") { return CLIP_PROJECTOR_TYPE_MLP; + } else if (name == "ldpv2") { + return CLIP_PROJECTOR_TYPE_LDPV2; } return CLIP_PROJECTOR_TYPE_UNKNOWN; }