diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 1054ff95d..45702ad65 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -561,7 +561,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); break; // case 256: - // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + // ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); // break; default: GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index b1e66d470..b0cf152f5 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -235,7 +235,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - if (!new_mma_available(cc)) { + if (!fp16_mma_available(cc)) { if (prec == GGML_PREC_DEFAULT) { if (Q->ne[1] <= 8) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); @@ -265,6 +265,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // The MMA implementation needs Turing or newer, use the old WMMA code for Volta: if (cc == GGML_CUDA_CC_VOLTA) { ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + return; } ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);