From e8457c90a07067cdc3a2dc6e2ba91119dd0b8e15 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Dec 2023 16:29:52 +0200 Subject: [PATCH] cuda : wip --- ggml-cuda.cu | 67 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 62 insertions(+), 5 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 16d17f801..c2ce1769c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4582,12 +4582,43 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { } } -// TODO: generalize for all quants -template +static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q4_0 * dsti = (block_q4_0 *) cdsti; + + float amax = 0.0f; + float max = 0.0f; + + for (int j = 0; j < QK4_0; ++j) { + const float v = xi[j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + for (int j = 0; j < QK4_0/2; ++j) { + const float x0 = xi[0 + j]*id; + const float x1 = xi[QK4_0/2 + j]*id; + + const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f)); + + dsti->qs[j] = xi0; + dsti->qs[j] |= xi1 << 4; + } +} + +template static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) { - const int i = (blockDim.x*blockIdx.x + threadIdx.x)*QK8_0; + const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; @@ -4600,7 +4631,7 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, const int i12 = i / (ne10*ne11); const int i11 = (i - i12*ne10*ne11) / ne10; - const int i10 = (i - i12*ne10*ne11 - i11*ne10)/QK8_0; + const int i10 = (i - i12*ne10*ne11 - i11*ne10)/qk; const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12; cpy_blck(cx + x_offset, cdst + dst_offset); @@ -5791,7 +5822,29 @@ static void ggml_cpy_f32_q8_0_cuda( GGML_ASSERT(ne % QK8_0 == 0); const int num_blocks = ne / QK8_0; - cpy_f32_q<<>> + cpy_f32_q<<>> + (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +} + +static void ggml_cpy_f32_q4_0_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { + + GGML_ASSERT(ne % QK4_0 == 0); + const int num_blocks = ne / QK4_0; + cpy_f32_q<<>> + (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +} + +static void ggml_cpy_f32_q4_1_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { + + GGML_ASSERT(ne % QK4_1 == 0); + const int num_blocks = ne / QK4_1; + cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); } @@ -7836,6 +7889,10 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { + ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { + ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); } else {