mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 22:59:24 +01:00
cuda : minor
This commit is contained in:
parent
ef68fac2a8
commit
1846e92a90
30
ggml-cuda.cu
30
ggml-cuda.cu
@ -6399,10 +6399,10 @@ static __global__ void flash_attn_f32(
|
||||
}
|
||||
|
||||
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_bT;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_bT;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
|
||||
#endif
|
||||
|
||||
// based on metal version
|
||||
@ -6443,15 +6443,17 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int iq2 = blockIdx.y;
|
||||
const int iq1 = blockIdx.x * Q;
|
||||
|
||||
const int D2 = D/2;
|
||||
const int D16 = D/16;
|
||||
const int Q16 = Q/16;
|
||||
const int C16 = C/16;
|
||||
|
||||
const int NW = WARP_SIZE;
|
||||
const int SH = (C + Q); // shared memory per simdgroup in (half)
|
||||
|
||||
const int T = D + num_warps*SH; // shared memory size per query in (half)
|
||||
const int T2 = T/2; // shared memory size per query in (half2)
|
||||
const int C2 = C/2;
|
||||
const int D2 = D/2;
|
||||
|
||||
extern __shared__ half __flash_attn_f16_shmem[];
|
||||
// pq
|
||||
@ -6571,7 +6573,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
// Q*K^T
|
||||
{
|
||||
#pragma unroll
|
||||
for (int cc = 0; cc < C/16; ++cc) {
|
||||
for (int cc = 0; cc < C16; ++cc) {
|
||||
half16x16_acc mqk[Q16];
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
nvcuda::wmma::fill_fragment(mqk[j], 0);
|
||||
@ -6684,7 +6686,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
// O = O + (Q*K^T)*V
|
||||
{
|
||||
for (int cc = 0; cc < C/16; ++cc) {
|
||||
for (int cc = 0; cc < C16; ++cc) {
|
||||
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
half16x16_b mv[D16];
|
||||
@ -6707,11 +6709,9 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||
for (int j = 0; j < Q; ++j) {
|
||||
if (lane_id == j) {
|
||||
ss[j*T + 0] = S;
|
||||
ss[j*T + 1] = M[j];
|
||||
}
|
||||
if (lane_id < Q) {
|
||||
ss[lane_id*T + 0] = S;
|
||||
ss[lane_id*T + 1] = M[lane_id];
|
||||
}
|
||||
}
|
||||
|
||||
@ -10939,6 +10939,10 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||
|
||||
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
|
||||
|
||||
// increase shared memory limit to 96KB
|
||||
//const size_t shmem_max = 96*1024;
|
||||
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
|
||||
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
flash_attn_ext_f16<64, NQPB, NCPW>
|
||||
@ -11045,6 +11049,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
Loading…
Reference in New Issue
Block a user