mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-04 01:57:53 +01:00
cuda : wip
This commit is contained in:
parent
6b58ae9892
commit
e8457c90a0
67
ggml-cuda.cu
67
ggml-cuda.cu
@ -4582,12 +4582,43 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: generalize for all quants
|
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
|
||||||
template <cpy_kernel_t cpy_blck>
|
const float * xi = (const float *) cxi;
|
||||||
|
block_q4_0 * dsti = (block_q4_0 *) cdsti;
|
||||||
|
|
||||||
|
float amax = 0.0f;
|
||||||
|
float max = 0.0f;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK4_0; ++j) {
|
||||||
|
const float v = xi[j];
|
||||||
|
if (amax < fabsf(v)) {
|
||||||
|
amax = fabsf(v);
|
||||||
|
max = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = max / -8;
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
y[i].d = d;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK4_0/2; ++j) {
|
||||||
|
const float x0 = xi[0 + j]*id;
|
||||||
|
const float x1 = xi[QK4_0/2 + j]*id;
|
||||||
|
|
||||||
|
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
|
||||||
|
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
|
||||||
|
|
||||||
|
dsti->qs[j] = xi0;
|
||||||
|
dsti->qs[j] |= xi1 << 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <cpy_kernel_t cpy_blck, int qk>
|
||||||
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
||||||
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
|
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
|
||||||
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*QK8_0;
|
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
||||||
|
|
||||||
if (i >= ne) {
|
if (i >= ne) {
|
||||||
return;
|
return;
|
||||||
@ -4600,7 +4631,7 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
|||||||
|
|
||||||
const int i12 = i / (ne10*ne11);
|
const int i12 = i / (ne10*ne11);
|
||||||
const int i11 = (i - i12*ne10*ne11) / ne10;
|
const int i11 = (i - i12*ne10*ne11) / ne10;
|
||||||
const int i10 = (i - i12*ne10*ne11 - i11*ne10)/QK8_0;
|
const int i10 = (i - i12*ne10*ne11 - i11*ne10)/qk;
|
||||||
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
|
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
|
||||||
|
|
||||||
cpy_blck(cx + x_offset, cdst + dst_offset);
|
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||||
@ -5791,7 +5822,29 @@ static void ggml_cpy_f32_q8_0_cuda(
|
|||||||
|
|
||||||
GGML_ASSERT(ne % QK8_0 == 0);
|
GGML_ASSERT(ne % QK8_0 == 0);
|
||||||
const int num_blocks = ne / QK8_0;
|
const int num_blocks = ne / QK8_0;
|
||||||
cpy_f32_q<cpy_blck_f32_q8_0><<<num_blocks, 1, 0, stream>>>
|
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||||
|
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_f32_q4_0_cuda(
|
||||||
|
const char * cx, char * cdst, const int ne,
|
||||||
|
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
||||||
|
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(ne % QK4_0 == 0);
|
||||||
|
const int num_blocks = ne / QK4_0;
|
||||||
|
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
|
||||||
|
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cpy_f32_q4_1_cuda(
|
||||||
|
const char * cx, char * cdst, const int ne,
|
||||||
|
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
||||||
|
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(ne % QK4_1 == 0);
|
||||||
|
const int num_blocks = ne / QK4_1;
|
||||||
|
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
|
||||||
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -7836,6 +7889,10 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
|
|||||||
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
|
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
|
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||||
|
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||||
|
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||||
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
|
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
|
||||||
} else {
|
} else {
|
||||||
|
Loading…
Reference in New Issue
Block a user