diff --git a/README.md b/README.md index 5f7933c13..4577862e7 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ Typically finetunes of the base models below are supported as well. - [x] [Bitnet b1.58 models](https://huggingface.co/1bitLLM) - [x] [Flan T5](https://huggingface.co/models?search=flan-t5) - [x] [Open Elm models](https://huggingface.co/collections/apple/openelm-instruct-models-6619ad295d7ae9f868b759ca) -- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b) +- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b) + [GLMEdge-1.5b](https://huggingface.co/THUDM/glm-edge-1.5b-chat) + [GLMEdge-4b](https://huggingface.co/THUDM/glm-edge-4b-chat) - [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966) - [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct) - [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a) @@ -109,6 +109,7 @@ Typically finetunes of the base models below are supported as well. - [x] [Mini CPM](https://huggingface.co/models?search=MiniCPM) - [x] [Moondream](https://huggingface.co/vikhyatk/moondream2) - [x] [Bunny](https://github.com/BAAI-DCAI/Bunny) +- [x] [GLM-EDGE](https://huggingface.co/models?search=glm-edge) **Bindings:** diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index cc576b70c..c3931aa2f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4011,7 +4011,7 @@ class ChatGLMModel(Model): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused - if name.endswith(".rotary_pos_emb.inv_freq"): + if name.endswith(".rotary_pos_emb.inv_freq") or name.startswith("model.vision."): return [] name = name.removeprefix("transformer.") diff --git a/examples/llava/README-glmedge.md b/examples/llava/README-glmedge.md new file mode 100644 index 000000000..c36203c9e --- /dev/null +++ b/examples/llava/README-glmedge.md @@ -0,0 +1,43 @@ +# GLMV-EDGE + +Currently this implementation supports [glm-edge-v-2b](https://huggingface.co/THUDM/glm-edge-v-2b) and [glm-edge-v-5b](https://huggingface.co/THUDM/glm-edge-v-5b). + +## Usage +Build with cmake or run `make llama-llava-cli` to build it. + +After building, run: `./llama-llava-cli` to see the usage. For example: + +```sh +./llama-llava-cli -m model_path/ggml-model-f16.gguf --mmproj model_path/mmproj-model-f16.gguf --image img_path/image.jpg -p "<|system|>\n system prompt <|user|>\n prompt <|assistant|>\n" +``` + +**note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so. +**note**: For GPU offloading ensure to use the `-ngl` flag just like usual + +## GGUF conversion + +1. Clone a GLMV-EDGE model ([2B](https://huggingface.co/THUDM/glm-edge-v-2b) or [5B](https://huggingface.co/THUDM/glm-edge-v-5b)). For example: + +```sh +git clone https://huggingface.co/THUDM/glm-edge-v-5b or https://huggingface.co/THUDM/glm-edge-v-2b +``` + +2. Use `glmedge-surgery.py` to split the GLMV-EDGE model to LLM and multimodel projector constituents: + +```sh +python ./examples/llava/glmedge-surgery.py -m ../model_path +``` + +4. Use `glmedge-convert-image-encoder-to-gguf.py` to convert the GLMV-EDGE image encoder to GGUF: + +```sh +python ./examples/llava/glmedge-convert-image-encoder-to-gguf.py -m ../model_path --llava-projector ../model_path/glm.projector --output-dir ../model_path +``` + +5. Use `examples/convert_hf_to_gguf.py` to convert the LLM part of GLMV-EDGE to GGUF: + +```sh +python convert_hf_to_gguf.py ../model_path +``` + +Now both the LLM part and the image encoder are in the `model_path` directory. \ No newline at end of file diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index aae49c965..b43aa3cc2 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -90,6 +90,7 @@ static std::string format(const char * fmt, ...) { #define KEY_HAS_VIS_ENC "clip.has_vision_encoder" #define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" #define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector" +#define KEY_HAS_GLM_PROJ "clip.has_glm_projector" #define KEY_MINICPMV_VERSION "clip.minicpmv_version" #define KEY_USE_GELU "clip.use_gelu" #define KEY_N_EMBD "clip.%s.embedding_length" @@ -145,6 +146,15 @@ static std::string format(const char * fmt, ...) { #define TN_MINICPMV_ATTN "resampler.attn.%s.%s" #define TN_MINICPMV_LN "resampler.ln_%s.%s" +#define TN_GLM_ADAPER_CONV "adapter.conv.%s" +#define TN_GLM_ADAPTER_LINEAR "adapter.linear.linear.%s" +#define TN_GLM_ADAPTER_NORM_1 "adapter.linear.norm1.%s" +#define TN_GLM_ADAPTER_D_H_2_4H "adapter.linear.dense_h_to_4h.%s" +#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s" +#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s" +#define TN_GLM_BOI_W "adapter.boi" +#define TN_GLM_EOI_W "adapter.eoi" + enum projector_type { PROJECTOR_TYPE_MLP, @@ -152,6 +162,7 @@ enum projector_type { PROJECTOR_TYPE_LDP, PROJECTOR_TYPE_LDPV2, PROJECTOR_TYPE_RESAMPLER, + PROJECTOR_TYPE_ADAPTER, PROJECTOR_TYPE_UNKNOWN, }; @@ -160,6 +171,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LDP, "ldp" }, { PROJECTOR_TYPE_LDPV2, "ldpv2"}, { PROJECTOR_TYPE_RESAMPLER, "resampler"}, + { PROJECTOR_TYPE_ADAPTER, "adapter"} }; @@ -482,6 +494,12 @@ struct clip_vision_model { struct ggml_tensor * mm_4_w = NULL; struct ggml_tensor * mm_4_b = NULL; + //GLMV-Edge projection + struct ggml_tensor * mm_model_adapter_conv_w; + struct ggml_tensor * mm_model_adapter_conv_b; + struct ggml_tensor * boi_w; + struct ggml_tensor * eoi_w; + // MobileVLM projection struct ggml_tensor * mm_model_mlp_1_w; struct ggml_tensor * mm_model_mlp_1_b; @@ -542,6 +560,7 @@ struct clip_ctx { bool has_vision_encoder = false; bool has_llava_projector = false; bool has_minicpmv_projector = false; + bool has_glm_projector = false; int minicpmv_version = 2; struct clip_vision_model vision_model; @@ -606,7 +625,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 const int batch_size = imgs->size; - if (ctx->has_llava_projector || ctx->has_minicpmv_projector) { + if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) { GGML_ASSERT(batch_size == 1); } @@ -677,7 +696,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } // loop over layers - if (ctx->has_minicpmv_projector) { + if (ctx->has_minicpmv_projector || ctx->has_glm_projector) { n_layer += 1; } for (int il = 0; il < n_layer - 1; il++) { @@ -1019,6 +1038,33 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 GGML_ASSERT(false); } } + // glm projector + else if(ctx->has_glm_projector){ + if (ctx->proj_type == PROJECTOR_TYPE_ADAPTER){ + size_t gridsz = (size_t)sqrt(embeddings->ne[1]); + embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3)); + embeddings = ggml_reshape_3d(ctx0,embeddings,gridsz,gridsz,embeddings->ne[1]); + embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1); + embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size); + embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3)); + embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b); + //GLU + { + embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); + embeddings = ggml_norm(ctx0, embeddings, eps); + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b); + embeddings = ggml_gelu_inplace(ctx0, embeddings); + struct ggml_tensor * x = embeddings; + embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings); + x = ggml_mul_mat(ctx0,model.mm_model_mlp_1_w,x); + embeddings = ggml_silu_inplace(ctx0,embeddings); + embeddings = ggml_mul(ctx0,embeddings,x); + embeddings = ggml_mul_mat(ctx0,model.mm_model_mlp_3_w,embeddings); + } + }else{ + GGML_ABORT("fatel error"); + } + } // build the graph ggml_build_forward_expand(gf, embeddings); @@ -1190,6 +1236,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx); } + idx = gguf_find_key(ctx, KEY_HAS_GLM_PROJ); + if (idx != -1) { + new_clip->has_glm_projector = gguf_get_val_bool(ctx, idx); + } + // GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search GGML_ASSERT(new_clip->has_vision_encoder); @@ -1203,6 +1254,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder); LOG_INF("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector); LOG_INF("%s: minicpmv_projector: %d\n", __func__, new_clip->has_minicpmv_projector); + LOG_INF("%s: glm_projector: %d\n", __func__, new_clip->has_glm_projector); LOG_INF("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0); LOG_INF("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0); } @@ -1465,6 +1517,19 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight")); vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias")); } + else if(new_clip->proj_type == PROJECTOR_TYPE_ADAPTER){ + printf("adapter get data\n"); + vision_model.mm_model_adapter_conv_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPER_CONV, "weight")); + vision_model.mm_model_adapter_conv_b = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPER_CONV, "bias")); + vision_model.mm_model_mlp_0_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_LINEAR,"weight")); + vision_model.mm_model_ln_q_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_NORM_1,"weight")); + vision_model.mm_model_ln_q_b = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_NORM_1,"bias")); + vision_model.mm_model_mlp_1_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_D_H_2_4H,"weight")); + vision_model.mm_model_mlp_2_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_GATE,"weight")); + vision_model.mm_model_mlp_3_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_D_4H_2_H,"weight")); + vision_model.boi_w = get_tensor(new_clip->ctx_data, TN_GLM_BOI_W); + vision_model.eoi_w = get_tensor(new_clip->ctx_data, TN_GLM_EOI_W); + } else { std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type]; throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); @@ -1969,6 +2034,20 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli return true; } + if(ctx->has_glm_projector){ + res_imgs->size = 1; + res_imgs->data = new clip_image_f32[res_imgs->size]; + clip_image_u8 resized_image; + int32_t sz=ctx->vision_model.hparams.image_size; + bicubic_resize(*img, resized_image,sz,sz); + clip_image_f32 * res = clip_image_f32_init(); + //clip_image_save_to_bmp(resized_image, "resized.bmp"); + normalize_image_u8_to_f32(&resized_image, res, ctx->image_mean, ctx->image_std); + res_imgs->data[0] = *res; + clip_image_f32_free(res); + return true; + } + bool pad_to_square = true; if (!ctx->has_vision_encoder) { LOG_ERR("This gguf file seems to have no vision encoder\n"); @@ -2154,6 +2233,8 @@ void clip_free(clip_ctx * ctx) { } size_t clip_embd_nbytes(const struct clip_ctx * ctx) { + if(ctx->has_glm_projector) + return (clip_n_patches(ctx)+2) * clip_n_mmproj_embd(ctx) * sizeof(float); return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float); } @@ -2182,7 +2263,7 @@ int clip_n_patches(const struct clip_ctx * ctx) { int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size); - if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) { + if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_ADAPTER) { n_patches /= 4; } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { if (ctx->minicpmv_version == 2) { @@ -2307,6 +2388,12 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima if (ctx->has_minicpmv_projector) { GGML_ASSERT(batch_size == 1); } + if(ctx->has_glm_projector){ + GGML_ASSERT(batch_size == 1); + ggml_tensor * boi = ctx->vision_model.boi_w; + ggml_backend_tensor_get(boi,vec,0,ggml_nbytes(boi)); + vec=(float*)(vec+ggml_nelements(boi)); //offset for boi + } // build the inference graph ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true); @@ -2430,7 +2517,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima free(positions_data); } - { + if (!ctx->has_glm_projector){ struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); int* patches_data = (int*)malloc(ggml_nbytes(patches)); for (int i = 0; i < num_patches; i++) { @@ -2453,6 +2540,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // copy the embeddings to the location passed by the user ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); + if(ctx->has_glm_projector){ + //eoi + ggml_tensor * eoi = ctx->vision_model.eoi_w; + int offset=ggml_nelements(eoi)*clip_n_patches(ctx); + ggml_backend_tensor_get(eoi,vec+offset,0,ggml_nbytes(eoi)); + } + return true; } @@ -2610,6 +2704,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return 3584; } } + if (ctx->proj_type == PROJECTOR_TYPE_ADAPTER){ + return ctx->vision_model.mm_model_mlp_3_w->ne[1]; + } std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type]; throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); @@ -2621,3 +2718,7 @@ int clip_is_minicpmv(const struct clip_ctx * ctx) { } return 0; } + +bool clip_is_glm(const struct clip_ctx * ctx) { + return ctx->has_glm_projector; +} \ No newline at end of file diff --git a/examples/llava/clip.h b/examples/llava/clip.h index 78588bdf1..c92cb5dee 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -87,6 +87,8 @@ CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx); +CLIP_API bool clip_is_glm(const struct clip_ctx * ctx); + #ifdef __cplusplus } #endif diff --git a/examples/llava/glmedge-convert-image-encoder-to-gguf.py b/examples/llava/glmedge-convert-image-encoder-to-gguf.py new file mode 100644 index 000000000..ac1911a14 --- /dev/null +++ b/examples/llava/glmedge-convert-image-encoder-to-gguf.py @@ -0,0 +1,280 @@ +import argparse +import os +import json +import re + +import torch +import numpy as np +from gguf import * + +TEXT = "clip.text" +VISION = "clip.vision" +from transformers import SiglipVisionModel, SiglipVisionConfig + +def k(raw_key: str, arch: str) -> str: + return raw_key.format(arch=arch) + + +def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: bool) -> bool: + if name in ( + "logit_scale", + "text_model.embeddings.position_ids", + "vision_model.embeddings.position_ids", + ): + return True + + if name in ( + "vision_model.head.probe", + "vision_model.head.attention.in_proj_weight", + "vision_model.head.attention.in_proj_bias", + "vision_model.head.attention.out_proj.weight", + "vision_model.head.attention.out_proj.bias", + "vision_model.head.layernorm.weight", + "vision_model.head.layernorm.bias", + "vision_model.head.mlp.fc1.weight", + "vision_model.head.mlp.fc1.bias", + "vision_model.head.mlp.fc2.weight", + "vision_model.head.mlp.fc2.bias" + ): + return True + + if name.startswith("v") and not has_vision: + return True + + if name.startswith("t") and not has_text: + return True + + return False + + +def get_tensor_name(name: str) -> str: + if "projection" in name: + return name + if "mm_projector" in name: + name = name.replace("model.mm_projector", "mm") + name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1) + name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1) + return name + + return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln") + + +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +ap = argparse.ArgumentParser() +ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True) +ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16") +ap.add_argument("--text-only", action="store_true", required=False, + help="Save a text-only model. It can't be used to encode images") +ap.add_argument("--vision-only", action="store_true", required=False, + help="Save a vision-only model. It can't be used to encode texts") +ap.add_argument("--clip-model-is-vision", action="store_true", required=False, + help="The clip model is a pure vision model (ShareGPT4V vision extract for example)") +ap.add_argument("--clip-model-is-openclip", action="store_true", required=False, + help="The clip model is from openclip (for ViT-SO400M type))") +ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.") +ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2","adapter"], default="adapter") +ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) +# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711 +# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5 +default_image_mean = [0.5, 0.5, 0.5] +default_image_std = [0.5, 0.5, 0.5] +ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) +ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) + +# with proper +args = ap.parse_args() + + +if args.text_only and args.vision_only: + print("--text-only and --image-only arguments cannot be specified at the same time.") + exit(1) + +if args.use_f32: + print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.") + +# output in the same directory as the model if output_dir is None +dir_model = args.model_dir + +if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip: + vocab = None + tokens = None +else: + with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f: + vocab = json.load(f) + tokens = [key for key in vocab] + +with open(dir_model + "/config.json", "r", encoding="utf-8") as f: + config = json.load(f) + if args.clip_model_is_vision: + v_hparams = config + t_hparams = None + else: + v_hparams = config["vision_config"] + t_hparams = None + +# possible data types +# ftype == 0 -> float32 +# ftype == 1 -> float16 +# +# map from ftype to string +ftype_str = ["f32", "f16"] + +ftype = 1 +if args.use_f32: + ftype = 0 + +vision_config = SiglipVisionConfig(**v_hparams) +model = SiglipVisionModel(vision_config) +model.load_state_dict(torch.load(os.path.join(dir_model, "glm.clip"))) + +fname_middle = None +has_text_encoder = False +has_vision_encoder = True +has_glm_projector = True +if args.text_only: + fname_middle = "text-" + has_vision_encoder = False +elif args.llava_projector is not None: + fname_middle = "mmproj-" + has_text_encoder = False + has_glm_projector = True +elif args.vision_only: + fname_middle = "vision-" + has_text_encoder = False +else: + fname_middle = "" + +output_dir = args.output_dir if args.output_dir is not None else dir_model +os.makedirs(output_dir, exist_ok=True) +output_prefix = os.path.basename(output_dir).replace("ggml_", "") +fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf") +fout = GGUFWriter(path=fname_out, arch="clip") + +fout.add_bool("clip.has_text_encoder", has_text_encoder) +fout.add_bool("clip.has_vision_encoder", has_vision_encoder) +fout.add_bool("clip.has_glm_projector", has_glm_projector) +fout.add_file_type(ftype) +model_name = config["_name_or_path"] if "_name_or_path" in config else os.path.basename(dir_model) +fout.add_name(model_name) +if has_glm_projector: + fout.add_description("image encoder for glm4v") + fout.add_string("clip.projector_type", "adapter") +else: + fout.add_description("two-tower CLIP model") + +if has_text_encoder: + assert t_hparams is not None + assert tokens is not None + # text_model hparams + fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"]) + fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"]) + fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"]) + fout.add_uint32("clip.text.projection_dim", t_hparams.get("projection_dim", config["projection_dim"])) + fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"]) + fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"]) + fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"]) + fout.add_token_list(tokens) + +if has_vision_encoder: + # vision_model hparams + fout.add_uint32("clip.vision.image_size", v_hparams["image_size"]) + fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"]) + fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"]) + fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"]) + fout.add_uint32("clip.vision.projection_dim", 0) + fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"]) + fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) + fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), v_hparams["num_hidden_layers"]) + + image_mean = args.image_mean if args.image_mean is not None else default_image_mean + image_std = args.image_std if args.image_std is not None else default_image_std + fout.add_array("clip.vision.image_mean", image_mean) + fout.add_array("clip.vision.image_std", image_std) + +fout.add_bool("clip.use_gelu", True) + + +if has_glm_projector: + # model.vision_model.encoder.layers.pop(-1) # pyright: ignore[reportAttributeAccessIssue] + projector = torch.load(args.llava_projector) + for name, data in projector.items(): + name = get_tensor_name(name) + # pw and dw conv ndim==4 + if data.ndim == 2 or data.ndim == 4: + data = data.squeeze().numpy().astype(np.float16) + else: + data = data.squeeze().numpy().astype(np.float32) + if name.startswith("vision."): + name=name.replace("vision.","") + fout.add_tensor(name, data) + print(f"Projector {name} - {data.dtype} - shape = {data.shape}") + # print(f"Projector {name} tensors added\n") + +state_dict = model.state_dict() # pyright: ignore[reportAttributeAccessIssue] +for name, data in state_dict.items(): + if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_glm_projector): + # we don't need this + print(f"skipping parameter: {name}") + continue + + name = get_tensor_name(name) + data = data.squeeze().numpy() + + n_dims = len(data.shape) + + # ftype == 0 -> float32, ftype == 1 -> float16 + ftype_cur = 0 + if n_dims == 4: + print(f"tensor {name} is always saved in f16") + data = data.astype(np.float16) + ftype_cur = 1 + elif ftype == 1: + if name[-7:] == ".weight" and n_dims == 2: + # print(" Converting to float16") + data = data.astype(np.float16) + ftype_cur = 1 + else: + # print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + else: + if data.dtype != np.float32: + # print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + print(f"siglip {name} - {data.dtype} - shape = {data.shape}") + # print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}") + fout.add_tensor(name, data) + + +fout.write_header_to_file() +fout.write_kv_data_to_file() +fout.write_tensors_to_file() +fout.close() + +print("Done. Output file: " + fname_out) \ No newline at end of file diff --git a/examples/llava/glmedge-surgery.py b/examples/llava/glmedge-surgery.py new file mode 100644 index 000000000..7d7dc6837 --- /dev/null +++ b/examples/llava/glmedge-surgery.py @@ -0,0 +1,33 @@ +import argparse +import os +import torch +from transformers import AutoModel, AutoTokenizer + +ap = argparse.ArgumentParser() +ap.add_argument("-m", "--model", help="Path to GLM model") +args = ap.parse_args() + +# find the model part that includes the the multimodal projector weights +model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True) +checkpoint = model.state_dict() + +# get a list of mm tensor names +mm_tensors = [k for k, v in checkpoint.items() if k.startswith("vision.adapter.")] + +# store these tensors in a new dictionary and torch.save them +projector = {name: checkpoint[name].float() for name in mm_tensors} +torch.save(projector, f"{args.model}/glm.projector") + +clip_tensors = [k for k, v in checkpoint.items() if k.startswith("vision.vit.model.vision_model.")] +if len(clip_tensors) > 0: + clip = {name.replace("vision.vit.model.", ""): checkpoint[name].float() for name in clip_tensors} + torch.save(clip, f"{args.model}/glm.clip") + + # added tokens should be removed to be able to convert Mistral models + if os.path.exists(f"{args.model}/added_tokens.json"): + with open(f"{args.model}/added_tokens.json", "w") as f: + f.write("{}\n") + +print("Done!") +print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") +print(f"Also, use {args.model}glm.projector to prepare a glm-encoder.gguf file.") diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index be6988540..59354aafa 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -299,6 +299,20 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli clip_add_load_image_size(ctx_clip, load_image_size); LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height); } + else if (clip_is_glm(ctx_clip)){ + struct clip_image_size * load_image_size = clip_image_size_init(); + load_image_size->width = img_res_v.data[0].nx; + load_image_size->height = img_res_v.data[0].ny; + clip_add_load_image_size(ctx_clip, load_image_size); + + bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); + int pos = int(load_image_size->width/clip_patch_size(ctx_clip)/2); + *n_img_pos = (pos * pos + 2); + if (!encoded) { + LOG_ERR("Unable to encode image \n"); + return false; + } + } else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { // flat / default llava-1.5 type embedding *n_img_pos = clip_n_patches(ctx_clip); @@ -383,6 +397,9 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co if (clip_is_minicpmv(ctx_clip)) { num_max_patches = 10; } + if (clip_is_glm(ctx_clip)) { + num_max_patches = 1; + } float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*num_max_patches); // TODO: base on gridsize/llava model if (!image_embd) { LOG_ERR("Unable to allocate memory for image embeddings\n");