From 4dbc8b9cb71876e005724f4e8f73a3544646bcf5 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Fri, 17 Jan 2025 02:10:38 +0800 Subject: [PATCH] llama : add internlm3 support (#11233) * support internlm3 * fix lint --- convert_hf_to_gguf.py | 60 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 4dc9837ab..95f112043 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2882,6 +2882,66 @@ class InternLM2Model(Model): return [(self.map_tensor_name(name), data_torch)] +@Model.register("InternLM3ForCausalLM") +class InternLM3Model(Model): + model_arch = gguf.MODEL_ARCH.LLAMA + + def set_vocab(self): + tokens, scores, toktypes = self._create_vocab_sentencepiece() + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + if "add_prefix_space" in tokenizer_config_json: + self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"]) + + if "added_tokens_decoder" in tokenizer_config_json: + for token_id, token_data in tokenizer_config_json["added_tokens_decoder"].items(): + if token_data.get("special"): + token_id = int(token_id) + token = token_data["content"] + special_vocab._set_special_token(token, token_id) + # update eos token + if token == '<|im_end|>' and "eos" in special_vocab.special_token_ids: + special_vocab.special_token_ids["eos"] = token_id + + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + + if "head_dim" in hparams: + rope_dim = hparams["head_dim"] + else: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(rope_dim) + + if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"].get("type") == "linear" or self.hparams["rope_scaling"].get("rope_type") == "linear": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + return [(self.map_tensor_name(name), data_torch)] + + @Model.register("BertModel", "BertForMaskedLM", "CamembertModel") class BertModel(Model): model_arch = gguf.MODEL_ARCH.BERT