mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 14:20:31 +01:00
CUDA: quantized KV support for FA vec
This commit is contained in:
parent
10b1e45876
commit
672244a88b
@ -106,6 +106,7 @@ set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
|
||||
"llama: max. batch size for using peer access")
|
||||
option(LLAMA_CUDA_NO_PEER_COPY "llama: do not use peer to peer copies" OFF)
|
||||
option(LLAMA_CUDA_NO_VMM "llama: do not try to use CUDA VMM" OFF)
|
||||
option(LLAMA_CUDA_FA_ALL_QUANTS "llama: compile all quants for FlashAttention" OFF)
|
||||
|
||||
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
|
||||
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
|
||||
@ -427,6 +428,9 @@ if (LLAMA_CUDA)
|
||||
if (LLAMA_CUDA_NO_PEER_COPY)
|
||||
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
|
||||
endif()
|
||||
if (LLAMA_CUDA_FA_ALL_QUANTS)
|
||||
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
|
||||
endif()
|
||||
|
||||
if (LLAMA_STATIC)
|
||||
if (WIN32)
|
||||
|
5
Makefile
5
Makefile
@ -493,7 +493,10 @@ ifdef LLAMA_CUDA_NO_PEER_COPY
|
||||
endif # LLAMA_CUDA_NO_PEER_COPY
|
||||
ifdef LLAMA_CUDA_CCBIN
|
||||
MK_NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
|
||||
endif
|
||||
endif # LLAMA_CUDA_CCBIN
|
||||
ifdef LLAMA_CUDA_FA_ALL_QUANTS
|
||||
MK_NVCCFLAGS += -DGGML_CUDA_FA_ALL_QUANTS
|
||||
endif # LLAMA_CUDA_FA_ALL_QUANTS
|
||||
|
||||
ifdef JETSON_EOL_MODULE_DETECT
|
||||
define NVCC_COMPILE
|
||||
|
@ -481,6 +481,7 @@ Building the program with BLAS support may lead to some performance improvements
|
||||
| LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
|
||||
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
|
||||
| LLAMA_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |
|
||||
| LLAMA_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. |
|
||||
|
||||
- #### hipBLAS
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include "common.cuh"
|
||||
#include "vecdotq.cuh"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
@ -34,11 +35,463 @@ typedef void (* fattn_kernel_t)(
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3);
|
||||
|
||||
typedef half (*vec_dot_KQ_f16_t)(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
||||
typedef float (*vec_dot_KQ_f32_t)(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
||||
|
||||
template<typename T, int D>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
#if __CUDA_ARCH__ > MIN_CC_DP4A
|
||||
|
||||
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
|
||||
GGML_UNUSED(Q_v);
|
||||
|
||||
half sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const int ib = k_KQ / QI8_1;
|
||||
const int iqs4 = k_KQ % QI4_0;
|
||||
const int shift = k_KQ & (QI8_1/2);
|
||||
|
||||
const int v = (get_int_from_uint8(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
||||
const int u = Q_q8[k_KQ_0/WARP_SIZE];
|
||||
|
||||
const int sumi = __dp4a(v, u, 0);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||
|
||||
const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
|
||||
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
|
||||
} else
|
||||
#endif // FP16_AVAILABLE
|
||||
{
|
||||
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
||||
|
||||
sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
|
||||
}
|
||||
}
|
||||
|
||||
return sum;
|
||||
#else
|
||||
GGML_UNUSED(K_c);
|
||||
GGML_UNUSED(Q_v);
|
||||
GGML_UNUSED(Q_q8);
|
||||
GGML_UNUSED(Q_ds_v);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
|
||||
}
|
||||
|
||||
template<typename T, int D>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
#if __CUDA_ARCH__ > MIN_CC_DP4A
|
||||
|
||||
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
|
||||
GGML_UNUSED(Q_v);
|
||||
|
||||
T sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const int ib = k_KQ / QI8_1;
|
||||
const int iqs4 = k_KQ % QI4_1;
|
||||
const int shift = k_KQ & (QI8_1/2);
|
||||
|
||||
const int v = (get_int_from_uint8_aligned(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
||||
const int u = Q_q8[k_KQ_0/WARP_SIZE];
|
||||
|
||||
const int sumi = __dp4a(v, u, 0);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||
|
||||
const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
|
||||
const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
|
||||
sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
|
||||
} else
|
||||
#endif // FP16_AVAILABLE
|
||||
{
|
||||
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
||||
|
||||
const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
|
||||
const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
|
||||
|
||||
sum += (T) (sumid4d8 + m4s8scaled);
|
||||
}
|
||||
}
|
||||
|
||||
return sum;
|
||||
#else
|
||||
GGML_UNUSED(K_c);
|
||||
GGML_UNUSED(Q_v);
|
||||
GGML_UNUSED(Q_q8);
|
||||
GGML_UNUSED(Q_ds_v);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
|
||||
}
|
||||
|
||||
template<typename T, int D>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
#if __CUDA_ARCH__ > MIN_CC_DP4A
|
||||
|
||||
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
|
||||
GGML_UNUSED(Q_v);
|
||||
|
||||
T sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const int ib = k_KQ / QI8_1;
|
||||
const int iqs4 = k_KQ % QI5_0;
|
||||
const int iqs8 = k_KQ % QI8_1;
|
||||
const int shift = k_KQ & (QI8_1/2);
|
||||
|
||||
int v = (get_int_from_uint8(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
||||
const int vh = get_int_from_uint8(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
|
||||
v |= (vh << 4) & 0x00000010; // 0 -> 4
|
||||
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
||||
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
||||
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
||||
|
||||
const int u = Q_q8[k_KQ_0/WARP_SIZE];
|
||||
|
||||
const int sumi = __dp4a(v, u, 0);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||
|
||||
const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
|
||||
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
|
||||
} else
|
||||
#endif // FP16_AVAILABLE
|
||||
{
|
||||
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
||||
|
||||
sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
|
||||
}
|
||||
}
|
||||
|
||||
return sum;
|
||||
#else
|
||||
GGML_UNUSED(K_c);
|
||||
GGML_UNUSED(Q_v);
|
||||
GGML_UNUSED(Q_q8);
|
||||
GGML_UNUSED(Q_ds_v);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
|
||||
}
|
||||
|
||||
template<typename T, int D>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
#if __CUDA_ARCH__ > MIN_CC_DP4A
|
||||
|
||||
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
|
||||
GGML_UNUSED(Q_v);
|
||||
|
||||
T sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const int ib = k_KQ / QI8_1;
|
||||
const int iqs4 = k_KQ % QI5_1;
|
||||
const int iqs8 = k_KQ % QI8_1;
|
||||
const int shift = k_KQ & (QI8_1/2);
|
||||
|
||||
int v = (get_int_from_uint8(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
||||
const int vh = get_int_from_uint8(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
|
||||
v |= (vh << 4) & 0x00000010; // 0 -> 4
|
||||
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
||||
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
||||
v |= (vh << 25) & 0x10000000; // 3 -> 28
|
||||
|
||||
const int u = Q_q8[k_KQ_0/WARP_SIZE];
|
||||
|
||||
const int sumi = __dp4a(v, u, 0);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||
|
||||
const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
|
||||
const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
|
||||
sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
|
||||
} else
|
||||
#endif // FP16_AVAILABLE
|
||||
{
|
||||
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
||||
|
||||
const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
|
||||
const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
|
||||
|
||||
sum += (T) (sumid5d8 + m5s8scaled);
|
||||
}
|
||||
}
|
||||
|
||||
return sum;
|
||||
#else
|
||||
GGML_UNUSED(K_c);
|
||||
GGML_UNUSED(Q_v);
|
||||
GGML_UNUSED(Q_q8);
|
||||
GGML_UNUSED(Q_ds_v);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
|
||||
}
|
||||
|
||||
template <typename T, int D>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
||||
#if __CUDA_ARCH__ > MIN_CC_DP4A
|
||||
|
||||
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
|
||||
GGML_UNUSED(Q_v);
|
||||
|
||||
T sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const int ib = k_KQ / QI8_0;
|
||||
const int iqs = k_KQ % QI8_0;
|
||||
|
||||
const int v = get_int_from_int8(K_q8_0[ib].qs, iqs);
|
||||
|
||||
T Q_d;
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||
Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]);
|
||||
} else {
|
||||
const float2 * Q_ds = (const float2 *) Q_ds_v;
|
||||
Q_d = Q_ds[k_KQ_0/WARP_SIZE].x;
|
||||
}
|
||||
|
||||
sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d);
|
||||
}
|
||||
|
||||
return sum;
|
||||
#else
|
||||
GGML_UNUSED(K_c);
|
||||
GGML_UNUSED(Q_v);
|
||||
GGML_UNUSED(Q_q8);
|
||||
GGML_UNUSED(Q_ds_v);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
|
||||
}
|
||||
|
||||
template <typename T, int D>
|
||||
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
|
||||
|
||||
const half2 * K_h2 = (const half2 *) K_c;
|
||||
GGML_UNUSED(Q_q8);
|
||||
GGML_UNUSED(Q_ds_v);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_h2 = (const half2 *) Q_v;
|
||||
|
||||
half2 sum2 = make_half2(0.0f, 0.0f);
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const half2 K_ik = K_h2[k_KQ];
|
||||
sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
|
||||
}
|
||||
|
||||
return __low2half(sum2) + __high2half(sum2);
|
||||
}
|
||||
#endif // FP16_AVAILABLE
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) Q_v;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const half2 K_ik = K_h2[k_KQ];
|
||||
sum += __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x;
|
||||
sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y;
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <typename Tds>
|
||||
static __device__ __forceinline__ void quantize_q8_1_to_shared(
|
||||
const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {
|
||||
|
||||
float vals[sizeof(int)] = {0.0f};
|
||||
#pragma unroll
|
||||
for (int l = 0; l < sizeof(int); ++l) {
|
||||
vals[l] = scale * x[4*threadIdx.x + l];
|
||||
}
|
||||
|
||||
float amax = fabsf(vals[0]);
|
||||
float sum = vals[0];
|
||||
#pragma unroll
|
||||
for (int l = 1; l < sizeof(int); ++l) {
|
||||
amax = fmaxf(amax, fabsf(vals[l]));
|
||||
sum += vals[l];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
|
||||
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32));
|
||||
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32);
|
||||
}
|
||||
|
||||
const float d = amax / 127;
|
||||
int q32 = 0;
|
||||
int8_t * q8 = (int8_t *) &q32;
|
||||
|
||||
if (d != 0.0f) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < sizeof(int); ++l) {
|
||||
q8[l] = roundf(vals[l] / d);
|
||||
}
|
||||
}
|
||||
|
||||
yq32[threadIdx.x] = q32;
|
||||
if (threadIdx.x % QI8_1 == 0) {
|
||||
if (std::is_same<Tds, half2>::value) {
|
||||
((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum);
|
||||
} else {
|
||||
((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef half (*dequantize_1_f16_t)(const void *, const int64_t);
|
||||
typedef float (*dequantize_1_f32_t)(const void *, const int64_t);
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) {
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
|
||||
const int64_t ib = i / QK4_0;
|
||||
const int iqs = i % (QK4_0/2);
|
||||
const int shift = (i % QK4_0) / (QK4_0/2);
|
||||
|
||||
const T d = x[ib].d;
|
||||
const int q0 = x[ib].qs[iqs];
|
||||
const int q = ((q0 >> (4*shift)) & 0x0F) - 8;
|
||||
|
||||
return d*((T) q);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) {
|
||||
const block_q4_1 * x = (const block_q4_1 *) vx;
|
||||
|
||||
const int64_t ib = i / QK4_1;
|
||||
const int iqs = i % (QK4_1/2);
|
||||
const int shift = (i % QK4_1) / (QK4_1/2);
|
||||
|
||||
const half2 dm = x[ib].dm;
|
||||
const int q0 = x[ib].qs[iqs];
|
||||
const int q = ((q0 >> (4*shift)) & 0x0F);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
return __low2half(dm)*((half) q) + __high2half(dm);
|
||||
}
|
||||
#endif // FP16_AVAILABLE
|
||||
|
||||
return __low2float(dm)*((float) q) + __high2float(dm);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ vx, const int64_t i) {
|
||||
const block_q5_0 * x = (const block_q5_0 *) vx;
|
||||
|
||||
const int64_t ib = i / QK5_0;
|
||||
const int idq = i % QK5_0;
|
||||
const int iqs = i % (QK5_0/2);
|
||||
const int shift = (i % QK5_0) / (QK5_0/2);
|
||||
|
||||
const T d = x[ib].d;
|
||||
const int ql0 = x[ib].qs[iqs];
|
||||
const int qh0 = get_int_from_uint8(x[ib].qh, 0);
|
||||
const int ql = ((ql0 >> (4*shift)) & 0x0F);
|
||||
const int qh = ((qh0 >> idq) << 4) & 0x10;
|
||||
const int q = (ql | qh) - 16;
|
||||
|
||||
return d*((T) q);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ vx, const int64_t i) {
|
||||
const block_q5_1 * x = (const block_q5_1 *) vx;
|
||||
|
||||
const int64_t ib = i / QK5_1;
|
||||
const int idq = i % QK5_1;
|
||||
const int iqs = i % (QK5_1/2);
|
||||
const int shift = (i % QK5_1) / (QK5_1/2);
|
||||
|
||||
const half2 dm = x[ib].dm;
|
||||
const int ql0 = x[ib].qs[iqs];
|
||||
const int qh0 = get_int_from_uint8_aligned(x[ib].qh, 0);
|
||||
const int ql = ((ql0 >> (4*shift)) & 0x0F);
|
||||
const int qh = ((qh0 >> idq) << 4) & 0x10;
|
||||
const int q = (ql | qh);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
return __low2half(dm)*((half) q) + __high2half(dm);
|
||||
}
|
||||
#endif // FP16_AVAILABLE
|
||||
|
||||
return __low2float(dm)*((float) q) + __high2float(dm);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) {
|
||||
const block_q8_0 * x = (const block_q8_0 *) vx;
|
||||
|
||||
const int64_t ib = i / QK8_0;
|
||||
const int iqs = i % QK8_0;
|
||||
|
||||
const T d = x[ib].d;
|
||||
const int q = x[ib].qs[iqs];
|
||||
|
||||
return d*((T) q);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) {
|
||||
const half * x = (const half *) vx;
|
||||
|
||||
return x[i];
|
||||
}
|
||||
|
||||
template<int D, int parallel_blocks> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
@ -83,6 +536,45 @@ static __global__ void flash_attn_combine_results(
|
||||
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||
}
|
||||
|
||||
// Aliases for FATTN_VEC_CASE macro:
|
||||
static constexpr ggml_type ggml_type_q4_0 = GGML_TYPE_Q4_0;
|
||||
static constexpr ggml_type ggml_type_q4_1 = GGML_TYPE_Q4_1;
|
||||
static constexpr ggml_type ggml_type_q5_0 = GGML_TYPE_Q5_0;
|
||||
static constexpr ggml_type ggml_type_q5_1 = GGML_TYPE_Q5_1;
|
||||
static constexpr ggml_type ggml_type_q8_0 = GGML_TYPE_Q8_0;
|
||||
static constexpr ggml_type ggml_type_f16 = GGML_TYPE_F16;
|
||||
|
||||
typedef half f16;
|
||||
typedef float f32;
|
||||
|
||||
#define FATTN_VEC_CASE(type_VKQ, D, type_suffix_K, type_suffix_V) \
|
||||
if (Q->ne[0] == (D) && K->type == ggml_type_##type_suffix_K && V->type == ggml_type_##type_suffix_V) { \
|
||||
constexpr int nwarps = (D)/WARP_SIZE; \
|
||||
constexpr bool Q_q8_1 = ggml_type_##type_suffix_K != GGML_TYPE_F16; \
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_##type_VKQ< \
|
||||
(D), cols_per_block, parallel_blocks, \
|
||||
vec_dot_fattn_vec_KQ_##type_suffix_K<type_VKQ, (D)>, Q_q8_1, dequantize_1_##type_suffix_V<type_VKQ>>; \
|
||||
launch_fattn<(D), parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); \
|
||||
return; \
|
||||
} \
|
||||
|
||||
static void on_no_fattn_vec_case(const int D) {
|
||||
if (D == 64) {
|
||||
fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
|
||||
fprintf(stderr, "By default only f16 KV cache is supported.\n");
|
||||
fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
|
||||
GGML_ASSERT(false);
|
||||
} else {
|
||||
fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
|
||||
fprintf(stderr, "Supported combinations:\n");
|
||||
fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
|
||||
fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
|
||||
fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
|
||||
fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
template <int D, int parallel_blocks>
|
||||
void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
@ -94,8 +586,6 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
|
||||
ggml_tensor * KQV = dst;
|
||||
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(K->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||
@ -143,6 +633,7 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
V->nb[1], V->nb[2], V->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
@ -36,6 +36,9 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
|
@ -36,6 +36,9 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
|
@ -2,7 +2,7 @@
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-vec-f16.cuh"
|
||||
|
||||
template<int D, int ncols, int parallel_blocks> // D == head size
|
||||
template<int D, int ncols, int parallel_blocks, vec_dot_KQ_f16_t vec_dot_KQ, bool Q_q8_1, dequantize_1_f16_t dequantize_1_v> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
@ -34,6 +34,9 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
@ -45,13 +48,11 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
Q += nb02* blockIdx.y + nb01*ic0;
|
||||
K += nb12*(blockIdx.y / gqa_ratio);
|
||||
V += nb22*(blockIdx.y / gqa_ratio);
|
||||
|
||||
const int stride_KV = nb11 / sizeof(half);
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
@ -62,10 +63,6 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
__builtin_assume(tid < D);
|
||||
|
||||
__shared__ half KQ[ncols*D];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
KQ[j*D + tid] = -HALF_MAX_HALF;
|
||||
}
|
||||
half2 * KQ2 = (half2 *) KQ;
|
||||
|
||||
half kqmax[ncols];
|
||||
@ -86,17 +83,76 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Convert Q to half2 and store in registers:
|
||||
half2 Q_h2[ncols][D/(2*WARP_SIZE)];
|
||||
// Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||
half2 Q_h2[ncols][D/(2*WARP_SIZE)];
|
||||
int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)];
|
||||
half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
|
||||
if (Q_q8_1) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
// Reuse KQ as temporary storage for converting Q to q8_1:
|
||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int));
|
||||
|
||||
// Set memory to zero if out of bounds:
|
||||
if (ncols > 2 && ic0 + j >= ne01) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
tmp_q_i32[i] = 0;
|
||||
}
|
||||
if (threadIdx.x < D/QK8_1) {
|
||||
tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const float * Q_f = (const float *) (Q + j*nb01);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
||||
quantize_q8_1_to_shared<half2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int));
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
|
||||
Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
|
||||
Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
|
||||
Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
||||
}
|
||||
KQ[j*D + tid] = -HALF_MAX_HALF;
|
||||
}
|
||||
|
||||
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
||||
@ -123,22 +179,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
break;
|
||||
}
|
||||
|
||||
half2 sum2[ncols] = {{0.0f, 0.0f}};
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
sum2[j] = warp_reduce_sum(sum2[j]);
|
||||
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
|
||||
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
|
||||
sum = warp_reduce_sum(sum);
|
||||
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||
|
||||
if (ncols == 1) {
|
||||
@ -189,8 +233,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
}
|
||||
|
||||
half2 V_k;
|
||||
reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
|
||||
reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
|
||||
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid);
|
||||
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
|
||||
@ -248,19 +292,22 @@ void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
case 64: {
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<
|
||||
D, cols_per_block, parallel_blocks, vec_dot_fattn_vec_KQ_f16<half, D>, false, dequantize_1_f16<half>>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<
|
||||
D, cols_per_block, parallel_blocks, vec_dot_fattn_vec_KQ_f16<half, D>, false, dequantize_1_f16<half>>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
||||
} break;
|
||||
case 256: {
|
||||
constexpr int D = 256;
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<
|
||||
D, cols_per_block, parallel_blocks, vec_dot_fattn_vec_KQ_f16<half, D>, false, dequantize_1_f16<half>>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
||||
} break;
|
||||
default:
|
||||
@ -272,57 +319,100 @@ void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tens
|
||||
template <int cols_per_block, int parallel_blocks>
|
||||
void launch_fattn_vec_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
switch (Q->ne[0]) {
|
||||
case 64: {
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||
} break;
|
||||
}
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
#ifdef GGML_CUDA_FA_ALL_QUANTS
|
||||
FATTN_VEC_CASE(f16, 64, f16, q4_0)
|
||||
FATTN_VEC_CASE(f16, 64, f16, q4_1)
|
||||
FATTN_VEC_CASE(f16, 64, f16, q5_0)
|
||||
FATTN_VEC_CASE(f16, 64, f16, q5_1)
|
||||
FATTN_VEC_CASE(f16, 64, f16, q8_0)
|
||||
FATTN_VEC_CASE(f16, 64, f16, f16)
|
||||
|
||||
FATTN_VEC_CASE(f16, 128, q4_0, q4_0)
|
||||
FATTN_VEC_CASE(f16, 128, q4_0, q4_1)
|
||||
FATTN_VEC_CASE(f16, 128, q4_0, q5_0)
|
||||
FATTN_VEC_CASE(f16, 128, q4_0, q5_1)
|
||||
FATTN_VEC_CASE(f16, 128, q4_0, q8_0)
|
||||
FATTN_VEC_CASE(f16, 128, q4_0, f16)
|
||||
|
||||
FATTN_VEC_CASE(f16, 128, q4_1, q4_0)
|
||||
FATTN_VEC_CASE(f16, 128, q4_1, q4_1)
|
||||
FATTN_VEC_CASE(f16, 128, q4_1, q5_0)
|
||||
FATTN_VEC_CASE(f16, 128, q4_1, q5_1)
|
||||
FATTN_VEC_CASE(f16, 128, q4_1, q8_0)
|
||||
FATTN_VEC_CASE(f16, 128, q4_1, f16)
|
||||
|
||||
FATTN_VEC_CASE(f16, 128, q5_0, q4_0)
|
||||
FATTN_VEC_CASE(f16, 128, q5_0, q4_1)
|
||||
FATTN_VEC_CASE(f16, 128, q5_0, q5_0)
|
||||
FATTN_VEC_CASE(f16, 128, q5_0, q5_1)
|
||||
FATTN_VEC_CASE(f16, 128, q5_0, q8_0)
|
||||
FATTN_VEC_CASE(f16, 128, q5_0, f16)
|
||||
|
||||
FATTN_VEC_CASE(f16, 128, q5_1, q4_0)
|
||||
FATTN_VEC_CASE(f16, 128, q5_1, q4_1)
|
||||
FATTN_VEC_CASE(f16, 128, q5_1, q5_0)
|
||||
FATTN_VEC_CASE(f16, 128, q5_1, q5_1)
|
||||
FATTN_VEC_CASE(f16, 128, q5_1, q8_0)
|
||||
FATTN_VEC_CASE(f16, 128, q5_1, f16)
|
||||
|
||||
FATTN_VEC_CASE(f16, 128, q8_0, q4_0)
|
||||
FATTN_VEC_CASE(f16, 128, q8_0, q4_1)
|
||||
FATTN_VEC_CASE(f16, 128, q8_0, q5_0)
|
||||
FATTN_VEC_CASE(f16, 128, q8_0, q5_1)
|
||||
FATTN_VEC_CASE(f16, 128, q8_0, q8_0)
|
||||
FATTN_VEC_CASE(f16, 128, q8_0, f16)
|
||||
|
||||
FATTN_VEC_CASE(f16, 128, f16, q4_0)
|
||||
FATTN_VEC_CASE(f16, 128, f16, q4_1)
|
||||
FATTN_VEC_CASE(f16, 128, f16, q5_0)
|
||||
FATTN_VEC_CASE(f16, 128, f16, q5_1)
|
||||
FATTN_VEC_CASE(f16, 128, f16, q8_0)
|
||||
FATTN_VEC_CASE(f16, 128, f16, f16)
|
||||
#else
|
||||
FATTN_VEC_CASE(f16, 128, q4_0, q4_0)
|
||||
FATTN_VEC_CASE(f16, 128, q8_0, q8_0)
|
||||
FATTN_VEC_CASE(f16, 128, f16, f16)
|
||||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||
|
||||
on_no_fattn_vec_case(Q->ne[0]);
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
const int32_t precision = KQV->op_params[2];
|
||||
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
||||
|
||||
if (Q->ne[1] == 1) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
return;
|
||||
}
|
||||
// if (Q->ne[1] == 1) {
|
||||
// constexpr int cols_per_block = 1;
|
||||
// constexpr int parallel_blocks = 4;
|
||||
// launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
// return;
|
||||
// }
|
||||
|
||||
if (Q->ne[1] == 2) {
|
||||
constexpr int cols_per_block = 2;
|
||||
constexpr int parallel_blocks = 4;
|
||||
launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
// if (Q->ne[1] == 2) {
|
||||
// constexpr int cols_per_block = 2;
|
||||
// constexpr int parallel_blocks = 4;
|
||||
// launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
// return;
|
||||
// }
|
||||
|
||||
if (Q->ne[1] <= 4) {
|
||||
constexpr int cols_per_block = 4;
|
||||
constexpr int parallel_blocks = 4;
|
||||
launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
// if (Q->ne[1] <= 4) {
|
||||
// constexpr int cols_per_block = 4;
|
||||
// constexpr int parallel_blocks = 4;
|
||||
// launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
// return;
|
||||
// }
|
||||
|
||||
if (Q->ne[1] <= 8) {
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 4;
|
||||
launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
// if (Q->ne[1] <= 8) {
|
||||
// constexpr int cols_per_block = 8;
|
||||
// constexpr int parallel_blocks = 4;
|
||||
// launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
// return;
|
||||
// }
|
||||
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 1;
|
||||
|
@ -2,7 +2,7 @@
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-vec-f32.cuh"
|
||||
|
||||
template<int D, int ncols, int parallel_blocks> // D == head size
|
||||
template<int D, int ncols, int parallel_blocks, vec_dot_KQ_f32_t vec_dot_KQ, bool Q_q8_1, dequantize_1_f32_t dequantize_1_v> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
@ -34,6 +34,9 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
@ -44,13 +47,10 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
|
||||
const int stride_KV = nb11 / sizeof(half);
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
Q += nb02* blockIdx.y + nb01*ic0;
|
||||
K += nb12*(blockIdx.y / gqa_ratio);
|
||||
V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
|
||||
@ -83,17 +83,69 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Convert Q to half2 and store in registers:
|
||||
float2 Q_h2[ncols][D/(2*WARP_SIZE)];
|
||||
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||
float2 Q_f2[ncols][D/(2*WARP_SIZE)];
|
||||
int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)];
|
||||
float2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
|
||||
if (Q_q8_1) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
Q_h2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
|
||||
Q_h2[j][i0/WARP_SIZE].x *= scale;
|
||||
Q_h2[j][i0/WARP_SIZE].y *= scale;
|
||||
// Reuse KQ as temporary storage for converting Q to q8_1:
|
||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
|
||||
|
||||
// Set memory to zero if out of bounds:
|
||||
if (ncols > 2 && ic0 + j >= ne01) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
tmp_q_i32[i] = 0;
|
||||
}
|
||||
if (threadIdx.x < D/QK8_1) {
|
||||
tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const float * Q_f = (const float *) (Q + j*nb01);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
||||
quantize_q8_1_to_shared<float2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
|
||||
Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
|
||||
Q_f2[j][i0/WARP_SIZE].x *= scale;
|
||||
Q_f2[j][i0/WARP_SIZE].y *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -117,28 +169,16 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
break;
|
||||
}
|
||||
|
||||
float sum[ncols] = {0.0f};
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
sum[j] += __low2float(K_ik) * Q_h2[j][k_KQ_0/WARP_SIZE].x;
|
||||
sum[j] += __high2float(K_ik) * Q_h2[j][k_KQ_0/WARP_SIZE].y;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
sum[j] = warp_reduce_sum(sum[j]);
|
||||
sum[j] += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
|
||||
sum = warp_reduce_sum(sum);
|
||||
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
|
||||
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum[j]);
|
||||
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
KQ[j*D + i_KQ] = sum[j];
|
||||
KQ[j*D + i_KQ] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -178,7 +218,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
break;
|
||||
}
|
||||
|
||||
const float V_ki = __half2float(V_h[(k_VKQ_0 + k)*stride_KV + tid]);
|
||||
const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j] += V_ki*KQ[j*D + k];
|
||||
@ -223,23 +263,65 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
template <int cols_per_block, int parallel_blocks>
|
||||
void launch_fattn_vec_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
switch (Q->ne[0]) {
|
||||
case 64: {
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||
} break;
|
||||
}
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
#ifdef GGML_CUDA_FA_ALL_QUANTS
|
||||
FATTN_VEC_CASE(f32, 64, f16, q4_0)
|
||||
FATTN_VEC_CASE(f32, 64, f16, q4_1)
|
||||
FATTN_VEC_CASE(f32, 64, f16, q5_0)
|
||||
FATTN_VEC_CASE(f32, 64, f16, q5_1)
|
||||
FATTN_VEC_CASE(f32, 64, f16, q8_0)
|
||||
FATTN_VEC_CASE(f32, 64, f16, f16)
|
||||
|
||||
FATTN_VEC_CASE(f32, 128, q4_0, q4_0)
|
||||
FATTN_VEC_CASE(f32, 128, q4_0, q4_1)
|
||||
FATTN_VEC_CASE(f32, 128, q4_0, q5_0)
|
||||
FATTN_VEC_CASE(f32, 128, q4_0, q5_1)
|
||||
FATTN_VEC_CASE(f32, 128, q4_0, q8_0)
|
||||
FATTN_VEC_CASE(f32, 128, q4_0, f16)
|
||||
|
||||
FATTN_VEC_CASE(f32, 128, q4_1, q4_0)
|
||||
FATTN_VEC_CASE(f32, 128, q4_1, q4_1)
|
||||
FATTN_VEC_CASE(f32, 128, q4_1, q5_0)
|
||||
FATTN_VEC_CASE(f32, 128, q4_1, q5_1)
|
||||
FATTN_VEC_CASE(f32, 128, q4_1, q8_0)
|
||||
FATTN_VEC_CASE(f32, 128, q4_1, f16)
|
||||
|
||||
FATTN_VEC_CASE(f32, 128, q5_0, q4_0)
|
||||
FATTN_VEC_CASE(f32, 128, q5_0, q4_1)
|
||||
FATTN_VEC_CASE(f32, 128, q5_0, q5_0)
|
||||
FATTN_VEC_CASE(f32, 128, q5_0, q5_1)
|
||||
FATTN_VEC_CASE(f32, 128, q5_0, q8_0)
|
||||
FATTN_VEC_CASE(f32, 128, q5_0, f16)
|
||||
|
||||
FATTN_VEC_CASE(f32, 128, q5_1, q4_0)
|
||||
FATTN_VEC_CASE(f32, 128, q5_1, q4_1)
|
||||
FATTN_VEC_CASE(f32, 128, q5_1, q5_0)
|
||||
FATTN_VEC_CASE(f32, 128, q5_1, q5_1)
|
||||
FATTN_VEC_CASE(f32, 128, q5_1, q8_0)
|
||||
FATTN_VEC_CASE(f32, 128, q5_1, f16)
|
||||
|
||||
FATTN_VEC_CASE(f32, 128, q8_0, q4_0)
|
||||
FATTN_VEC_CASE(f32, 128, q8_0, q4_1)
|
||||
FATTN_VEC_CASE(f32, 128, q8_0, q5_0)
|
||||
FATTN_VEC_CASE(f32, 128, q8_0, q5_1)
|
||||
FATTN_VEC_CASE(f32, 128, q8_0, q8_0)
|
||||
FATTN_VEC_CASE(f32, 128, q8_0, f16)
|
||||
|
||||
FATTN_VEC_CASE(f32, 128, f16, q4_0)
|
||||
FATTN_VEC_CASE(f32, 128, f16, q4_1)
|
||||
FATTN_VEC_CASE(f32, 128, f16, q5_0)
|
||||
FATTN_VEC_CASE(f32, 128, f16, q5_1)
|
||||
FATTN_VEC_CASE(f32, 128, f16, q8_0)
|
||||
FATTN_VEC_CASE(f32, 128, f16, f16)
|
||||
#else
|
||||
FATTN_VEC_CASE(f32, 128, q4_0, q4_0)
|
||||
FATTN_VEC_CASE(f32, 128, q8_0, q8_0)
|
||||
FATTN_VEC_CASE(f32, 128, f16, f16)
|
||||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||
|
||||
on_no_fattn_vec_case(Q->ne[0]);
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
@ -45,6 +45,9 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
@ -457,14 +460,18 @@ void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int32_t precision = KQV->op_params[2];
|
||||
|
||||
const bool quantized_KV = ggml_is_quantized(K->type) || ggml_is_quantized(V->type);
|
||||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
if (cc >= CC_OFFSET_AMD) {
|
||||
if (precision == GGML_PREC_DEFAULT) {
|
||||
if (cc >= CC_OFFSET_AMD || quantized_KV) {
|
||||
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
|
@ -386,7 +386,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
|
||||
u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
|
||||
}
|
||||
|
||||
return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
|
||||
return vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
|
||||
(&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
|
||||
}
|
||||
|
||||
@ -547,7 +547,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
|
||||
const float * x_dmf = (const float *) x_dm;
|
||||
const float * y_df = (const float *) y_ds;
|
||||
|
||||
return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
|
||||
return vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
|
||||
(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
|
||||
y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
|
||||
}
|
||||
|
@ -180,8 +180,8 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
|
||||
#define VDR_Q8_0_Q8_1_MMVQ 2
|
||||
#define VDR_Q8_0_Q8_1_MMQ 8
|
||||
|
||||
template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl(
|
||||
const int * v, const int * u, const float & d8_0, const float & d8_1) {
|
||||
template <typename T, int vdr> static __device__ __forceinline__ T vec_dot_q8_0_q8_1_impl(
|
||||
const int * v, const int * u, const T & d8_0, const T & d8_1) {
|
||||
|
||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
int sumi = 0;
|
||||
@ -192,7 +192,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp
|
||||
sumi = __dp4a(v[i], u[i], sumi);
|
||||
}
|
||||
|
||||
return d8_0*d8_1 * sumi;
|
||||
return d8_0*d8_1 * ((T) sumi);
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
@ -656,7 +656,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
|
||||
u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
|
||||
}
|
||||
|
||||
return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));
|
||||
return vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
|
||||
|
Loading…
Reference in New Issue
Block a user