From 7420bef83e4f5077e7ef140174c1abf72668abb9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 1 Nov 2023 08:51:43 +0200 Subject: [PATCH] wip wip wip --- llama.cpp | 69 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/llama.cpp b/llama.cpp index ead1d421d..84235e363 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3438,6 +3438,75 @@ static struct ggml_tensor * llm_build_kqv( return cur; } +struct llm_build_context { + llm_build_context( + llama_context & lctx, + const llama_batch & batch, + const llm_build_cb & cb, + bool worst_case) : + model (lctx.model), + hparams (model.hparams), + cparams (lctx.cparams), + kv_self (lctx.kv_self), + n_embd (hparams.n_embd), + n_layer (hparams.n_layer), + n_ctx (cparams.n_ctx), + n_head (hparams.n_head), + n_head_kv (hparams.n_head_kv), + n_embd_head (hparams.n_embd_head()), + freq_base (cparams.rope_freq_base), + freq_scale (cparams.rope_freq_scale), + norm_rms_eps (hparams.f_norm_rms_eps), + n_tokens (batch.n_tokens), + n_kv (worst_case ? n_ctx : kv_self.n), + kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), + do_rope_shift(worst_case || kv_self.has_shift), + cb (cb), + buf_compute (lctx.buf_compute) { + GGML_ASSERT(!!kv_self.ctx); + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + + } + + const llama_model & model; + const llama_hparams & hparams; + const llama_cparams & cparams; + + const llama_kv_cache & kv_self; + + const int64_t n_embd; + const int64_t n_layer; + const int64_t n_ctx; + const int64_t n_head; + const int64_t n_head_kv; + const int64_t n_embd_head; + + const float freq_base; + const float freq_scale; + const float norm_rms_eps; + + const int32_t n_tokens; + const int32_t n_kv; + const int32_t kv_head; + + const bool do_rope_shift; + + const llm_build_cb & cb; + + llama_buffer & buf_compute; + + struct ggml_cgraph * llama() { + GGML_ASSERT(n_embd_head == hparams.n_rot); + } + +}; + static struct ggml_cgraph * llm_build_llama( llama_context & lctx, const llama_batch & batch,