From f4003cfba1d95511ae59820ff5ea1b62e3cce225 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 26 May 2024 23:00:15 +0200 Subject: [PATCH] fix nwarps > batch size --- ggml-cuda/fattn-vec-f16.cu | 4 ++++ ggml-cuda/fattn-vec-f32.cu | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/ggml-cuda/fattn-vec-f16.cu b/ggml-cuda/fattn-vec-f16.cu index 71581509c..c427e18ab 100644 --- a/ggml-cuda/fattn-vec-f16.cu +++ b/ggml-cuda/fattn-vec-f16.cu @@ -92,6 +92,10 @@ static __global__ void flash_attn_vec_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + // Reuse KQ as temporary storage for converting Q to q8_1: int * tmp_q_i32 = (int *) &KQ[j*D]; half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); diff --git a/ggml-cuda/fattn-vec-f32.cu b/ggml-cuda/fattn-vec-f32.cu index dded24320..1b6197d05 100644 --- a/ggml-cuda/fattn-vec-f32.cu +++ b/ggml-cuda/fattn-vec-f32.cu @@ -92,6 +92,10 @@ static __global__ void flash_attn_vec_ext_f32( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + // Reuse KQ as temporary storage for converting Q to q8_1: int * tmp_q_i32 = (int *) &KQ[j*D]; float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));