mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 15:18:26 +01:00
cuda : fix matrix names
This commit is contained in:
parent
cfd9732b2e
commit
ef68fac2a8
10
ggml-cuda.cu
10
ggml-cuda.cu
@ -6687,19 +6687,19 @@ static __global__ void flash_attn_ext_f16(
|
||||
for (int cc = 0; cc < C/16; ++cc) {
|
||||
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
half16x16_b mk[D16];
|
||||
half16x16_b mv[D16];
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half));
|
||||
nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half));
|
||||
}
|
||||
|
||||
half16x16_a mv[Q16];
|
||||
half16x16_a ms[Q16];
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T);
|
||||
nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T);
|
||||
}
|
||||
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]);
|
||||
nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user