From 1f7dc122228271615924759101f9657b08c0e9ba Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 17 Jan 2025 14:42:09 +0200 Subject: [PATCH] context : initial need_reserve logic ggml-ci --- src/llama-context.cpp | 173 +++++++++++++++++++++- src/llama-context.h | 4 +- src/llama.cpp | 337 +++++++++++++----------------------------- 3 files changed, 269 insertions(+), 245 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5cb31abc0..d696090cc 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -576,9 +576,7 @@ ggml_tensor * llama_context::build_lora_mm_id( return res; } -bool llama_context::kv_self_update() { - bool need_reserve = false; - +void llama_context::kv_self_update() { auto & kv = kv_self; if (kv.has_shift) { @@ -655,12 +653,14 @@ bool llama_context::kv_self_update() { ggml_free(ctx0); - need_reserve = true; - kv.do_defrag = false; - } - return need_reserve; + need_reserve = true; + } +} + +void llama_kv_self_update(llama_context * ctx) { + ctx->kv_self_update(); } void llama_context::build_attn_inp( @@ -1824,6 +1824,165 @@ int32_t llama_apply_adapter_cvec( return ctx->cvec.apply(ctx->model, data, len, n_embd, il_start, il_end); } +// +// kv cache view +// + +struct llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) { + return llama_kv_cache_view_init(ctx->kv_self, n_seq_max); +} + +void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) { + llama_kv_cache_view_update(view, ctx->kv_self); +} + +// +// kv cache +// + +// deprecated +int32_t llama_get_kv_cache_token_count(const llama_context * ctx) { + return llama_kv_self_n_tokens(ctx); +} + +int32_t llama_kv_self_n_tokens(const llama_context * ctx) { + return llama_kv_cache_n_tokens(&ctx->kv_self); +} + +// deprecated +int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) { + return llama_kv_self_used_cells(ctx); +} + +int32_t llama_kv_self_used_cells(const llama_context * ctx) { + return llama_kv_cache_used_cells(&ctx->kv_self); +} + +// deprecated +void llama_kv_cache_clear(llama_context * ctx) { + llama_kv_self_clear(ctx); +} + +void llama_kv_self_clear(llama_context * ctx) { + llama_kv_cache_clear(&ctx->kv_self); +} + +// deprecated +bool llama_kv_cache_seq_rm( + llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + return llama_kv_self_seq_rm(ctx, seq_id, p0, p1); +} + +bool llama_kv_self_seq_rm( + llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + return llama_kv_cache_seq_rm(&ctx->kv_self, seq_id, p0, p1); +} + +// deprecated +void llama_kv_cache_seq_cp( + llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_self_seq_cp( + llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + return llama_kv_cache_seq_cp(&ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); +} + +// deprecated +void llama_kv_cache_seq_keep( + llama_context * ctx, + llama_seq_id seq_id) { + return llama_kv_self_seq_keep(ctx, seq_id); +} + +void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { + return llama_kv_cache_seq_keep(&ctx->kv_self, seq_id); +} + +// deprecated +void llama_kv_cache_seq_add( + llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { + return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta); +} + +void llama_kv_self_seq_add( + llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { + return llama_kv_cache_seq_add(&ctx->kv_self, seq_id, p0, p1, delta); +} + +// deprecated +void llama_kv_cache_seq_div( + llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { + return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d); +} + +void llama_kv_self_seq_div( + llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { + return llama_kv_cache_seq_div(&ctx->kv_self, seq_id, p0, p1, d); +} + +// deprecated +llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { + return llama_kv_self_seq_pos_max(ctx, seq_id); +} + +llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { + return llama_kv_cache_seq_pos_max(&ctx->kv_self, seq_id); +} + +// deprecated +void llama_kv_cache_defrag(llama_context * ctx) { + return llama_kv_self_defrag(ctx); +} + +void llama_kv_self_defrag(llama_context * ctx) { + return llama_kv_cache_defrag(&ctx->kv_self); +} + +// deprecated +bool llama_kv_cache_can_shift(const llama_context * ctx) { + return llama_kv_self_can_shift(ctx); +} + +bool llama_kv_self_can_shift(const llama_context * ctx) { + return llama_kv_cache_can_shift(&ctx->kv_self); +} + +// deprecated +void llama_kv_cache_update(llama_context * ctx) { + llama_kv_self_update(ctx); +} // llama state API diff --git a/src/llama-context.h b/src/llama-context.h index 45eaafaad..eb9a17391 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -62,6 +62,7 @@ struct llama_context { int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch bool logits_all = false; + bool need_reserve = false; // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE @@ -87,6 +88,7 @@ struct llama_context { // max token position across all sequences in the current context llama_pos pos_max() const; + // certain implementations could require a padding for the context size uint32_t get_ctx_padding(const llama_cparams & cparams) const; void reset(); @@ -140,7 +142,7 @@ struct llama_context { struct ggml_tensor * inp_K_shift; // I32 [kv_size] // return true if need to reserve new worst-case graph - bool kv_self_update(); + void kv_self_update(); void build_attn_inp( ggml_context * ctx0, diff --git a/src/llama.cpp b/src/llama.cpp index 836731a4c..16f67fd8f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -28,57 +28,6 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback -static int llama_model_load(const std::string & fname, std::vector & splits, llama_model & model, llama_model_params & params) { - // loading time will be recalculated after the first eval, so - // we take page faults deferred by mmap() into consideration - model.t_load_us = 0; - time_meas tm(model.t_load_us); - - model.t_start_us = tm.t_start_us; - - try { - llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides); - - ml.print_info(); - - model.hparams.vocab_only = params.vocab_only; - - try { - model.load_arch(ml); - } catch(const std::exception & e) { - throw std::runtime_error("error loading model architecture: " + std::string(e.what())); - } - try { - model.load_hparams(ml); - } catch(const std::exception & e) { - throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what())); - } - try { - model.load_vocab(ml); - } catch(const std::exception & e) { - throw std::runtime_error("error loading model vocabulary: " + std::string(e.what())); - } - - model.load_stats(ml); - model.print_info(); - - if (params.vocab_only) { - LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); - return 0; - } - - if (!model.load_tensors(ml)) { - return -2; - } - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what()); - return -1; - } - - return 0; -} - // // llm_build // @@ -7951,6 +7900,30 @@ static int llama_decode_impl( } } + // reserve a worst case graph if needed + // TODO: extract to a function + if (lctx.need_reserve) { + const auto & cparams = lctx.cparams; + const auto & model = lctx.model; + + // build worst-case graph + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true); + + // initialize scheduler with the worst-case graph + ggml_backend_sched_reset(lctx.sched.get()); + if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); + } + + lctx.need_reserve = false; + } + //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); ggml_backend_sched_reset(lctx.sched.get()); @@ -8206,6 +8179,31 @@ static int llama_encode_impl( lctx.prepare_decode(ubatch); + // reserve a worst case graph if needed + // TODO: extract to a function + if (lctx.need_reserve) { + // TODO: extract to a function + const auto & cparams = lctx.cparams; + const auto & model = lctx.model; + + // build worst-case graph + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true); + + // initialize scheduler with the worst-case graph + ggml_backend_sched_reset(lctx.sched.get()); + if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); + } + + lctx.need_reserve = false; + } + ggml_backend_sched_reset(lctx.sched.get()); ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); @@ -8419,6 +8417,57 @@ int64_t llama_time_us(void) { return ggml_time_us(); } +// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback +static int llama_model_load(const std::string & fname, std::vector & splits, llama_model & model, llama_model_params & params) { + // loading time will be recalculated after the first eval, so + // we take page faults deferred by mmap() into consideration + model.t_load_us = 0; + time_meas tm(model.t_load_us); + + model.t_start_us = tm.t_start_us; + + try { + llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides); + + ml.print_info(); + + model.hparams.vocab_only = params.vocab_only; + + try { + model.load_arch(ml); + } catch(const std::exception & e) { + throw std::runtime_error("error loading model architecture: " + std::string(e.what())); + } + try { + model.load_hparams(ml); + } catch(const std::exception & e) { + throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what())); + } + try { + model.load_vocab(ml); + } catch(const std::exception & e) { + throw std::runtime_error("error loading model vocabulary: " + std::string(e.what())); + } + + model.load_stats(ml); + model.print_info(); + + if (params.vocab_only) { + LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); + return 0; + } + + if (!model.load_tensors(ml)) { + return -2; + } + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what()); + return -1; + } + + return 0; +} + static struct llama_model * llama_model_load_from_file_impl( const std::string & path_model, std::vector & splits, @@ -8879,192 +8928,6 @@ struct llama_context * llama_new_context_with_model( return llama_init_from_model(model, params); } -// -// kv cache view -// - -struct llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) { - return llama_kv_cache_view_init(ctx->kv_self, n_seq_max); -} - -void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) { - llama_kv_cache_view_update(view, ctx->kv_self); -} - -// -// kv cache -// - -// deprecated -int32_t llama_get_kv_cache_token_count(const llama_context * ctx) { - return llama_kv_self_n_tokens(ctx); -} - -int32_t llama_kv_self_n_tokens(const llama_context * ctx) { - return llama_kv_cache_n_tokens(&ctx->kv_self); -} - -// deprecated -int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) { - return llama_kv_self_used_cells(ctx); -} - -int32_t llama_kv_self_used_cells(const llama_context * ctx) { - return llama_kv_cache_used_cells(&ctx->kv_self); -} - -// deprecated -void llama_kv_cache_clear(llama_context * ctx) { - llama_kv_self_clear(ctx); -} - -void llama_kv_self_clear(llama_context * ctx) { - llama_kv_cache_clear(&ctx->kv_self); -} - -// deprecated -bool llama_kv_cache_seq_rm( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { - return llama_kv_self_seq_rm(ctx, seq_id, p0, p1); -} - -bool llama_kv_self_seq_rm( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { - return llama_kv_cache_seq_rm(&ctx->kv_self, seq_id, p0, p1); -} - -// deprecated -void llama_kv_cache_seq_cp( - llama_context * ctx, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { - return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); -} - -void llama_kv_self_seq_cp( - llama_context * ctx, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { - return llama_kv_cache_seq_cp(&ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); -} - -// deprecated -void llama_kv_cache_seq_keep( - llama_context * ctx, - llama_seq_id seq_id) { - return llama_kv_self_seq_keep(ctx, seq_id); -} - -void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { - return llama_kv_cache_seq_keep(&ctx->kv_self, seq_id); -} - -// deprecated -void llama_kv_cache_seq_add( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { - return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta); -} - -void llama_kv_self_seq_add( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { - return llama_kv_cache_seq_add(&ctx->kv_self, seq_id, p0, p1, delta); -} - -// deprecated -void llama_kv_cache_seq_div( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d) { - return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d); -} - -void llama_kv_self_seq_div( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d) { - return llama_kv_cache_seq_div(&ctx->kv_self, seq_id, p0, p1, d); -} - -// deprecated -llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { - return llama_kv_self_seq_pos_max(ctx, seq_id); -} - -llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { - return llama_kv_cache_seq_pos_max(&ctx->kv_self, seq_id); -} - -// deprecated -void llama_kv_cache_defrag(llama_context * ctx) { - return llama_kv_self_defrag(ctx); -} - -void llama_kv_self_defrag(llama_context * ctx) { - return llama_kv_cache_defrag(&ctx->kv_self); -} - -// deprecated -bool llama_kv_cache_can_shift(const llama_context * ctx) { - return llama_kv_self_can_shift(ctx); -} - -bool llama_kv_self_can_shift(const llama_context * ctx) { - return llama_kv_cache_can_shift(&ctx->kv_self); -} - -// deprecated -void llama_kv_cache_update(llama_context * ctx) { - llama_kv_self_update(ctx); -} - -void llama_kv_self_update(llama_context * ctx) { - const bool need_reserve = ctx->kv_self_update(); - - // reserve a worst case graph again - if (need_reserve) { - // TODO: extract to a function - const auto & cparams = ctx->cparams; - const auto & model = ctx->model; - - // build worst-case graph - uint32_t n_seqs = 1; // TODO: worst-case number of sequences - uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - - llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true); - - // initialize scheduler with the worst-case graph - ggml_backend_sched_reset(ctx->sched.get()); - if (!ggml_backend_sched_reserve(ctx->sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); - } - } -} - /// int32_t llama_encode(