refactor minicpm-v support

This commit is contained in:
Xuan Son Nguyen 2025-01-25 15:52:54 +01:00
parent 0959cc18ee
commit 90eefc2ba4
5 changed files with 186 additions and 136 deletions

View File

@ -1008,6 +1008,29 @@ class Model:
self.gguf_writer.add_add_eos_token(field.parts[-1].tolist()[0])
# TODO: maybe merge this with Model in the future
class VisionModelHelper:
model: Model
tok_embd_tensor: Tensor | None = None
def __init__(self, model: Model):
self.model = model
# TODO: how to do this without reading the whole safetensor file?
for tname, tensor in model.get_tensors():
if tname.endswith("embed_tokens.weight"):
self.tok_embd_tensor = tensor
def get_embd_for_tokens(self, map_token_to_tensor_name: Iterable[tuple[str, gguf.MODEL_TENSOR]], tensor_name_postfix = '.weight') -> Iterable[tuple[str, Tensor]]:
if self.tok_embd_tensor is None:
raise ValueError("Token embedding tensor not found")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.model.dir_model, trust_remote_code=True)
for token, tensor_name in map_token_to_tensor_name:
tok_id = tokenizer.get_vocab()[token]
row = self.tok_embd_tensor[tok_id]
yield gguf.TENSOR_NAMES[tensor_name] + tensor_name_postfix, row
@Model.register("GPTNeoXForCausalLM")
class GPTNeoXModel(Model):
model_arch = gguf.MODEL_ARCH.GPTNEOX
@ -2355,11 +2378,11 @@ class Qwen2VLModel(Model):
@Model.register("MiniCPMV")
class MiniCPMVModel(Qwen2Model):
# based on minicpmv-surgery.py, not sure why it is Qwen2Model instead of MiniCPMModel
# MiniCPM-V 2.5 is Qwen2 and 2.6 is Qwen-2.5
model_arch = gguf.MODEL_ARCH.QWEN2
proj_type: gguf.constants.CLIPProjectorType | None
resampler_n_embd = 0
tok_embd_tensor: Tensor | None = None
vhelper: VisionModelHelper | None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -2378,56 +2401,49 @@ class MiniCPMVModel(Qwen2Model):
self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_6
else:
raise ValueError(f"Unsupported MiniCPM-V version: {version}")
self.vhelper = VisionModelHelper(self)
# TODO: how to do this without reading the whole safetensor file?
for tname, tensor in self.get_tensors():
if tname == "resampler.ln_post.bias":
self.resampler_n_embd = tensor.shape[0]
if tname.endswith("embed_tokens.weight"):
self.tok_embd_tensor = tensor
if self.resampler_n_embd < 2:
raise ValueError("Failed to detect resampler embedding size")
else:
raise ValueError("Expected vision_config, but not found")
if self.vparams is not None and self.vision_arch is not None and self.preprocessor_config is not None:
self.preprocessor_config["image_mean"] = [0.5, 0.5, 0.5]
self.preprocessor_config["image_std"] = [0.5, 0.5, 0.5]
self.hparams["vision_feature_layer"] = 0
self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"])
def get_embd_of_tokens(self, map_token_to_tensor_name: Iterable[tuple[str, str]]) -> Iterable[tuple[str, Tensor]]:
if self.tok_embd_tensor is None:
raise ValueError("Token embedding tensor not found")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
for token, tensor_name in map_token_to_tensor_name:
tok_id = tokenizer.get_vocab()[token]
row = self.tok_embd_tensor[tok_id]
yield tensor_name, row
assert self.vparams is not None
assert self.vision_arch is not None
assert self.preprocessor_config is not None
self.preprocessor_config["image_mean"] = [0.5, 0.5, 0.5]
self.preprocessor_config["image_std"] = [0.5, 0.5, 0.5]
self.hparams["vision_feature_layer"] = 0
self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"])
def set_gguf_parameters(self):
super().set_gguf_parameters()
# For vision model
if self.vparams is not None and self.proj_type is not None:
self.gguf_writer.add_vision_vit_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
self.gguf_writer.add_vision_vit_projector_type(self.proj_type)
self.gguf_writer.add_vision_vit_layer_norm_epsilon(1e-06)
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2
self.gguf_writer.add_vision_vit_max_position_embeddings(max_pos_embd)
assert self.vparams is not None and self.proj_type is not None
self.gguf_writer.add_vision_vit_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
self.gguf_writer.add_vision_vit_projector_type(self.proj_type)
self.gguf_writer.add_vision_vit_layer_norm_epsilon(1e-06)
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2
self.gguf_writer.add_vision_vit_max_position_embeddings(max_pos_embd)
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# because the model operates excusively on 70x70 patches for now, we should precompute the positional embeddings to gain performance
# in the future, we can do it in cpp if we figure out how to do it efficiently
yield (
self.format_tensor_name(gguf.MODEL_TENSOR.V_RESMPL_POS_EMBD_K, is_vision=True),
torch.from_numpy(self._get_2d_sincos_pos_embed(self.resampler_n_embd, (70, 70)))
)
assert self.vhelper is not None
added_tokens = [
("<image>", gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_IMAGE ] + ".weight"),
("</image>", gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_END_IMAGE] + ".weight"),
("<slice>", gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_SLICE ] + ".weight"),
("</slice>", gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_END_SLICE] + ".weight"),
("<image>", gguf.MODEL_TENSOR.V_TOK_EMBD_IMAGE),
("</image>", gguf.MODEL_TENSOR.V_TOK_EMBD_END_IMAGE),
("<slice>", gguf.MODEL_TENSOR.V_TOK_EMBD_SLICE),
("</slice>", gguf.MODEL_TENSOR.V_TOK_EMBD_END_SLICE),
]
for tensor_name, tensor in self.get_embd_of_tokens(added_tokens):
for tensor_name, tensor in self.vhelper.get_embd_for_tokens(added_tokens):
yield tensor_name, tensor
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:

View File

@ -1559,9 +1559,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
// vision
{LLM_TENSOR_V_MMPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_V_MMPROJ_MLP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_V_MMPROJ_PEG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_V_MMPROJ, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}},
{LLM_TENSOR_V_MMPROJ_MLP, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}},
{LLM_TENSOR_V_MMPROJ_PEG, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}},
{LLM_TENSOR_V_ENC_EMBD_CLS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_ADD}},
{LLM_TENSOR_V_ENC_EMBD_PATCH, {LLM_TENSOR_LAYER_INPUT, GGML_OP_ADD}},
{LLM_TENSOR_V_ENC_EMBD_POS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_ADD}},
@ -1575,7 +1575,22 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_V_ENC_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_V_PRE_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_V_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
// TODO: add minicpmv resampler tensors
{LLM_TENSOR_V_RESMPL_POS_EMBD_K, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_ADD}},
{LLM_TENSOR_V_RESMPL_ATTN_Q, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}},
{LLM_TENSOR_V_RESMPL_ATTN_K, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}},
{LLM_TENSOR_V_RESMPL_ATTN_V, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}},
{LLM_TENSOR_V_RESMPL_ATTN_OUT, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}},
{LLM_TENSOR_V_RESMPL_KV, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}},
{LLM_TENSOR_V_RESMPL_KV_NORM, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL}},
{LLM_TENSOR_V_RESMPL_POST_NORM, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL}},
{LLM_TENSOR_V_RESMPL_Q_NORM, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL}},
{LLM_TENSOR_V_RESMPL_PROJ, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}},
{LLM_TENSOR_V_RESMPL_QUERY, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}},
// special token embeddings for image
{LLM_TENSOR_V_TOK_EMBD_IMAGE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_CONCAT}},
{LLM_TENSOR_V_TOK_EMBD_END_IMAGE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_CONCAT}},
{LLM_TENSOR_V_TOK_EMBD_SLICE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_CONCAT}},
{LLM_TENSOR_V_TOK_EMBD_END_SLICE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_CONCAT}},
};
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}

View File

@ -393,6 +393,7 @@ enum llm_tensor {
enum llm_tensor_layer {
LLM_TENSOR_LAYER_INPUT,
LLM_TENSOR_LAYER_REPEATING,
LLM_TENSOR_LAYER_PROJECTION,
LLM_TENSOR_LAYER_OUTPUT,
};

View File

@ -217,6 +217,11 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1);
op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16);
} break;
case GGML_OP_CONCAT:
{
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
op_tensor = ggml_concat(ctx, w, b, 0);
} break;
default:
GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
}
@ -1469,7 +1474,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
// sanity checks
if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) {
if (info.layer == LLM_TENSOR_LAYER_PROJECTION) {
// nothing to check
} else if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) {
if (tn.bid != -1) {
GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str());
}
@ -1491,6 +1498,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
case LLM_TENSOR_LAYER_REPEATING:
buft_list = pimpl->dev_layer.at(tn.bid).buft_list;
break;
case LLM_TENSOR_LAYER_PROJECTION:
buft_list = pimpl->dev_layer.back().buft_list;
break;
default:
GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str());
}
@ -3469,7 +3479,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// TODO: vit is cpu only for now
vit.buft = ggml_backend_cpu_buffer_type();
ggml_context * ctx_vision = ctx_map.at(vit.buft);
vit.layers.resize(n_vlayer);
switch (vparams.arch) {
@ -3477,27 +3486,27 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
case LLM_ARCH_VISION_MOBILEVLM:
{
if (vparams.arch == LLM_ARCH_VISION_LLAVA) {
vit.mm_1_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ, "weight", 1), {n_vembd, n_vff});
vit.mm_1_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ, "bias" , 1), {n_vff});
vit.mm_2_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ, "weight", 2), {n_vff, n_vff});
vit.mm_2_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ, "bias" , 2), {n_vff});
vit.mm_1_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ, "weight", 1), {n_vembd, n_vff}, 0);
vit.mm_1_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ, "bias" , 1), {n_vff}, 0);
vit.mm_2_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ, "weight", 2), {n_vff, n_vff}, 0);
vit.mm_2_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ, "bias" , 2), {n_vff}, 0);
} else if (vparams.arch == LLM_ARCH_VISION_MOBILEVLM) {
vit.mm_model_mlp_0_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_MLP, "weight", 0), {n_vembd, n_embd});
vit.mm_model_mlp_0_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_MLP, "bias", 0), {n_embd});
vit.mm_model_mlp_2_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_MLP, "weight", 2), {n_embd, n_embd});
vit.mm_model_mlp_2_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_MLP, "bias", 2), {n_embd});
vit.mm_model_peg_0_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_PEG, "weight", 0), {n_channel, n_channel, 1, n_embd});
vit.mm_model_peg_0_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_PEG, "bias", 0), {n_embd});
vit.mm_model_mlp_0_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ_MLP, "weight", 0), {n_vembd, n_embd}, 0);
vit.mm_model_mlp_0_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ_MLP, "bias", 0), {n_embd}, 0);
vit.mm_model_mlp_2_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ_MLP, "weight", 2), {n_embd, n_embd}, 0);
vit.mm_model_mlp_2_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ_MLP, "bias", 2), {n_embd}, 0);
vit.mm_model_peg_0_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ_PEG, "weight", 0), {n_channel, n_channel, 1, n_embd}, 0);
vit.mm_model_peg_0_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ_PEG, "bias", 0), {n_embd}, 0);
}
vit.class_embedding = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_CLS ), {n_vembd});
vit.patch_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd});
vit.position_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd});
vit.class_embedding = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_CLS ), {n_vembd}, 0);
vit.patch_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd}, 0);
vit.position_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd}, 0);
vit.pre_norm_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_PRE_NORM, "weight"), {n_vembd});
vit.pre_norm_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_PRE_NORM, "bias" ), {n_vembd});
vit.post_norm_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_POST_NORM, "weight"), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED);
vit.post_norm_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_POST_NORM, "bias" ), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED);
vit.pre_norm_w = create_tensor(tn(LLM_TENSOR_V_PRE_NORM, "weight"), {n_vembd}, 0);
vit.pre_norm_b = create_tensor(tn(LLM_TENSOR_V_PRE_NORM, "bias" ), {n_vembd}, 0);
vit.post_norm_w = create_tensor(tn(LLM_TENSOR_V_POST_NORM, "weight"), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED);
vit.post_norm_b = create_tensor(tn(LLM_TENSOR_V_POST_NORM, "bias" ), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED);
for (int i = 0; i < n_vlayer; ++i) {
auto & layer = vit.layers[i];
@ -3525,36 +3534,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} break;
case LLM_ARCH_VISION_MINICPMV:
{
vit.patch_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd});
vit.patch_bias = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "bias" ), {n_vembd});
vit.position_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd});
// resampler
int rs_n_embd = llama_vision_n_mmproj_embd(vit);
vit.mm_model_pos_embed_k = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_POS_EMBD_K, "weight"), {rs_n_embd, max_pos_embd});
vit.mm_model_query = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_QUERY, "weight"), {rs_n_embd, 64}); // why 64?
vit.mm_model_proj = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_PROJ, "weight"), {rs_n_embd, rs_n_embd});
vit.mm_model_kv_proj = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_KV, "weight"), {n_vembd, rs_n_embd});
vit.mm_model_attn_q_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_Q, "weight"), {rs_n_embd, rs_n_embd});
vit.mm_model_attn_q_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_Q, "bias" ), {rs_n_embd});
vit.mm_model_attn_k_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_K, "weight"), {rs_n_embd, rs_n_embd});
vit.mm_model_attn_k_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_K, "bias" ), {rs_n_embd});
vit.mm_model_attn_v_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_V, "weight"), {rs_n_embd, rs_n_embd});
vit.mm_model_attn_v_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_V, "bias" ), {rs_n_embd});
vit.mm_model_attn_o_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_OUT, "weight"), {rs_n_embd, rs_n_embd});
vit.mm_model_attn_o_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_OUT, "bias" ), {rs_n_embd});
vit.mm_model_ln_q_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_Q_NORM, "weight"), {rs_n_embd});
vit.mm_model_ln_q_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_Q_NORM, "bias" ), {rs_n_embd});
vit.mm_model_ln_kv_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_KV_NORM, "weight"), {rs_n_embd});
vit.mm_model_ln_kv_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_KV_NORM, "bias" ), {rs_n_embd});
vit.mm_model_ln_post_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_POST_NORM, "weight"), {rs_n_embd});
vit.mm_model_ln_post_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_POST_NORM, "bias" ), {rs_n_embd});
vit.patch_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd}, 0);
vit.patch_bias = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "bias" ), {n_vembd}, 0);
vit.position_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd}, 0);
// tok embd
vit.mm_tok_embd_image = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_TOK_EMBD_IMAGE, "weight"), {n_embd});
vit.mm_tok_embd_end_image = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_TOK_EMBD_END_IMAGE, "weight"), {n_embd});
vit.mm_tok_embd_slice = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_TOK_EMBD_SLICE, "weight"), {n_embd});
vit.mm_tok_embd_end_slice = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_TOK_EMBD_END_SLICE, "weight"), {n_embd});
vit.mm_tok_embd_image = create_tensor(tn(LLM_TENSOR_V_TOK_EMBD_IMAGE, "weight"), {n_embd}, 0);
vit.mm_tok_embd_end_image = create_tensor(tn(LLM_TENSOR_V_TOK_EMBD_END_IMAGE, "weight"), {n_embd}, 0);
vit.mm_tok_embd_slice = create_tensor(tn(LLM_TENSOR_V_TOK_EMBD_SLICE, "weight"), {n_embd}, 0);
vit.mm_tok_embd_end_slice = create_tensor(tn(LLM_TENSOR_V_TOK_EMBD_END_SLICE, "weight"), {n_embd}, 0);
for (int i = 0; i < n_vlayer; ++i) {
auto & layer = vit.layers[i];
@ -3579,18 +3567,41 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.output_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "weight", i), {n_vembd, n_vembd}, 0);
layer.output_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "bias" , i), {n_vembd}, 0);
}
// resampler, we consider it as one layer on top of the encoder
int il = n_vlayer - 1;
int rs_n_embd = llama_vision_n_mmproj_embd(vit);
vit.mm_model_pos_embed_k = create_tensor(tn(LLM_TENSOR_V_RESMPL_POS_EMBD_K, "weight", il), {rs_n_embd, max_pos_embd}, 0);
vit.mm_model_query = create_tensor(tn(LLM_TENSOR_V_RESMPL_QUERY, "weight", il), {rs_n_embd, 64}, 0); // why 64?
vit.mm_model_proj = create_tensor(tn(LLM_TENSOR_V_RESMPL_PROJ, "weight", il), {rs_n_embd, rs_n_embd}, 0);
vit.mm_model_kv_proj = create_tensor(tn(LLM_TENSOR_V_RESMPL_KV, "weight", il), {n_vembd, rs_n_embd}, 0);
vit.mm_model_attn_q_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_Q, "weight", il), {rs_n_embd, rs_n_embd}, 0);
vit.mm_model_attn_q_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_Q, "bias" , il), {rs_n_embd}, 0);
vit.mm_model_attn_k_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_K, "weight", il), {rs_n_embd, rs_n_embd}, 0);
vit.mm_model_attn_k_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_K, "bias" , il), {rs_n_embd}, 0);
vit.mm_model_attn_v_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_V, "weight", il), {rs_n_embd, rs_n_embd}, 0);
vit.mm_model_attn_v_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_V, "bias" , il), {rs_n_embd}, 0);
vit.mm_model_attn_o_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_OUT, "weight", il), {rs_n_embd, rs_n_embd}, 0);
vit.mm_model_attn_o_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_OUT, "bias" , il), {rs_n_embd}, 0);
vit.mm_model_ln_q_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_Q_NORM, "weight", il), {rs_n_embd}, 0);
vit.mm_model_ln_q_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_Q_NORM, "bias" , il), {rs_n_embd}, 0);
vit.mm_model_ln_kv_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_KV_NORM, "weight", il), {rs_n_embd}, 0);
vit.mm_model_ln_kv_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_KV_NORM, "bias" , il), {rs_n_embd}, 0);
vit.mm_model_ln_post_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_POST_NORM, "weight", il), {rs_n_embd}, 0);
vit.mm_model_ln_post_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_POST_NORM, "bias" , il), {rs_n_embd}, 0);
} break;
case LLM_ARCH_VISION_IDEFICS3:
{
int scale_factor = vit.hparams.scale_factor;
vit.projection = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_FC, "weight"), {n_vembd * scale_factor * scale_factor, n_embd});
vit.projection = create_tensor(tn(LLM_TENSOR_V_MMPROJ_FC, "weight"), {n_vembd * scale_factor * scale_factor, n_embd}, 0);
vit.patch_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd});
vit.patch_bias = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "bias" ), {n_vembd});
vit.position_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd});
vit.patch_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd}, 0);
vit.patch_bias = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "bias" ), {n_vembd}, 0);
vit.position_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd}, 0);
vit.post_norm_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_POST_NORM, "weight"), {n_vembd});
vit.post_norm_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_POST_NORM, "bias" ), {n_vembd});
vit.post_norm_w = create_tensor(tn(LLM_TENSOR_V_POST_NORM, "weight"), {n_vembd}, 0);
vit.post_norm_b = create_tensor(tn(LLM_TENSOR_V_POST_NORM, "bias" ), {n_vembd}, 0);
for (int i = 0; i < n_vlayer; ++i) {
auto & layer = vit.layers[i];

View File

@ -15,13 +15,13 @@
// export llama_image_u8 to bmp file for debugging
// https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c
struct img_size;
static int bmp_export(const struct llama_image_u8 &img, const std::string &location);
#endif
struct img_size {
int width;
int height;
img_size(int w, int h) : width(w), height(h) {}
};
// RGB uint8 image
@ -89,7 +89,7 @@ static img_size select_best_resolution(const img_size & original_size, const std
int original_width = original_size.width;
int original_height = original_size.height;
img_size best_fit;
img_size best_fit(0, 0);
int max_effective_resolution = 0;
int min_wasted_resolution = std::numeric_limits<int>::max();
@ -314,12 +314,12 @@ struct llama_vision_processor_llava : llama_vision_processor {
// "spatial_unpad" with "anyres" processing for llava-1.6
std::vector<img_size> possible_resolutions;
for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i += 2) {
img_size s;
img_size s(0, 0);
s.width = params.image_grid_pinpoints[i];
s.height = params.image_grid_pinpoints[i+1];
possible_resolutions.push_back(s);
}
img_size best_resolution = select_best_resolution({img.nx, img.ny}, possible_resolutions);
img_size best_resolution = select_best_resolution(img_size(img.nx, img.ny), possible_resolutions);
// debug_image_save_to_bmp(*img, "input.bmp");
temp = resize_and_pad_image(img, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6
// debug_image_save_to_bmp(*temp, "resized.bmp");
@ -415,9 +415,9 @@ struct llama_vision_processor_uhd : llama_vision_processor {
return std::max(static_cast<int>(std::round(static_cast<float>(length) / patch_size) * patch_size), patch_size);
}
std::pair<int, int> find_best_resize(std::pair<int, int> original_size, int scale_resolution, int patch_size, bool allow_upscale = false) {
int width = original_size.first;
int height = original_size.second;
img_size find_best_resize(const img_size & original_size, int scale_resolution, int patch_size, bool allow_upscale = false) {
int width = original_size.width;
int height = original_size.height;
if ((width * height > scale_resolution * scale_resolution) || allow_upscale) {
float r = static_cast<float>(width) / height;
height = static_cast<int>(scale_resolution / std::sqrt(r));
@ -425,14 +425,14 @@ struct llama_vision_processor_uhd : llama_vision_processor {
}
int best_width = ensure_divide(width, patch_size);
int best_height = ensure_divide(height, patch_size);
return std::make_pair(best_width, best_height);
return img_size(best_width, best_height);
}
std::pair<int, int> get_refine_size(std::pair<int, int> original_size, std::pair<int, int> grid, int scale_resolution, int patch_size, bool allow_upscale = false) {
int width, height;
std::tie(width, height) = original_size;
int grid_x, grid_y;
std::tie(grid_x, grid_y) = grid;
img_size get_refine_size(const img_size & original_size, const img_size & grid, int scale_resolution, int patch_size, bool allow_upscale = false) {
int width = original_size.width;
int height = original_size.height;
int grid_x = grid.width;
int grid_y = grid.height;
int refine_width = ensure_divide(width, grid_x);
int refine_height = ensure_divide(height, grid_y);
@ -441,16 +441,14 @@ struct llama_vision_processor_uhd : llama_vision_processor {
int grid_height = refine_height / grid_y;
// auto best_grid_size = find_best_resize(std::make_tuple(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); (old line)
auto best_grid_size = find_best_resize(std::make_pair(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair
int best_grid_width, best_grid_height;
std::tie(best_grid_width, best_grid_height) = best_grid_size;
auto best_grid = find_best_resize({grid_width, grid_height}, scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair
// std::pair<int, int> refine_size = std::make_tuple(best_grid_width * grid_x, best_grid_height * grid_y); (old line)
std::pair<int, int> refine_size = std::make_pair(best_grid_width * grid_x, best_grid_height * grid_y); // (new line)
// img_size refine_size = std::make_tuple(best_grid_width * grid_x, best_grid_height * grid_y); (old line)
img_size refine_size = img_size(best_grid.width * grid_x, best_grid.height * grid_y); // (new line)
return refine_size;
}
std::pair<int, int> find_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) {
img_size find_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) {
std::vector<int> candidate_split_grids_nums;
for (int i : {multiple - 1, multiple, multiple + 1}) {
if (i == 1 || i > max_slice_nums) {
@ -459,7 +457,7 @@ struct llama_vision_processor_uhd : llama_vision_processor {
candidate_split_grids_nums.push_back(i);
}
std::vector<std::pair<int, int>> candidate_grids;
std::vector<img_size> candidate_grids;
for (int split_grids_nums : candidate_split_grids_nums) {
int m = 1;
while (m <= split_grids_nums) {
@ -470,10 +468,10 @@ struct llama_vision_processor_uhd : llama_vision_processor {
}
}
std::pair<int, int> best_grid{1, 1};
img_size best_grid = img_size(1, 1);
float min_error = std::numeric_limits<float>::infinity();
for (const auto& grid : candidate_grids) {
float error = std::abs(log_ratio - std::log(1.0 * grid.first / grid.second));
float error = std::abs(log_ratio - std::log(1.0 * grid.width / grid.height));
if (error < min_error) {
best_grid = grid;
min_error = error;
@ -487,7 +485,7 @@ struct llama_vision_processor_uhd : llama_vision_processor {
const int max_slice_nums = 9,
const int scale_resolution = 448,
const int patch_size = 14) {
const std::pair<int, int> original_size={img.nx,img.ny};
const img_size original_size = img_size(img.nx, img.ny);
const int original_width = img.nx;
const int original_height = img.ny;
const float log_ratio = log(1.0*original_width/original_height);
@ -501,34 +499,36 @@ struct llama_vision_processor_uhd : llama_vision_processor {
if (multiple <= 1) {
auto best_size = find_best_resize(original_size, scale_resolution, patch_size, true);
llama_image_u8 source_image;
bicubic_resize(img, source_image, best_size.first, best_size.second);
bicubic_resize(img, source_image, best_size.width, best_size.height);
// source_image = image.resize(best_size, Image.Resampling.BICUBIC)
images.back().push_back(source_image);
} else if (multiple > 1) {
auto best_size = find_best_resize(original_size, scale_resolution, patch_size);
llama_image_u8 source_image;
bicubic_resize(img, source_image, best_size.first, best_size.second);
bicubic_resize(img, source_image, best_size.width, best_size.height);
// source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
LLAMA_LOG_DEBUG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img.nx, img.ny, best_size.first, best_size.second);
LLAMA_LOG_DEBUG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img.nx, img.ny, best_size.width, best_size.height);
images.back().push_back(source_image);
std::pair<int, int> best_grid = find_best_grid(max_slice_nums, multiple, log_ratio);
LLAMA_LOG_DEBUG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img.nx, img.ny, best_grid.first, best_grid.second);
img_size best_grid = find_best_grid(max_slice_nums, multiple, log_ratio);
LLAMA_LOG_DEBUG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img.nx, img.ny, best_grid.width, best_grid.height);
auto refine_size = get_refine_size(original_size, best_grid, scale_resolution, patch_size, true);
llama_image_u8 refine_image;
bicubic_resize(img, refine_image, refine_size.first, refine_size.second);
// TODO: so far, we spend most of the time in bicubic_resize, we should optimize it
bicubic_resize(img, refine_image, refine_size.width, refine_size.height);
LLAMA_LOG_DEBUG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image.nx, refine_image.ny, refine_size.first, refine_size.second);
LLAMA_LOG_DEBUG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image.nx, refine_image.ny, refine_size.width, refine_size.height);
// split_to_patches
int width = refine_image.nx;
int height = refine_image.ny;
int grid_x = int(width / best_grid.first);
int grid_y = int(height / best_grid.second);
for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){
int grid_x = int(width / best_grid.width);
int grid_y = int(height / best_grid.height);
for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.height; patches_i += grid_y, ic += 1){
std::vector<llama_image_u8> patches_out;
images.push_back(std::vector<llama_image_u8>());
for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){
for (int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.width; patches_j += grid_x, jc += 1) {
llama_image_u8 patch;
patch.nx = grid_x;
patch.ny = grid_y;
@ -542,8 +542,9 @@ struct llama_vision_processor_uhd : llama_vision_processor {
patch.buf[j+2] = refine_image.buf[i+2];
}
}
images.back().push_back(patch);
patches_out.push_back(std::move(patch));
}
images.push_back(std::move(patches_out));
}
}
return images;
@ -551,7 +552,6 @@ struct llama_vision_processor_uhd : llama_vision_processor {
virtual llama_vision_tokens tokenize(const llama_image_u8 & img) override {
auto & params = ctx.model->hparams;
GGML_ASSERT(params.arch == LLM_ARCH_VISION_MINICPMV);
std::vector<std::vector<llama_image_u8>> imgs = slice_image(img);
@ -573,6 +573,10 @@ struct llama_vision_processor_uhd : llama_vision_processor {
}
};
//
// cgraph builder
//
// TODO: move this to llm_build_context in llama.cpp
struct llama_vision_graph_builder {
llama_vision_context & ctx;
@ -590,6 +594,7 @@ struct llama_vision_graph_builder {
int img_h;
bool use_gelu;
int n_layers;
int rs_n_embd;
vision_projector_type proj_type;
llama_vision_graph_builder(llama_vision_context & ctx, const llama_vision_tokens & inp) : ctx(ctx), model(*ctx.model) {
@ -950,7 +955,7 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_
GGML_ASSERT(batch_size == 1); // TODO: support multiple images
}
img_size image_size{(int)hparams.image_size, (int)hparams.image_size};
img_size image_size = img_size((int)hparams.image_size, (int)hparams.image_size);
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size.width / patch_size) * (image_size.height / patch_size));
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
@ -1016,23 +1021,25 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "inp_pos");
std::vector<int> pos_buf(ggml_nelements(positions));
GGML_ASSERT(num_positions == (int)pos_buf.size());
std::vector<int> buf(ggml_nelements(positions));
GGML_ASSERT(num_positions == (int)buf.size());
int bucket_coords_h[70];
int bucket_coords_w[70];
for (size_t i = 0; i < inp.n_py; i++) {
bucket_coords_h[i] = std::floor(70.0*i/inp.n_py);
size_t h = inp.py;
size_t w = inp.py;
for (size_t i = 0; i < h; i++) {
bucket_coords_h[i] = std::floor(70.0*i/h);
}
for (size_t i = 0; i < inp.n_px; i++) {
bucket_coords_w[i] = std::floor(70.0*i/inp.n_px);
for (size_t i = 0; i < w; i++) {
bucket_coords_w[i] = std::floor(70.0*i/w);
}
for (size_t i = 0, id = 0; i < inp.n_py; i++){
for (size_t j = 0; j < inp.n_px; j++){
pos_buf[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
for (size_t i = 0, id = 0; i < h; i++){
for (size_t j = 0; j < w; j++){
buf[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
}
}
ggml_backend_tensor_set(positions, pos_buf.data(), 0, ggml_nbytes(positions));
ggml_backend_tensor_set(positions, buf.data(), 0, ggml_nbytes(positions));
} else {
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "inp_pos");
@ -1055,6 +1062,7 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_
}
// compute
LLAMA_LOG_DEBUG("%s: compute start\n", __func__);
int64_t t_start = ggml_time_ms();
ggml_backend_sched_graph_compute(ctx.sched, gf);
@ -1106,7 +1114,6 @@ struct llama_vision_tokens * llama_vision_tokenize(
case LLM_ARCH_VISION_IDEFICS3:
return new llama_vision_tokens(llama_vision_processor_llava(vctx).tokenize(*bmp));
case LLM_ARCH_VISION_MINICPMV:
//return new llama_vision_tokens(llama_vision_processor_uhd(vctx).tokenize(*bmp));
return new llama_vision_tokens(llama_vision_processor_llava(vctx).tokenize(*bmp));
default:
GGML_ASSERT(false && "unsupported arch");