mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
fix nwarps > batch size
This commit is contained in:
parent
f08776041d
commit
f4003cfba1
@ -92,6 +92,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
const int j = j0 + threadIdx.y;
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
|
if (j0 + nwarps > ncols && j >= ncols) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// Reuse KQ as temporary storage for converting Q to q8_1:
|
// Reuse KQ as temporary storage for converting Q to q8_1:
|
||||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||||
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int));
|
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int));
|
||||||
|
@ -92,6 +92,10 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
const int j = j0 + threadIdx.y;
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
|
if (j0 + nwarps > ncols && j >= ncols) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// Reuse KQ as temporary storage for converting Q to q8_1:
|
// Reuse KQ as temporary storage for converting Q to q8_1:
|
||||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||||
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
|
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
|
||||||
|
Loading…
Reference in New Issue
Block a user