mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 07:34:18 +01:00
llama : fix expert weighting in the FFN
This commit is contained in:
parent
7ea36953ba
commit
8b185b7030
12
llama.cpp
12
llama.cpp
@ -4250,11 +4250,13 @@ struct llm_build_context {
|
||||
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
|
||||
|
||||
// select experts
|
||||
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
|
||||
//ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [n_tokens, num_experts_per_tok, 1]
|
||||
ggml_tensor * weights = ggml_get_rows(ctx0,
|
||||
ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts);
|
||||
weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [n_tokens, num_experts_per_tok, 1]
|
||||
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
|
||||
ggml_tensor * weights =
|
||||
ggml_reshape_2d(ctx0,
|
||||
ggml_get_rows(ctx0,
|
||||
ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts),
|
||||
n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok]
|
||||
weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [n_tokens, num_experts_per_tok]
|
||||
|
||||
// compute expert outputs
|
||||
ggml_tensor * moe_out;
|
||||
|
Loading…
Reference in New Issue
Block a user