From 1fc5bf5bcb9b9fbfcec8a4f1768349ca9f28a3fa Mon Sep 17 00:00:00 2001 From: XingXing Qiao Date: Mon, 17 Jun 2024 10:08:52 +0800 Subject: [PATCH] support glm-4-9b-chat Signed-off-by: XingXing Qiao --- convert-hf-to-gguf.py | 96 ++++++++++++++++++++++++++++++++++-- examples/server/server.cpp | 2 +- llama.cpp | 20 ++++++-- llama.h | 1 + tests/test-chat-template.cpp | 4 ++ 5 files changed, 116 insertions(+), 7 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 7591da6ef..b36b5193c 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -476,6 +476,9 @@ class Model: if chkhsh == "c136ed14d01c2745d4f60a9596ae66800e2b61fa45643e72436041855ad4089d": # ref: https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct res = "smaug-bpe" + if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b": + # ref: https://huggingface.co/THUDM/glm-4-9b-chat + res = "chatglm-bpe" if res is None: logger.warning("\n") @@ -2714,7 +2717,7 @@ class DeepseekV2Model(Model): class ChatGLMModel(Model): model_arch = gguf.MODEL_ARCH.CHATGLM - def set_vocab(self): + def set_vocab_chatglm3(self): dir_model = self.dir_model hparams = self.hparams tokens: list[bytearray] = [] @@ -2725,7 +2728,8 @@ class ChatGLMModel(Model): tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab())) assert max(tokenizer.get_vocab().values()) < vocab_size - + print(vocab_size) + print(max(tokenizer.get_vocab().values())) for token_id in range(vocab_size): piece = tokenizer._convert_id_to_token(token_id) if token_id == 0: @@ -2774,6 +2778,91 @@ class ChatGLMModel(Model): special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) special_vocab.add_to_gguf(self.gguf_writer) + @staticmethod + def token_bytes_to_string(b): + from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode + byte_encoder = bytes_to_unicode() + return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')]) + + @staticmethod + def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]: + parts = [bytes([b]) for b in token] + while True: + min_idx = None + min_rank = None + for i, pair in enumerate(zip(parts[:-1], parts[1:])): + rank = mergeable_ranks.get(pair[0] + pair[1]) + if rank is not None and (min_rank is None or rank < min_rank): + min_idx = i + min_rank = rank + if min_rank is None or (max_rank is not None and min_rank >= max_rank): + break + assert min_idx is not None + parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:] + return parts + + def set_vocab(self): + if "THUDM/chatglm3-6b" in self.hparams.get("_name_or_path", ""): + self.set_vocab_chatglm3() + return + + dir_model = self.dir_model + hparams = self.hparams + tokens: list[str] = [] + toktypes: list[int] = [] + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) + vocab_size = hparams["padded_vocab_size"] + assert max(tokenizer.get_vocab().values()) < vocab_size + + tokpre = self.get_vocab_base_pre(tokenizer) + + merges = [] + vocab = {} + mergeable_ranks = tokenizer.mergeable_ranks + for token, rank in mergeable_ranks.items(): + vocab[ChatGLMModel.token_bytes_to_string(token)] = rank + if len(token) == 1: + continue + merged = ChatGLMModel.bpe(mergeable_ranks, token, max_rank=rank) + assert len(merged) >= 2 and len(merged) <= 7 + merges.append(' '.join(map(ChatGLMModel.token_bytes_to_string, merged))) + + # for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined + added_vocab = tokenizer.get_added_vocab() + reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()} + + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.USER_DEFINED) + elif reverse_vocab[i] in added_vocab: + tokens.append(reverse_vocab[i]) + if tokenizer.added_tokens_decoder[i].special: + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.USER_DEFINED) + else: + tokens.append(reverse_vocab[i]) + toktypes.append(gguf.TokenType.NORMAL) + + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(dir_model, load_merges=False) + special_vocab.chat_template = "ChatGLM4" + special_vocab.merges = merges + # only add special tokens when they were not already loaded from config.json + if len(special_vocab.special_token_ids) == 0: + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) + # this one is usually not in config.json anyway + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab.add_to_gguf(self.gguf_writer) + def set_gguf_parameters(self): self.gguf_writer.add_name(self.dir_model.name) n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) @@ -2934,7 +3023,8 @@ def main() -> None: with torch.inference_mode(): model_class = Model.from_model_architecture(hparams["architectures"][0]) model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy) - + print(model_class) + print(model_instance) logger.info("Set model parameters") model_instance.set_gguf_parameters() diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e9904263d..078806357 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3056,7 +3056,7 @@ int main(int argc, char ** argv) { chat.push_back({{"role", "user"}, {"content", "Hello"}}); chat.push_back({{"role", "assistant"}, {"content", "Hi there"}}); chat.push_back({{"role", "user"}, {"content", "How are you?"}}); - + printf("sparams.chat_template: #%s#\n", sparams.chat_template.c_str()); const std::string chat_example = format_chat(ctx_server.model, sparams.chat_template, chat); LOG_INFO("chat template", { diff --git a/llama.cpp b/llama.cpp index 154168ef3..a0255bac8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4508,6 +4508,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 28: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_8B; break; default: model.type = e_model::MODEL_UNKNOWN; } } break; @@ -4636,9 +4637,9 @@ static void llm_load_vocab( if (merges_keyidx == -1) { throw std::runtime_error("cannot find tokenizer merges in model file\n"); } - + printf("merges_keyidx: %d\n", merges_keyidx); const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); - + printf("n_merges: %d\n", n_merges); for (int i = 0; i < n_merges; i++) { const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); @@ -4728,6 +4729,9 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "smaug-bpe") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG; + } else if ( + tokenizer_pre == "chatglm-bpe") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -11449,7 +11453,7 @@ struct llm_build_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - + //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); Qcur = ggml_rope_ext( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, @@ -13032,6 +13036,7 @@ struct llm_tokenizer_bpe { break; case LLAMA_VOCAB_PRE_TYPE_DBRX: case LLAMA_VOCAB_PRE_TYPE_SMAUG: + case LLAMA_VOCAB_PRE_TYPE_CHATGLM4: word_collection = unicode_regex_split(text, { // same as llama3 "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", @@ -18741,6 +18746,15 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>"; } + } else if (tmpl == "ChatGLM4") { + ss << "[gMASK]" << ""; + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n" << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } } else { // template not supported return -1; diff --git a/llama.h b/llama.h index 3e4474bb9..a670e1911 100644 --- a/llama.h +++ b/llama.h @@ -86,6 +86,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_OLMO = 12, LLAMA_VOCAB_PRE_TYPE_DBRX = 13, LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, + LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 15, }; // note: these values should be synchronized with ggml_rope diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 87f39f103..0fe4d2967 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -59,6 +59,8 @@ int main(void) { "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}", // ChatGLM3 "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", + // ChatGLM4 + "ChatGLM4", }; std::vector expected_output = { // teknium/OpenHermes-2.5-Mistral-7B @@ -97,6 +99,8 @@ int main(void) { "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", // ChatGLM3 "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>", + // ChatGLM4 + "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>", }; std::vector formatted_chat(1024); int32_t res;