mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 00:39:00 +01:00
cuda : add F32 -> Q8_0 copy kernel
ggml-ci
This commit is contained in:
parent
bcfebf241d
commit
a1bf6c09f8
93
ggml-cuda.cu
93
ggml-cuda.cu
@ -4559,6 +4559,53 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
||||
cpy_1(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
block_q8_0 * dsti = (block_q8_0 *) cdsti;
|
||||
|
||||
float amax = 0.0f; // absolute max
|
||||
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
const float v = xi[j];
|
||||
amax = fmaxf(amax, fabsf(v));
|
||||
}
|
||||
|
||||
const float d = amax / ((1 << 7) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dsti->d = d;
|
||||
|
||||
for (int j = 0; j < QK8_0; ++j) {
|
||||
const float x0 = xi[j]*id;
|
||||
|
||||
dsti->qs[j] = roundf(x0);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: generalize for all quants
|
||||
template <cpy_kernel_t cpy_blck>
|
||||
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
||||
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
|
||||
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*QK8_0;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i02 = i / (ne00*ne01);
|
||||
const int i01 = (i - i02*ne01*ne00) / ne00;
|
||||
const int i00 = (i - i02*ne01*ne00 - i01*ne00);
|
||||
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
|
||||
|
||||
const int i12 = i / (ne10*ne11);
|
||||
const int i11 = (i - i12*ne10*ne11) / ne10;
|
||||
const int i10 = (i - i12*ne10*ne11 - i11*ne10)/QK8_0;
|
||||
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
||||
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||
return 1.0f - min(1.0f, max(0.0f, y));
|
||||
@ -5737,6 +5784,17 @@ static void ggml_cpy_f32_f16_cuda(
|
||||
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q8_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
||||
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK8_0 == 0);
|
||||
const int num_blocks = ne / QK8_0;
|
||||
cpy_f32_q<cpy_blck_f32_q8_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f16_f16_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
||||
@ -6093,20 +6151,21 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
|
||||
const enum ggml_type type = src->type;
|
||||
const int64_t ts = ggml_type_size(type);
|
||||
const int64_t bs = ggml_blck_size(type);
|
||||
int64_t i1_diff = i1_high - i1_low;
|
||||
const int64_t i1_diff = i1_high - i1_low;
|
||||
|
||||
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
|
||||
if (nb0 == ts && nb1 == ts*ne0/bs) {
|
||||
if (nb0 == ts && nb1 == ts*(ne0/bs)) {
|
||||
return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream);
|
||||
}
|
||||
if (nb0 == ts) {
|
||||
return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream);
|
||||
return cudaMemcpy2DAsync(dst_ptr, ts*(ne0/bs), x, nb1, ts*(ne0/bs), i1_diff, kind, stream);
|
||||
}
|
||||
GGML_ASSERT(bs == 1 && "TODO: implement bs != 1");
|
||||
for (int64_t i1 = 0; i1 < i1_diff; i1++) {
|
||||
const void * rx = (const void *) ((const char *) x + i1*nb1);
|
||||
void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
|
||||
void * rd = (void *) (dst_ptr + i1*ts*ne0);
|
||||
// pretend the row is a matrix with cols=1
|
||||
cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream);
|
||||
cudaError_t r = cudaMemcpy2DAsync(rd, ts, rx, nb0, ts, ne0, kind, stream);
|
||||
if (r != cudaSuccess) { return r; }
|
||||
}
|
||||
return cudaSuccess;
|
||||
@ -6533,7 +6592,8 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||
size_t ash;
|
||||
dfloat * src1_dfloat = nullptr; // dfloat == half
|
||||
|
||||
bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
|
||||
bool src1_convert_f16 =
|
||||
src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
|
||||
src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
|
||||
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
|
||||
|
||||
@ -7103,10 +7163,9 @@ static void ggml_cuda_op_mul_mat(
|
||||
|
||||
const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
|
||||
const bool src0_is_contiguous = ggml_is_contiguous(src0);
|
||||
|
||||
const bool src1_is_contiguous = ggml_is_contiguous(src1);
|
||||
const int64_t src1_padded_col_size = ne10 % MATRIX_ROW_PADDING == 0 ?
|
||||
ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
|
||||
|
||||
const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
|
||||
|
||||
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
|
||||
GGML_ASSERT(!(split && ne02 > 1));
|
||||
@ -7231,7 +7290,7 @@ static void ggml_cuda_op_mul_mat(
|
||||
const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
|
||||
|
||||
// for split tensors the data begins at i0 == i0_offset_low
|
||||
char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * ne01*ne00*src0_ts/src0_bs;
|
||||
char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
|
||||
float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10;
|
||||
char * src1_ddq_i = src1_ddq[id] + src1_ddq_i_offset;
|
||||
float * dst_dd_i = dst_dd[id] + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
|
||||
@ -7694,7 +7753,7 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
||||
} else if (src0->type == GGML_TYPE_F32) {
|
||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
|
||||
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
|
||||
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
|
||||
if (ggml_nrows(src1) == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
|
||||
#ifdef GGML_CUDA_FORCE_DMMV
|
||||
const bool use_mul_mat_vec_q = false;
|
||||
#else
|
||||
@ -7770,14 +7829,13 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
|
||||
char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
|
||||
ne10, ne11, nb10, nb11, nb12, main_stream);
|
||||
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
|
||||
ne10, ne11, nb10, nb11, nb12, main_stream);
|
||||
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
|
||||
ne10, ne11, nb10, nb11, nb12, main_stream);
|
||||
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
@ -7788,6 +7846,7 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
|
||||
}
|
||||
|
||||
static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
// TODO: why do we pass dst as src1 here?
|
||||
ggml_cuda_cpy(src0, dst, nullptr);
|
||||
(void) src1;
|
||||
}
|
||||
|
@ -1246,7 +1246,6 @@ struct llama_cparams {
|
||||
|
||||
bool mul_mat_q;
|
||||
bool offload_kqv;
|
||||
|
||||
};
|
||||
|
||||
struct llama_layer {
|
||||
@ -1562,7 +1561,7 @@ static bool llama_kv_cache_init(
|
||||
cache.k_l.reserve(n_layer);
|
||||
cache.v_l.reserve(n_layer);
|
||||
|
||||
const int i_gpu_start = n_layer - n_gpu_layers; GGML_UNUSED(i_gpu_start);
|
||||
const int i_gpu_start = (int) n_layer - n_gpu_layers; GGML_UNUSED(i_gpu_start);
|
||||
|
||||
GGML_UNUSED(offload);
|
||||
|
||||
@ -5696,6 +5695,7 @@ static int llama_decode_internal(
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
|
||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user