convert-llama-h5-to-gguf.py : add 70b gqa support

This commit is contained in:
klosax 2023-08-15 00:43:10 +02:00 committed by GitHub
parent ca4758290c
commit 2dd5d2c92c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,4 @@
# HF llama --> gguf conversion, GQA/70b not supported # HF llama --> gguf conversion
import gguf import gguf
import gguf_namemap as tmap import gguf_namemap as tmap
@ -10,7 +10,7 @@ import json
import numpy as np import numpy as np
import torch import torch
from typing import Any, List from typing import Any, List, Optional
from pathlib import Path from pathlib import Path
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
@ -18,8 +18,8 @@ from sentencepiece import SentencePieceProcessor
# compatible with python < 3.9 # compatible with python < 3.9
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]' NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
def permute(weights: NDArray, n_head: int) -> NDArray: if n_kv_head is not None and n_head != n_kv_head: n_head //= n_kv_head
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2) .swapaxes(1, 2)
.reshape(weights.shape)) .reshape(weights.shape))
@ -220,7 +220,7 @@ for part_name in part_names:
# permute these # permute these
if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"): if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"):
data = permute(data,head_count) data = permute(data, head_count, head_count_kv)
# map tensor names # map tensor names
if name.endswith(".weight") and name[:-7] in tensor_map: if name.endswith(".weight") and name[:-7] in tensor_map:
@ -289,7 +289,7 @@ for part_name in part_names:
# permute these # permute these
if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"): if name.endswith(".q_proj.weight") or name.endswith(".k_proj.weight"):
data = permute(data, head_count) data = permute(data, head_count, head_count_kv)
# map tensor names # map tensor names
if name.endswith(".weight") and name[:-7] in tensor_map: if name.endswith(".weight") and name[:-7] in tensor_map: