mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 21:37:19 +01:00
model: Add support for PhiMoE arch (#11003)
* model: support phimoe * python linter * doc: minor Co-authored-by: ThiloteE <73715071+ThiloteE@users.noreply.github.com> * doc: minor Co-authored-by: ThiloteE <73715071+ThiloteE@users.noreply.github.com> * doc: add phimoe as supported model ggml-ci --------- Co-authored-by: ThiloteE <73715071+ThiloteE@users.noreply.github.com>
This commit is contained in:
parent
be0e950c91
commit
f8feb4b01a
@ -69,6 +69,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen)
|
||||
- [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557)
|
||||
- [x] [Phi models](https://huggingface.co/models?search=microsoft/phi)
|
||||
- [x] [PhiMoE](https://github.com/ggerganov/llama.cpp/pull/11003)
|
||||
- [x] [GPT-2](https://huggingface.co/gpt2)
|
||||
- [x] [Orion 14B](https://github.com/ggerganov/llama.cpp/pull/5118)
|
||||
- [x] [InternLM2](https://huggingface.co/models?search=internlm2)
|
||||
|
@ -2562,6 +2562,63 @@ class Phi3MiniModel(Model):
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
|
||||
|
||||
|
||||
@Model.register("PhiMoEForCausalLM")
|
||||
class PhiMoeModel(Phi3MiniModel):
|
||||
model_arch = gguf.MODEL_ARCH.PHIMOE
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
|
||||
self.gguf_writer.add_expert_count(self.hparams["num_local_experts"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# process the experts separately
|
||||
if name.find("block_sparse_moe.experts") != -1:
|
||||
n_experts = self.hparams["num_local_experts"]
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["w1", "w2", "w3"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
if self._experts is not None:
|
||||
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("PlamoForCausalLM")
|
||||
class PlamoModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.PLAMO
|
||||
|
@ -28,7 +28,7 @@ The required steps to implement for an HF model are:
|
||||
```python
|
||||
@Model.register("MyModelForCausalLM")
|
||||
class MyModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.GROK
|
||||
model_arch = gguf.MODEL_ARCH.MYMODEL
|
||||
```
|
||||
|
||||
2. Define the layout of the GGUF tensors in [constants.py](/gguf-py/gguf/constants.py)
|
||||
@ -79,14 +79,14 @@ Depending on the model configuration, tokenizer, code and tensors layout, you wi
|
||||
- `Model#set_vocab`
|
||||
- `Model#write_tensors`
|
||||
|
||||
NOTE: Tensor names must end with `.weight` suffix, that is the convention and several tools like `quantize` expect this to proceed the weights.
|
||||
NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the convention and several tools like `quantize` expect this to proceed the weights.
|
||||
|
||||
### 2. Define the model architecture in `llama.cpp`
|
||||
|
||||
The model params and tensors layout must be defined in `llama.cpp`:
|
||||
1. Define a new `llm_arch`
|
||||
2. Define the tensors layout in `LLM_TENSOR_NAMES`
|
||||
3. Add any non standard metadata in `llm_load_hparams`
|
||||
3. Add any non-standard metadata in `llm_load_hparams`
|
||||
4. Create the tensors for inference in `llm_load_tensors`
|
||||
5. If the model has a RoPE operation, add the rope type in `llama_rope_type`
|
||||
|
||||
@ -96,9 +96,9 @@ NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorc
|
||||
|
||||
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`.
|
||||
|
||||
Have a look at existing implementation like `build_llama`, `build_dbrx` or `build_bert`.
|
||||
Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`.
|
||||
|
||||
When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support for missing backend operations can be added in another PR.
|
||||
Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR.
|
||||
|
||||
Note: to debug the inference graph: you can use [llama-eval-callback](/examples/eval-callback/).
|
||||
|
||||
|
@ -244,6 +244,7 @@ class MODEL_ARCH(IntEnum):
|
||||
QWEN2VL = auto()
|
||||
PHI2 = auto()
|
||||
PHI3 = auto()
|
||||
PHIMOE = auto()
|
||||
PLAMO = auto()
|
||||
CODESHELL = auto()
|
||||
ORION = auto()
|
||||
@ -428,6 +429,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.QWEN2VL: "qwen2vl",
|
||||
MODEL_ARCH.PHI2: "phi2",
|
||||
MODEL_ARCH.PHI3: "phi3",
|
||||
MODEL_ARCH.PHIMOE: "phimoe",
|
||||
MODEL_ARCH.PLAMO: "plamo",
|
||||
MODEL_ARCH.CODESHELL: "codeshell",
|
||||
MODEL_ARCH.ORION: "orion",
|
||||
@ -940,6 +942,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.PHIMOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FACTORS_LONG,
|
||||
MODEL_TENSOR.ROPE_FACTORS_SHORT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.CODESHELL: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.POS_EMBD,
|
||||
|
@ -55,7 +55,7 @@ class TensorNameMap:
|
||||
# Output
|
||||
MODEL_TENSOR.OUTPUT: (
|
||||
"embed_out", # gptneox
|
||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2
|
||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe
|
||||
"output", # llama-pth bloom internlm2
|
||||
"word_embeddings_for_head", # persimmon
|
||||
"lm_head.linear", # phi2
|
||||
@ -68,7 +68,7 @@ class TensorNameMap:
|
||||
MODEL_TENSOR.OUTPUT_NORM: (
|
||||
"gpt_neox.final_layer_norm", # gptneox
|
||||
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
|
||||
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2
|
||||
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe
|
||||
"norm", # llama-pth
|
||||
"transformer.norm_f", # mpt dbrx
|
||||
"ln_f", # refact bloom qwen gpt2
|
||||
@ -108,7 +108,7 @@ class TensorNameMap:
|
||||
"transformer.h.{bid}.input_layernorm", # falcon7b
|
||||
"h.{bid}.input_layernorm", # bloom
|
||||
"transformer.h.{bid}.ln_mlp", # falcon40b
|
||||
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe
|
||||
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe
|
||||
"layers.{bid}.attention_norm", # llama-pth
|
||||
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
||||
"model.layers.{bid}.ln1", # yi
|
||||
@ -152,7 +152,7 @@ class TensorNameMap:
|
||||
|
||||
# Attention query
|
||||
MODEL_TENSOR.ATTN_Q: (
|
||||
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2
|
||||
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe
|
||||
"model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom
|
||||
"layers.{bid}.attention.wq", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.query", # bert
|
||||
@ -165,7 +165,7 @@ class TensorNameMap:
|
||||
|
||||
# Attention key
|
||||
MODEL_TENSOR.ATTN_K: (
|
||||
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2
|
||||
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe
|
||||
"model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom
|
||||
"layers.{bid}.attention.wk", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.key", # bert
|
||||
@ -179,7 +179,7 @@ class TensorNameMap:
|
||||
|
||||
# Attention value
|
||||
MODEL_TENSOR.ATTN_V: (
|
||||
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2
|
||||
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe
|
||||
"layers.{bid}.attention.wv", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.value", # bert
|
||||
"transformer.h.{bid}.attn.v_proj", # gpt-j
|
||||
@ -197,7 +197,7 @@ class TensorNameMap:
|
||||
"transformer.blocks.{bid}.attn.out_proj", # mpt
|
||||
"transformer.h.{bid}.self_attention.dense", # falcon
|
||||
"h.{bid}.self_attention.dense", # bloom
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe
|
||||
"model.layers.{bid}.self_attn.linear_attn", # deci
|
||||
"layers.{bid}.attention.wo", # llama-pth
|
||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||
@ -242,7 +242,7 @@ class TensorNameMap:
|
||||
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
|
||||
"h.{bid}.post_attention_layernorm", # bloom
|
||||
"transformer.blocks.{bid}.norm_2", # mpt
|
||||
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe
|
||||
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe phimoe
|
||||
"layers.{bid}.ffn_norm", # llama-pth
|
||||
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
|
||||
"model.layers.{bid}.ln2", # yi
|
||||
@ -265,7 +265,7 @@ class TensorNameMap:
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP: (
|
||||
"layers.{bid}.feed_forward.gate", # mixtral
|
||||
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
|
||||
"model.layers.{bid}.block_sparse_moe.gate", # mixtral phimoe
|
||||
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
|
||||
"transformer.decoder_layer.{bid}.router", # Grok
|
||||
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
||||
@ -310,10 +310,11 @@ class TensorNameMap:
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
|
||||
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_SHEXP: (
|
||||
@ -342,10 +343,11 @@ class TensorNameMap:
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
|
||||
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP: (
|
||||
@ -387,6 +389,7 @@ class TensorNameMap:
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP: (
|
||||
|
@ -27,6 +27,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
|
||||
{ LLM_ARCH_PHI2, "phi2" },
|
||||
{ LLM_ARCH_PHI3, "phi3" },
|
||||
{ LLM_ARCH_PHIMOE, "phimoe" },
|
||||
{ LLM_ARCH_PLAMO, "plamo" },
|
||||
{ LLM_ARCH_CODESHELL, "codeshell" },
|
||||
{ LLM_ARCH_ORION, "orion" },
|
||||
@ -584,6 +585,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_PHIMOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" },
|
||||
{ LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_PLAMO,
|
||||
{
|
||||
|
@ -31,6 +31,7 @@ enum llm_arch {
|
||||
LLM_ARCH_QWEN2VL,
|
||||
LLM_ARCH_PHI2,
|
||||
LLM_ARCH_PHI3,
|
||||
LLM_ARCH_PHIMOE,
|
||||
LLM_ARCH_PLAMO,
|
||||
LLM_ARCH_CODESHELL,
|
||||
LLM_ARCH_ORION,
|
||||
|
@ -76,6 +76,7 @@ const char * llm_type_name(llm_type type) {
|
||||
case MODEL_8x7B: return "8x7B";
|
||||
case MODEL_8x22B: return "8x22B";
|
||||
case MODEL_16x12B: return "16x12B";
|
||||
case MODEL_16x3_8B: return "16x3.8B";
|
||||
case MODEL_10B_128x3_66B: return "10B+128x3.66B";
|
||||
case MODEL_57B_A14B: return "57B.A14B";
|
||||
case MODEL_27B: return "27B";
|
||||
@ -661,6 +662,15 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
|
||||
throw std::runtime_error("invalid value for sliding_window");
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_PHIMOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 32: model.type = e_model::MODEL_16x3_8B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_PLAMO:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
@ -2094,6 +2104,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_OLMOE:
|
||||
case LLM_ARCH_PHI2:
|
||||
case LLM_ARCH_PHI3:
|
||||
case LLM_ARCH_PHIMOE:
|
||||
case LLM_ARCH_GEMMA:
|
||||
case LLM_ARCH_GEMMA2:
|
||||
case LLM_ARCH_STARCODER2:
|
||||
|
@ -73,6 +73,7 @@ enum llm_type {
|
||||
MODEL_8x7B,
|
||||
MODEL_8x22B,
|
||||
MODEL_16x12B,
|
||||
MODEL_16x3_8B,
|
||||
MODEL_10B_128x3_66B,
|
||||
MODEL_57B_A14B,
|
||||
MODEL_27B,
|
||||
|
@ -1212,6 +1212,50 @@ static bool llm_load_tensors(
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
|
||||
|
||||
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_PHIMOE:
|
||||
{
|
||||
const int64_t n_embd_head = n_embd / n_head;
|
||||
|
||||
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
|
||||
|
||||
// output
|
||||
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
||||
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
||||
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
|
||||
model.output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), { n_vocab }, 0);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0);
|
||||
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
if (layer.wqkv == nullptr) {
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
|
||||
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0);
|
||||
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
|
||||
}
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
|
||||
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0);
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
||||
|
||||
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
}
|
||||
@ -6266,7 +6310,7 @@ struct llm_build_context {
|
||||
|
||||
struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm,
|
||||
NULL,
|
||||
model.layers[il].attn_norm_b,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(attn_norm_output, "attn_norm", il);
|
||||
|
||||
@ -6281,8 +6325,7 @@ struct llm_build_context {
|
||||
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
|
||||
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd)));
|
||||
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
|
||||
Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
|
||||
Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
|
||||
@ -6326,14 +6369,12 @@ struct llm_build_context {
|
||||
residual = cur;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// FF
|
||||
// special-case: the up and gate tensors are merged into a single tensor
|
||||
// TOOD: support into llm_build_ffn
|
||||
{
|
||||
// feed-forward network
|
||||
if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||
cur = llm_build_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
NULL, NULL, NULL,
|
||||
@ -6341,6 +6382,20 @@ struct llm_build_context {
|
||||
NULL,
|
||||
LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
// MoE branch
|
||||
cur = llm_build_moe_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
cb, il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, residual, cur);
|
||||
@ -6353,11 +6408,16 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.output_norm,
|
||||
NULL,
|
||||
model.output_norm_b,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||
|
||||
if (model.output_b != nullptr) {
|
||||
cb(cur, "result_output_no_bias", -1);
|
||||
cur = ggml_add(ctx0, cur, model.output_b);
|
||||
}
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
@ -10536,6 +10596,7 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
result = llm.build_phi2();
|
||||
} break;
|
||||
case LLM_ARCH_PHI3:
|
||||
case LLM_ARCH_PHIMOE:
|
||||
{
|
||||
result = llm.build_phi3();
|
||||
} break;
|
||||
|
Loading…
x
Reference in New Issue
Block a user