mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 13:27:21 +01:00
llama : Add support for DeepSeek V3 (#11049)
* convert : extend DEEPSEEK2 model architecture to support DeepseekV3ForCausalLM by adding EXPERT_WEIGHTS_NORM and EXPERT_GATING_FUNC model parameters and FFN_EXP_PROBS_B tensor type * vocab : add DeepSeek V3 pre-tokenizer regexes * unicode : handle ACCENT_MARK and SYMBOL categories in regex * llama : add DeepSeek V3 chat template, handle new model parameters and tensor types --------- Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
This commit is contained in:
parent
f922a9c542
commit
9394bbd484
@ -687,6 +687,9 @@ class Model:
|
|||||||
if chkhsh == "d4c8f286ea6b520b3d495c4455483cfa2302c0cfcd4be05d781b6a8a0a7cdaf1":
|
if chkhsh == "d4c8f286ea6b520b3d495c4455483cfa2302c0cfcd4be05d781b6a8a0a7cdaf1":
|
||||||
# ref: https://huggingface.co/Infinigence/Megrez-3B-Instruct
|
# ref: https://huggingface.co/Infinigence/Megrez-3B-Instruct
|
||||||
res = "megrez"
|
res = "megrez"
|
||||||
|
if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5":
|
||||||
|
# ref: https://huggingface.co/deepseek-ai/DeepSeek-V3
|
||||||
|
res = "deepseek-v3"
|
||||||
|
|
||||||
if res is None:
|
if res is None:
|
||||||
logger.warning("\n")
|
logger.warning("\n")
|
||||||
@ -3849,6 +3852,7 @@ class DeepseekModel(Model):
|
|||||||
|
|
||||||
|
|
||||||
@Model.register("DeepseekV2ForCausalLM")
|
@Model.register("DeepseekV2ForCausalLM")
|
||||||
|
@Model.register("DeepseekV3ForCausalLM")
|
||||||
class DeepseekV2Model(Model):
|
class DeepseekV2Model(Model):
|
||||||
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
|
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
|
||||||
|
|
||||||
@ -3870,6 +3874,15 @@ class DeepseekV2Model(Model):
|
|||||||
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
|
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
|
||||||
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
|
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
|
||||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||||
|
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||||
|
|
||||||
|
if hparams["scoring_func"] == "sigmoid":
|
||||||
|
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||||
|
elif hparams["scoring_func"] == "softmax":
|
||||||
|
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}")
|
||||||
|
|
||||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||||
|
|
||||||
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||||
@ -3882,6 +3895,16 @@ class DeepseekV2Model(Model):
|
|||||||
_experts: list[dict[str, Tensor]] | None = None
|
_experts: list[dict[str, Tensor]] | None = None
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
# rename e_score_correction_bias tensors
|
||||||
|
if name.endswith("e_score_correction_bias"):
|
||||||
|
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
|
||||||
|
|
||||||
|
# skip Multi-Token Prediction (MTP) layers
|
||||||
|
block_count = self.hparams["num_hidden_layers"]
|
||||||
|
match = re.match(r"model.layers.(\d+)", name)
|
||||||
|
if match and int(match.group(1)) >= block_count:
|
||||||
|
return []
|
||||||
|
|
||||||
# process the experts separately
|
# process the experts separately
|
||||||
if name.find("mlp.experts") != -1:
|
if name.find("mlp.experts") != -1:
|
||||||
n_experts = self.hparams["n_routed_experts"]
|
n_experts = self.hparams["n_routed_experts"]
|
||||||
|
@ -107,6 +107,7 @@ models = [
|
|||||||
{"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"},
|
{"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"},
|
||||||
{"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"},
|
{"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"},
|
||||||
{"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"},
|
{"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"},
|
||||||
|
{"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -102,6 +102,8 @@ class Keys:
|
|||||||
EXPERT_USED_COUNT = "{arch}.expert_used_count"
|
EXPERT_USED_COUNT = "{arch}.expert_used_count"
|
||||||
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
|
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
|
||||||
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
|
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
|
||||||
|
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
|
||||||
|
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
|
||||||
POOLING_TYPE = "{arch}.pooling_type"
|
POOLING_TYPE = "{arch}.pooling_type"
|
||||||
LOGIT_SCALE = "{arch}.logit_scale"
|
LOGIT_SCALE = "{arch}.logit_scale"
|
||||||
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
|
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
|
||||||
@ -313,6 +315,7 @@ class MODEL_TENSOR(IntEnum):
|
|||||||
FFN_GATE_SHEXP = auto()
|
FFN_GATE_SHEXP = auto()
|
||||||
FFN_DOWN_SHEXP = auto()
|
FFN_DOWN_SHEXP = auto()
|
||||||
FFN_UP_SHEXP = auto()
|
FFN_UP_SHEXP = auto()
|
||||||
|
FFN_EXP_PROBS_B = auto()
|
||||||
ATTN_Q_NORM = auto()
|
ATTN_Q_NORM = auto()
|
||||||
ATTN_K_NORM = auto()
|
ATTN_K_NORM = auto()
|
||||||
LAYER_OUT_NORM = auto()
|
LAYER_OUT_NORM = auto()
|
||||||
@ -498,6 +501,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||||||
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
|
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
|
||||||
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
|
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
|
||||||
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
|
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
|
||||||
|
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
|
||||||
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
|
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
|
||||||
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
|
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
|
||||||
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
|
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
|
||||||
@ -1290,6 +1294,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||||
|
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||||
],
|
],
|
||||||
MODEL_ARCH.CHATGLM : [
|
MODEL_ARCH.CHATGLM : [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
@ -1590,6 +1595,11 @@ class GGMLQuantizationType(IntEnum):
|
|||||||
TQ2_0 = 35
|
TQ2_0 = 35
|
||||||
|
|
||||||
|
|
||||||
|
class ExpertGatingFuncType(IntEnum):
|
||||||
|
SOFTMAX = 1
|
||||||
|
SIGMOID = 2
|
||||||
|
|
||||||
|
|
||||||
# TODO: add GGMLFileType from ggml_ftype in ggml.h
|
# TODO: add GGMLFileType from ggml_ftype in ggml.h
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ from .constants import (
|
|||||||
RopeScalingType,
|
RopeScalingType,
|
||||||
PoolingType,
|
PoolingType,
|
||||||
TokenType,
|
TokenType,
|
||||||
|
ExpertGatingFuncType,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .quants import quant_shape_from_byte_shape
|
from .quants import quant_shape_from_byte_shape
|
||||||
@ -715,6 +716,12 @@ class GGUFWriter:
|
|||||||
def add_expert_weights_scale(self, value: float) -> None:
|
def add_expert_weights_scale(self, value: float) -> None:
|
||||||
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
|
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_expert_weights_norm(self, value: bool) -> None:
|
||||||
|
self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
|
||||||
|
self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
|
||||||
|
|
||||||
def add_swin_norm(self, value: bool) -> None:
|
def add_swin_norm(self, value: bool) -> None:
|
||||||
self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
|
self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
@ -276,6 +276,10 @@ class TensorNameMap:
|
|||||||
"model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe
|
"model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe
|
||||||
),
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.FFN_EXP_PROBS_B: (
|
||||||
|
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3
|
||||||
|
),
|
||||||
|
|
||||||
# Feed-forward up
|
# Feed-forward up
|
||||||
MODEL_TENSOR.FFN_UP: (
|
MODEL_TENSOR.FFN_UP: (
|
||||||
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
|
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
|
||||||
|
@ -105,6 +105,7 @@ extern "C" {
|
|||||||
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
|
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
|
||||||
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
||||||
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum llama_rope_type {
|
enum llama_rope_type {
|
||||||
|
@ -92,6 +92,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||||||
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
|
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
|
||||||
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
|
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
|
||||||
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
|
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
|
||||||
|
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
|
||||||
|
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
|
||||||
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
|
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
|
||||||
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
|
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
|
||||||
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
|
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
|
||||||
@ -984,6 +986,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1366,6 +1369,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|||||||
{LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
{LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
||||||
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
||||||
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
||||||
|
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
// this tensor is loaded for T5, but never used
|
// this tensor is loaded for T5, but never used
|
||||||
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
|
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
|
||||||
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
|
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
|
||||||
|
@ -96,6 +96,8 @@ enum llm_kv {
|
|||||||
LLM_KV_EXPERT_USED_COUNT,
|
LLM_KV_EXPERT_USED_COUNT,
|
||||||
LLM_KV_EXPERT_SHARED_COUNT,
|
LLM_KV_EXPERT_SHARED_COUNT,
|
||||||
LLM_KV_EXPERT_WEIGHTS_SCALE,
|
LLM_KV_EXPERT_WEIGHTS_SCALE,
|
||||||
|
LLM_KV_EXPERT_WEIGHTS_NORM,
|
||||||
|
LLM_KV_EXPERT_GATING_FUNC,
|
||||||
LLM_KV_POOLING_TYPE,
|
LLM_KV_POOLING_TYPE,
|
||||||
LLM_KV_LOGIT_SCALE,
|
LLM_KV_LOGIT_SCALE,
|
||||||
LLM_KV_DECODER_START_TOKEN_ID,
|
LLM_KV_DECODER_START_TOKEN_ID,
|
||||||
@ -231,6 +233,7 @@ enum llm_tensor {
|
|||||||
LLM_TENSOR_FFN_DOWN_SHEXP,
|
LLM_TENSOR_FFN_DOWN_SHEXP,
|
||||||
LLM_TENSOR_FFN_GATE_SHEXP,
|
LLM_TENSOR_FFN_GATE_SHEXP,
|
||||||
LLM_TENSOR_FFN_UP_SHEXP,
|
LLM_TENSOR_FFN_UP_SHEXP,
|
||||||
|
LLM_TENSOR_FFN_EXP_PROBS_B,
|
||||||
LLM_TENSOR_ATTN_Q_NORM,
|
LLM_TENSOR_ATTN_Q_NORM,
|
||||||
LLM_TENSOR_ATTN_K_NORM,
|
LLM_TENSOR_ATTN_K_NORM,
|
||||||
LLM_TENSOR_LAYER_OUT_NORM,
|
LLM_TENSOR_LAYER_OUT_NORM,
|
||||||
|
@ -45,6 +45,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|||||||
{ "vicuna-orca", LLM_CHAT_TEMPLATE_VICUNA_ORCA },
|
{ "vicuna-orca", LLM_CHAT_TEMPLATE_VICUNA_ORCA },
|
||||||
{ "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK },
|
{ "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK },
|
||||||
{ "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 },
|
{ "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 },
|
||||||
|
{ "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 },
|
||||||
{ "command-r", LLM_CHAT_TEMPLATE_COMMAND_R },
|
{ "command-r", LLM_CHAT_TEMPLATE_COMMAND_R },
|
||||||
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
|
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
|
||||||
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 },
|
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 },
|
||||||
@ -148,6 +149,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|||||||
return LLM_CHAT_TEMPLATE_MINICPM;
|
return LLM_CHAT_TEMPLATE_MINICPM;
|
||||||
} else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
|
} else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
|
||||||
return LLM_CHAT_TEMPLATE_DEEPSEEK_2;
|
return LLM_CHAT_TEMPLATE_DEEPSEEK_2;
|
||||||
|
} else if (tmpl_contains(LU8("'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'"))) {
|
||||||
|
return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
|
||||||
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
|
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
|
||||||
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
|
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
|
||||||
// EXAONE-3.0-7.8B-Instruct
|
// EXAONE-3.0-7.8B-Instruct
|
||||||
@ -453,6 +456,21 @@ int32_t llm_chat_apply_template(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "Assistant:";
|
ss << "Assistant:";
|
||||||
}
|
}
|
||||||
|
} else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_3) {
|
||||||
|
// DeepSeek-V3
|
||||||
|
for (auto message : chat) {
|
||||||
|
std::string role(message->role);
|
||||||
|
if (role == "system") {
|
||||||
|
ss << message->content << "\n\n";
|
||||||
|
} else if (role == "user") {
|
||||||
|
ss << LU8("<|User|>") << message->content;
|
||||||
|
} else if (role == "assistant") {
|
||||||
|
ss << LU8("<|Assistant|>") << message->content << LU8("<|end▁of▁sentence|>");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (add_ass) {
|
||||||
|
ss << LU8("<|Assistant|>");
|
||||||
|
}
|
||||||
} else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) {
|
} else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) {
|
||||||
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
|
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
|
||||||
// EXAONE-3.0-7.8B-Instruct
|
// EXAONE-3.0-7.8B-Instruct
|
||||||
|
@ -25,6 +25,7 @@ enum llm_chat_template {
|
|||||||
LLM_CHAT_TEMPLATE_VICUNA_ORCA,
|
LLM_CHAT_TEMPLATE_VICUNA_ORCA,
|
||||||
LLM_CHAT_TEMPLATE_DEEPSEEK,
|
LLM_CHAT_TEMPLATE_DEEPSEEK,
|
||||||
LLM_CHAT_TEMPLATE_DEEPSEEK_2,
|
LLM_CHAT_TEMPLATE_DEEPSEEK_2,
|
||||||
|
LLM_CHAT_TEMPLATE_DEEPSEEK_3,
|
||||||
LLM_CHAT_TEMPLATE_COMMAND_R,
|
LLM_CHAT_TEMPLATE_COMMAND_R,
|
||||||
LLM_CHAT_TEMPLATE_LLAMA_3,
|
LLM_CHAT_TEMPLATE_LLAMA_3,
|
||||||
LLM_CHAT_TEMPLATE_CHATGML_3,
|
LLM_CHAT_TEMPLATE_CHATGML_3,
|
||||||
|
@ -6,7 +6,13 @@
|
|||||||
|
|
||||||
// bump if necessary
|
// bump if necessary
|
||||||
#define LLAMA_MAX_LAYERS 512
|
#define LLAMA_MAX_LAYERS 512
|
||||||
#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2
|
#define LLAMA_MAX_EXPERTS 256 // DeepSeekV3
|
||||||
|
|
||||||
|
enum llama_expert_gating_func_type {
|
||||||
|
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
|
||||||
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1,
|
||||||
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
|
||||||
|
};
|
||||||
|
|
||||||
struct llama_hparams_posnet {
|
struct llama_hparams_posnet {
|
||||||
uint32_t n_embd;
|
uint32_t n_embd;
|
||||||
@ -54,7 +60,9 @@ struct llama_hparams {
|
|||||||
uint32_t n_expert_shared = 0;
|
uint32_t n_expert_shared = 0;
|
||||||
uint32_t n_norm_groups = 0;
|
uint32_t n_norm_groups = 0;
|
||||||
|
|
||||||
float expert_weights_scale = 0.0;
|
float expert_weights_scale = 0.0;
|
||||||
|
bool expert_weights_norm = false;
|
||||||
|
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
|
||||||
|
|
||||||
float f_norm_eps;
|
float f_norm_eps;
|
||||||
float f_norm_rms_eps;
|
float f_norm_rms_eps;
|
||||||
|
@ -66,6 +66,7 @@ const char * llm_type_name(llm_type type) {
|
|||||||
case MODEL_70B: return "70B";
|
case MODEL_70B: return "70B";
|
||||||
case MODEL_236B: return "236B";
|
case MODEL_236B: return "236B";
|
||||||
case MODEL_314B: return "314B";
|
case MODEL_314B: return "314B";
|
||||||
|
case MODEL_671B: return "671B";
|
||||||
case MODEL_SMALL: return "0.1B";
|
case MODEL_SMALL: return "0.1B";
|
||||||
case MODEL_MEDIUM: return "0.4B";
|
case MODEL_MEDIUM: return "0.4B";
|
||||||
case MODEL_LARGE: return "0.8B";
|
case MODEL_LARGE: return "0.8B";
|
||||||
@ -125,6 +126,14 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static const char * llama_expert_gating_func_name(llama_expert_gating_func_type type) {
|
||||||
|
switch (type) {
|
||||||
|
case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX: return "softmax";
|
||||||
|
case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID: return "sigmoid";
|
||||||
|
default: return "unknown";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::string llama_model_arch_name (const llama_model & model) {
|
std::string llama_model_arch_name (const llama_model & model) {
|
||||||
return llm_arch_name(model.arch);
|
return llm_arch_name(model.arch);
|
||||||
}
|
}
|
||||||
@ -933,11 +942,19 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
|
|||||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||||
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
||||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
|
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
|
||||||
|
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
|
||||||
|
// for compatibility with existing DeepSeek V2 and V2.5 GGUFs
|
||||||
|
// that have no expert_gating_func model parameter set
|
||||||
|
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX;
|
||||||
|
}
|
||||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul);
|
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul);
|
||||||
|
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 27: model.type = e_model::MODEL_16B; break;
|
case 27: model.type = e_model::MODEL_16B; break;
|
||||||
case 60: model.type = e_model::MODEL_236B; break;
|
case 60: model.type = e_model::MODEL_236B; break;
|
||||||
|
case 61: model.type = e_model::MODEL_671B; break;
|
||||||
default: model.type = e_model::MODEL_UNKNOWN;
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
@ -1259,6 +1276,10 @@ void llm_load_vocab(llama_model_loader & ml, llama_model & model) {
|
|||||||
tokenizer_pre == "deepseek-coder") {
|
tokenizer_pre == "deepseek-coder") {
|
||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER;
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER;
|
||||||
vocab.tokenizer_clean_spaces = false;
|
vocab.tokenizer_clean_spaces = false;
|
||||||
|
} else if (
|
||||||
|
tokenizer_pre == "deepseek-v3") {
|
||||||
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM;
|
||||||
|
vocab.tokenizer_clean_spaces = false;
|
||||||
} else if (
|
} else if (
|
||||||
tokenizer_pre == "falcon") {
|
tokenizer_pre == "falcon") {
|
||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON;
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON;
|
||||||
@ -1941,6 +1962,8 @@ void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
|||||||
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||||
LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
|
LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
|
||||||
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
|
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
|
||||||
|
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
|
||||||
|
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((enum llama_expert_gating_func_type) hparams.expert_gating_func));
|
||||||
LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul);
|
LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,6 +63,7 @@ enum llm_type {
|
|||||||
MODEL_70B,
|
MODEL_70B,
|
||||||
MODEL_236B,
|
MODEL_236B,
|
||||||
MODEL_314B,
|
MODEL_314B,
|
||||||
|
MODEL_671B,
|
||||||
MODEL_SMALL,
|
MODEL_SMALL,
|
||||||
MODEL_MEDIUM,
|
MODEL_MEDIUM,
|
||||||
MODEL_LARGE,
|
MODEL_LARGE,
|
||||||
@ -213,6 +214,7 @@ struct llama_layer {
|
|||||||
struct ggml_tensor * ffn_down_b = nullptr; // b2
|
struct ggml_tensor * ffn_down_b = nullptr; // b2
|
||||||
struct ggml_tensor * ffn_up_b = nullptr; // b3
|
struct ggml_tensor * ffn_up_b = nullptr; // b3
|
||||||
struct ggml_tensor * ffn_act = nullptr;
|
struct ggml_tensor * ffn_act = nullptr;
|
||||||
|
struct ggml_tensor * ffn_exp_probs_b = nullptr;
|
||||||
|
|
||||||
// mamba proj
|
// mamba proj
|
||||||
struct ggml_tensor * ssm_in = nullptr;
|
struct ggml_tensor * ssm_in = nullptr;
|
||||||
|
@ -382,6 +382,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|||||||
"\\p{N}+",
|
"\\p{N}+",
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
|
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM:
|
||||||
|
regex_exprs = {
|
||||||
|
"\\p{N}{1,3}",
|
||||||
|
"[一-龥-ゟ゠-ヿ]+",
|
||||||
|
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
};
|
||||||
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
|
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
|
||||||
regex_exprs = {
|
regex_exprs = {
|
||||||
"[\r\n]",
|
"[\r\n]",
|
||||||
|
@ -1857,6 +1857,7 @@ static bool llm_load_tensors(
|
|||||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||||
} else {
|
} else {
|
||||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||||
|
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
if (n_expert == 0) {
|
if (n_expert == 0) {
|
||||||
throw std::runtime_error("n_expert must be > 0");
|
throw std::runtime_error("n_expert must be > 0");
|
||||||
@ -2837,12 +2838,14 @@ static struct ggml_tensor * llm_build_moe_ffn(
|
|||||||
struct ggml_tensor * up_exps,
|
struct ggml_tensor * up_exps,
|
||||||
struct ggml_tensor * gate_exps,
|
struct ggml_tensor * gate_exps,
|
||||||
struct ggml_tensor * down_exps,
|
struct ggml_tensor * down_exps,
|
||||||
|
struct ggml_tensor * exp_probs_b,
|
||||||
int64_t n_expert,
|
int64_t n_expert,
|
||||||
int64_t n_expert_used,
|
int64_t n_expert_used,
|
||||||
llm_ffn_op_type type_op,
|
llm_ffn_op_type type_op,
|
||||||
bool norm_w,
|
bool norm_w,
|
||||||
bool scale_w,
|
bool scale_w,
|
||||||
float w_scale,
|
float w_scale,
|
||||||
|
llama_expert_gating_func_type gating_op,
|
||||||
const llm_build_cb & cb,
|
const llm_build_cb & cb,
|
||||||
int il) {
|
int il) {
|
||||||
int64_t n_embd = cur->ne[0];
|
int64_t n_embd = cur->ne[0];
|
||||||
@ -2851,11 +2854,31 @@ static struct ggml_tensor * llm_build_moe_ffn(
|
|||||||
ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
|
ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
|
||||||
cb(logits, "ffn_moe_logits", il);
|
cb(logits, "ffn_moe_logits", il);
|
||||||
|
|
||||||
ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
|
ggml_tensor * probs = nullptr;
|
||||||
|
switch (gating_op) {
|
||||||
|
case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
|
||||||
|
{
|
||||||
|
probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
|
||||||
|
} break;
|
||||||
|
case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
|
||||||
|
{
|
||||||
|
probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens]
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
cb(probs, "ffn_moe_probs", il);
|
cb(probs, "ffn_moe_probs", il);
|
||||||
|
|
||||||
|
// add experts selection bias - introduced in DeepSeek V3
|
||||||
|
// leave probs unbiased as it's later used to get expert weights
|
||||||
|
ggml_tensor * selection_probs = probs;
|
||||||
|
if (exp_probs_b != nullptr) {
|
||||||
|
selection_probs = ggml_add(ctx, probs, exp_probs_b);
|
||||||
|
cb(selection_probs, "ffn_moe_probs_biased", il);
|
||||||
|
}
|
||||||
|
|
||||||
// select experts
|
// select experts
|
||||||
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
|
ggml_tensor * selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||||
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
||||||
cb(selected_experts, "ffn_moe_topk", il);
|
cb(selected_experts, "ffn_moe_topk", il);
|
||||||
|
|
||||||
@ -3976,9 +3999,11 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up_exps,
|
model.layers[il].ffn_up_exps,
|
||||||
model.layers[il].ffn_gate_exps,
|
model.layers[il].ffn_gate_exps,
|
||||||
model.layers[il].ffn_down_exps,
|
model.layers[il].ffn_down_exps,
|
||||||
|
nullptr,
|
||||||
n_expert, n_expert_used,
|
n_expert, n_expert_used,
|
||||||
LLM_FFN_SILU, true,
|
LLM_FFN_SILU, true,
|
||||||
false, 0.0,
|
false, 0.0,
|
||||||
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||||
cb, il);
|
cb, il);
|
||||||
cb(cur, "ffn_moe_out", il);
|
cb(cur, "ffn_moe_out", il);
|
||||||
}
|
}
|
||||||
@ -4628,9 +4653,11 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up_exps,
|
model.layers[il].ffn_up_exps,
|
||||||
model.layers[il].ffn_gate_exps,
|
model.layers[il].ffn_gate_exps,
|
||||||
model.layers[il].ffn_down_exps,
|
model.layers[il].ffn_down_exps,
|
||||||
|
nullptr,
|
||||||
n_expert, n_expert_used,
|
n_expert, n_expert_used,
|
||||||
LLM_FFN_GELU, true,
|
LLM_FFN_GELU, true,
|
||||||
false, 0.0,
|
false, 0.0,
|
||||||
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||||
cb, il);
|
cb, il);
|
||||||
cb(cur, "ffn_moe_out", il);
|
cb(cur, "ffn_moe_out", il);
|
||||||
|
|
||||||
@ -4769,9 +4796,11 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up_exps,
|
model.layers[il].ffn_up_exps,
|
||||||
model.layers[il].ffn_gate_exps,
|
model.layers[il].ffn_gate_exps,
|
||||||
model.layers[il].ffn_down_exps,
|
model.layers[il].ffn_down_exps,
|
||||||
|
nullptr,
|
||||||
n_expert, n_expert_used,
|
n_expert, n_expert_used,
|
||||||
LLM_FFN_SILU, true,
|
LLM_FFN_SILU, true,
|
||||||
false, 0.0,
|
false, 0.0,
|
||||||
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||||
cb, il);
|
cb, il);
|
||||||
cb(cur, "ffn_moe_out", il);
|
cb(cur, "ffn_moe_out", il);
|
||||||
|
|
||||||
@ -6017,9 +6046,11 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up_exps,
|
model.layers[il].ffn_up_exps,
|
||||||
model.layers[il].ffn_gate_exps,
|
model.layers[il].ffn_gate_exps,
|
||||||
model.layers[il].ffn_down_exps,
|
model.layers[il].ffn_down_exps,
|
||||||
|
nullptr,
|
||||||
n_expert, n_expert_used,
|
n_expert, n_expert_used,
|
||||||
LLM_FFN_SILU, false,
|
LLM_FFN_SILU, false,
|
||||||
false, 0.0,
|
false, 0.0,
|
||||||
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||||
cb, il);
|
cb, il);
|
||||||
cb(cur, "ffn_moe_out", il);
|
cb(cur, "ffn_moe_out", il);
|
||||||
|
|
||||||
@ -8142,9 +8173,11 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up_exps,
|
model.layers[il].ffn_up_exps,
|
||||||
model.layers[il].ffn_gate_exps,
|
model.layers[il].ffn_gate_exps,
|
||||||
model.layers[il].ffn_down_exps,
|
model.layers[il].ffn_down_exps,
|
||||||
|
nullptr,
|
||||||
n_expert, n_expert_used,
|
n_expert, n_expert_used,
|
||||||
LLM_FFN_SILU, false,
|
LLM_FFN_SILU, false,
|
||||||
false, 0.0,
|
false, 0.0,
|
||||||
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||||
cb, il);
|
cb, il);
|
||||||
cb(cur, "ffn_moe_out", il);
|
cb(cur, "ffn_moe_out", il);
|
||||||
|
|
||||||
@ -8539,9 +8572,11 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up_exps,
|
model.layers[il].ffn_up_exps,
|
||||||
model.layers[il].ffn_gate_exps,
|
model.layers[il].ffn_gate_exps,
|
||||||
model.layers[il].ffn_down_exps,
|
model.layers[il].ffn_down_exps,
|
||||||
|
nullptr,
|
||||||
n_expert, n_expert_used,
|
n_expert, n_expert_used,
|
||||||
LLM_FFN_SILU, true,
|
LLM_FFN_SILU, true,
|
||||||
false, 0.0,
|
false, 0.0,
|
||||||
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||||
cb, il);
|
cb, il);
|
||||||
cb(cur, "ffn_moe_out", il);
|
cb(cur, "ffn_moe_out", il);
|
||||||
|
|
||||||
@ -8680,9 +8715,11 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up_exps,
|
model.layers[il].ffn_up_exps,
|
||||||
model.layers[il].ffn_gate_exps,
|
model.layers[il].ffn_gate_exps,
|
||||||
model.layers[il].ffn_down_exps,
|
model.layers[il].ffn_down_exps,
|
||||||
|
nullptr,
|
||||||
n_expert, n_expert_used,
|
n_expert, n_expert_used,
|
||||||
LLM_FFN_SILU, false,
|
LLM_FFN_SILU, false,
|
||||||
false, hparams.expert_weights_scale,
|
false, hparams.expert_weights_scale,
|
||||||
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||||
cb, il);
|
cb, il);
|
||||||
cb(moe_out, "ffn_moe_out", il);
|
cb(moe_out, "ffn_moe_out", il);
|
||||||
|
|
||||||
@ -8909,9 +8946,11 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up_exps,
|
model.layers[il].ffn_up_exps,
|
||||||
model.layers[il].ffn_gate_exps,
|
model.layers[il].ffn_gate_exps,
|
||||||
model.layers[il].ffn_down_exps,
|
model.layers[il].ffn_down_exps,
|
||||||
|
model.layers[il].ffn_exp_probs_b,
|
||||||
n_expert, n_expert_used,
|
n_expert, n_expert_used,
|
||||||
LLM_FFN_SILU, false,
|
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||||
true, hparams.expert_weights_scale,
|
true, hparams.expert_weights_scale,
|
||||||
|
(enum llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||||
cb, il);
|
cb, il);
|
||||||
cb(moe_out, "ffn_moe_out", il);
|
cb(moe_out, "ffn_moe_out", il);
|
||||||
|
|
||||||
|
@ -667,18 +667,24 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||||||
{ "\\p{N}", unicode_cpt_flags::NUMBER },
|
{ "\\p{N}", unicode_cpt_flags::NUMBER },
|
||||||
{ "\\p{L}", unicode_cpt_flags::LETTER },
|
{ "\\p{L}", unicode_cpt_flags::LETTER },
|
||||||
{ "\\p{P}", unicode_cpt_flags::PUNCTUATION },
|
{ "\\p{P}", unicode_cpt_flags::PUNCTUATION },
|
||||||
|
{ "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
|
||||||
|
{ "\\p{S}", unicode_cpt_flags::SYMBOL },
|
||||||
};
|
};
|
||||||
|
|
||||||
static const std::map<int, int> k_ucat_cpt = {
|
static const std::map<int, int> k_ucat_cpt = {
|
||||||
{ unicode_cpt_flags::NUMBER, 0xD1 },
|
{ unicode_cpt_flags::NUMBER, 0xD1 },
|
||||||
{ unicode_cpt_flags::LETTER, 0xD2 },
|
{ unicode_cpt_flags::LETTER, 0xD2 },
|
||||||
{ unicode_cpt_flags::PUNCTUATION, 0xD3 },
|
{ unicode_cpt_flags::PUNCTUATION, 0xD3 },
|
||||||
|
{ unicode_cpt_flags::ACCENT_MARK, 0xD4 },
|
||||||
|
{ unicode_cpt_flags::SYMBOL, 0xD5 },
|
||||||
};
|
};
|
||||||
|
|
||||||
static const std::map<int, std::string> k_ucat_map = {
|
static const std::map<int, std::string> k_ucat_map = {
|
||||||
{ unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9
|
{ unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9
|
||||||
{ unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
{ unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
||||||
{ unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
{ unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
||||||
|
{ unicode_cpt_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
|
||||||
|
{ unicode_cpt_flags::SYMBOL, "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
|
||||||
};
|
};
|
||||||
|
|
||||||
// compute collapsed codepoints only if needed by at least one regex
|
// compute collapsed codepoints only if needed by at least one regex
|
||||||
|
Loading…
x
Reference in New Issue
Block a user