diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index a6751cc80..c5cb8bbec 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -80,7 +80,7 @@ class Model: if not self.is_safetensors: self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin") self.hparams = Model.load_hparams(self.dir_model) - self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) + self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) self.tensor_names = None if self.ftype == gguf.LlamaFileType.GUESSED: @@ -483,6 +483,9 @@ class Model: if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a": # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code res = "jina-v2-code" + if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b": + # ref: https://huggingface.co/THUDM/glm-4-9b-chat + res = "chatglm-bpe" if res is None: logger.warning("\n") @@ -2725,6 +2728,188 @@ class DeepseekV2Model(Model): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("ChatGLMModel") +class ChatGLMModel(Model): + model_arch = gguf.MODEL_ARCH.CHATGLM + + def set_vocab_chatglm3(self): + dir_model = self.dir_model + hparams = self.hparams + tokens: list[bytearray] = [] + toktypes: list[int] = [] + scores: list[float] = [] + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) + vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab())) + assert max(tokenizer.get_vocab().values()) < vocab_size + role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"] + special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens + print(vocab_size) + print(max(tokenizer.get_vocab().values())) + for token_id in range(vocab_size): + piece = tokenizer._convert_id_to_token(token_id) + if token_id == 0: + piece = "" + elif token_id == 1: + piece = "" + elif token_id == 2: + piece = "" + + text = piece.encode("utf-8") + score = 0.0 + # Referencing the tokenizer Python implementation(https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py), + # it is only valid if it is less than tokenizer.tokenizer.sp_model.vocab_size() + if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size(): + score = tokenizer.tokenizer.sp_model.get_score(token_id) + + if len(piece) == 0: + text = f"[PAD{token_id}]".encode("utf-8") + + if token_id >= tokenizer.tokenizer.sp_model.vocab_size(): + if piece in special_tokens: + # show special tokens in prompt + toktype = SentencePieceTokenTypes.USER_DEFINED + else: + toktype = SentencePieceTokenTypes.UNKNOWN + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + continue + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.tokenizer.sp_model.is_unknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.tokenizer.sp_model.is_control(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.tokenizer.sp_model.is_unused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.tokenizer.sp_model.is_byte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + @staticmethod + def token_bytes_to_string(b): + from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode + byte_encoder = bytes_to_unicode() + return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')]) + + @staticmethod + def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]: + parts = [bytes([b]) for b in token] + while True: + min_idx = None + min_rank = None + for i, pair in enumerate(zip(parts[:-1], parts[1:])): + rank = mergeable_ranks.get(pair[0] + pair[1]) + if rank is not None and (min_rank is None or rank < min_rank): + min_idx = i + min_rank = rank + if min_rank is None or (max_rank is not None and min_rank >= max_rank): + break + assert min_idx is not None + parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:] + return parts + + def set_vocab(self): + if "THUDM/chatglm3-6b" in self.hparams.get("_name_or_path", ""): + self.set_vocab_chatglm3() + return + + dir_model = self.dir_model + hparams = self.hparams + tokens: list[str] = [] + toktypes: list[int] = [] + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) + vocab_size = hparams["padded_vocab_size"] + assert max(tokenizer.get_vocab().values()) < vocab_size + + tokpre = self.get_vocab_base_pre(tokenizer) + + merges = [] + vocab = {} + mergeable_ranks = tokenizer.mergeable_ranks + for token, rank in mergeable_ranks.items(): + vocab[ChatGLMModel.token_bytes_to_string(token)] = rank + if len(token) == 1: + continue + merged = ChatGLMModel.bpe(mergeable_ranks, token, max_rank=rank) + assert len(merged) >= 2 and len(merged) <= 7 + merges.append(' '.join(map(ChatGLMModel.token_bytes_to_string, merged))) + + # for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined + added_vocab = tokenizer.get_added_vocab() + reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()} + + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.USER_DEFINED) + elif reverse_vocab[i] in added_vocab: + tokens.append(reverse_vocab[i]) + if tokenizer.added_tokens_decoder[i].special: + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.USER_DEFINED) + else: + tokens.append(reverse_vocab[i]) + toktypes.append(gguf.TokenType.NORMAL) + + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(dir_model, load_merges=False) + special_vocab.chat_template = "ChatGLM4" + special_vocab.merges = merges + # only add special tokens when they were not already loaded from config.json + # if len(special_vocab.special_token_ids) == 0: + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) + # this one is usually not in config.json anyway + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + self.gguf_writer.add_name(self.dir_model.name) + n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) + n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) + n_head_kv = self.hparams.get("multi_query_group_num", n_head) + self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed)) + self.gguf_writer.add_embedding_length(n_embed) + self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", 4 * n_embed)) + self.gguf_writer.add_block_count(self.hparams["num_layers"]) + self.gguf_writer.add_head_count(n_head) + self.gguf_writer.add_head_count_kv(n_head_kv) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layernorm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_rope_dimension_count(64) + self.gguf_writer.add_add_bos_token(False) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.endswith(".rotary_pos_emb.inv_freq"): + return [] + + name = name.removeprefix("transformer.") + return [(self.map_tensor_name(name), data_torch)] + + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index fb20cfabb..65aa3298d 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -149,6 +149,7 @@ class MODEL_ARCH(IntEnum): OLMO = auto() ARCTIC = auto() DEEPSEEK2 = auto() + CHATGLM = auto() class MODEL_TENSOR(IntEnum): @@ -237,6 +238,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.OLMO: "olmo", MODEL_ARCH.ARCTIC: "arctic", MODEL_ARCH.DEEPSEEK2: "deepseek2", + MODEL_ARCH.CHATGLM: "chatglm", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -808,6 +810,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, ], + MODEL_ARCH.CHATGLM : [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } @@ -845,6 +859,9 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], + MODEL_ARCH.CHATGLM: [ + MODEL_TENSOR.ROPE_FREQS, + ], } # diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 81b4992a5..2703657e0 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -24,6 +24,7 @@ class TensorNameMap: "backbone.embedding", # mamba "backbone.embeddings", # mamba-hf "transformer.in_out_embed", # Grok + "embedding.word_embeddings", # chatglm ), # Token type embeddings @@ -52,6 +53,7 @@ class TensorNameMap: "output", # llama-pth bloom internlm2 "word_embeddings_for_head", # persimmon "lm_head.linear", # phi2 + "output_layer", # chatglm ), # Output norm @@ -68,11 +70,13 @@ class TensorNameMap: "model.norm_f", # mamba-qbert "backbone.norm_f", # mamba "transformer.rms_norm", # Grok + "encoder.final_layernorm", # chatglm ), # Rope frequencies MODEL_TENSOR.ROPE_FREQS: ( "rope.freqs", # llama-pth + "rotary_pos_emb.inv_freq", # chatglm ), } @@ -97,6 +101,7 @@ class TensorNameMap: "backbone.layers.{bid}.norm", # mamba "transformer.decoder_layer.{bid}.rms_norm", # Grok "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx + "encoder.layers.{bid}.input_layernorm", # chatglm ), # Attention norm 2 @@ -118,7 +123,8 @@ class TensorNameMap: "h.{bid}.attn.c_attn", # gpt2 "transformer.h.{bid}.mixer.Wqkv", # phi2 "encoder.layers.{bid}.attn.Wqkv", # nomic-bert - "model.layers.{bid}.self_attn.qkv_proj" # phi3 + "model.layers.{bid}.self_attn.qkv_proj", # phi3 + "encoder.layers.{bid}.self_attention.query_key_value", # chatglm ), # Attention query @@ -129,7 +135,7 @@ class TensorNameMap: "transformer.h.{bid}.attn.q_proj", # gpt-j "model.layers.layers.{bid}.self_attn.q_proj", # plamo "model.layers.{bid}.attention.wq", # internlm2 - "transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok + "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok ), # Attention key @@ -141,7 +147,7 @@ class TensorNameMap: "transformer.h.{bid}.attn.k", # refact "model.layers.layers.{bid}.self_attn.k_proj", # plamo "model.layers.{bid}.attention.wk", # internlm2 - "transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok + "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok ), # Attention value @@ -176,6 +182,7 @@ class TensorNameMap: "encoder.layers.{bid}.attn.out_proj", # nomic-bert "transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx + "encoder.layers.{bid}.self_attention.dense", # chatglm ), # Attention output norm @@ -207,6 +214,7 @@ class TensorNameMap: "h.{bid}.ln_2", # gpt2 "model.layers.{bid}.ffn_norm", # internlm2 "transformer.decoder_layer.{bid}.rms_norm_2", # Grok + "encoder.layers.{bid}.post_attention_layernorm", # chatglm ), MODEL_TENSOR.FFN_GATE_INP: ( @@ -246,6 +254,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.c_fc", # starcoder2 "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 "model.layers.{bid}.residual_mlp.w3", # arctic + "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -313,6 +322,7 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.wo", # jina-bert-v2 "model.layers.{bid}.residual_mlp.w2", # arctic "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 + "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm ), MODEL_TENSOR.FFN_DOWN_EXP: ( diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index 36e63ee3b..62129126b 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gguf" -version = "0.9.0" +version = "0.9.1" description = "Read and write ML models in GGUF for GGML" authors = ["GGML "] packages = [ diff --git a/llama.cpp b/llama.cpp index 8818c6928..a2df298a8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -225,6 +225,7 @@ enum llm_arch { LLM_ARCH_OLMO, LLM_ARCH_ARCTIC, LLM_ARCH_DEEPSEEK2, + LLM_ARCH_CHATGLM, LLM_ARCH_UNKNOWN, }; @@ -263,6 +264,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_OLMO, "olmo" }, { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1113,6 +1115,21 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, + { + LLM_ARCH_CHATGLM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1917,9 +1934,11 @@ enum e_model { MODEL_2_8B, MODEL_3B, MODEL_4B, + MODEL_6B, MODEL_6_9B, MODEL_7B, MODEL_8B, + MODEL_9B, MODEL_12B, MODEL_13B, MODEL_14B, @@ -4120,9 +4139,11 @@ static const char * llama_model_type_name(e_model type) { case MODEL_2_8B: return "2.8B"; case MODEL_3B: return "3B"; case MODEL_4B: return "4B"; + case MODEL_6B: return "6B"; case MODEL_6_9B: return "6.9B"; case MODEL_7B: return "7B"; case MODEL_8B: return "8B"; + case MODEL_9B: return "9B"; case MODEL_12B: return "12B"; case MODEL_13B: return "13B"; case MODEL_14B: return "14B"; @@ -4708,6 +4729,15 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_CHATGLM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 28: model.type = e_model::MODEL_6B; break; + case 40: model.type = e_model::MODEL_9B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -4798,9 +4828,9 @@ static void llm_load_vocab( if (merges_keyidx == -1) { throw std::runtime_error("cannot find tokenizer merges in model file\n"); } - + printf("merges_keyidx: %d\n", merges_keyidx); const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); - + printf("n_merges: %d\n", n_merges); for (int i = 0; i < n_merges; i++) { const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); @@ -4897,6 +4927,9 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "poro-chat") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO; + } else if ( + tokenizer_pre == "chatglm-bpe") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -6650,6 +6683,36 @@ static bool llm_load_tensors( } } } break; + case LLM_ARCH_CHATGLM: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + (hparams.n_embd_head_k << 2)}); + layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + (hparams.n_embd_head_k << 2)}); + + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}); + + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -6874,6 +6937,7 @@ enum llm_ffn_op_type { LLM_FFN_GELU, LLM_FFN_RELU, LLM_FFN_RELU_SQR, + LLM_FFN_SWIGLU, }; enum llm_ffn_gate_type { @@ -7064,6 +7128,19 @@ static struct ggml_tensor * llm_build_ffn( cur = ggml_sqr(ctx, cur); cb(cur, "ffn_sqr(relu)", il); } break; + case LLM_FFN_SWIGLU: + { + // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + int64_t split_point = cur->ne[0] / 2; + struct ggml_tensor * x0 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], 0)); + struct ggml_tensor * x1 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); + + x0 = ggml_silu(ctx, x0); + cb(cur, "ffn_silu", il); + + cur = ggml_mul(ctx, x0, x1); + cb(cur, "ffn_mul", il); + } break; } if (type_gate == LLM_FFN_PAR) { @@ -11684,6 +11761,119 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_chatglm() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + struct ggml_tensor * Qcur = nullptr; + struct ggml_tensor * Kcur = nullptr; + struct ggml_tensor * Vcur = nullptr; + + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur_rope", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur_rope", il); + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Add the input + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + NULL, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + + } + + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, + NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -11907,6 +12097,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_deepseek2(); } break; + case LLM_ARCH_CHATGLM: + { + result = llm.build_chatglm(); + } break; default: GGML_ASSERT(false); } @@ -13252,6 +13446,11 @@ struct llm_tokenizer_bpe { " ?[^(\\s|.,!?…。,、।۔،)]+", }; break; + case LLAMA_VOCAB_PRE_TYPE_CHATGLM4: + regex_exprs = { + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -16704,6 +16903,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_OLMO: case LLM_ARCH_ARCTIC: case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_CHATGLM: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 @@ -18324,6 +18524,19 @@ llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_to } bool llama_token_is_eog(const struct llama_model * model, llama_token token) { + auto arch_name = llama_model_arch_name(model->arch); + auto vocab_type = model->vocab.type; + if (strcmp(arch_name, "chatglm") == 0) { + if (LLAMA_VOCAB_TYPE_BPE == vocab_type) { // glm4 + return token != -1 && ( + token == llama_token_eos(model) || + token == llama_token_eot(model) || + token == 151329 || + token == 151336 || + token == 151338 + ); + } + } return token != -1 && ( token == llama_token_eos(model) || token == llama_token_eot(model) @@ -18386,8 +18599,18 @@ int32_t llama_tokenize( int32_t n_tokens_max, bool add_special, bool parse_special) { - auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special); - + auto arch_name = llama_model_arch_name(model->arch); + auto prompt = std::move(std::string(text, text_len)); + auto vocab_type = model->vocab.type; + if (strcmp(arch_name, "chatglm") == 0) { + // chatglm3 + if (LLAMA_VOCAB_TYPE_SPM == vocab_type) { + prompt = "[gMASK]sop<|user|>\n" + prompt + "<|assistant|>"; + } else if (LLAMA_VOCAB_TYPE_BPE == vocab_type) { // glm4 + prompt = "[gMASK]<|user|>\n" + prompt + "<|assistant|>"; + } + } + auto res = llama_tokenize_internal(model->vocab, prompt, add_special, parse_special); if (n_tokens_max < (int) res.size()) { // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); return -((int) res.size()); @@ -18714,6 +18937,28 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; } + } else if (tmpl == "chatglm3" || + (tmpl.find("add_generation_prompt") != std::string::npos && + tmpl.find("for message in messages") != std::string::npos && + tmpl.find("loop.first") != std::string::npos)) { + // chatglm3-6b + ss << "[gMASK]" << "sop"; + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n " << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } + } else if (tmpl == "ChatGLM4") { + ss << "[gMASK]" << ""; + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n" << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } } else { // template not supported return -1; diff --git a/llama.h b/llama.h index da310ffaf..b1ff05bd7 100644 --- a/llama.h +++ b/llama.h @@ -87,6 +87,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_DBRX = 13, LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, LLAMA_VOCAB_PRE_TYPE_PORO = 15, + LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 16, }; // note: these values should be synchronized with ggml_rope diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index cef9a650b..0fe4d2967 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -56,7 +56,11 @@ int main(void) { //Phi-3-medium "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", //Phi-3-vision - "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}" + "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}", + // ChatGLM3 + "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", + // ChatGLM4 + "ChatGLM4", }; std::vector expected_output = { // teknium/OpenHermes-2.5-Mistral-7B @@ -93,6 +97,10 @@ int main(void) { "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", //Phi-3-vision "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + // ChatGLM3 + "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>", + // ChatGLM4 + "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>", }; std::vector formatted_chat(1024); int32_t res;