mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 15:18:26 +01:00
cuda : minor
This commit is contained in:
parent
ef68fac2a8
commit
1846e92a90
30
ggml-cuda.cu
30
ggml-cuda.cu
@ -6399,10 +6399,10 @@ static __global__ void flash_attn_f32(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= CC_VOLTA
|
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_bT;
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_bT;
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// based on metal version
|
// based on metal version
|
||||||
@ -6443,15 +6443,17 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
const int iq2 = blockIdx.y;
|
const int iq2 = blockIdx.y;
|
||||||
const int iq1 = blockIdx.x * Q;
|
const int iq1 = blockIdx.x * Q;
|
||||||
|
|
||||||
const int D2 = D/2;
|
|
||||||
const int D16 = D/16;
|
const int D16 = D/16;
|
||||||
const int Q16 = Q/16;
|
const int Q16 = Q/16;
|
||||||
|
const int C16 = C/16;
|
||||||
|
|
||||||
const int NW = WARP_SIZE;
|
const int NW = WARP_SIZE;
|
||||||
const int SH = (C + Q); // shared memory per simdgroup in (half)
|
const int SH = (C + Q); // shared memory per simdgroup in (half)
|
||||||
|
|
||||||
const int T = D + num_warps*SH; // shared memory size per query in (half)
|
const int T = D + num_warps*SH; // shared memory size per query in (half)
|
||||||
const int T2 = T/2; // shared memory size per query in (half2)
|
const int T2 = T/2; // shared memory size per query in (half2)
|
||||||
const int C2 = C/2;
|
const int C2 = C/2;
|
||||||
|
const int D2 = D/2;
|
||||||
|
|
||||||
extern __shared__ half __flash_attn_f16_shmem[];
|
extern __shared__ half __flash_attn_f16_shmem[];
|
||||||
// pq
|
// pq
|
||||||
@ -6571,7 +6573,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
// Q*K^T
|
// Q*K^T
|
||||||
{
|
{
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int cc = 0; cc < C/16; ++cc) {
|
for (int cc = 0; cc < C16; ++cc) {
|
||||||
half16x16_acc mqk[Q16];
|
half16x16_acc mqk[Q16];
|
||||||
for (int j = 0; j < Q16; ++j) {
|
for (int j = 0; j < Q16; ++j) {
|
||||||
nvcuda::wmma::fill_fragment(mqk[j], 0);
|
nvcuda::wmma::fill_fragment(mqk[j], 0);
|
||||||
@ -6684,7 +6686,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
|
|
||||||
// O = O + (Q*K^T)*V
|
// O = O + (Q*K^T)*V
|
||||||
{
|
{
|
||||||
for (int cc = 0; cc < C/16; ++cc) {
|
for (int cc = 0; cc < C16; ++cc) {
|
||||||
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
half16x16_b mv[D16];
|
half16x16_b mv[D16];
|
||||||
@ -6707,11 +6709,9 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||||
for (int j = 0; j < Q; ++j) {
|
if (lane_id < Q) {
|
||||||
if (lane_id == j) {
|
ss[lane_id*T + 0] = S;
|
||||||
ss[j*T + 0] = S;
|
ss[lane_id*T + 1] = M[lane_id];
|
||||||
ss[j*T + 1] = M[j];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -10939,6 +10939,10 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||||||
|
|
||||||
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
|
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
|
||||||
|
|
||||||
|
// increase shared memory limit to 96KB
|
||||||
|
//const size_t shmem_max = 96*1024;
|
||||||
|
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
|
||||||
|
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
case 64:
|
case 64:
|
||||||
flash_attn_ext_f16<64, NQPB, NCPW>
|
flash_attn_ext_f16<64, NQPB, NCPW>
|
||||||
@ -11045,6 +11049,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
Loading…
Reference in New Issue
Block a user