From 1f2e0ee01226bf3fd1717f5b8e1c9b688c216f83 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Thu, 6 Jun 2024 12:28:11 +0800 Subject: [PATCH] finish bitnet e2e --- convert-hf-to-gguf.py | 21 ++++++++++++++------- llama.cpp | 12 ++++++------ 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 8ba7119fc..d39bb3bd1 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1398,14 +1398,21 @@ class BitnetModel(Model): def set_gguf_parameters(self): super().set_gguf_parameters() - hparams = self.hparams - self.gguf_writer.add_vocab_size(hparams["vocab_size"]) - self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"]) + self.gguf_writer.add_name("Bitnet") + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) - if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: - if self.hparams["rope_scaling"].get("type") == "linear": - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) - self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(1.0) + self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"]) def weight_quant(self, weight): dtype = weight.dtype diff --git a/llama.cpp b/llama.cpp index 5dc25e3de..38dbf31e0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -11587,16 +11587,16 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il);