diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index b3c2ce2b1..d7eab4c46 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1008,6 +1008,29 @@ class Model: self.gguf_writer.add_add_eos_token(field.parts[-1].tolist()[0]) +# TODO: maybe merge this with Model in the future +class VisionModelHelper: + model: Model + tok_embd_tensor: Tensor | None = None + + def __init__(self, model: Model): + self.model = model + # TODO: how to do this without reading the whole safetensor file? + for tname, tensor in model.get_tensors(): + if tname.endswith("embed_tokens.weight"): + self.tok_embd_tensor = tensor + + def get_embd_for_tokens(self, map_token_to_tensor_name: Iterable[tuple[str, gguf.MODEL_TENSOR]], tensor_name_postfix = '.weight') -> Iterable[tuple[str, Tensor]]: + if self.tok_embd_tensor is None: + raise ValueError("Token embedding tensor not found") + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.model.dir_model, trust_remote_code=True) + for token, tensor_name in map_token_to_tensor_name: + tok_id = tokenizer.get_vocab()[token] + row = self.tok_embd_tensor[tok_id] + yield gguf.TENSOR_NAMES[tensor_name] + tensor_name_postfix, row + + @Model.register("GPTNeoXForCausalLM") class GPTNeoXModel(Model): model_arch = gguf.MODEL_ARCH.GPTNEOX @@ -2355,11 +2378,11 @@ class Qwen2VLModel(Model): @Model.register("MiniCPMV") class MiniCPMVModel(Qwen2Model): - # based on minicpmv-surgery.py, not sure why it is Qwen2Model instead of MiniCPMModel + # MiniCPM-V 2.5 is Qwen2 and 2.6 is Qwen-2.5 model_arch = gguf.MODEL_ARCH.QWEN2 proj_type: gguf.constants.CLIPProjectorType | None resampler_n_embd = 0 - tok_embd_tensor: Tensor | None = None + vhelper: VisionModelHelper | None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -2378,56 +2401,49 @@ class MiniCPMVModel(Qwen2Model): self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_6 else: raise ValueError(f"Unsupported MiniCPM-V version: {version}") + self.vhelper = VisionModelHelper(self) # TODO: how to do this without reading the whole safetensor file? for tname, tensor in self.get_tensors(): if tname == "resampler.ln_post.bias": self.resampler_n_embd = tensor.shape[0] - if tname.endswith("embed_tokens.weight"): - self.tok_embd_tensor = tensor if self.resampler_n_embd < 2: raise ValueError("Failed to detect resampler embedding size") else: raise ValueError("Expected vision_config, but not found") - if self.vparams is not None and self.vision_arch is not None and self.preprocessor_config is not None: - self.preprocessor_config["image_mean"] = [0.5, 0.5, 0.5] - self.preprocessor_config["image_std"] = [0.5, 0.5, 0.5] - self.hparams["vision_feature_layer"] = 0 - self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"]) - - def get_embd_of_tokens(self, map_token_to_tensor_name: Iterable[tuple[str, str]]) -> Iterable[tuple[str, Tensor]]: - if self.tok_embd_tensor is None: - raise ValueError("Token embedding tensor not found") - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) - for token, tensor_name in map_token_to_tensor_name: - tok_id = tokenizer.get_vocab()[token] - row = self.tok_embd_tensor[tok_id] - yield tensor_name, row + assert self.vparams is not None + assert self.vision_arch is not None + assert self.preprocessor_config is not None + self.preprocessor_config["image_mean"] = [0.5, 0.5, 0.5] + self.preprocessor_config["image_std"] = [0.5, 0.5, 0.5] + self.hparams["vision_feature_layer"] = 0 + self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"]) def set_gguf_parameters(self): super().set_gguf_parameters() - # For vision model - if self.vparams is not None and self.proj_type is not None: - self.gguf_writer.add_vision_vit_patch_merge_type(gguf.CLIPPatchMergeType.FLAT) - self.gguf_writer.add_vision_vit_projector_type(self.proj_type) - self.gguf_writer.add_vision_vit_layer_norm_epsilon(1e-06) - max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 - self.gguf_writer.add_vision_vit_max_position_embeddings(max_pos_embd) + assert self.vparams is not None and self.proj_type is not None + self.gguf_writer.add_vision_vit_patch_merge_type(gguf.CLIPPatchMergeType.FLAT) + self.gguf_writer.add_vision_vit_projector_type(self.proj_type) + self.gguf_writer.add_vision_vit_layer_norm_epsilon(1e-06) + max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + self.gguf_writer.add_vision_vit_max_position_embeddings(max_pos_embd) def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # because the model operates excusively on 70x70 patches for now, we should precompute the positional embeddings to gain performance + # in the future, we can do it in cpp if we figure out how to do it efficiently yield ( self.format_tensor_name(gguf.MODEL_TENSOR.V_RESMPL_POS_EMBD_K, is_vision=True), torch.from_numpy(self._get_2d_sincos_pos_embed(self.resampler_n_embd, (70, 70))) ) + assert self.vhelper is not None added_tokens = [ - ("", gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_IMAGE ] + ".weight"), - ("", gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_END_IMAGE] + ".weight"), - ("", gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_SLICE ] + ".weight"), - ("", gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_END_SLICE] + ".weight"), + ("", gguf.MODEL_TENSOR.V_TOK_EMBD_IMAGE), + ("", gguf.MODEL_TENSOR.V_TOK_EMBD_END_IMAGE), + ("", gguf.MODEL_TENSOR.V_TOK_EMBD_SLICE), + ("", gguf.MODEL_TENSOR.V_TOK_EMBD_END_SLICE), ] - for tensor_name, tensor in self.get_embd_of_tokens(added_tokens): + for tensor_name, tensor in self.vhelper.get_embd_for_tokens(added_tokens): yield tensor_name, tensor def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 92e488f57..0da19fe67 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1559,9 +1559,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // vision - {LLM_TENSOR_V_MMPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_V_MMPROJ_MLP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_V_MMPROJ_PEG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_MMPROJ, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_MMPROJ_MLP, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_MMPROJ_PEG, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, {LLM_TENSOR_V_ENC_EMBD_CLS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_ADD}}, {LLM_TENSOR_V_ENC_EMBD_PATCH, {LLM_TENSOR_LAYER_INPUT, GGML_OP_ADD}}, {LLM_TENSOR_V_ENC_EMBD_POS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_ADD}}, @@ -1575,7 +1575,22 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_V_ENC_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_V_PRE_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_V_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, - // TODO: add minicpmv resampler tensors + {LLM_TENSOR_V_RESMPL_POS_EMBD_K, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_ADD}}, + {LLM_TENSOR_V_RESMPL_ATTN_Q, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_RESMPL_ATTN_K, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_RESMPL_ATTN_V, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_RESMPL_ATTN_OUT, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_RESMPL_KV, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_RESMPL_KV_NORM, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL}}, + {LLM_TENSOR_V_RESMPL_POST_NORM, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL}}, + {LLM_TENSOR_V_RESMPL_Q_NORM, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL}}, + {LLM_TENSOR_V_RESMPL_PROJ, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_RESMPL_QUERY, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + // special token embeddings for image + {LLM_TENSOR_V_TOK_EMBD_IMAGE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_CONCAT}}, + {LLM_TENSOR_V_TOK_EMBD_END_IMAGE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_CONCAT}}, + {LLM_TENSOR_V_TOK_EMBD_SLICE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_CONCAT}}, + {LLM_TENSOR_V_TOK_EMBD_END_SLICE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_CONCAT}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-arch.h b/src/llama-arch.h index c3fc32032..a84e17b57 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -393,6 +393,7 @@ enum llm_tensor { enum llm_tensor_layer { LLM_TENSOR_LAYER_INPUT, LLM_TENSOR_LAYER_REPEATING, + LLM_TENSOR_LAYER_PROJECTION, LLM_TENSOR_LAYER_OUTPUT, }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6a6b65618..d5cd1eb04 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -217,6 +217,11 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1); op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); } break; + case GGML_OP_CONCAT: + { + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_concat(ctx, w, b, 0); + } break; default: GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); } @@ -1469,7 +1474,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // sanity checks - if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) { + if (info.layer == LLM_TENSOR_LAYER_PROJECTION) { + // nothing to check + } else if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) { if (tn.bid != -1) { GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str()); } @@ -1491,6 +1498,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_TENSOR_LAYER_REPEATING: buft_list = pimpl->dev_layer.at(tn.bid).buft_list; break; + case LLM_TENSOR_LAYER_PROJECTION: + buft_list = pimpl->dev_layer.back().buft_list; + break; default: GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); } @@ -3469,7 +3479,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // TODO: vit is cpu only for now vit.buft = ggml_backend_cpu_buffer_type(); - ggml_context * ctx_vision = ctx_map.at(vit.buft); vit.layers.resize(n_vlayer); switch (vparams.arch) { @@ -3477,27 +3486,27 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_VISION_MOBILEVLM: { if (vparams.arch == LLM_ARCH_VISION_LLAVA) { - vit.mm_1_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ, "weight", 1), {n_vembd, n_vff}); - vit.mm_1_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ, "bias" , 1), {n_vff}); - vit.mm_2_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ, "weight", 2), {n_vff, n_vff}); - vit.mm_2_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ, "bias" , 2), {n_vff}); + vit.mm_1_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ, "weight", 1), {n_vembd, n_vff}, 0); + vit.mm_1_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ, "bias" , 1), {n_vff}, 0); + vit.mm_2_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ, "weight", 2), {n_vff, n_vff}, 0); + vit.mm_2_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ, "bias" , 2), {n_vff}, 0); } else if (vparams.arch == LLM_ARCH_VISION_MOBILEVLM) { - vit.mm_model_mlp_0_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_MLP, "weight", 0), {n_vembd, n_embd}); - vit.mm_model_mlp_0_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_MLP, "bias", 0), {n_embd}); - vit.mm_model_mlp_2_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_MLP, "weight", 2), {n_embd, n_embd}); - vit.mm_model_mlp_2_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_MLP, "bias", 2), {n_embd}); - vit.mm_model_peg_0_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_PEG, "weight", 0), {n_channel, n_channel, 1, n_embd}); - vit.mm_model_peg_0_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_PEG, "bias", 0), {n_embd}); + vit.mm_model_mlp_0_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ_MLP, "weight", 0), {n_vembd, n_embd}, 0); + vit.mm_model_mlp_0_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ_MLP, "bias", 0), {n_embd}, 0); + vit.mm_model_mlp_2_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ_MLP, "weight", 2), {n_embd, n_embd}, 0); + vit.mm_model_mlp_2_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ_MLP, "bias", 2), {n_embd}, 0); + vit.mm_model_peg_0_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ_PEG, "weight", 0), {n_channel, n_channel, 1, n_embd}, 0); + vit.mm_model_peg_0_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ_PEG, "bias", 0), {n_embd}, 0); } - vit.class_embedding = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_CLS ), {n_vembd}); - vit.patch_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd}); - vit.position_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd}); + vit.class_embedding = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_CLS ), {n_vembd}, 0); + vit.patch_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd}, 0); + vit.position_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd}, 0); - vit.pre_norm_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_PRE_NORM, "weight"), {n_vembd}); - vit.pre_norm_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_PRE_NORM, "bias" ), {n_vembd}); - vit.post_norm_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_POST_NORM, "weight"), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED); - vit.post_norm_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_POST_NORM, "bias" ), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED); + vit.pre_norm_w = create_tensor(tn(LLM_TENSOR_V_PRE_NORM, "weight"), {n_vembd}, 0); + vit.pre_norm_b = create_tensor(tn(LLM_TENSOR_V_PRE_NORM, "bias" ), {n_vembd}, 0); + vit.post_norm_w = create_tensor(tn(LLM_TENSOR_V_POST_NORM, "weight"), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED); + vit.post_norm_b = create_tensor(tn(LLM_TENSOR_V_POST_NORM, "bias" ), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED); for (int i = 0; i < n_vlayer; ++i) { auto & layer = vit.layers[i]; @@ -3525,36 +3534,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_VISION_MINICPMV: { - vit.patch_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd}); - vit.patch_bias = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "bias" ), {n_vembd}); - vit.position_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd}); - - // resampler - int rs_n_embd = llama_vision_n_mmproj_embd(vit); - vit.mm_model_pos_embed_k = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_POS_EMBD_K, "weight"), {rs_n_embd, max_pos_embd}); - vit.mm_model_query = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_QUERY, "weight"), {rs_n_embd, 64}); // why 64? - vit.mm_model_proj = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_PROJ, "weight"), {rs_n_embd, rs_n_embd}); - vit.mm_model_kv_proj = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_KV, "weight"), {n_vembd, rs_n_embd}); - vit.mm_model_attn_q_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_Q, "weight"), {rs_n_embd, rs_n_embd}); - vit.mm_model_attn_q_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_Q, "bias" ), {rs_n_embd}); - vit.mm_model_attn_k_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_K, "weight"), {rs_n_embd, rs_n_embd}); - vit.mm_model_attn_k_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_K, "bias" ), {rs_n_embd}); - vit.mm_model_attn_v_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_V, "weight"), {rs_n_embd, rs_n_embd}); - vit.mm_model_attn_v_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_V, "bias" ), {rs_n_embd}); - vit.mm_model_attn_o_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_OUT, "weight"), {rs_n_embd, rs_n_embd}); - vit.mm_model_attn_o_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_OUT, "bias" ), {rs_n_embd}); - vit.mm_model_ln_q_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_Q_NORM, "weight"), {rs_n_embd}); - vit.mm_model_ln_q_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_Q_NORM, "bias" ), {rs_n_embd}); - vit.mm_model_ln_kv_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_KV_NORM, "weight"), {rs_n_embd}); - vit.mm_model_ln_kv_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_KV_NORM, "bias" ), {rs_n_embd}); - vit.mm_model_ln_post_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_POST_NORM, "weight"), {rs_n_embd}); - vit.mm_model_ln_post_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_POST_NORM, "bias" ), {rs_n_embd}); + vit.patch_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd}, 0); + vit.patch_bias = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "bias" ), {n_vembd}, 0); + vit.position_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd}, 0); // tok embd - vit.mm_tok_embd_image = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_TOK_EMBD_IMAGE, "weight"), {n_embd}); - vit.mm_tok_embd_end_image = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_TOK_EMBD_END_IMAGE, "weight"), {n_embd}); - vit.mm_tok_embd_slice = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_TOK_EMBD_SLICE, "weight"), {n_embd}); - vit.mm_tok_embd_end_slice = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_TOK_EMBD_END_SLICE, "weight"), {n_embd}); + vit.mm_tok_embd_image = create_tensor(tn(LLM_TENSOR_V_TOK_EMBD_IMAGE, "weight"), {n_embd}, 0); + vit.mm_tok_embd_end_image = create_tensor(tn(LLM_TENSOR_V_TOK_EMBD_END_IMAGE, "weight"), {n_embd}, 0); + vit.mm_tok_embd_slice = create_tensor(tn(LLM_TENSOR_V_TOK_EMBD_SLICE, "weight"), {n_embd}, 0); + vit.mm_tok_embd_end_slice = create_tensor(tn(LLM_TENSOR_V_TOK_EMBD_END_SLICE, "weight"), {n_embd}, 0); for (int i = 0; i < n_vlayer; ++i) { auto & layer = vit.layers[i]; @@ -3579,18 +3567,41 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.output_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "weight", i), {n_vembd, n_vembd}, 0); layer.output_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "bias" , i), {n_vembd}, 0); } + + // resampler, we consider it as one layer on top of the encoder + int il = n_vlayer - 1; + int rs_n_embd = llama_vision_n_mmproj_embd(vit); + vit.mm_model_pos_embed_k = create_tensor(tn(LLM_TENSOR_V_RESMPL_POS_EMBD_K, "weight", il), {rs_n_embd, max_pos_embd}, 0); + vit.mm_model_query = create_tensor(tn(LLM_TENSOR_V_RESMPL_QUERY, "weight", il), {rs_n_embd, 64}, 0); // why 64? + vit.mm_model_proj = create_tensor(tn(LLM_TENSOR_V_RESMPL_PROJ, "weight", il), {rs_n_embd, rs_n_embd}, 0); + vit.mm_model_kv_proj = create_tensor(tn(LLM_TENSOR_V_RESMPL_KV, "weight", il), {n_vembd, rs_n_embd}, 0); + vit.mm_model_attn_q_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_Q, "weight", il), {rs_n_embd, rs_n_embd}, 0); + vit.mm_model_attn_q_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_Q, "bias" , il), {rs_n_embd}, 0); + vit.mm_model_attn_k_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_K, "weight", il), {rs_n_embd, rs_n_embd}, 0); + vit.mm_model_attn_k_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_K, "bias" , il), {rs_n_embd}, 0); + vit.mm_model_attn_v_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_V, "weight", il), {rs_n_embd, rs_n_embd}, 0); + vit.mm_model_attn_v_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_V, "bias" , il), {rs_n_embd}, 0); + vit.mm_model_attn_o_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_OUT, "weight", il), {rs_n_embd, rs_n_embd}, 0); + vit.mm_model_attn_o_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_OUT, "bias" , il), {rs_n_embd}, 0); + vit.mm_model_ln_q_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_Q_NORM, "weight", il), {rs_n_embd}, 0); + vit.mm_model_ln_q_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_Q_NORM, "bias" , il), {rs_n_embd}, 0); + vit.mm_model_ln_kv_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_KV_NORM, "weight", il), {rs_n_embd}, 0); + vit.mm_model_ln_kv_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_KV_NORM, "bias" , il), {rs_n_embd}, 0); + vit.mm_model_ln_post_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_POST_NORM, "weight", il), {rs_n_embd}, 0); + vit.mm_model_ln_post_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_POST_NORM, "bias" , il), {rs_n_embd}, 0); + } break; case LLM_ARCH_VISION_IDEFICS3: { int scale_factor = vit.hparams.scale_factor; - vit.projection = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_FC, "weight"), {n_vembd * scale_factor * scale_factor, n_embd}); + vit.projection = create_tensor(tn(LLM_TENSOR_V_MMPROJ_FC, "weight"), {n_vembd * scale_factor * scale_factor, n_embd}, 0); - vit.patch_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd}); - vit.patch_bias = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "bias" ), {n_vembd}); - vit.position_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd}); + vit.patch_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd}, 0); + vit.patch_bias = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "bias" ), {n_vembd}, 0); + vit.position_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd}, 0); - vit.post_norm_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_POST_NORM, "weight"), {n_vembd}); - vit.post_norm_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_POST_NORM, "bias" ), {n_vembd}); + vit.post_norm_w = create_tensor(tn(LLM_TENSOR_V_POST_NORM, "weight"), {n_vembd}, 0); + vit.post_norm_b = create_tensor(tn(LLM_TENSOR_V_POST_NORM, "bias" ), {n_vembd}, 0); for (int i = 0; i < n_vlayer; ++i) { auto & layer = vit.layers[i]; diff --git a/src/llama-vision.cpp b/src/llama-vision.cpp index 2da15e2eb..bb6ffcf32 100644 --- a/src/llama-vision.cpp +++ b/src/llama-vision.cpp @@ -15,13 +15,13 @@ // export llama_image_u8 to bmp file for debugging // https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c -struct img_size; static int bmp_export(const struct llama_image_u8 &img, const std::string &location); #endif struct img_size { int width; int height; + img_size(int w, int h) : width(w), height(h) {} }; // RGB uint8 image @@ -89,7 +89,7 @@ static img_size select_best_resolution(const img_size & original_size, const std int original_width = original_size.width; int original_height = original_size.height; - img_size best_fit; + img_size best_fit(0, 0); int max_effective_resolution = 0; int min_wasted_resolution = std::numeric_limits::max(); @@ -314,12 +314,12 @@ struct llama_vision_processor_llava : llama_vision_processor { // "spatial_unpad" with "anyres" processing for llava-1.6 std::vector possible_resolutions; for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i += 2) { - img_size s; + img_size s(0, 0); s.width = params.image_grid_pinpoints[i]; s.height = params.image_grid_pinpoints[i+1]; possible_resolutions.push_back(s); } - img_size best_resolution = select_best_resolution({img.nx, img.ny}, possible_resolutions); + img_size best_resolution = select_best_resolution(img_size(img.nx, img.ny), possible_resolutions); // debug_image_save_to_bmp(*img, "input.bmp"); temp = resize_and_pad_image(img, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6 // debug_image_save_to_bmp(*temp, "resized.bmp"); @@ -415,9 +415,9 @@ struct llama_vision_processor_uhd : llama_vision_processor { return std::max(static_cast(std::round(static_cast(length) / patch_size) * patch_size), patch_size); } - std::pair find_best_resize(std::pair original_size, int scale_resolution, int patch_size, bool allow_upscale = false) { - int width = original_size.first; - int height = original_size.second; + img_size find_best_resize(const img_size & original_size, int scale_resolution, int patch_size, bool allow_upscale = false) { + int width = original_size.width; + int height = original_size.height; if ((width * height > scale_resolution * scale_resolution) || allow_upscale) { float r = static_cast(width) / height; height = static_cast(scale_resolution / std::sqrt(r)); @@ -425,14 +425,14 @@ struct llama_vision_processor_uhd : llama_vision_processor { } int best_width = ensure_divide(width, patch_size); int best_height = ensure_divide(height, patch_size); - return std::make_pair(best_width, best_height); + return img_size(best_width, best_height); } - std::pair get_refine_size(std::pair original_size, std::pair grid, int scale_resolution, int patch_size, bool allow_upscale = false) { - int width, height; - std::tie(width, height) = original_size; - int grid_x, grid_y; - std::tie(grid_x, grid_y) = grid; + img_size get_refine_size(const img_size & original_size, const img_size & grid, int scale_resolution, int patch_size, bool allow_upscale = false) { + int width = original_size.width; + int height = original_size.height; + int grid_x = grid.width; + int grid_y = grid.height; int refine_width = ensure_divide(width, grid_x); int refine_height = ensure_divide(height, grid_y); @@ -441,16 +441,14 @@ struct llama_vision_processor_uhd : llama_vision_processor { int grid_height = refine_height / grid_y; // auto best_grid_size = find_best_resize(std::make_tuple(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); (old line) - auto best_grid_size = find_best_resize(std::make_pair(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair - int best_grid_width, best_grid_height; - std::tie(best_grid_width, best_grid_height) = best_grid_size; + auto best_grid = find_best_resize({grid_width, grid_height}, scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair - // std::pair refine_size = std::make_tuple(best_grid_width * grid_x, best_grid_height * grid_y); (old line) - std::pair refine_size = std::make_pair(best_grid_width * grid_x, best_grid_height * grid_y); // (new line) + // img_size refine_size = std::make_tuple(best_grid_width * grid_x, best_grid_height * grid_y); (old line) + img_size refine_size = img_size(best_grid.width * grid_x, best_grid.height * grid_y); // (new line) return refine_size; } - std::pair find_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) { + img_size find_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) { std::vector candidate_split_grids_nums; for (int i : {multiple - 1, multiple, multiple + 1}) { if (i == 1 || i > max_slice_nums) { @@ -459,7 +457,7 @@ struct llama_vision_processor_uhd : llama_vision_processor { candidate_split_grids_nums.push_back(i); } - std::vector> candidate_grids; + std::vector candidate_grids; for (int split_grids_nums : candidate_split_grids_nums) { int m = 1; while (m <= split_grids_nums) { @@ -470,10 +468,10 @@ struct llama_vision_processor_uhd : llama_vision_processor { } } - std::pair best_grid{1, 1}; + img_size best_grid = img_size(1, 1); float min_error = std::numeric_limits::infinity(); for (const auto& grid : candidate_grids) { - float error = std::abs(log_ratio - std::log(1.0 * grid.first / grid.second)); + float error = std::abs(log_ratio - std::log(1.0 * grid.width / grid.height)); if (error < min_error) { best_grid = grid; min_error = error; @@ -487,7 +485,7 @@ struct llama_vision_processor_uhd : llama_vision_processor { const int max_slice_nums = 9, const int scale_resolution = 448, const int patch_size = 14) { - const std::pair original_size={img.nx,img.ny}; + const img_size original_size = img_size(img.nx, img.ny); const int original_width = img.nx; const int original_height = img.ny; const float log_ratio = log(1.0*original_width/original_height); @@ -501,34 +499,36 @@ struct llama_vision_processor_uhd : llama_vision_processor { if (multiple <= 1) { auto best_size = find_best_resize(original_size, scale_resolution, patch_size, true); llama_image_u8 source_image; - bicubic_resize(img, source_image, best_size.first, best_size.second); + bicubic_resize(img, source_image, best_size.width, best_size.height); // source_image = image.resize(best_size, Image.Resampling.BICUBIC) images.back().push_back(source_image); } else if (multiple > 1) { auto best_size = find_best_resize(original_size, scale_resolution, patch_size); llama_image_u8 source_image; - bicubic_resize(img, source_image, best_size.first, best_size.second); + bicubic_resize(img, source_image, best_size.width, best_size.height); // source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) - LLAMA_LOG_DEBUG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img.nx, img.ny, best_size.first, best_size.second); + LLAMA_LOG_DEBUG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img.nx, img.ny, best_size.width, best_size.height); images.back().push_back(source_image); - std::pair best_grid = find_best_grid(max_slice_nums, multiple, log_ratio); - LLAMA_LOG_DEBUG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img.nx, img.ny, best_grid.first, best_grid.second); + img_size best_grid = find_best_grid(max_slice_nums, multiple, log_ratio); + LLAMA_LOG_DEBUG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img.nx, img.ny, best_grid.width, best_grid.height); auto refine_size = get_refine_size(original_size, best_grid, scale_resolution, patch_size, true); llama_image_u8 refine_image; - bicubic_resize(img, refine_image, refine_size.first, refine_size.second); + // TODO: so far, we spend most of the time in bicubic_resize, we should optimize it + bicubic_resize(img, refine_image, refine_size.width, refine_size.height); - LLAMA_LOG_DEBUG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image.nx, refine_image.ny, refine_size.first, refine_size.second); + LLAMA_LOG_DEBUG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image.nx, refine_image.ny, refine_size.width, refine_size.height); // split_to_patches int width = refine_image.nx; int height = refine_image.ny; - int grid_x = int(width / best_grid.first); - int grid_y = int(height / best_grid.second); - for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){ + int grid_x = int(width / best_grid.width); + int grid_y = int(height / best_grid.height); + for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.height; patches_i += grid_y, ic += 1){ + std::vector patches_out; images.push_back(std::vector()); - for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){ + for (int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.width; patches_j += grid_x, jc += 1) { llama_image_u8 patch; patch.nx = grid_x; patch.ny = grid_y; @@ -542,8 +542,9 @@ struct llama_vision_processor_uhd : llama_vision_processor { patch.buf[j+2] = refine_image.buf[i+2]; } } - images.back().push_back(patch); + patches_out.push_back(std::move(patch)); } + images.push_back(std::move(patches_out)); } } return images; @@ -551,7 +552,6 @@ struct llama_vision_processor_uhd : llama_vision_processor { virtual llama_vision_tokens tokenize(const llama_image_u8 & img) override { auto & params = ctx.model->hparams; - GGML_ASSERT(params.arch == LLM_ARCH_VISION_MINICPMV); std::vector> imgs = slice_image(img); @@ -573,6 +573,10 @@ struct llama_vision_processor_uhd : llama_vision_processor { } }; +// +// cgraph builder +// + // TODO: move this to llm_build_context in llama.cpp struct llama_vision_graph_builder { llama_vision_context & ctx; @@ -590,6 +594,7 @@ struct llama_vision_graph_builder { int img_h; bool use_gelu; int n_layers; + int rs_n_embd; vision_projector_type proj_type; llama_vision_graph_builder(llama_vision_context & ctx, const llama_vision_tokens & inp) : ctx(ctx), model(*ctx.model) { @@ -950,7 +955,7 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_ GGML_ASSERT(batch_size == 1); // TODO: support multiple images } - img_size image_size{(int)hparams.image_size, (int)hparams.image_size}; + img_size image_size = img_size((int)hparams.image_size, (int)hparams.image_size); const int patch_size = hparams.patch_size; const int num_patches = ((image_size.width / patch_size) * (image_size.height / patch_size)); const int num_positions = num_patches + (model.class_embedding ? 1 : 0); @@ -1016,23 +1021,25 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_ // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "inp_pos"); - std::vector pos_buf(ggml_nelements(positions)); - GGML_ASSERT(num_positions == (int)pos_buf.size()); + std::vector buf(ggml_nelements(positions)); + GGML_ASSERT(num_positions == (int)buf.size()); int bucket_coords_h[70]; int bucket_coords_w[70]; - for (size_t i = 0; i < inp.n_py; i++) { - bucket_coords_h[i] = std::floor(70.0*i/inp.n_py); + size_t h = inp.py; + size_t w = inp.py; + for (size_t i = 0; i < h; i++) { + bucket_coords_h[i] = std::floor(70.0*i/h); } - for (size_t i = 0; i < inp.n_px; i++) { - bucket_coords_w[i] = std::floor(70.0*i/inp.n_px); + for (size_t i = 0; i < w; i++) { + bucket_coords_w[i] = std::floor(70.0*i/w); } - for (size_t i = 0, id = 0; i < inp.n_py; i++){ - for (size_t j = 0; j < inp.n_px; j++){ - pos_buf[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j]; + for (size_t i = 0, id = 0; i < h; i++){ + for (size_t j = 0; j < w; j++){ + buf[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j]; } } - ggml_backend_tensor_set(positions, pos_buf.data(), 0, ggml_nbytes(positions)); + ggml_backend_tensor_set(positions, buf.data(), 0, ggml_nbytes(positions)); } else { struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "inp_pos"); @@ -1055,6 +1062,7 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_ } // compute + LLAMA_LOG_DEBUG("%s: compute start\n", __func__); int64_t t_start = ggml_time_ms(); ggml_backend_sched_graph_compute(ctx.sched, gf); @@ -1106,7 +1114,6 @@ struct llama_vision_tokens * llama_vision_tokenize( case LLM_ARCH_VISION_IDEFICS3: return new llama_vision_tokens(llama_vision_processor_llava(vctx).tokenize(*bmp)); case LLM_ARCH_VISION_MINICPMV: - //return new llama_vision_tokens(llama_vision_processor_uhd(vctx).tokenize(*bmp)); return new llama_vision_tokens(llama_vision_processor_llava(vctx).tokenize(*bmp)); default: GGML_ASSERT(false && "unsupported arch");