diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 910e2243d..daea125fe 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -32,6 +32,38 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t return relative_bucket; } +enum ggml_status llama_context::compute_graph( + ggml_cgraph * graph, + bool batched) { + int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads; + ggml_threadpool_t tp = batched ? threadpool_batch : threadpool; + + if (backend_cpu != nullptr) { + auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu)); + auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool"); + set_threadpool_fn(backend_cpu, tp); + } + + // set the number of threads for all the backends + for (const auto & set_n_threads_fn : set_n_threads_fns) { + set_n_threads_fn.second(set_n_threads_fn.first, n_threads); + } + + auto status = ggml_backend_sched_graph_compute_async(sched.get(), graph); + if (status != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); + } + + // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched)); + + return status; +} + + +llama_pos llama_context::pos_max() const { + return kv_self.pos_max(); +} + // TODO: improve void llama_context::reset() { inp_tokens = nullptr; @@ -540,6 +572,93 @@ ggml_tensor * llama_context::build_lora_mm_id( return res; } +bool llama_context::kv_self_update() { + bool need_reserve = false; + + auto & kv = kv_self; + + if (kv.has_shift) { + if (!kv.can_shift) { + GGML_ABORT("The current context does not support K-shift"); + } + + // apply K-shift if needed + if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { + prepare_k_shift(); + + ggml_backend_sched_reset(sched.get()); + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute_meta.size(), + /*.mem_buffer =*/ buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + ggml_context * ctx0 = ggml_init(params); + + reset(); + + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); + + build_k_shift(ctx0, gf); + + ggml_backend_sched_alloc_graph(sched.get(), gf); + + set_inputs({}); + + compute_graph(gf, false); + + ggml_free(ctx0); + + need_reserve = true; + } + + { + kv.has_shift = false; + + for (uint32_t i = 0; i < kv.size; ++i) { + kv.cells[i].delta = 0; + } + } + } + + // defragment the KV cache if needed + if (kv.do_defrag) { + prepare_defrag(); + + ggml_backend_sched_reset(sched.get()); + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute_meta.size(), + /*.mem_buffer =*/ buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + ggml_context * ctx0 = ggml_init(params); + + reset(); + + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); + + build_defrag(ctx0, gf); + + ggml_backend_sched_alloc_graph(sched.get(), gf); + + // no input + //set_inputs({}); + + compute_graph(gf, false); + + ggml_free(ctx0); + + need_reserve = true; + + kv.do_defrag = false; + } + + return need_reserve; +} + void llama_context::build_attn_inp( ggml_context * ctx0, int32_t n_tokens, diff --git a/src/llama-context.h b/src/llama-context.h index a2f41b5c8..bc33fc6ef 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -79,6 +79,13 @@ struct llama_context { ggml_abort_callback abort_callback = nullptr; void * abort_callback_data = nullptr; + // returns the result of ggml_backend_sched_graph_compute_async execution + enum ggml_status compute_graph( + ggml_cgraph * graph, + bool batched); + + llama_pos pos_max() const; + void reset(); void prepare_k_shift(); @@ -129,6 +136,9 @@ struct llama_context { struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] struct ggml_tensor * inp_K_shift; // I32 [kv_size] + // return true if need to reserve new worst-case graph + bool kv_self_update(); + void build_attn_inp( ggml_context * ctx0, int32_t n_tokens, diff --git a/src/llama.cpp b/src/llama.cpp index 3abc9a0b2..2b2d8f4b1 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -110,7 +110,6 @@ struct llm_build_context { const llama_hparams & hparams; const llama_cparams & cparams; const llama_ubatch & ubatch; - //const llama_kv_cache & kv_self; const llama_adapter_cvec & cvec; const llama_loras & loras; @@ -137,8 +136,6 @@ struct llm_build_context { const float norm_rms_eps; const int32_t n_tokens; - //const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) - //const int32_t kv_head; // index of where we store new KV data in the cache const int32_t n_outputs; const int32_t n_outputs_enc; const int32_t n_ctx_orig; @@ -166,7 +163,6 @@ struct llm_build_context { hparams (model.hparams), cparams (lctx.cparams), ubatch (ubatch), - //kv_self (lctx.kv_self), cvec (lctx.cvec), loras (lctx.loras), n_embd (hparams.n_embd), @@ -190,8 +186,6 @@ struct llm_build_context { norm_eps (hparams.f_norm_eps), norm_rms_eps (hparams.f_norm_rms_eps), n_tokens (ubatch.n_tokens), - //n_kv (worst_case ? kv_self.size : kv_self.n), - //kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_outputs (worst_case ? n_tokens : lctx.n_outputs), n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd), n_ctx_orig (cparams.n_ctx_orig_yarn), @@ -7532,40 +7526,6 @@ struct llm_build_context { } }; -static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx) { - llama_ubatch dummy = {}; - dummy.equal_seqs = true; - - llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - - struct llm_build_context llm(lctx, dummy, cb, false); - - llm.init(); - - struct ggml_cgraph * result = llm.build_defrag(); - - llm.free(); - - return result; -} - -static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { - llama_ubatch dummy = {}; - dummy.equal_seqs = true; - - llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - - struct llm_build_context llm(lctx, dummy, cb, false); - - llm.init(); - - struct ggml_cgraph * result = llm.build_k_shift(); - - llm.free(); - - return result; -} - static struct ggml_cgraph * llama_build_graph( llama_context & lctx, const llama_ubatch & ubatch, @@ -7836,33 +7796,6 @@ static struct ggml_cgraph * llama_build_graph( return result; } -// returns the result of ggml_backend_sched_graph_compute_async execution -static enum ggml_status llama_graph_compute( - llama_context & lctx, - ggml_cgraph * gf, - int n_threads, - ggml_threadpool * threadpool) { - if (lctx.backend_cpu != nullptr) { - auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(lctx.backend_cpu)); - auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool"); - set_threadpool_fn(lctx.backend_cpu, threadpool); - } - - // set the number of threads for all the backends - for (const auto & set_n_threads_fn : lctx.set_n_threads_fns) { - set_n_threads_fn.second(set_n_threads_fn.first, n_threads); - } - - auto status = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf); - if (status != GGML_STATUS_SUCCESS) { - LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); - } - - // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched)); - - return status; -} - // decode a batch of tokens by evaluating the transformer // in case of unsuccessful decoding (error or warning), // the kv_cache state will be returned to its original state @@ -7887,7 +7820,7 @@ static int llama_decode_impl( } // temporary allocate memory for the input batch if needed - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.pos_max() + 1); + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.pos_max() + 1); const llama_batch & batch = batch_allocr.batch; const uint32_t n_tokens_all = batch.n_tokens; @@ -7989,16 +7922,11 @@ static int llama_decode_impl( lctx.n_outputs = n_outputs_new; } - int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; - ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch; - - GGML_ASSERT(n_threads > 0); - lctx.prepare_decode(ubatch); // non-causal masks do not use the KV cache if (hparams.causal_attn) { - llama_kv_self_update(&lctx); // TODO: lctx->kv_self_update() + llama_kv_self_update(&lctx); // if we have enough unused cells before the current head -> // better to start searching from the beginning of the cache, hoping to fill it @@ -8058,7 +7986,7 @@ static int llama_decode_impl( GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor"); } - const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool); + const auto compute_status = lctx.compute_graph(gf, n_tokens > 1); if (compute_status != GGML_STATUS_SUCCESS) { kv_slot_restorer.restore(kv_self); switch (compute_status) { @@ -8226,7 +8154,7 @@ static int llama_encode_impl( } // temporary allocate memory for the input batch if needed - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.pos_max() + 1); + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.pos_max() + 1); const llama_batch & batch = batch_allocr.batch; const uint32_t n_tokens = batch.n_tokens; @@ -8274,11 +8202,6 @@ static int llama_encode_impl( lctx.inp_embd_enc = NULL; lctx.n_outputs = n_tokens; - int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; - ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch; - - GGML_ASSERT(n_threads > 0); - lctx.prepare_decode(ubatch); ggml_backend_sched_reset(lctx.sched.get()); @@ -8310,7 +8233,7 @@ static int llama_encode_impl( } } - const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool); + const auto compute_status = lctx.compute_graph(gf, n_tokens > 1); switch (compute_status) { case GGML_STATUS_SUCCESS: break; @@ -8397,76 +8320,6 @@ static int llama_encode_impl( return 0; } -// TODO: move to llama_context -static void llama_kv_self_update_impl(llama_context & lctx) { - bool need_reserve = false; - - auto & kv = lctx.kv_self; - - if (kv.has_shift) { - if (!kv.can_shift) { - GGML_ABORT("The current context does not support K-shift"); - } - - // apply K-shift if needed - if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { - lctx.prepare_k_shift(); - - ggml_backend_sched_reset(lctx.sched.get()); - - ggml_cgraph * gf = llama_build_graph_k_shift(lctx); - - ggml_backend_sched_alloc_graph(lctx.sched.get(), gf); - - lctx.set_inputs({}); - - llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool); - - need_reserve = true; - } - - { - kv.has_shift = false; - - for (uint32_t i = 0; i < kv.size; ++i) { - kv.cells[i].delta = 0; - } - } - } - - // defragment the KV cache if needed - if (kv.do_defrag) { - lctx.prepare_defrag(); - - ggml_backend_sched_reset(lctx.sched.get()); - - ggml_cgraph * gf = llama_build_graph_defrag(lctx); - - llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool); - - need_reserve = true; - - kv.do_defrag = false; - } - - // reserve a worst case graph again - if (need_reserve) { - // TODO: extract to a function - // build worst-case graph - uint32_t n_seqs = 1; // TODO: worst-case number of sequences - uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch); - llama_token token = lctx.model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true); - - // initialize scheduler with the worst-case graph - ggml_backend_sched_reset(lctx.sched.get()); - if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); - } - } -} - int32_t llama_set_adapter_lora( struct llama_context * ctx, struct llama_adapter_lora * adapter, @@ -9214,9 +9067,30 @@ void llama_kv_cache_update(llama_context * ctx) { llama_kv_self_update(ctx); } -// TODO: move to llama-context void llama_kv_self_update(llama_context * ctx) { - llama_kv_self_update_impl(*ctx); + const bool need_reserve = ctx->kv_self_update(); + + // reserve a worst case graph again + if (need_reserve) { + // TODO: extract to a function + const auto & cparams = ctx->cparams; + const auto & model = ctx->model; + + // build worst-case graph + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true); + + // initialize scheduler with the worst-case graph + ggml_backend_sched_reset(ctx->sched.get()); + if (!ggml_backend_sched_reserve(ctx->sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); + } + } } ///