diff --git a/ggml-cuda/fattn-vec-f16.cu b/ggml-cuda/fattn-vec-f16.cu index 71581509c..c427e18ab 100644 --- a/ggml-cuda/fattn-vec-f16.cu +++ b/ggml-cuda/fattn-vec-f16.cu @@ -92,6 +92,10 @@ static __global__ void flash_attn_vec_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + // Reuse KQ as temporary storage for converting Q to q8_1: int * tmp_q_i32 = (int *) &KQ[j*D]; half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); diff --git a/ggml-cuda/fattn-vec-f32.cu b/ggml-cuda/fattn-vec-f32.cu index dded24320..1b6197d05 100644 --- a/ggml-cuda/fattn-vec-f32.cu +++ b/ggml-cuda/fattn-vec-f32.cu @@ -92,6 +92,10 @@ static __global__ void flash_attn_vec_ext_f32( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + // Reuse KQ as temporary storage for converting Q to q8_1: int * tmp_q_i32 = (int *) &KQ[j*D]; float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));