mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
CUDA: remove unnecessary warp reduce in FA (ggml/1032)
* kqmax_new_j in every thread within warp is same after operate at line 199,this reduce can be omit * same problem in vec32 --------- Co-authored-by: ZhaoXiaoYu <zhao.xiaoyu@zte.com.cn>
This commit is contained in:
parent
efb6ae9630
commit
e9e661bd59
@ -220,7 +220,6 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
|
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
|
||||||
|
|
||||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
||||||
}
|
}
|
||||||
|
@ -206,7 +206,6 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
float kqmax_new_j = kqmax_new_arr[j];
|
float kqmax_new_j = kqmax_new_arr[j];
|
||||||
|
|
||||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user