From 0ede372a51fd8160688e01b587582666c14e94e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 18 Jun 2023 16:07:09 +0200 Subject: [PATCH] Fixed incorrectly applying RMS norm twice (#1925) --- llama.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index dfbb85a68..45360cea3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1657,11 +1657,7 @@ static bool llama_eval_internal( { cur = ggml_rms_norm(ctx0, inpL); offload_func_nr(cur); - ggml_set_name(cur, "rms_norm_inpL"); - - cur = ggml_rms_norm(ctx0, cur); - offload_func_nr(cur); - ggml_set_name(cur, "rms_norm_after"); + ggml_set_name(cur, "rms_norm_2"); // cur = cur*norm(broadcasted) cur = ggml_mul(ctx0, cur, model.norm);