fix nwarps > batch size

This commit is contained in:
Johannes Gäßler 2024-05-26 23:00:15 +02:00
parent f08776041d
commit f4003cfba1
2 changed files with 8 additions and 0 deletions

View File

@ -92,6 +92,10 @@ static __global__ void flash_attn_vec_ext_f16(
for (int j0 = 0; j0 < ncols; j0 += nwarps) { for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y; const int j = j0 + threadIdx.y;
if (j0 + nwarps > ncols && j >= ncols) {
break;
}
// Reuse KQ as temporary storage for converting Q to q8_1: // Reuse KQ as temporary storage for converting Q to q8_1:
int * tmp_q_i32 = (int *) &KQ[j*D]; int * tmp_q_i32 = (int *) &KQ[j*D];
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int));

View File

@ -92,6 +92,10 @@ static __global__ void flash_attn_vec_ext_f32(
for (int j0 = 0; j0 < ncols; j0 += nwarps) { for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y; const int j = j0 + threadIdx.y;
if (j0 + nwarps > ncols && j >= ncols) {
break;
}
// Reuse KQ as temporary storage for converting Q to q8_1: // Reuse KQ as temporary storage for converting Q to q8_1:
int * tmp_q_i32 = (int *) &KQ[j*D]; int * tmp_q_i32 = (int *) &KQ[j*D];
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));