mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 05:48:47 +01:00
CUDA: fix Pascal FA, deq. KV to FP16 for batch > 8 (#7681)
This commit is contained in:
parent
9b596417af
commit
750f60c03e
@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
|
#include "convert.cuh"
|
||||||
#include "vecdotq.cuh"
|
#include "vecdotq.cuh"
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
@ -53,7 +54,7 @@ typedef float (*vec_dot_KQ_f32_t)(
|
|||||||
template<typename T, int D>
|
template<typename T, int D>
|
||||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
||||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||||
#if __CUDA_ARCH__ > MIN_CC_DP4A
|
#if __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||||
|
|
||||||
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
|
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
|
||||||
GGML_UNUSED(Q_v);
|
GGML_UNUSED(Q_v);
|
||||||
@ -95,13 +96,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
|||||||
GGML_UNUSED(Q_q8);
|
GGML_UNUSED(Q_q8);
|
||||||
GGML_UNUSED(Q_ds_v);
|
GGML_UNUSED(Q_ds_v);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
|
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T, int D>
|
template<typename T, int D>
|
||||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
||||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||||
#if __CUDA_ARCH__ > MIN_CC_DP4A
|
#if __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||||
|
|
||||||
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
|
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
|
||||||
GGML_UNUSED(Q_v);
|
GGML_UNUSED(Q_v);
|
||||||
@ -147,13 +148,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
|||||||
GGML_UNUSED(Q_q8);
|
GGML_UNUSED(Q_q8);
|
||||||
GGML_UNUSED(Q_ds_v);
|
GGML_UNUSED(Q_ds_v);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
|
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T, int D>
|
template<typename T, int D>
|
||||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
||||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||||
#if __CUDA_ARCH__ > MIN_CC_DP4A
|
#if __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||||
|
|
||||||
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
|
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
|
||||||
GGML_UNUSED(Q_v);
|
GGML_UNUSED(Q_v);
|
||||||
@ -202,13 +203,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
|||||||
GGML_UNUSED(Q_q8);
|
GGML_UNUSED(Q_q8);
|
||||||
GGML_UNUSED(Q_ds_v);
|
GGML_UNUSED(Q_ds_v);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
|
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T, int D>
|
template<typename T, int D>
|
||||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
||||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||||
#if __CUDA_ARCH__ > MIN_CC_DP4A
|
#if __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||||
|
|
||||||
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
|
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
|
||||||
GGML_UNUSED(Q_v);
|
GGML_UNUSED(Q_v);
|
||||||
@ -261,13 +262,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
|||||||
GGML_UNUSED(Q_q8);
|
GGML_UNUSED(Q_q8);
|
||||||
GGML_UNUSED(Q_ds_v);
|
GGML_UNUSED(Q_ds_v);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
|
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int D>
|
template <typename T, int D>
|
||||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
||||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||||
#if __CUDA_ARCH__ > MIN_CC_DP4A
|
#if __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||||
|
|
||||||
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
|
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
|
||||||
GGML_UNUSED(Q_v);
|
GGML_UNUSED(Q_v);
|
||||||
@ -302,7 +303,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
|||||||
GGML_UNUSED(Q_q8);
|
GGML_UNUSED(Q_q8);
|
||||||
GGML_UNUSED(Q_ds_v);
|
GGML_UNUSED(Q_ds_v);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
|
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int D>
|
template <typename T, int D>
|
||||||
@ -620,7 +621,10 @@ static void on_no_fattn_vec_case(const int D) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int D, int parallel_blocks>
|
template <int D, int parallel_blocks>
|
||||||
void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
|
void launch_fattn(
|
||||||
|
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
||||||
|
const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
|
||||||
|
) {
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
const ggml_tensor * K = dst->src[1];
|
const ggml_tensor * K = dst->src[1];
|
||||||
const ggml_tensor * V = dst->src[2];
|
const ggml_tensor * V = dst->src[2];
|
||||||
@ -641,9 +645,49 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
|
|||||||
ggml_cuda_pool & pool = ctx.pool();
|
ggml_cuda_pool & pool = ctx.pool();
|
||||||
cudaStream_t main_stream = ctx.stream();
|
cudaStream_t main_stream = ctx.stream();
|
||||||
|
|
||||||
|
ggml_cuda_pool_alloc<half> K_f16(pool);
|
||||||
|
ggml_cuda_pool_alloc<half> V_f16(pool);
|
||||||
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
||||||
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
||||||
|
|
||||||
|
char * K_data = (char *) K->data;
|
||||||
|
size_t nb11 = K->nb[1];
|
||||||
|
size_t nb12 = K->nb[2];
|
||||||
|
size_t nb13 = K->nb[3];
|
||||||
|
|
||||||
|
char * V_data = (char *) V->data;
|
||||||
|
size_t nb21 = V->nb[1];
|
||||||
|
size_t nb22 = V->nb[2];
|
||||||
|
size_t nb23 = V->nb[3];
|
||||||
|
|
||||||
|
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
||||||
|
K_f16.alloc(ggml_nelements(K));
|
||||||
|
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
|
||||||
|
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
|
||||||
|
K_data = (char *) K_f16.ptr;
|
||||||
|
|
||||||
|
const size_t bs = ggml_blck_size(K->type);
|
||||||
|
const size_t ts = ggml_type_size(K->type);
|
||||||
|
|
||||||
|
nb11 = nb11*bs*sizeof(half)/ts;
|
||||||
|
nb12 = nb12*bs*sizeof(half)/ts;
|
||||||
|
nb13 = nb13*bs*sizeof(half)/ts;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (need_f16_V && V->type != GGML_TYPE_F16) {
|
||||||
|
V_f16.alloc(ggml_nelements(V));
|
||||||
|
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||||
|
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
||||||
|
V_data = (char *) V_f16.ptr;
|
||||||
|
|
||||||
|
const size_t bs = ggml_blck_size(V->type);
|
||||||
|
const size_t ts = ggml_type_size(V->type);
|
||||||
|
|
||||||
|
nb21 = nb21*bs*sizeof(half)/ts;
|
||||||
|
nb22 = nb22*bs*sizeof(half)/ts;
|
||||||
|
nb23 = nb23*bs*sizeof(half)/ts;
|
||||||
|
}
|
||||||
|
|
||||||
if (parallel_blocks > 1) {
|
if (parallel_blocks > 1) {
|
||||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||||
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
||||||
@ -667,8 +711,8 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
|
|||||||
|
|
||||||
fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
|
fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
|
||||||
(const char *) Q->data,
|
(const char *) Q->data,
|
||||||
(const char *) K->data,
|
K_data,
|
||||||
(const char *) V->data,
|
V_data,
|
||||||
mask ? ((const char *) mask->data) : nullptr,
|
mask ? ((const char *) mask->data) : nullptr,
|
||||||
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||||
scale, max_bias, m0, m1, n_head_log2,
|
scale, max_bias, m0, m1, n_head_log2,
|
||||||
@ -676,8 +720,8 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
|
|||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
K->nb[1], K->nb[2], K->nb[3],
|
nb11, nb12, nb13,
|
||||||
V->nb[1], V->nb[2], V->nb[3],
|
nb21, nb22, nb23,
|
||||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
@ -278,13 +278,13 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||||||
constexpr int D = 64;
|
constexpr int D = 64;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
} break;
|
} break;
|
||||||
case 128: {
|
case 128: {
|
||||||
constexpr int D = 128;
|
constexpr int D = 128;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
} break;
|
} break;
|
||||||
default: {
|
default: {
|
||||||
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||||
|
@ -275,13 +275,13 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||||||
constexpr int D = 64;
|
constexpr int D = 64;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
} break;
|
} break;
|
||||||
case 128: {
|
case 128: {
|
||||||
constexpr int D = 128;
|
constexpr int D = 128;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
} break;
|
} break;
|
||||||
default: {
|
default: {
|
||||||
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||||
|
@ -290,7 +290,9 @@ template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml
|
|||||||
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
constexpr int nwarps = D/WARP_SIZE;
|
constexpr int nwarps = D/WARP_SIZE;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
constexpr bool need_f16_K = D != 128;
|
||||||
|
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||||
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int D, ggml_type type_K, ggml_type type_V>
|
template <int D, ggml_type type_K, ggml_type type_V>
|
||||||
|
@ -271,7 +271,9 @@ template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml
|
|||||||
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
constexpr int nwarps = D/WARP_SIZE;
|
constexpr int nwarps = D/WARP_SIZE;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
constexpr bool need_f16_K = D != 128;
|
||||||
|
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||||
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int D, ggml_type type_K, ggml_type type_V>
|
template <int D, ggml_type type_K, ggml_type type_V>
|
||||||
|
@ -438,18 +438,18 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
|||||||
if (4*blocks_num_pb1 < 2*nsm) {
|
if (4*blocks_num_pb1 < 2*nsm) {
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (2*blocks_num_pb1 < 2*nsm) {
|
if (2*blocks_num_pb1 < 2*nsm) {
|
||||||
constexpr int parallel_blocks = 2;
|
constexpr int parallel_blocks = 2;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
constexpr int parallel_blocks = 1;
|
constexpr int parallel_blocks = 1;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \
|
#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \
|
||||||
|
@ -298,17 +298,13 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
|
|||||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * KQV = dst;
|
const ggml_tensor * KQV = dst;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
const ggml_tensor * K = dst->src[1];
|
|
||||||
const ggml_tensor * V = dst->src[2];
|
|
||||||
|
|
||||||
ggml_cuda_set_device(ctx.device);
|
ggml_cuda_set_device(ctx.device);
|
||||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||||
const int32_t precision = KQV->op_params[2];
|
const int32_t precision = KQV->op_params[2];
|
||||||
|
|
||||||
const bool quantized_KV = ggml_is_quantized(K->type) || ggml_is_quantized(V->type);
|
|
||||||
|
|
||||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||||
if (cc >= CC_OFFSET_AMD || quantized_KV) {
|
if (cc >= CC_OFFSET_AMD) {
|
||||||
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||||
} else {
|
} else {
|
||||||
|
Loading…
Reference in New Issue
Block a user