diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 9019a849f..3ad0d305d 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4559,6 +4559,53 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, cpy_1(cx + x_offset, cdst + dst_offset); } +static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + block_q8_0 * dsti = (block_q8_0 *) cdsti; + + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = xi[j]; + amax = fmaxf(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dsti->d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = xi[j]*id; + + dsti->qs[j] = roundf(x0); + } +} + +// TODO: generalize for all quants +template +static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) { + const int i = (blockDim.x*blockIdx.x + threadIdx.x)*QK8_0; + + if (i >= ne) { + return; + } + + const int i02 = i / (ne00*ne01); + const int i01 = (i - i02*ne01*ne00) / ne00; + const int i00 = (i - i02*ne01*ne00 - i01*ne00); + const int x_offset = i00*nb00 + i01*nb01 + i02*nb02; + + const int i12 = i / (ne10*ne11); + const int i11 = (i - i12*ne10*ne11) / ne10; + const int i10 = (i - i12*ne10*ne11 - i11*ne10)/QK8_0; + const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12; + + cpy_blck(cx + x_offset, cdst + dst_offset); +} + static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / max(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y)); @@ -5737,6 +5784,17 @@ static void ggml_cpy_f32_f16_cuda( (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); } +static void ggml_cpy_f32_q8_0_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { + + GGML_ASSERT(ne % QK8_0 == 0); + const int num_blocks = ne / QK8_0; + cpy_f32_q<<>> + (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +} + static void ggml_cpy_f16_f16_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, @@ -6093,20 +6151,21 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( const enum ggml_type type = src->type; const int64_t ts = ggml_type_size(type); const int64_t bs = ggml_blck_size(type); - int64_t i1_diff = i1_high - i1_low; + const int64_t i1_diff = i1_high - i1_low; const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; - if (nb0 == ts && nb1 == ts*ne0/bs) { + if (nb0 == ts && nb1 == ts*(ne0/bs)) { return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream); } if (nb0 == ts) { - return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream); + return cudaMemcpy2DAsync(dst_ptr, ts*(ne0/bs), x, nb1, ts*(ne0/bs), i1_diff, kind, stream); } + GGML_ASSERT(bs == 1 && "TODO: implement bs != 1"); for (int64_t i1 = 0; i1 < i1_diff; i1++) { const void * rx = (const void *) ((const char *) x + i1*nb1); - void * rd = (void *) (dst_ptr + i1*ts*ne0/bs); + void * rd = (void *) (dst_ptr + i1*ts*ne0); // pretend the row is a matrix with cols=1 - cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream); + cudaError_t r = cudaMemcpy2DAsync(rd, ts, rx, nb0, ts, ne0, kind, stream); if (r != cudaSuccess) { return r; } } return cudaSuccess; @@ -6533,7 +6592,8 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec( size_t ash; dfloat * src1_dfloat = nullptr; // dfloat == half - bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || + bool src1_convert_f16 = + src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; @@ -7103,10 +7163,9 @@ static void ggml_cuda_op_mul_mat( const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT; const bool src0_is_contiguous = ggml_is_contiguous(src0); - const bool src1_is_contiguous = ggml_is_contiguous(src1); - const int64_t src1_padded_col_size = ne10 % MATRIX_ROW_PADDING == 0 ? - ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING; + + const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING); const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; GGML_ASSERT(!(split && ne02 > 1)); @@ -7231,7 +7290,7 @@ static void ggml_cuda_op_mul_mat( const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs; // for split tensors the data begins at i0 == i0_offset_low - char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * ne01*ne00*src0_ts/src0_bs; + char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10; char * src1_ddq_i = src1_ddq[id] + src1_ddq_i_offset; float * dst_dd_i = dst_dd[id] + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff); @@ -7694,7 +7753,7 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { - if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { + if (ggml_nrows(src1) == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { #ifdef GGML_CUDA_FORCE_DMMV const bool use_mul_mat_vec_q = false; #else @@ -7770,14 +7829,13 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg char * src1_ddc = (char *) src1_extra->data_device[g_main_device]; if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, - ne10, ne11, nb10, nb11, nb12, main_stream); + ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, - ne10, ne11, nb10, nb11, nb12, main_stream); + ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { + ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, - ne10, ne11, nb10, nb11, nb12, main_stream); + ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); } else { fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); @@ -7788,6 +7846,7 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg } static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + // TODO: why do we pass dst as src1 here? ggml_cuda_cpy(src0, dst, nullptr); (void) src1; } diff --git a/llama.cpp b/llama.cpp index ca23408b4..a70e40dba 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1246,7 +1246,6 @@ struct llama_cparams { bool mul_mat_q; bool offload_kqv; - }; struct llama_layer { @@ -1562,7 +1561,7 @@ static bool llama_kv_cache_init( cache.k_l.reserve(n_layer); cache.v_l.reserve(n_layer); - const int i_gpu_start = n_layer - n_gpu_layers; GGML_UNUSED(i_gpu_start); + const int i_gpu_start = (int) n_layer - n_gpu_layers; GGML_UNUSED(i_gpu_start); GGML_UNUSED(offload); @@ -5696,6 +5695,7 @@ static int llama_decode_internal( // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + //kv_self.n = llama_kv_cache_cell_max(kv_self); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);