From 76d32cca59fcf205f48f61cf5c2b467bb866d0e2 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 15 Sep 2023 11:42:16 +0800 Subject: [PATCH] convert MQA to MHA --- convert-starcoder-hf-to-gguf.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/convert-starcoder-hf-to-gguf.py b/convert-starcoder-hf-to-gguf.py index 4416b5d9e..00e4f0d92 100755 --- a/convert-starcoder-hf-to-gguf.py +++ b/convert-starcoder-hf-to-gguf.py @@ -212,6 +212,24 @@ for part_name in part_names: data = data.squeeze().numpy() + if name.endswith(".attn.c_attn.weight") or name.endswith(".attn.c_attn.bias"): + print("Duplicate K,V heads to use MHA instead of MQA for", name) + + embed_dim = hparams["n_embd"] + head_dim = embed_dim // hparams["n_head"] + + # ((n_heads + 2) * head_dim, hidden_dim) -> (3 * n_heads * head_dim, hidden_dim) + q, k ,v = np.split(data, (hparams["n_head"] * head_dim, (hparams["n_head"] + 1) * head_dim), axis=0) + # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim) + if len(k.shape) == 2: + k = np.tile(k, (hparams["n_head"], 1)) + v = np.tile(v, (hparams["n_head"], 1)) + elif len(k.shape) == 1: + k = np.tile(k, (hparams["n_head"])) + v = np.tile(v, (hparams["n_head"])) + # concat q, k, v along the first axis (n_heads * head_dim, hidden_dim) -> (3 * n_heads * head_dim, hidden_dim) + data = np.concatenate((q, k, v), axis=0) + # map tensor names new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias")) if new_name is None: