From c8c07d658a6cefc5a50cfdf6be7d726503612303 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 22 Oct 2024 16:59:02 +0200 Subject: [PATCH] llama : fix empty batch causing llama_batch_allocr to crash (#9966) * llama : fix empty batch cause llama_batch_allocr to crash * move batch_allocr inside decode/encode_internal * fix build * add GGML_ASSERT * Apply suggestions from code review Co-authored-by: Georgi Gerganov --------- Co-authored-by: Georgi Gerganov --- src/llama.cpp | 128 ++++++++++++++++++++++++++------------------------ 1 file changed, 67 insertions(+), 61 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 7a5a46dce..24e1f1f01 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5177,6 +5177,57 @@ struct llama_model_loader { } }; +// temporary allocate memory for the input batch if needed +static const llama_seq_id batch_default_seq_id = 0; +struct llama_batch_allocr { + std::array seq_id_0 = {batch_default_seq_id}; + std::vector pos; + std::vector n_seq_id; + std::vector seq_id; + std::vector logits; + struct llama_batch batch; + // optionally fulfill the batch returned by llama_batch_get_one + llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) { + batch = in_batch; + GGML_ASSERT(batch.n_tokens > 0); + if (!batch.pos) { + // determine the last position in KV cache + llama_pos last_pos = -1; + for (const auto & cell : ctx.kv_self.cells) { + if (cell.has_seq_id(batch_default_seq_id)) { + last_pos = std::max(last_pos, cell.pos); + } + } + last_pos++; // next position + pos.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { + pos[i] = i+last_pos; + } + batch.pos = pos.data(); + } + if (!batch.n_seq_id) { + n_seq_id.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { + n_seq_id[i] = seq_id_0.size(); + } + batch.n_seq_id = n_seq_id.data(); + } + if (!batch.seq_id) { + seq_id.resize(batch.n_tokens + 1); + seq_id[batch.n_tokens] = NULL; + for (int32_t i = 0; i < batch.n_tokens; i++) { + seq_id[i] = seq_id_0.data(); + } + batch.seq_id = seq_id.data(); + } + if (!batch.logits) { + logits.resize(batch.n_tokens); + logits[logits.size() - 1] = true; + batch.logits = logits.data(); + } + } +}; + template<> bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) { uint32_t tmp; @@ -17095,16 +17146,20 @@ static void llama_graph_compute( // static int llama_decode_internal( llama_context & lctx, - llama_batch batch) { + llama_batch inp_batch) { lctx.is_encoding = false; - const uint32_t n_tokens_all = batch.n_tokens; - if (n_tokens_all == 0) { + if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; } + // temporary allocate memory for the input batch if needed + llama_batch_allocr batch_allocr(lctx, inp_batch); + const llama_batch & batch = batch_allocr.batch; + const uint32_t n_tokens_all = batch.n_tokens; + const auto & model = lctx.model; const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; @@ -17409,17 +17464,20 @@ static int llama_decode_internal( // static int llama_encode_internal( llama_context & lctx, - llama_batch batch) { + llama_batch inp_batch) { lctx.is_encoding = true; - const uint32_t n_tokens = batch.n_tokens; - - if (n_tokens == 0) { + if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; } + // temporary allocate memory for the input batch if needed + llama_batch_allocr batch_allocr(lctx, inp_batch); + const llama_batch & batch = batch_allocr.batch; + const uint32_t n_tokens = batch.n_tokens; + const auto & model = lctx.model; const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; @@ -21090,61 +21148,10 @@ void llama_batch_free(struct llama_batch batch) { if (batch.logits) free(batch.logits); } -// temporary allocate memory for the input batch if needed -static const llama_seq_id batch_default_seq_id = 0; -struct llama_batch_allocr { - std::array seq_id_0 = {batch_default_seq_id}; - std::vector pos; - std::vector n_seq_id; - std::vector seq_id; - std::vector logits; - struct llama_batch batch; - // optionally fulfill the batch returned by llama_batch_get_one - llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) { - batch = in_batch; - if (!batch.pos) { - // determine the last position in KV cache - llama_pos last_pos = -1; - for (const auto & cell : ctx->kv_self.cells) { - if (cell.has_seq_id(batch_default_seq_id)) { - last_pos = std::max(last_pos, cell.pos); - } - } - last_pos++; // next position - pos.resize(batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; i++) { - pos[i] = i+last_pos; - } - batch.pos = pos.data(); - } - if (!batch.n_seq_id) { - n_seq_id.resize(batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; i++) { - n_seq_id[i] = seq_id_0.size(); - } - batch.n_seq_id = n_seq_id.data(); - } - if (!batch.seq_id) { - seq_id.resize(batch.n_tokens + 1); - seq_id[batch.n_tokens] = NULL; - for (int32_t i = 0; i < batch.n_tokens; i++) { - seq_id[i] = seq_id_0.data(); - } - batch.seq_id = seq_id.data(); - } - if (!batch.logits) { - logits.resize(batch.n_tokens); - logits[logits.size() - 1] = true; - batch.logits = logits.data(); - } - } -}; - int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch) { - llama_batch_allocr batch_allocr(ctx, batch); - const int ret = llama_encode_internal(*ctx, batch_allocr.batch); + const int ret = llama_encode_internal(*ctx, batch); if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); } @@ -21155,8 +21162,7 @@ int32_t llama_encode( int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) { - llama_batch_allocr batch_allocr(ctx, batch); - const int ret = llama_decode_internal(*ctx, batch_allocr.batch); + const int ret = llama_decode_internal(*ctx, batch); if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); }