From a7b471569bdf4e09e97b2d02c27989b8cb801861 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 15:17:49 +0200 Subject: [PATCH] cuda : switch to 1 warp for bs > 16 --- ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 1fed9d23e..c98b551b3 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -10933,7 +10933,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why - const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 2; + const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1; dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 block_dim(32, nwarps, 1);