mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-14 06:19:02 +01:00
Fix eos tokens to glm4 and adapts to glm3
This commit is contained in:
commit
e773174052
@ -80,7 +80,7 @@ class Model:
|
||||
if not self.is_safetensors:
|
||||
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||
self.hparams = Model.load_hparams(self.dir_model)
|
||||
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
|
||||
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
self.tensor_names = None
|
||||
if self.ftype == gguf.LlamaFileType.GUESSED:
|
||||
@ -483,6 +483,9 @@ class Model:
|
||||
if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a":
|
||||
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code
|
||||
res = "jina-v2-code"
|
||||
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
|
||||
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
|
||||
res = "chatglm-bpe"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
@ -2725,6 +2728,188 @@ class DeepseekV2Model(Model):
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("ChatGLMModel")
|
||||
class ChatGLMModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.CHATGLM
|
||||
|
||||
def set_vocab_chatglm3(self):
|
||||
dir_model = self.dir_model
|
||||
hparams = self.hparams
|
||||
tokens: list[bytearray] = []
|
||||
toktypes: list[int] = []
|
||||
scores: list[float] = []
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
|
||||
vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab()))
|
||||
assert max(tokenizer.get_vocab().values()) < vocab_size
|
||||
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
|
||||
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
|
||||
print(vocab_size)
|
||||
print(max(tokenizer.get_vocab().values()))
|
||||
for token_id in range(vocab_size):
|
||||
piece = tokenizer._convert_id_to_token(token_id)
|
||||
if token_id == 0:
|
||||
piece = "<unk>"
|
||||
elif token_id == 1:
|
||||
piece = "<bos>"
|
||||
elif token_id == 2:
|
||||
piece = "<eos>"
|
||||
|
||||
text = piece.encode("utf-8")
|
||||
score = 0.0
|
||||
# Referencing the tokenizer Python implementation(https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py),
|
||||
# it is only valid if it is less than tokenizer.tokenizer.sp_model.vocab_size()
|
||||
if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size():
|
||||
score = tokenizer.tokenizer.sp_model.get_score(token_id)
|
||||
|
||||
if len(piece) == 0:
|
||||
text = f"[PAD{token_id}]".encode("utf-8")
|
||||
|
||||
if token_id >= tokenizer.tokenizer.sp_model.vocab_size():
|
||||
if piece in special_tokens:
|
||||
# show special tokens in prompt
|
||||
toktype = SentencePieceTokenTypes.USER_DEFINED
|
||||
else:
|
||||
toktype = SentencePieceTokenTypes.UNKNOWN
|
||||
tokens.append(text)
|
||||
scores.append(score)
|
||||
toktypes.append(toktype)
|
||||
continue
|
||||
|
||||
toktype = SentencePieceTokenTypes.NORMAL
|
||||
if tokenizer.tokenizer.sp_model.is_unknown(token_id):
|
||||
toktype = SentencePieceTokenTypes.UNKNOWN
|
||||
elif tokenizer.tokenizer.sp_model.is_control(token_id):
|
||||
toktype = SentencePieceTokenTypes.CONTROL
|
||||
elif tokenizer.tokenizer.sp_model.is_unused(token_id):
|
||||
toktype = SentencePieceTokenTypes.UNUSED
|
||||
elif tokenizer.tokenizer.sp_model.is_byte(token_id):
|
||||
toktype = SentencePieceTokenTypes.BYTE
|
||||
|
||||
tokens.append(text)
|
||||
scores.append(score)
|
||||
toktypes.append(toktype)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("llama")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
@staticmethod
|
||||
def token_bytes_to_string(b):
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||
byte_encoder = bytes_to_unicode()
|
||||
return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
|
||||
|
||||
@staticmethod
|
||||
def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
|
||||
parts = [bytes([b]) for b in token]
|
||||
while True:
|
||||
min_idx = None
|
||||
min_rank = None
|
||||
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
|
||||
rank = mergeable_ranks.get(pair[0] + pair[1])
|
||||
if rank is not None and (min_rank is None or rank < min_rank):
|
||||
min_idx = i
|
||||
min_rank = rank
|
||||
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
|
||||
break
|
||||
assert min_idx is not None
|
||||
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
|
||||
return parts
|
||||
|
||||
def set_vocab(self):
|
||||
if "THUDM/chatglm3-6b" in self.hparams.get("_name_or_path", ""):
|
||||
self.set_vocab_chatglm3()
|
||||
return
|
||||
|
||||
dir_model = self.dir_model
|
||||
hparams = self.hparams
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
|
||||
vocab_size = hparams["padded_vocab_size"]
|
||||
assert max(tokenizer.get_vocab().values()) < vocab_size
|
||||
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
merges = []
|
||||
vocab = {}
|
||||
mergeable_ranks = tokenizer.mergeable_ranks
|
||||
for token, rank in mergeable_ranks.items():
|
||||
vocab[ChatGLMModel.token_bytes_to_string(token)] = rank
|
||||
if len(token) == 1:
|
||||
continue
|
||||
merged = ChatGLMModel.bpe(mergeable_ranks, token, max_rank=rank)
|
||||
assert len(merged) >= 2 and len(merged) <= 7
|
||||
merges.append(' '.join(map(ChatGLMModel.token_bytes_to_string, merged)))
|
||||
|
||||
# for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
tokens.append(f"[PAD{i}]")
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
elif reverse_vocab[i] in added_vocab:
|
||||
tokens.append(reverse_vocab[i])
|
||||
if tokenizer.added_tokens_decoder[i].special:
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
else:
|
||||
tokens.append(reverse_vocab[i])
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
|
||||
special_vocab.chat_template = "ChatGLM4"
|
||||
special_vocab.merges = merges
|
||||
# only add special tokens when they were not already loaded from config.json
|
||||
# if len(special_vocab.special_token_ids) == 0:
|
||||
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
# this one is usually not in config.json anyway
|
||||
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_name(self.dir_model.name)
|
||||
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
|
||||
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
|
||||
n_head_kv = self.hparams.get("multi_query_group_num", n_head)
|
||||
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
|
||||
self.gguf_writer.add_embedding_length(n_embed)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", 4 * n_embed))
|
||||
self.gguf_writer.add_block_count(self.hparams["num_layers"])
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layernorm_epsilon"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
self.gguf_writer.add_rope_dimension_count(64)
|
||||
self.gguf_writer.add_add_bos_token(False)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
if name.endswith(".rotary_pos_emb.inv_freq"):
|
||||
return []
|
||||
|
||||
name = name.removeprefix("transformer.")
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
@ -149,6 +149,7 @@ class MODEL_ARCH(IntEnum):
|
||||
OLMO = auto()
|
||||
ARCTIC = auto()
|
||||
DEEPSEEK2 = auto()
|
||||
CHATGLM = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
@ -237,6 +238,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.OLMO: "olmo",
|
||||
MODEL_ARCH.ARCTIC: "arctic",
|
||||
MODEL_ARCH.DEEPSEEK2: "deepseek2",
|
||||
MODEL_ARCH.CHATGLM: "chatglm",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
@ -808,6 +810,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.CHATGLM : [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
@ -845,6 +859,9 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
],
|
||||
MODEL_ARCH.CHATGLM: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
],
|
||||
}
|
||||
|
||||
#
|
||||
|
@ -24,6 +24,7 @@ class TensorNameMap:
|
||||
"backbone.embedding", # mamba
|
||||
"backbone.embeddings", # mamba-hf
|
||||
"transformer.in_out_embed", # Grok
|
||||
"embedding.word_embeddings", # chatglm
|
||||
),
|
||||
|
||||
# Token type embeddings
|
||||
@ -52,6 +53,7 @@ class TensorNameMap:
|
||||
"output", # llama-pth bloom internlm2
|
||||
"word_embeddings_for_head", # persimmon
|
||||
"lm_head.linear", # phi2
|
||||
"output_layer", # chatglm
|
||||
),
|
||||
|
||||
# Output norm
|
||||
@ -68,11 +70,13 @@ class TensorNameMap:
|
||||
"model.norm_f", # mamba-qbert
|
||||
"backbone.norm_f", # mamba
|
||||
"transformer.rms_norm", # Grok
|
||||
"encoder.final_layernorm", # chatglm
|
||||
),
|
||||
|
||||
# Rope frequencies
|
||||
MODEL_TENSOR.ROPE_FREQS: (
|
||||
"rope.freqs", # llama-pth
|
||||
"rotary_pos_emb.inv_freq", # chatglm
|
||||
),
|
||||
}
|
||||
|
||||
@ -97,6 +101,7 @@ class TensorNameMap:
|
||||
"backbone.layers.{bid}.norm", # mamba
|
||||
"transformer.decoder_layer.{bid}.rms_norm", # Grok
|
||||
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
|
||||
"encoder.layers.{bid}.input_layernorm", # chatglm
|
||||
),
|
||||
|
||||
# Attention norm 2
|
||||
@ -118,7 +123,8 @@ class TensorNameMap:
|
||||
"h.{bid}.attn.c_attn", # gpt2
|
||||
"transformer.h.{bid}.mixer.Wqkv", # phi2
|
||||
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
|
||||
"model.layers.{bid}.self_attn.qkv_proj" # phi3
|
||||
"model.layers.{bid}.self_attn.qkv_proj", # phi3
|
||||
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
|
||||
),
|
||||
|
||||
# Attention query
|
||||
@ -129,7 +135,7 @@ class TensorNameMap:
|
||||
"transformer.h.{bid}.attn.q_proj", # gpt-j
|
||||
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
|
||||
"model.layers.{bid}.attention.wq", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
|
||||
),
|
||||
|
||||
# Attention key
|
||||
@ -141,7 +147,7 @@ class TensorNameMap:
|
||||
"transformer.h.{bid}.attn.k", # refact
|
||||
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
|
||||
"model.layers.{bid}.attention.wk", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
|
||||
),
|
||||
|
||||
# Attention value
|
||||
@ -176,6 +182,7 @@ class TensorNameMap:
|
||||
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
|
||||
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
|
||||
"encoder.layers.{bid}.self_attention.dense", # chatglm
|
||||
),
|
||||
|
||||
# Attention output norm
|
||||
@ -207,6 +214,7 @@ class TensorNameMap:
|
||||
"h.{bid}.ln_2", # gpt2
|
||||
"model.layers.{bid}.ffn_norm", # internlm2
|
||||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP: (
|
||||
@ -246,6 +254,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.c_fc", # starcoder2
|
||||
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
|
||||
"model.layers.{bid}.residual_mlp.w3", # arctic
|
||||
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
@ -313,6 +322,7 @@ class TensorNameMap:
|
||||
"encoder.layer.{bid}.mlp.wo", # jina-bert-v2
|
||||
"model.layers.{bid}.residual_mlp.w2", # arctic
|
||||
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
|
||||
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "gguf"
|
||||
version = "0.9.0"
|
||||
version = "0.9.1"
|
||||
description = "Read and write ML models in GGUF for GGML"
|
||||
authors = ["GGML <ggml@ggml.ai>"]
|
||||
packages = [
|
||||
|
253
llama.cpp
253
llama.cpp
@ -225,6 +225,7 @@ enum llm_arch {
|
||||
LLM_ARCH_OLMO,
|
||||
LLM_ARCH_ARCTIC,
|
||||
LLM_ARCH_DEEPSEEK2,
|
||||
LLM_ARCH_CHATGLM,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@ -263,6 +264,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_OLMO, "olmo" },
|
||||
{ LLM_ARCH_ARCTIC, "arctic" },
|
||||
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
||||
{ LLM_ARCH_CHATGLM, "chatglm" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@ -1113,6 +1115,21 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_CHATGLM,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
@ -1917,9 +1934,11 @@ enum e_model {
|
||||
MODEL_2_8B,
|
||||
MODEL_3B,
|
||||
MODEL_4B,
|
||||
MODEL_6B,
|
||||
MODEL_6_9B,
|
||||
MODEL_7B,
|
||||
MODEL_8B,
|
||||
MODEL_9B,
|
||||
MODEL_12B,
|
||||
MODEL_13B,
|
||||
MODEL_14B,
|
||||
@ -4120,9 +4139,11 @@ static const char * llama_model_type_name(e_model type) {
|
||||
case MODEL_2_8B: return "2.8B";
|
||||
case MODEL_3B: return "3B";
|
||||
case MODEL_4B: return "4B";
|
||||
case MODEL_6B: return "6B";
|
||||
case MODEL_6_9B: return "6.9B";
|
||||
case MODEL_7B: return "7B";
|
||||
case MODEL_8B: return "8B";
|
||||
case MODEL_9B: return "9B";
|
||||
case MODEL_12B: return "12B";
|
||||
case MODEL_13B: return "13B";
|
||||
case MODEL_14B: return "14B";
|
||||
@ -4708,6 +4729,15 @@ static void llm_load_hparams(
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_CHATGLM:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
switch (hparams.n_layer) {
|
||||
case 28: model.type = e_model::MODEL_6B; break;
|
||||
case 40: model.type = e_model::MODEL_9B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: (void)0;
|
||||
}
|
||||
|
||||
@ -4798,9 +4828,9 @@ static void llm_load_vocab(
|
||||
if (merges_keyidx == -1) {
|
||||
throw std::runtime_error("cannot find tokenizer merges in model file\n");
|
||||
}
|
||||
|
||||
printf("merges_keyidx: %d\n", merges_keyidx);
|
||||
const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
|
||||
|
||||
printf("n_merges: %d\n", n_merges);
|
||||
for (int i = 0; i < n_merges; i++) {
|
||||
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
|
||||
GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
|
||||
@ -4897,6 +4927,9 @@ static void llm_load_vocab(
|
||||
} else if (
|
||||
tokenizer_pre == "poro-chat") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO;
|
||||
} else if (
|
||||
tokenizer_pre == "chatglm-bpe") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
|
||||
} else {
|
||||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||
}
|
||||
@ -6650,6 +6683,36 @@ static bool llm_load_tensors(
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_CHATGLM:
|
||||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
|
||||
// output
|
||||
{
|
||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + (hparams.n_embd_head_k << 2)});
|
||||
layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + (hparams.n_embd_head_k << 2)});
|
||||
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2});
|
||||
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
@ -6874,6 +6937,7 @@ enum llm_ffn_op_type {
|
||||
LLM_FFN_GELU,
|
||||
LLM_FFN_RELU,
|
||||
LLM_FFN_RELU_SQR,
|
||||
LLM_FFN_SWIGLU,
|
||||
};
|
||||
|
||||
enum llm_ffn_gate_type {
|
||||
@ -7064,6 +7128,19 @@ static struct ggml_tensor * llm_build_ffn(
|
||||
cur = ggml_sqr(ctx, cur);
|
||||
cb(cur, "ffn_sqr(relu)", il);
|
||||
} break;
|
||||
case LLM_FFN_SWIGLU:
|
||||
{
|
||||
// Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
||||
int64_t split_point = cur->ne[0] / 2;
|
||||
struct ggml_tensor * x0 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
||||
struct ggml_tensor * x1 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
||||
|
||||
x0 = ggml_silu(ctx, x0);
|
||||
cb(cur, "ffn_silu", il);
|
||||
|
||||
cur = ggml_mul(ctx, x0, x1);
|
||||
cb(cur, "ffn_mul", il);
|
||||
} break;
|
||||
}
|
||||
|
||||
if (type_gate == LLM_FFN_PAR) {
|
||||
@ -11684,6 +11761,119 @@ struct llm_build_context {
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_chatglm() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
struct ggml_tensor * Qcur = nullptr;
|
||||
struct ggml_tensor * Kcur = nullptr;
|
||||
struct ggml_tensor * Vcur = nullptr;
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
|
||||
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
//printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Qcur, "Qcur_rope", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Kcur, "Kcur_rope", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
// Add the input
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// FF
|
||||
{
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = llm_build_ffn(ctx0, cur,
|
||||
model.layers[il].ffn_up, NULL,
|
||||
NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
}
|
||||
|
||||
inpL = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(inpL, "l_out", il);
|
||||
}
|
||||
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.output_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
};
|
||||
|
||||
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
||||
@ -11907,6 +12097,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
{
|
||||
result = llm.build_deepseek2();
|
||||
} break;
|
||||
case LLM_ARCH_CHATGLM:
|
||||
{
|
||||
result = llm.build_chatglm();
|
||||
} break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
@ -13252,6 +13446,11 @@ struct llm_tokenizer_bpe {
|
||||
" ?[^(\\s|.,!?…。,、।۔،)]+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_CHATGLM4:
|
||||
regex_exprs = {
|
||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
default:
|
||||
// default regex for BPE tokenization pre-processing
|
||||
regex_exprs = {
|
||||
@ -16704,6 +16903,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_OLMO:
|
||||
case LLM_ARCH_ARCTIC:
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
case LLM_ARCH_CHATGLM:
|
||||
return LLAMA_ROPE_TYPE_NORM;
|
||||
|
||||
// the pairs of head values are offset by n_rot/2
|
||||
@ -18324,6 +18524,19 @@ llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_to
|
||||
}
|
||||
|
||||
bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
|
||||
auto arch_name = llama_model_arch_name(model->arch);
|
||||
auto vocab_type = model->vocab.type;
|
||||
if (strcmp(arch_name, "chatglm") == 0) {
|
||||
if (LLAMA_VOCAB_TYPE_BPE == vocab_type) { // glm4
|
||||
return token != -1 && (
|
||||
token == llama_token_eos(model) ||
|
||||
token == llama_token_eot(model) ||
|
||||
token == 151329 ||
|
||||
token == 151336 ||
|
||||
token == 151338
|
||||
);
|
||||
}
|
||||
}
|
||||
return token != -1 && (
|
||||
token == llama_token_eos(model) ||
|
||||
token == llama_token_eot(model)
|
||||
@ -18386,8 +18599,18 @@ int32_t llama_tokenize(
|
||||
int32_t n_tokens_max,
|
||||
bool add_special,
|
||||
bool parse_special) {
|
||||
auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special);
|
||||
|
||||
auto arch_name = llama_model_arch_name(model->arch);
|
||||
auto prompt = std::move(std::string(text, text_len));
|
||||
auto vocab_type = model->vocab.type;
|
||||
if (strcmp(arch_name, "chatglm") == 0) {
|
||||
// chatglm3
|
||||
if (LLAMA_VOCAB_TYPE_SPM == vocab_type) {
|
||||
prompt = "[gMASK]sop<|user|>\n" + prompt + "<|assistant|>";
|
||||
} else if (LLAMA_VOCAB_TYPE_BPE == vocab_type) { // glm4
|
||||
prompt = "[gMASK]<sop><|user|>\n" + prompt + "<|assistant|>";
|
||||
}
|
||||
}
|
||||
auto res = llama_tokenize_internal(model->vocab, prompt, add_special, parse_special);
|
||||
if (n_tokens_max < (int) res.size()) {
|
||||
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
|
||||
return -((int) res.size());
|
||||
@ -18714,6 +18937,28 @@ static int32_t llama_chat_apply_template_internal(
|
||||
if (add_ass) {
|
||||
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
|
||||
}
|
||||
} else if (tmpl == "chatglm3" ||
|
||||
(tmpl.find("add_generation_prompt") != std::string::npos &&
|
||||
tmpl.find("for message in messages") != std::string::npos &&
|
||||
tmpl.find("loop.first") != std::string::npos)) {
|
||||
// chatglm3-6b
|
||||
ss << "[gMASK]" << "sop";
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|" << role << "|>" << "\n " << message->content;
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>";
|
||||
}
|
||||
} else if (tmpl == "ChatGLM4") {
|
||||
ss << "[gMASK]" << "<sop>";
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|" << role << "|>" << "\n" << message->content;
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>";
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
|
1
llama.h
1
llama.h
@ -87,6 +87,7 @@ extern "C" {
|
||||
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
|
||||
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
|
||||
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
|
||||
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 16,
|
||||
};
|
||||
|
||||
// note: these values should be synchronized with ggml_rope
|
||||
|
@ -56,7 +56,11 @@ int main(void) {
|
||||
//Phi-3-medium
|
||||
"{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
||||
//Phi-3-vision
|
||||
"{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}"
|
||||
"{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
|
||||
// ChatGLM3
|
||||
"{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
|
||||
// ChatGLM4
|
||||
"ChatGLM4",
|
||||
};
|
||||
std::vector<std::string> expected_output = {
|
||||
// teknium/OpenHermes-2.5-Mistral-7B
|
||||
@ -93,6 +97,10 @@ int main(void) {
|
||||
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
||||
//Phi-3-vision
|
||||
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
|
||||
// ChatGLM3
|
||||
"[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>",
|
||||
// ChatGLM4
|
||||
"[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
|
||||
};
|
||||
std::vector<char> formatted_chat(1024);
|
||||
int32_t res;
|
||||
|
Loading…
Reference in New Issue
Block a user