wip minicpmv

This commit is contained in:
Xuan Son Nguyen 2025-01-19 22:33:05 +01:00
parent d0068ef0ed
commit 4a7ab89d75
9 changed files with 491 additions and 77 deletions

View File

@ -17,7 +17,7 @@ from hashlib import sha256
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
from itertools import chain from itertools import chain
from transformers import AutoConfig, AutoImageProcessor from transformers import AutoConfig
import math import math
import numpy as np import numpy as np
import torch import torch
@ -134,6 +134,16 @@ class Model:
return None return None
raise KeyError(f"could not find any of: {keys}") raise KeyError(f"could not find any of: {keys}")
def find_vparams(self, keys: Iterable[str], optional: bool = False) -> Any:
if self.vparams is None:
raise ValueError("vision model parameters not set")
key = next((k for k in keys if k in self.vparams), None)
if key is not None:
return self.vparams[key]
if optional:
return None
raise KeyError(f"(vision) could not find any of: {keys}")
def set_vocab(self): def set_vocab(self):
self._set_vocab_gpt2() self._set_vocab_gpt2()
@ -269,6 +279,20 @@ class Model:
self.gguf_writer.add_key_length(head_dim) self.gguf_writer.add_key_length(head_dim)
self.gguf_writer.add_value_length(head_dim) self.gguf_writer.add_value_length(head_dim)
# Vision model parameters
if self.vparams is not None and self.preprocessor_config is not None and self.vision_arch is not None:
self.gguf_writer.add_vision_type("clip-vit")
self.gguf_writer.add_vision_image_size(self.vparams["image_size"])
self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"])
self.gguf_writer.add_vision_clip_architecture(gguf.MODEL_ARCH_NAMES[self.vision_arch])
self.gguf_writer.add_vision_clip_block_count(self.vparams["num_hidden_layers"])
self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"])
self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"])
self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"])
self.gguf_writer.add_vision_clip_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_clip_image_std(self.preprocessor_config["image_std"])
self.gguf_writer.add_vision_clip_select_layer(self.find_hparam(["vision_feature_layer", "mm_vision_select_layer"]))
self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_file_type(self.ftype)
logger.info(f"gguf: file type = {self.ftype}") logger.info(f"gguf: file type = {self.ftype}")
@ -488,17 +512,14 @@ class Model:
return hparams return hparams
@staticmethod @staticmethod
def load_preprocessor_config(dir_or_model_id: Path | str): def load_preprocessor_config(dir_model: Path):
# TODO: this varies vastly among models, need to handle more cases in the future # TODO: this varies vastly among models, need to handle more cases in the future
if isinstance(dir_or_model_id, Path): file_path = dir_model / "preprocessor_config.json"
file_path = dir_or_model_id / "preprocessor_config.json"
if os.path.exists(file_path): if os.path.exists(file_path):
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
else: else:
raise Exception(f"Preprocessor config not found at {file_path}") raise Exception(f"Preprocessor config not found at {file_path}")
else:
return AutoImageProcessor.from_pretrained(dir_or_model_id).to_dict()
@classmethod @classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
@ -551,7 +572,9 @@ class Model:
toktypes: list[int] = [] toktypes: list[int] = []
from transformers import AutoTokenizer from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model) # DEBIAN_FRONTEND=noninteractive means that the script is running in a non-interactive environment (i.e. CI), so we cannot answer Y/N when it asks for user input
is_cli_non_interactive = os.environ.get("DEBIAN_FRONTEND", "") == "noninteractive"
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=is_cli_non_interactive)
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size assert max(tokenizer.vocab.values()) < vocab_size
@ -1607,9 +1630,10 @@ class LlamaModel(Model):
# only tested with https://huggingface.co/mtgv/MobileVLM_V2-1.7B # only tested with https://huggingface.co/mtgv/MobileVLM_V2-1.7B
if "mm_vision_tower" in self.hparams and model_type == "mobilevlm": if "mm_vision_tower" in self.hparams and model_type == "mobilevlm":
from transformers import AutoImageProcessor
vision_model_id = self.hparams["mm_vision_tower"] vision_model_id = self.hparams["mm_vision_tower"]
self.vparams = AutoConfig.from_pretrained(vision_model_id).to_dict()["vision_config"] self.vparams = AutoConfig.from_pretrained(vision_model_id).to_dict()["vision_config"]
self.preprocessor_config = self.load_preprocessor_config(vision_model_id) self.preprocessor_config = AutoImageProcessor.from_pretrained(vision_model_id).to_dict()
self.vision_arch = gguf.MODEL_ARCH.VISION_MOBILEVLM self.vision_arch = gguf.MODEL_ARCH.VISION_MOBILEVLM
if self.vparams is not None and self.vision_arch is not None: if self.vparams is not None and self.vision_arch is not None:
@ -1648,34 +1672,6 @@ class LlamaModel(Model):
if self.hparams.get("vocab_size", 32000) == 49152: if self.hparams.get("vocab_size", 32000) == 49152:
self.gguf_writer.add_add_bos_token(False) self.gguf_writer.add_add_bos_token(False)
# For vision model
if self.vparams is not None and self.preprocessor_config is not None and self.vision_arch is not None:
self.gguf_writer.add_vision_type("clip-vit")
self.gguf_writer.add_vision_image_size(self.vparams["image_size"])
self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"])
self.gguf_writer.add_vision_clip_architecture(gguf.MODEL_ARCH_NAMES[self.vision_arch])
self.gguf_writer.add_vision_clip_block_count(self.vparams["num_hidden_layers"])
self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"])
self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"])
self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"])
self.gguf_writer.add_vision_clip_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_clip_image_std(self.preprocessor_config["image_std"])
self.gguf_writer.add_vision_clip_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd)
if "vision_feature_layer" in self.hparams:
self.gguf_writer.add_vision_clip_select_layer(self.hparams["vision_feature_layer"])
elif "mm_vision_select_layer" in self.hparams:
self.gguf_writer.add_vision_clip_select_layer(self.hparams["mm_vision_select_layer"])
else:
raise ValueError("gguf: can not find vision_feature_layer parameter.")
# TODO: should not hardcode these, but they are currently missing from config.json
if self.vision_arch == gguf.MODEL_ARCH.VISION_LLAVA:
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP)
if self.vision_arch == gguf.MODEL_ARCH.VISION_MOBILEVLM:
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.LDPV2)
self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05)
def set_gguf_parameters(self): def set_gguf_parameters(self):
super().set_gguf_parameters() super().set_gguf_parameters()
hparams = self.hparams hparams = self.hparams
@ -1692,6 +1688,18 @@ class LlamaModel(Model):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
# For vision model
if self.vparams is not None:
self.gguf_writer.add_vision_clip_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
# TODO: should not hardcode these, but they are currently missing from config.json
if self.vision_arch == gguf.MODEL_ARCH.VISION_LLAVA:
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP)
if self.vision_arch == gguf.MODEL_ARCH.VISION_MOBILEVLM:
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.LDPV2)
self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05)
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd)
@staticmethod @staticmethod
def permute(weights: Tensor, n_head: int, n_head_kv: int | None): def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
if n_head_kv is not None and n_head != n_head_kv: if n_head_kv is not None and n_head != n_head_kv:
@ -2132,16 +2140,50 @@ class DbrxModel(Model):
@Model.register("MiniCPMForCausalLM", "MiniCPMV") @Model.register("MiniCPMForCausalLM", "MiniCPMV")
class MiniCPMModel(Model): class MiniCPMModel(Model):
model_arch = gguf.MODEL_ARCH.MINICPM model_arch = gguf.MODEL_ARCH.MINICPM
proj_type: gguf.constants.CLIPProjectorType | None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
model_type = self.hparams.get("model_type", None)
# only tested with https://huggingface.co/openbmb/MiniCPM-V-2_6
if "vision_config" in self.hparams and model_type == "minicpmv":
self.vparams = self.hparams["vision_config"]
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
self.vision_arch = gguf.MODEL_ARCH.VISION_MINICPMV
version = str(self.hparams.get("version", "unknown"))
if version == "2.5":
self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_5
elif version == "2.6":
self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_6
else:
raise ValueError(f"Unsupported MiniCPM-V version: {version}")
if self.vparams is not None and self.vision_arch is not None and self.preprocessor_config is not None:
self.preprocessor_config["image_mean"] = [0.5, 0.5, 0.5]
self.preprocessor_config["image_std"] = [0.5, 0.5, 0.5]
self.hparams["vision_feature_layer"] = 0
self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"])
def set_gguf_parameters(self): def set_gguf_parameters(self):
super().set_gguf_parameters() super().set_gguf_parameters()
embedding_scale = float(self.hparams["scale_emb"]) # scale_emb
embedding_scale = float(self.hparams.get("scale_emb", 1.0))
self.gguf_writer.add_embedding_scale(embedding_scale) self.gguf_writer.add_embedding_scale(embedding_scale)
logger.info(f"gguf: (minicpm) embedding_scale = {embedding_scale}") logger.info(f"gguf: (minicpm) embedding_scale = {embedding_scale}")
# scale_depth
if "scale_depth" in self.hparams:
residual_scale = self.hparams["scale_depth"] / self.hparams["num_hidden_layers"] ** 0.5 residual_scale = self.hparams["scale_depth"] / self.hparams["num_hidden_layers"] ** 0.5
else:
residual_scale = 1.0
self.gguf_writer.add_residual_scale(residual_scale) self.gguf_writer.add_residual_scale(residual_scale)
logger.info(f"gguf: (minicpm) residual_scale = {residual_scale}") logger.info(f"gguf: (minicpm) residual_scale = {residual_scale}")
# logit_scale
if "dim_model_base" in self.hparams:
logit_scale = self.hparams["hidden_size"] / self.hparams["dim_model_base"] logit_scale = self.hparams["hidden_size"] / self.hparams["dim_model_base"]
else:
logit_scale = 1.0
self.gguf_writer.add_logit_scale(logit_scale) self.gguf_writer.add_logit_scale(logit_scale)
logger.info(f"gguf: (minicpm) logit_scale = {logit_scale}") logger.info(f"gguf: (minicpm) logit_scale = {logit_scale}")
if self.hparams.get("rope_scaling") is not None: if self.hparams.get("rope_scaling") is not None:
@ -2149,6 +2191,15 @@ class MiniCPMModel(Model):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LONGROPE) self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LONGROPE)
logger.info(f"gguf: (minicpm) rope_scaling_type = {gguf.RopeScalingType.LONGROPE}") logger.info(f"gguf: (minicpm) rope_scaling_type = {gguf.RopeScalingType.LONGROPE}")
# For vision model
if self.vparams is not None and self.proj_type is not None:
self.gguf_writer.add_vision_clip_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
self.gguf_writer.add_vision_clip_projector_type(self.proj_type)
self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-06)
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2
self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd)
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
@ -2167,18 +2218,33 @@ class MiniCPMModel(Model):
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
def set_vocab(self): def set_vocab(self):
if self.vision_arch == gguf.MODEL_ARCH.VISION_MINICPMV:
# undocumented anywhere, I only found this thanks to https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf
self._set_vocab_gpt2()
else:
self._set_vocab_sentencepiece() self._set_vocab_sentencepiece()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused del bid # unused
# For vision model
if name.startswith("llm."):
name = name.replace("llm.", "")
# attention, someone mess up and use underscore instead of dot
if name.endswith("in_proj_weight"):
name = name.replace("_weight", ".weight")
if name.endswith("in_proj_bias"):
name = name.replace("_bias", ".bias")
if "post_layernorm" in name:
return [] # skip post_layernorm
n_head = self.hparams["num_attention_heads"] n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads") n_kv_head = self.hparams.get("num_key_value_heads")
# HF models permute some of the tensors, so we need to undo that # HF models permute some of the tensors, so we need to undo that
if name.endswith(("q_proj.weight")): if not name.startswith("vpm") and name.endswith(("q_proj.weight")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head) data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight")): if not name.startswith("vpm") and name.endswith(("k_proj.weight")):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
return [(self.map_tensor_name(name), data_torch)] return [(self.map_tensor_name(name), data_torch)]
@ -5064,7 +5130,7 @@ class LazyTorchTensor(gguf.LazyBase):
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Convert a huggingface model to a GGML compatible file") description="Convert a huggingface model to a GGML compatible file\n\nNote: When converting vision models, this script may use internet connection to download configuration files via Hugging Face.")
parser.add_argument( parser.add_argument(
"--vocab-only", action="store_true", "--vocab-only", action="store_true",
help="extract only the vocab", help="extract only the vocab",

View File

@ -310,6 +310,7 @@ class MODEL_ARCH(IntEnum):
# vision models # vision models
VISION_LLAVA = auto() VISION_LLAVA = auto()
VISION_MOBILEVLM = auto() VISION_MOBILEVLM = auto()
VISION_MINICPMV = auto()
class MODEL_TENSOR(IntEnum): class MODEL_TENSOR(IntEnum):
@ -455,6 +456,15 @@ class MODEL_TENSOR(IntEnum):
V_ENC_FFN_DOWN = auto() V_ENC_FFN_DOWN = auto()
V_PRE_NORM = auto() V_PRE_NORM = auto()
V_POST_NORM = auto() V_POST_NORM = auto()
V_RESMPL_POS_EMBD_K = auto() # minicpmv
V_RESMPL_ATTN_IN = auto() # minicpmv
V_RESMPL_ATTN_OUT = auto() # minicpmv
V_RESMPL_KV_PROJ = auto() # minicpmv
V_RESMPL_NORM_POST = auto() # minicpmv
V_RESMPL_NORM_KV = auto() # minicpmv
V_RESMPL_NORM_Q = auto() # minicpmv
V_RESMPL_PROJ = auto() # minicpmv
V_RESMPL_QUERY = auto() # minicpmv
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@ -518,6 +528,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
# vision # vision
MODEL_ARCH.VISION_LLAVA: "llava", MODEL_ARCH.VISION_LLAVA: "llava",
MODEL_ARCH.VISION_MOBILEVLM: "mobilevlm", MODEL_ARCH.VISION_MOBILEVLM: "mobilevlm",
MODEL_ARCH.VISION_MINICPMV: "minicpmv",
} }
TENSOR_NAMES: dict[MODEL_TENSOR, str] = { TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@ -662,6 +673,15 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_ENC_FFN_DOWN: "v.enc.blk.{bid}.ffn_down", MODEL_TENSOR.V_ENC_FFN_DOWN: "v.enc.blk.{bid}.ffn_down",
MODEL_TENSOR.V_PRE_NORM: "v.pre_norm", MODEL_TENSOR.V_PRE_NORM: "v.pre_norm",
MODEL_TENSOR.V_POST_NORM: "v.post_norm", MODEL_TENSOR.V_POST_NORM: "v.post_norm",
MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "v.resmpl.pos_embd_k",
MODEL_TENSOR.V_RESMPL_ATTN_IN: "v.resmpl.attn_in",
MODEL_TENSOR.V_RESMPL_ATTN_OUT: "v.resmpl.attn_out",
MODEL_TENSOR.V_RESMPL_KV_PROJ: "v.resmpl.kv_proj",
MODEL_TENSOR.V_RESMPL_NORM_POST: "v.resmpl.norm_post",
MODEL_TENSOR.V_RESMPL_NORM_KV: "v.resmpl.norm_kv",
MODEL_TENSOR.V_RESMPL_NORM_Q: "v.resmpl.norm_q",
MODEL_TENSOR.V_RESMPL_PROJ: "v.resmpl.proj",
MODEL_TENSOR.V_RESMPL_QUERY: "v.resmpl.query",
} }
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -1636,6 +1656,26 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_PRE_NORM, MODEL_TENSOR.V_PRE_NORM,
MODEL_TENSOR.V_POST_NORM, MODEL_TENSOR.V_POST_NORM,
], ],
MODEL_ARCH.VISION_MINICPMV: [
MODEL_TENSOR.V_ENC_EMBD_PATCH,
MODEL_TENSOR.V_ENC_EMBD_POS,
MODEL_TENSOR.V_ENC_ATTN_Q,
MODEL_TENSOR.V_ENC_ATTN_K,
MODEL_TENSOR.V_ENC_ATTN_V,
MODEL_TENSOR.V_ENC_INPUT_NORM,
MODEL_TENSOR.V_ENC_OUTPUT,
MODEL_TENSOR.V_ENC_OUTPUT_NORM,
MODEL_TENSOR.V_ENC_FFN_UP,
MODEL_TENSOR.V_ENC_FFN_DOWN,
MODEL_TENSOR.V_RESMPL_ATTN_IN,
MODEL_TENSOR.V_RESMPL_ATTN_OUT,
MODEL_TENSOR.V_RESMPL_KV_PROJ,
MODEL_TENSOR.V_RESMPL_NORM_POST,
MODEL_TENSOR.V_RESMPL_NORM_KV,
MODEL_TENSOR.V_RESMPL_NORM_Q,
MODEL_TENSOR.V_RESMPL_PROJ,
MODEL_TENSOR.V_RESMPL_QUERY,
],
# TODO # TODO
} }
@ -1720,6 +1760,8 @@ class PoolingType(IntEnum):
class CLIPProjectorType(Enum): class CLIPProjectorType(Enum):
MLP = 'mlp' MLP = 'mlp'
LDPV2 = 'ldpv2' LDPV2 = 'ldpv2'
MINICPMV_2_5 = 'minicpmv-2.5' # resampler
MINICPMV_2_6 = 'minicpmv-2.6' # resampler
class CLIPPatchMergeType(Enum): class CLIPPatchMergeType(Enum):

View File

@ -808,42 +808,52 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_EMBD_PATCH: ( MODEL_TENSOR.V_ENC_EMBD_PATCH: (
"vision_tower.vision_model.embeddings.patch_embedding", "vision_tower.vision_model.embeddings.patch_embedding",
"vpm.embeddings.patch_embedding",
), ),
MODEL_TENSOR.V_ENC_EMBD_POS: ( MODEL_TENSOR.V_ENC_EMBD_POS: (
"vision_tower.vision_model.embeddings.position_embedding", "vision_tower.vision_model.embeddings.position_embedding",
"vpm.embeddings.position_embedding",
), ),
MODEL_TENSOR.V_ENC_ATTN_Q: ( MODEL_TENSOR.V_ENC_ATTN_Q: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj", "vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
"vpm.encoder.layers.{bid}.self_attn.q_proj",
), ),
MODEL_TENSOR.V_ENC_ATTN_K: ( MODEL_TENSOR.V_ENC_ATTN_K: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj", "vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
"vpm.encoder.layers.{bid}.self_attn.k_proj",
), ),
MODEL_TENSOR.V_ENC_ATTN_V: ( MODEL_TENSOR.V_ENC_ATTN_V: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj", "vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
"vpm.encoder.layers.{bid}.self_attn.v_proj",
), ),
MODEL_TENSOR.V_ENC_INPUT_NORM: ( MODEL_TENSOR.V_ENC_INPUT_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm1", "vision_tower.vision_model.encoder.layers.{bid}.layer_norm1",
"vpm.encoder.layers.{bid}.layer_norm1",
), ),
MODEL_TENSOR.V_ENC_OUTPUT: ( MODEL_TENSOR.V_ENC_OUTPUT: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj", "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
"vpm.encoder.layers.{bid}.self_attn.out_proj",
), ),
MODEL_TENSOR.V_ENC_OUTPUT_NORM: ( MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm2", "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
"vpm.encoder.layers.{bid}.layer_norm2",
), ),
MODEL_TENSOR.V_ENC_FFN_UP: ( MODEL_TENSOR.V_ENC_FFN_UP: (
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1", "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
"vpm.encoder.layers.{bid}.mlp.fc1",
), ),
MODEL_TENSOR.V_ENC_FFN_DOWN: ( MODEL_TENSOR.V_ENC_FFN_DOWN: (
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2", "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
"vpm.encoder.layers.{bid}.mlp.fc2",
), ),
MODEL_TENSOR.V_PRE_NORM: ( MODEL_TENSOR.V_PRE_NORM: (
@ -853,6 +863,42 @@ class TensorNameMap:
MODEL_TENSOR.V_POST_NORM: ( MODEL_TENSOR.V_POST_NORM: (
"vision_tower.vision_model.post_layernorm", "vision_tower.vision_model.post_layernorm",
), ),
MODEL_TENSOR.V_RESMPL_POS_EMBD_K: (
"resampler.pos_embed_k",
),
MODEL_TENSOR.V_RESMPL_ATTN_IN: (
"resampler.attn.in_proj",
),
MODEL_TENSOR.V_RESMPL_ATTN_OUT: (
"resampler.attn.out_proj",
),
MODEL_TENSOR.V_RESMPL_KV_PROJ: (
"resampler.kv_proj",
),
MODEL_TENSOR.V_RESMPL_NORM_POST: (
"resampler.ln_post",
),
MODEL_TENSOR.V_RESMPL_NORM_KV: (
"resampler.ln_kv",
),
MODEL_TENSOR.V_RESMPL_NORM_Q: (
"resampler.ln_q",
),
MODEL_TENSOR.V_RESMPL_PROJ: (
"resampler.proj",
),
MODEL_TENSOR.V_RESMPL_QUERY: (
"resampler.query",
),
} }
# architecture-specific block mappings # architecture-specific block mappings

View File

@ -3,6 +3,7 @@
#include "llama-impl.h" #include "llama-impl.h"
#include <map> #include <map>
#include <exception>
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLAMA, "llama" }, { LLM_ARCH_LLAMA, "llama" },
@ -65,12 +66,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_UNKNOWN, "(unknown)" }, { LLM_ARCH_UNKNOWN, "(unknown)" },
}; };
static const std::map<vision_arch, const char *> VISION_ARCH_NAMES = {
{ VISION_ARCH_LLAVA, "llava" },
{ VISION_ARCH_MOBILEVLM, "mobilevlm" },
{ VISION_ARCH_UNKNOWN, "(unknown)" },
};
static const std::map<llm_kv, const char *> LLM_KV_NAMES = { static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_GENERAL_TYPE, "general.type" }, { LLM_KV_GENERAL_TYPE, "general.type" },
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
@ -1367,6 +1362,30 @@ static const std::map<vision_arch, std::map<vision_tensor, const char *>> VISION
{ VISION_TENSOR_POST_NORM, "v.post_norm" }, { VISION_TENSOR_POST_NORM, "v.post_norm" },
} }
}, },
{
VISION_ARCH_MINICPMV,
{
{ VISION_TENSOR_ENC_EMBD_PATCH, "v.enc.embd.patch" },
{ VISION_TENSOR_ENC_EMBD_POS, "v.enc.embd.pos" },
{ VISION_TENSOR_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" },
{ VISION_TENSOR_ENC_ATTN_K, "v.enc.blk.%d.attn_k" },
{ VISION_TENSOR_ENC_ATTN_V, "v.enc.blk.%d.attn_v" },
{ VISION_TENSOR_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" },
{ VISION_TENSOR_ENC_OUTPUT, "v.enc.blk.%d.output" },
{ VISION_TENSOR_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" },
{ VISION_TENSOR_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" },
{ VISION_TENSOR_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" },
{ VISION_TENSOR_RESMPL_POS_EMBD_K, "v.resmpl.pos_embd_k" },
{ VISION_TENSOR_RESMPL_ATTN_IN, "v.resmpl.attn_in" },
{ VISION_TENSOR_RESMPL_ATTN_OUT, "v.resmpl.attn_out" },
{ VISION_TENSOR_RESMPL_KV_PROJ, "v.resmpl.kv_proj" },
{ VISION_TENSOR_RESMPL_NORM_POST, "v.resmpl.norm_post" },
{ VISION_TENSOR_RESMPL_NORM_KV, "v.resmpl.norm_kv" },
{ VISION_TENSOR_RESMPL_NORM_Q, "v.resmpl.norm_q" },
{ VISION_TENSOR_RESMPL_PROJ, "v.resmpl.proj" },
{ VISION_TENSOR_RESMPL_QUERY, "v.resmpl.query" },
}
},
}; };
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = { static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
@ -1576,16 +1595,6 @@ llm_arch llm_arch_from_string(const std::string & name) {
return LLM_ARCH_UNKNOWN; return LLM_ARCH_UNKNOWN;
} }
vision_arch vision_arch_from_string(const std::string & name) {
for (const auto & kv : VISION_ARCH_NAMES) { // NOLINT
if (kv.second == name) {
return kv.first;
}
}
return VISION_ARCH_UNKNOWN;
}
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) { const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
return LLM_TENSOR_INFOS.at(tensor); return LLM_TENSOR_INFOS.at(tensor);
} }

View File

@ -73,6 +73,7 @@ enum vision_arch {
VISION_ARCH_UNKNOWN, VISION_ARCH_UNKNOWN,
VISION_ARCH_LLAVA, VISION_ARCH_LLAVA,
VISION_ARCH_MOBILEVLM, VISION_ARCH_MOBILEVLM,
VISION_ARCH_MINICPMV,
}; };
enum llm_kv { enum llm_kv {
@ -372,6 +373,16 @@ enum vision_tensor {
VISION_TENSOR_ENC_FFN_DOWN, VISION_TENSOR_ENC_FFN_DOWN,
VISION_TENSOR_PRE_NORM, VISION_TENSOR_PRE_NORM,
VISION_TENSOR_POST_NORM, VISION_TENSOR_POST_NORM,
// minicpmv
VISION_TENSOR_RESMPL_POS_EMBD_K,
VISION_TENSOR_RESMPL_ATTN_IN,
VISION_TENSOR_RESMPL_ATTN_OUT,
VISION_TENSOR_RESMPL_KV_PROJ,
VISION_TENSOR_RESMPL_NORM_POST,
VISION_TENSOR_RESMPL_NORM_KV,
VISION_TENSOR_RESMPL_NORM_Q,
VISION_TENSOR_RESMPL_PROJ,
VISION_TENSOR_RESMPL_QUERY,
}; };
enum llm_tensor_layer { enum llm_tensor_layer {

View File

@ -96,7 +96,7 @@ struct llama_hparams {
float f_max_alibi_bias = 0.0f; float f_max_alibi_bias = 0.0f;
float f_logit_scale = 0.0f; float f_logit_scale = 0.0f;
// Additional scale factors (Granite/Granite MoE) // Additional scale factors (Granite/Granite MoE/MiniCPM)
float f_residual_scale = 0.0f; float f_residual_scale = 0.0f;
float f_embedding_scale = 0.0f; float f_embedding_scale = 0.0f;
float f_attention_scale = 0.0f; float f_attention_scale = 0.0f;

View File

@ -2,6 +2,7 @@
#include "llama-impl.h" #include "llama-impl.h"
#include "llama-mmap.h" #include "llama-mmap.h"
#include "llama-vision.h"
#include "llama-model-loader.h" #include "llama-model-loader.h"
#include "ggml-cpp.h" #include "ggml-cpp.h"
@ -1263,6 +1264,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, vparams.n_head, true); ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, vparams.n_head, true);
ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, vparams.eps, true); ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, vparams.eps, true);
ml.get_key(LLM_KV_VISION_CLIP_SELECT_LAYER, vparams.select_layer, true); ml.get_key(LLM_KV_VISION_CLIP_SELECT_LAYER, vparams.select_layer, true);
ml.get_key(LLM_KV_VISION_CLIP_MAX_POS_EMBD, vparams.max_pos_embd, true);
{ {
std::string name; std::string name;
ml.get_key(LLM_KV_VISION_CLIP_PROJECTOR_TYPE, name, true); ml.get_key(LLM_KV_VISION_CLIP_PROJECTOR_TYPE, name, true);
@ -1289,14 +1291,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} }
// arch-specific CLIP hparams // arch-specific CLIP hparams
switch (vparams.arch) { // switch (vparams.arch) {
case VISION_ARCH_LLAVA: // case VISION_ARCH_LLAVA:
case VISION_ARCH_MOBILEVLM: // default: (void)0;
{ // }
ml.get_key(LLM_KV_VISION_CLIP_MAX_POS_EMBD, vparams.max_pos_embd, true);
} break;
default: (void)0;
}
} }
void llama_model::load_vocab(llama_model_loader & ml) { void llama_model::load_vocab(llama_model_loader & ml) {
@ -3457,6 +3455,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "weight"), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED); clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "weight"), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED);
clip.post_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "bias" ), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED); clip.post_norm_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "bias" ), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED);
for (int i = 0; i < n_vlayer; ++i) {
auto & layer = clip.layers[i];
layer.k_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_K, "weight", i), {n_vembd, n_vembd});
layer.k_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_K, "bias" , i), {n_vembd});
layer.v_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_V, "weight", i), {n_vembd, n_vembd});
layer.v_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_V, "bias" , i), {n_vembd});
layer.q_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_Q, "weight", i), {n_vembd, n_vembd});
layer.q_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_ATTN_Q, "bias" , i), {n_vembd});
layer.ffn_up_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_UP, "weight", i), {n_vembd, n_vff});
layer.ffn_up_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_UP, "bias" , i), {n_vff});
layer.ffn_down_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_DOWN, "weight", i), {n_vff, n_vembd});
layer.ffn_down_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_FFN_DOWN, "bias" , i), {n_vembd});
layer.norm_in_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_INPUT_NORM, "weight", i), {n_vembd});
layer.norm_in_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_INPUT_NORM, "bias" , i), {n_vembd});
layer.norm_out_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT_NORM, "weight", i), {n_vembd});
layer.norm_out_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT_NORM, "bias" , i), {n_vembd});
layer.output_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT, "weight", i), {n_vembd, n_vembd});
layer.output_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_OUTPUT, "bias" , i), {n_vembd});
}
} break;
case VISION_ARCH_MINICPMV:
{
clip.patch_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd});
clip.position_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd});
// TODO: load all resampler tensors
for (int i = 0; i < n_vlayer; ++i) { for (int i = 0; i < n_vlayer; ++i) {
auto & layer = clip.layers[i]; auto & layer = clip.layers[i];

View File

@ -63,6 +63,10 @@ uint32_t clip_n_mmproj_embd(const clip_vision_model & clip_model) {
return clip_model.mm_2_b->ne[0]; return clip_model.mm_2_b->ne[0];
} else if (proj_type == CLIP_PROJECTOR_TYPE_LDPV2) { } else if (proj_type == CLIP_PROJECTOR_TYPE_LDPV2) {
return clip_model.mm_model_peg_0_b->ne[0]; return clip_model.mm_model_peg_0_b->ne[0];
} else if (proj_type == CLIP_PROJECTOR_TYPE_MINICPMV_2_5) {
return 4096;
} else if (proj_type == CLIP_PROJECTOR_TYPE_MINICPMV_2_6) {
return 3584;
} else { } else {
GGML_ASSERT(false && "invalid proj type"); GGML_ASSERT(false && "invalid proj type");
} }
@ -243,6 +247,173 @@ static void normalize_image_u8_to_f32(const clip_image_u8 & src, std::vector<flo
} }
} }
#define LLAMA_LOG_DEBUG LLAMA_LOG_INFO
// minicpmv preprocessor
struct minicpmv_preprocessor {
int ensure_divide(int length, int patch_size) {
return std::max(static_cast<int>(std::round(static_cast<float>(length) / patch_size) * patch_size), patch_size);
}
std::pair<int, int> uhd_find_best_resize(std::pair<int, int> original_size, int scale_resolution, int patch_size, bool allow_upscale = false) {
int width = original_size.first;
int height = original_size.second;
if ((width * height > scale_resolution * scale_resolution) || allow_upscale) {
float r = static_cast<float>(width) / height;
height = static_cast<int>(scale_resolution / std::sqrt(r));
width = static_cast<int>(height * r);
}
int best_width = ensure_divide(width, patch_size);
int best_height = ensure_divide(height, patch_size);
return std::make_pair(best_width, best_height);
}
std::pair<int, int> uhd_get_refine_size(std::pair<int, int> original_size, std::pair<int, int> grid, int scale_resolution, int patch_size, bool allow_upscale = false) {
int width, height;
std::tie(width, height) = original_size;
int grid_x, grid_y;
std::tie(grid_x, grid_y) = grid;
int refine_width = ensure_divide(width, grid_x);
int refine_height = ensure_divide(height, grid_y);
int grid_width = refine_width / grid_x;
int grid_height = refine_height / grid_y;
// auto best_grid_size = find_best_resize(std::make_tuple(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); (old line)
auto best_grid_size = uhd_find_best_resize(std::make_pair(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair
int best_grid_width, best_grid_height;
std::tie(best_grid_width, best_grid_height) = best_grid_size;
// std::pair<int, int> refine_size = std::make_tuple(best_grid_width * grid_x, best_grid_height * grid_y); (old line)
std::pair<int, int> refine_size = std::make_pair(best_grid_width * grid_x, best_grid_height * grid_y); // (new line)
return refine_size;
}
std::pair<int, int> uhd_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) {
std::vector<int> candidate_split_grids_nums;
for (int i : {multiple - 1, multiple, multiple + 1}) {
if (i == 1 || i > max_slice_nums) {
continue;
}
candidate_split_grids_nums.push_back(i);
}
std::vector<std::pair<int, int>> candidate_grids;
for (int split_grids_nums : candidate_split_grids_nums) {
int m = 1;
while (m <= split_grids_nums) {
if (split_grids_nums % m == 0) {
candidate_grids.emplace_back(m, split_grids_nums / m);
}
++m;
}
}
std::pair<int, int> best_grid{1, 1};
float min_error = std::numeric_limits<float>::infinity();
for (const auto& grid : candidate_grids) {
float error = std::abs(log_ratio - std::log(1.0 * grid.first / grid.second));
if (error < min_error) {
best_grid = grid;
min_error = error;
}
}
return best_grid;
}
std::vector<std::vector<clip_image_u8>> uhd_slice_image(
const clip_image_u8 & img,
const int max_slice_nums = 9,
const int scale_resolution = 448,
const int patch_size = 14) {
const std::pair<int, int> original_size={img.nx,img.ny};
const int original_width = img.nx;
const int original_height = img.ny;
const float log_ratio = log(1.0*original_width/original_height);
const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution);
const int multiple = fmin(ceil(ratio), max_slice_nums);
std::vector<std::vector<clip_image_u8>> images;
LLAMA_LOG_DEBUG("%s: multiple %d\n", __func__, multiple);
images.push_back(std::vector<clip_image_u8>());
if (multiple <= 1) {
auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true);
clip_image_u8 source_image;
bicubic_resize(img, source_image, best_size.first, best_size.second);
// source_image = image.resize(best_size, Image.Resampling.BICUBIC)
images[images.size()-1].push_back(source_image);
}
else if (multiple > 1) {
auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size);
clip_image_u8 source_image;
bicubic_resize(img, source_image, best_size.first, best_size.second);
// source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
LLAMA_LOG_DEBUG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img.nx, img.ny, best_size.first, best_size.second);
images[images.size()-1].push_back(source_image);
std::pair<int, int> best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio);
LLAMA_LOG_DEBUG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img.nx, img.ny, best_grid.first, best_grid.second);
auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true);
clip_image_u8 refine_image;
bicubic_resize(img, refine_image, refine_size.first, refine_size.second);
LLAMA_LOG_DEBUG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image.nx, refine_image.ny, refine_size.first, refine_size.second);
// split_to_patches
int width = refine_image.nx;
int height = refine_image.ny;
int grid_x = int(width / best_grid.first);
int grid_y = int(height / best_grid.second);
for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){
images.push_back(std::vector<clip_image_u8>());
for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){
clip_image_u8 patch;
patch.nx = grid_x;
patch.ny = grid_y;
patch.buf.resize(3 * patch.nx * patch.ny);
for (int y = patches_i; y < patches_i + grid_y; ++y) {
for (int x = patches_j; x < patches_j + grid_x; ++x) {
const int i = 3 * (y * refine_image.nx + x);
const int j = 3 * ((y-patches_i) * patch.nx + (x-patches_j));
patch.buf[j] = refine_image.buf[i];
patch.buf[j+1] = refine_image.buf[i+1];
patch.buf[j+2] = refine_image.buf[i+2];
}
}
images[images.size()-1].push_back(patch);
}
}
}
return images;
}
};
static llama_vision_patches clip_image_preprocess_minicpmv(const clip_context & ctx, const clip_image_u8 & img) {
auto & params = ctx.model->hparams;
GGML_ASSERT(params.arch == VISION_ARCH_MINICPMV);
static const int max_slice_nums = 9;
minicpmv_preprocessor preprocessor;
std::vector<std::vector<clip_image_u8>> imgs = preprocessor.uhd_slice_image(img, max_slice_nums);
llama_vision_patches output_patches;
output_patches.n_px = clip_n_patches_x(ctx);
output_patches.n_py = clip_n_patches_y(ctx);
output_patches.px = params.patch_size;
output_patches.py = params.patch_size;
for (size_t i = 0; i < imgs.size(); ++i) {
for (size_t j = 0; j < imgs[i].size(); ++j) {
std::vector<float> res;
normalize_image_u8_to_f32(imgs[i][j], res, params.image_mean, params.image_std);
output_patches.buf.push_back(res);
}
}
}
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
// res_imgs memory is being allocated here, previous allocations will be freed if found // res_imgs memory is being allocated here, previous allocations will be freed if found
static llama_vision_patches clip_image_preprocess(const clip_context & ctx, const clip_image_u8 & img) { static llama_vision_patches clip_image_preprocess(const clip_context & ctx, const clip_image_u8 & img) {
@ -724,8 +895,10 @@ struct llama_vision_patches * llama_vision_patches_init(
struct llama_context * ctx, struct llama_context * ctx,
llama_vision_bitmap * bmp) { llama_vision_bitmap * bmp) {
clip_context & vctx = ctx->vctx; clip_context & vctx = ctx->vctx;
llama_vision_patches p = clip_image_preprocess(vctx, *bmp); if (vctx.model->hparams.arch == VISION_ARCH_MINICPMV) {
return new llama_vision_patches(p); return new llama_vision_patches(clip_image_preprocess_minicpmv(vctx, *bmp));
}
return new llama_vision_patches(clip_image_preprocess(vctx, *bmp));
} }
void llama_vision_patches_free(llama_vision_patches * p) { void llama_vision_patches_free(llama_vision_patches * p) {

View File

@ -11,6 +11,8 @@ enum clip_projector_type {
CLIP_PROJECTOR_TYPE_UNKNOWN, CLIP_PROJECTOR_TYPE_UNKNOWN,
CLIP_PROJECTOR_TYPE_MLP, CLIP_PROJECTOR_TYPE_MLP,
CLIP_PROJECTOR_TYPE_LDPV2, CLIP_PROJECTOR_TYPE_LDPV2,
CLIP_PROJECTOR_TYPE_MINICPMV_2_5,
CLIP_PROJECTOR_TYPE_MINICPMV_2_6,
}; };
enum mm_patch_merge { enum mm_patch_merge {
@ -36,7 +38,7 @@ struct clip_hparams {
float eps; float eps;
clip_projector_type proj_type = CLIP_PROJECTOR_TYPE_UNKNOWN; clip_projector_type proj_type = CLIP_PROJECTOR_TYPE_UNKNOWN;
mm_patch_merge mm_patch_merge_type = MM_PATCH_MERGE_FLAT; mm_patch_merge mm_patch_merge_type = MM_PATCH_MERGE_UNKNOWN;
std::array<float, 3> image_mean; std::array<float, 3> image_mean;
std::array<float, 3> image_std; std::array<float, 3> image_std;
@ -107,6 +109,26 @@ struct clip_vision_model {
struct ggml_tensor * mm_model_peg_0_w = nullptr; struct ggml_tensor * mm_model_peg_0_w = nullptr;
struct ggml_tensor * mm_model_peg_0_b = nullptr; struct ggml_tensor * mm_model_peg_0_b = nullptr;
// MINICPMV projection
struct ggml_tensor * mm_model_pos_embed_k;
struct ggml_tensor * mm_model_query;
struct ggml_tensor * mm_model_proj;
struct ggml_tensor * mm_model_kv_proj;
struct ggml_tensor * mm_model_attn_q_w;
struct ggml_tensor * mm_model_attn_q_b;
struct ggml_tensor * mm_model_attn_k_w;
struct ggml_tensor * mm_model_attn_k_b;
struct ggml_tensor * mm_model_attn_v_w;
struct ggml_tensor * mm_model_attn_v_b;
struct ggml_tensor * mm_model_attn_o_w;
struct ggml_tensor * mm_model_attn_o_b;
struct ggml_tensor * mm_model_ln_q_w;
struct ggml_tensor * mm_model_ln_q_b;
struct ggml_tensor * mm_model_ln_kv_w;
struct ggml_tensor * mm_model_ln_kv_b;
struct ggml_tensor * mm_model_ln_post_w;
struct ggml_tensor * mm_model_ln_post_b;
struct ggml_tensor * image_newline = nullptr; struct ggml_tensor * image_newline = nullptr;
}; };
@ -135,6 +157,18 @@ struct llama_vision_patches {
std::vector<std::vector<float>> buf; // preprocessed image data std::vector<std::vector<float>> buf; // preprocessed image data
}; };
inline vision_arch vision_arch_from_string(const std::string & name) {
if (name == "llava") {
return VISION_ARCH_LLAVA;
} else if (name == "mobilevlm") {
return VISION_ARCH_MOBILEVLM;
} else if (name == "minicpmv") {
return VISION_ARCH_MINICPMV;
}
return VISION_ARCH_UNKNOWN;
}
inline mm_patch_merge mm_patch_merge_from_name(std::string & name) { inline mm_patch_merge mm_patch_merge_from_name(std::string & name) {
if (name == "flat") { if (name == "flat") {
return MM_PATCH_MERGE_FLAT; return MM_PATCH_MERGE_FLAT;
@ -149,6 +183,10 @@ inline clip_projector_type clip_projector_type_from_name(std::string & name) {
return CLIP_PROJECTOR_TYPE_MLP; return CLIP_PROJECTOR_TYPE_MLP;
} else if (name == "ldpv2") { } else if (name == "ldpv2") {
return CLIP_PROJECTOR_TYPE_LDPV2; return CLIP_PROJECTOR_TYPE_LDPV2;
} else if (name == "minicpmv-2.5") {
return CLIP_PROJECTOR_TYPE_MINICPMV_2_5;
} else if (name == "minicpmv-2.6") {
return CLIP_PROJECTOR_TYPE_MINICPMV_2_6;
} }
return CLIP_PROJECTOR_TYPE_UNKNOWN; return CLIP_PROJECTOR_TYPE_UNKNOWN;
} }