mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-26 03:12:23 +01:00
llama : prepare next graph while the current one is being evaluated
This commit is contained in:
parent
fc4ca27b25
commit
33a5c8e37c
@ -2739,6 +2739,9 @@ struct llama_context {
|
||||
std::vector<uint8_t> buf_compute_meta;
|
||||
ggml_backend_sched_t sched = nullptr;
|
||||
|
||||
std::vector<uint8_t> buf_compute_meta_next;
|
||||
struct ggml_cgraph * gf_next = nullptr;
|
||||
|
||||
ggml_abort_callback abort_callback = nullptr;
|
||||
void * abort_callback_data = nullptr;
|
||||
|
||||
@ -8383,7 +8386,7 @@ struct llm_build_context {
|
||||
const float norm_rms_eps;
|
||||
|
||||
const int32_t n_tokens;
|
||||
const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size)
|
||||
int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size)
|
||||
const int32_t n_outputs;
|
||||
const int32_t n_outputs_enc;
|
||||
const int32_t kv_head; // index of where we store new KV data in the cache
|
||||
@ -8405,7 +8408,8 @@ struct llm_build_context {
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch,
|
||||
const llm_build_cb & cb,
|
||||
bool worst_case) :
|
||||
bool worst_case,
|
||||
bool prepare_only = false) :
|
||||
model (lctx.model),
|
||||
lctx (lctx),
|
||||
hparams (model.hparams),
|
||||
@ -8442,8 +8446,12 @@ struct llm_build_context {
|
||||
pooling_type (cparams.pooling_type),
|
||||
rope_type (hparams.rope_type),
|
||||
cb (cb),
|
||||
buf_compute_meta (lctx.buf_compute_meta) {
|
||||
buf_compute_meta (prepare_only ? lctx.buf_compute_meta_next : lctx.buf_compute_meta) {
|
||||
// all initializations should be done in init()
|
||||
if (prepare_only) {
|
||||
const uint32_t pad = llama_kv_cache_get_padding(cparams);
|
||||
n_kv = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self) + 1, pad)));
|
||||
}
|
||||
}
|
||||
|
||||
void init() {
|
||||
@ -13805,7 +13813,8 @@ static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
|
||||
static struct ggml_cgraph * llama_build_graph(
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch,
|
||||
bool worst_case) {
|
||||
bool worst_case,
|
||||
bool prepare_only = false) {
|
||||
const auto & model = lctx.model;
|
||||
|
||||
// this callback allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
|
||||
@ -13841,7 +13850,7 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
|
||||
struct ggml_cgraph * result = NULL;
|
||||
|
||||
struct llm_build_context llm(lctx, batch, cb, worst_case);
|
||||
struct llm_build_context llm(lctx, batch, cb, worst_case, prepare_only);
|
||||
|
||||
llm.init();
|
||||
|
||||
@ -14536,7 +14545,8 @@ static void llama_graph_compute(
|
||||
//
|
||||
static int llama_decode_internal(
|
||||
llama_context & lctx,
|
||||
llama_batch batch_all) { // TODO: rename back to batch
|
||||
llama_batch batch_all, // TODO: rename back to batch
|
||||
bool prepare_only = false) {
|
||||
|
||||
lctx.is_encoding = false;
|
||||
const uint32_t n_tokens_all = batch_all.n_tokens;
|
||||
@ -14556,10 +14566,12 @@ static int llama_decode_internal(
|
||||
|
||||
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
||||
|
||||
if (!prepare_only) {
|
||||
if (lctx.t_compute_start_us == 0) {
|
||||
lctx.t_compute_start_us = ggml_time_us();
|
||||
}
|
||||
lctx.n_queued_tokens += n_tokens_all;
|
||||
}
|
||||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
@ -14612,6 +14624,10 @@ static int llama_decode_internal(
|
||||
}
|
||||
}
|
||||
|
||||
if (n_tokens_all != 1) {
|
||||
lctx.gf_next = nullptr;
|
||||
}
|
||||
|
||||
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
|
||||
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
|
||||
llama_batch u_batch = {
|
||||
@ -14678,7 +14694,7 @@ static int llama_decode_internal(
|
||||
}
|
||||
|
||||
// non-causal masks do not use the KV cache
|
||||
if (hparams.causal_attn) {
|
||||
if (hparams.causal_attn && !prepare_only) {
|
||||
llama_kv_cache_update(&lctx);
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
@ -14703,10 +14719,23 @@ static int llama_decode_internal(
|
||||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
|
||||
ggml_cgraph * gf = lctx.gf_next;
|
||||
|
||||
if (!gf) {
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||
gf = llama_build_graph(lctx, u_batch, false, prepare_only);
|
||||
ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
||||
}
|
||||
|
||||
if (prepare_only) {
|
||||
lctx.gf_next = gf;
|
||||
return 0;
|
||||
}
|
||||
|
||||
lctx.gf_next = nullptr;
|
||||
|
||||
ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
|
||||
|
||||
// the output is always the last tensor in the graph
|
||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||
@ -14732,7 +14761,6 @@ static int llama_decode_internal(
|
||||
}
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
|
||||
ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
||||
|
||||
llama_set_inputs(lctx, u_batch);
|
||||
|
||||
@ -14836,6 +14864,15 @@ static int llama_decode_internal(
|
||||
// overlap with device computation.
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
|
||||
if (n_tokens_all == 1 && !prepare_only) {
|
||||
// prepare graph for the next token
|
||||
llama_token next_token_dummy = 0;
|
||||
llama_pos n_past = batch_all.all_pos_0 + 1;
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch batch_next = llama_batch_get_one(&next_token_dummy, 1, n_past, seq_id);
|
||||
llama_decode_internal(lctx, batch_next, true);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -16940,6 +16977,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
|
||||
// buffer used to store the computation graph and the tensor meta data
|
||||
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
||||
ctx->buf_compute_meta_next.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
||||
|
||||
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
|
||||
bool pipeline_parallel =
|
||||
|
Loading…
Reference in New Issue
Block a user