CUDA: fix Volta FlashAttention logic (#11615)

This commit is contained in:
Johannes Gäßler 2025-02-03 13:25:56 +01:00 committed by GitHub
parent d92cb67e37
commit 21c84b5d2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 2 deletions

View File

@ -561,7 +561,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
break;
// case 256:
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
// ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
// break;
default:
GGML_ABORT("fatal error");

View File

@ -235,7 +235,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
return;
}
if (!new_mma_available(cc)) {
if (!fp16_mma_available(cc)) {
if (prec == GGML_PREC_DEFAULT) {
if (Q->ne[1] <= 8) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
@ -265,6 +265,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
if (cc == GGML_CUDA_CC_VOLTA) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);