From 40b521b656d8f4cd19955ca8449c33ee22e90a9a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 15 Jan 2025 10:54:21 +0200 Subject: [PATCH] context : minor ggml-ci --- src/llama-context.cpp | 36 +++++++++++------------------------- src/llama-context.h | 8 +++----- src/llama-kv-cache.cpp | 1 + src/llama-kv-cache.h | 6 +++--- src/llama.cpp | 33 +++++++++++++++++++-------------- 5 files changed, 37 insertions(+), 47 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 0004e214b..9eae6fe57 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -8,30 +8,6 @@ #include #include -void llama_set_k_shift(struct llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; - - assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer)); - - int32_t * data = (int32_t *) lctx.inp_K_shift->data; - - for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].delta; - } -} - -void llama_set_s_copy(struct llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; - - assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); - - int32_t * data = (int32_t *) lctx.inp_s_copy->data; - - for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].src; - } -} - // llama input static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { @@ -58,6 +34,16 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t return relative_bucket; } +void llama_context::set_k_shift(llama_kv_cache & kv) { + assert(ggml_backend_buffer_is_host(inp_K_shift->buffer)); + + int32_t * data = (int32_t *) inp_K_shift->data; + + for (uint32_t i = 0; i < kv.size; ++i) { + data[i] = kv.cells[i].delta; + } +} + void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) { // // set input data @@ -134,7 +120,6 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) { const int64_t n_seq_tokens = ubatch.n_seq_tokens; const int64_t n_seqs = ubatch.n_seqs; - float * data = nullptr; float * data_swa = nullptr; @@ -599,6 +584,7 @@ uint32_t llama_n_ubatch(const struct llama_context * ctx) { } uint32_t llama_n_seq_max(const struct llama_context * ctx) { + // TODO: add notion of n_seq_max to llama_kv_cache and use it here return ctx->kv_self.size; } diff --git a/src/llama-context.h b/src/llama-context.h index a9268b292..73baa711f 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -18,7 +18,7 @@ struct llama_context { llama_context(const llama_model & model) : model(model) , t_start_us(model.t_start_us) - , t_load_us(model.t_load_us) {} + , t_load_us (model.t_load_us) {} const struct llama_model & model; @@ -107,13 +107,11 @@ struct llama_context { struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] + + void set_k_shift(llama_kv_cache & kv); }; // TODO: make these methods of llama_context -void llama_set_k_shift(struct llama_context & lctx); - -void llama_set_s_copy(struct llama_context & lctx); - void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch); // Make sure enough space is available for outputs. diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index d2b81a022..b79c2ff93 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -6,6 +6,7 @@ #include "llama-model.h" #include +#include #include #include #include diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 2e021d4ed..5ffee6281 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -88,11 +88,11 @@ struct llama_kv_cache { void clear(); - bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1); + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1); void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1); void seq_keep(llama_seq_id seq_id); - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); - void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d); + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d); llama_pos seq_pos_max(llama_seq_id seq_id); diff --git a/src/llama.cpp b/src/llama.cpp index 23e7e83f0..5970195af 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1142,18 +1142,18 @@ struct llm_build_context { ctx0 = ggml_init(params); - lctx.inp_tokens = nullptr; - lctx.inp_embd = nullptr; - lctx.inp_pos = nullptr; - lctx.inp_out_ids = nullptr; - lctx.inp_KQ_mask = nullptr; - lctx.inp_KQ_mask_swa = nullptr; - lctx.inp_K_shift = nullptr; - lctx.inp_mean = nullptr; - lctx.inp_cls = nullptr; - lctx.inp_s_copy = nullptr; - lctx.inp_s_mask = nullptr; - lctx.inp_s_seq = nullptr; + lctx.inp_tokens = nullptr; + lctx.inp_embd = nullptr; + lctx.inp_pos = nullptr; + lctx.inp_out_ids = nullptr; + lctx.inp_KQ_mask = nullptr; + lctx.inp_KQ_mask_swa = nullptr; + lctx.inp_K_shift = nullptr; + lctx.inp_mean = nullptr; + lctx.inp_cls = nullptr; + lctx.inp_s_copy = nullptr; + lctx.inp_s_mask = nullptr; + lctx.inp_s_seq = nullptr; lctx.inp_pos_bucket = nullptr; lctx.inp_embd_enc = nullptr; lctx.inp_KQ_mask_cross = nullptr; @@ -1174,9 +1174,11 @@ struct llm_build_context { ggml_set_input(lctx.inp_K_shift); for (int il = 0; il < n_layer; ++il) { - const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_head_kv = hparams.n_head_kv(il); const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + struct ggml_tensor * rope_factors = build_rope_factors(il); + struct ggml_tensor * k = ggml_view_3d(ctx0, kv_self.k_l[il], n_embd_head_k, n_head_kv, n_ctx, @@ -1189,6 +1191,7 @@ struct llm_build_context { // dequantize to f32 -> RoPE -> quantize back tmp = ggml_cast(ctx0, k, GGML_TYPE_F32); cb(tmp, "K_f32", il); + for (auto & backend : lctx.backends) { // Figure out which backend KV cache belongs to if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(kv_self.k_l[il]->buffer))) { @@ -1200,6 +1203,7 @@ struct llm_build_context { lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(tmp, "K_shifted_f32", il); + tmp = ggml_cpy(ctx0, tmp, k); } else { // we rotate only the first n_rot dimensions @@ -1208,6 +1212,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); } cb(tmp, "K_shifted", il); + ggml_build_forward_expand(gf, tmp); } @@ -9201,7 +9206,7 @@ static void llama_kv_self_update_impl(llama_context & lctx) { ggml_backend_sched_alloc_graph(lctx.sched.get(), gf); - llama_set_k_shift(lctx); + lctx.set_k_shift(kv); llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);