mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 21:37:19 +01:00
CUDA: optimize MMQ int8 tensor core performance (#8062)
* CUDA: optimize MMQ int8 tensor core performance * only a single get_mma_tile_x_k function * simplify code, make functions constexpr
This commit is contained in:
parent
52fc8705a0
commit
9a590c8226
@ -643,7 +643,7 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
|
|||||||
static constexpr int qi = QI3_S;
|
static constexpr int qi = QI3_S;
|
||||||
};
|
};
|
||||||
|
|
||||||
static int get_mmq_x_max_host(const int cc) {
|
static constexpr int get_mmq_x_max_host(int cc) {
|
||||||
#ifdef CUDA_USE_TENSOR_CORES
|
#ifdef CUDA_USE_TENSOR_CORES
|
||||||
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
|
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
|
||||||
#else
|
#else
|
||||||
@ -652,7 +652,7 @@ static int get_mmq_x_max_host(const int cc) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Round rows to this value for --split-mode row:
|
// Round rows to this value for --split-mode row:
|
||||||
static int get_mmq_y_host(const int cc) {
|
static constexpr int get_mmq_y_host(int cc) {
|
||||||
return cc >= CC_VOLTA ? 128 : 64;
|
return cc >= CC_VOLTA ? 128 : 64;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,6 +20,20 @@ struct mma_int_A_I16K4 {
|
|||||||
GGML_CUDA_ASSUME(ret < K);
|
GGML_CUDA_ASSUME(ret < K);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
|
||||||
|
#if defined(INT8_MMA_AVAILABLE)
|
||||||
|
const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
|
||||||
|
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
||||||
|
: "+r"(x[0]), "+r"(x[1])
|
||||||
|
: "l"(xs));
|
||||||
|
#else
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < ne; ++l) {
|
||||||
|
x[l] = xs0[get_i(l)*stride + get_k(l)];
|
||||||
|
}
|
||||||
|
#endif // defined(INT8_MMA_AVAILABLE)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct mma_int_A_I16K8 {
|
struct mma_int_A_I16K8 {
|
||||||
@ -42,6 +56,20 @@ struct mma_int_A_I16K8 {
|
|||||||
GGML_CUDA_ASSUME(ret < K);
|
GGML_CUDA_ASSUME(ret < K);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
|
||||||
|
#if defined(INT8_MMA_AVAILABLE)
|
||||||
|
const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
|
||||||
|
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
||||||
|
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
||||||
|
: "l"(xs));
|
||||||
|
#else
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < ne; ++l) {
|
||||||
|
x[l] = xs0[get_i(l)*stride + get_k(l)];
|
||||||
|
}
|
||||||
|
#endif // defined(INT8_MMA_AVAILABLE)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct mma_int_B_J8K4 {
|
struct mma_int_B_J8K4 {
|
||||||
@ -64,6 +92,20 @@ struct mma_int_B_J8K4 {
|
|||||||
GGML_CUDA_ASSUME(ret < K);
|
GGML_CUDA_ASSUME(ret < K);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
|
||||||
|
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
|
||||||
|
const int * xs = xs0 + (threadIdx.x%J)*stride;
|
||||||
|
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
|
||||||
|
: "+r"(x[0])
|
||||||
|
: "l"(xs));
|
||||||
|
#else
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < ne; ++l) {
|
||||||
|
x[l] = xs0[get_j(l)*stride + get_k(l)];
|
||||||
|
}
|
||||||
|
#endif // defined(INT8_MMA_AVAILABLE)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct mma_int_B_J8K8 {
|
struct mma_int_B_J8K8 {
|
||||||
@ -86,6 +128,20 @@ struct mma_int_B_J8K8 {
|
|||||||
GGML_CUDA_ASSUME(ret < K);
|
GGML_CUDA_ASSUME(ret < K);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
|
||||||
|
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
|
||||||
|
const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
|
||||||
|
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
||||||
|
: "+r"(x[0]), "+r"(x[1])
|
||||||
|
: "l"(xs));
|
||||||
|
#else
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < ne; ++l) {
|
||||||
|
x[l] = xs0[get_j(l)*stride + get_k(l)];
|
||||||
|
}
|
||||||
|
#endif // defined(INT8_MMA_AVAILABLE)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct mma_int_C_I16J8 {
|
struct mma_int_C_I16J8 {
|
||||||
|
1412
ggml-cuda/mmq.cuh
1412
ggml-cuda/mmq.cuh
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user