fix commented-out kernel variants

This commit is contained in:
Johannes Gäßler 2024-05-26 20:14:55 +02:00
parent 462add6a01
commit 3194a01058

View File

@ -382,37 +382,38 @@ void launch_fattn_vec_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const int32_t precision = KQV->op_params[2];
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
// if (Q->ne[1] == 1) {
// constexpr int cols_per_block = 1;
// constexpr int parallel_blocks = 4;
// launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
// return;
// }
if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4;
launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
return;
}
// if (Q->ne[1] == 2) {
// constexpr int cols_per_block = 2;
// constexpr int parallel_blocks = 4;
// launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
// return;
// }
if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4;
launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
return;
}
// if (Q->ne[1] <= 4) {
// constexpr int cols_per_block = 4;
// constexpr int parallel_blocks = 4;
// launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
// return;
// }
if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4;
launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
return;
}
// if (Q->ne[1] <= 8) {
// constexpr int cols_per_block = 8;
// constexpr int parallel_blocks = 4;
// launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
// return;
// }
if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4;
launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
return;
}
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1;