diff --git a/llama.cpp b/llama.cpp index 6333af4aa..3c4da6a1c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4250,11 +4250,13 @@ struct llm_build_context { ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts] // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok] - //ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [n_tokens, num_experts_per_tok, 1] - ggml_tensor * weights = ggml_get_rows(ctx0, - ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts); - weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [n_tokens, num_experts_per_tok, 1] + ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok] + ggml_tensor * weights = + ggml_reshape_2d(ctx0, + ggml_get_rows(ctx0, + ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts), + n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok] + weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [n_tokens, num_experts_per_tok] // compute expert outputs ggml_tensor * moe_out;