llama : support quantum K cache (wip)

This commit is contained in:
Georgi Gerganov 2023-12-03 21:31:05 +02:00
parent 66aaac9867
commit d04ee928a2
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 20 additions and 12 deletions

View File

@ -1114,7 +1114,7 @@ void ggml_metal_graph_compute(
!ggml_is_transposed(src1) &&
src1t == GGML_TYPE_F32 &&
ne00 % 32 == 0 && ne00 >= 64 &&
ne11 > ne11_mm_min) {
(ne11 > ne11_mm_min || ne12 > 1)) {
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;

View File

@ -1522,7 +1522,8 @@ struct llama_context {
static bool llama_kv_cache_init(
const struct llama_hparams & hparams,
struct llama_kv_cache & cache,
ggml_type wtype,
ggml_type ktype,
ggml_type vtype,
uint32_t n_ctx,
int n_gpu_layers,
bool offload) {
@ -1541,7 +1542,7 @@ static bool llama_kv_cache_init(
cache.cells.clear();
cache.cells.resize(n_ctx);
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*n_layer*ggml_tensor_overhead());
cache.buf.resize(n_elements*(ggml_type_sizef(ktype) + ggml_type_sizef(vtype)) + 2u*n_layer*ggml_tensor_overhead());
memset(cache.buf.data, 0, cache.buf.size);
struct ggml_init_params params;
@ -1566,8 +1567,8 @@ static bool llama_kv_cache_init(
GGML_UNUSED(offload);
for (int i = 0; i < (int) n_layer; i++) {
ggml_tensor * k = ggml_new_tensor_1d(cache.ctx, wtype, n_embd*n_ctx);
ggml_tensor * v = ggml_new_tensor_1d(cache.ctx, wtype, n_embd*n_ctx);
ggml_tensor * k = ggml_new_tensor_1d(cache.ctx, ktype, n_embd*n_ctx);
ggml_tensor * v = ggml_new_tensor_1d(cache.ctx, vtype, n_embd*n_ctx);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k);
@ -3558,8 +3559,8 @@ static void llm_build_k_shift(
ggml_rope_custom_inplace(ctx,
ggml_view_3d(ctx, kv.k_l[il],
n_embd_head, n_head_kv, n_ctx,
ggml_element_size(kv.k_l[il])*n_embd_head,
ggml_element_size(kv.k_l[il])*n_embd_gqa,
ggml_type_sizef(kv.k_l[il]->type)*n_embd_head,
ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa,
0),
K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
@ -3588,7 +3589,7 @@ static void llm_build_kv_store(
cb(v_cur_t, "v_cur_t", il);
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_gqa,
(ggml_element_size(kv.k_l[il])*n_embd_gqa)*kv_head);
(ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa)*kv_head);
cb(k_cache_view, "k_cache_view", il);
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_gqa,
@ -3747,8 +3748,8 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * k =
ggml_view_3d(ctx, kv.k_l[il],
n_embd_head, n_kv, n_head_kv,
ggml_element_size(kv.k_l[il])*n_embd_gqa,
ggml_element_size(kv.k_l[il])*n_embd_head,
ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa,
ggml_type_sizef(kv.k_l[il]->type)*n_embd_head,
0);
cb(k, "k", il);
@ -8734,11 +8735,18 @@ struct llama_context * llama_new_context_with_model(
ctx->rng = std::mt19937(params.seed);
ctx->logits_all = params.logits_all;
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
//const ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
// TODO: move as params
const ggml_type k_type = GGML_TYPE_Q4_0;
const ggml_type v_type = GGML_TYPE_F16;
GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(k_type) == 0);
GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(v_type) == 0);
// reserve memory for context buffers
if (!hparams.vocab_only) {
if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, cparams.n_ctx, model->n_gpu_layers, cparams.offload_kqv)) {
if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, k_type, v_type, cparams.n_ctx, model->n_gpu_layers, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;