mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-05 08:00:42 +01:00
llama: refactor llama_decode_impl
This commit is contained in:
parent
564804b79b
commit
cd0aee8981
243
src/llama.cpp
243
src/llama.cpp
@ -8432,13 +8432,141 @@ static enum ggml_status llama_graph_compute(
|
|||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int llama_prepare_sbatch(
|
||||||
|
llama_context & lctx,
|
||||||
|
const llama_batch & batch,
|
||||||
|
uint32_t & n_outputs) {
|
||||||
|
const auto & model = lctx.model;
|
||||||
|
const auto & hparams = model.hparams;
|
||||||
|
const auto & cparams = lctx.cparams;
|
||||||
|
|
||||||
|
const uint32_t n_tokens_all = batch.n_tokens;
|
||||||
|
const int64_t n_embd = hparams.n_embd;
|
||||||
|
|
||||||
|
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||||
|
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||||
|
|
||||||
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||||
|
if (batch.token) {
|
||||||
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||||
|
if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) {
|
||||||
|
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
||||||
|
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
||||||
|
|
||||||
|
lctx.n_queued_tokens += n_tokens_all;
|
||||||
|
lctx.embd_seq.clear();
|
||||||
|
|
||||||
|
// count outputs
|
||||||
|
if (batch.logits && !embd_pooled) {
|
||||||
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||||
|
n_outputs += batch.logits[i] != 0;
|
||||||
|
}
|
||||||
|
} else if (lctx.logits_all || embd_pooled) {
|
||||||
|
n_outputs = n_tokens_all;
|
||||||
|
} else {
|
||||||
|
// keep last output only
|
||||||
|
n_outputs = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
lctx.sbatch.from_batch(batch, n_embd,
|
||||||
|
/* simple_split */ !lctx.kv_self.recurrent,
|
||||||
|
/* logits_all */ n_outputs == n_tokens_all);
|
||||||
|
|
||||||
|
// reserve output buffer
|
||||||
|
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
|
||||||
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
|
||||||
|
return -2;
|
||||||
|
};
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int llama_prepare_ubatch(
|
||||||
|
llama_context & lctx,
|
||||||
|
llama_kv_slot_restorer & kv_slot_restorer,
|
||||||
|
llama_ubatch & ubatch,
|
||||||
|
const uint32_t n_outputs,
|
||||||
|
const uint32_t n_tokens_all) {
|
||||||
|
GGML_ASSERT(lctx.sbatch.n_tokens > 0);
|
||||||
|
|
||||||
|
auto & kv_self = lctx.kv_self;
|
||||||
|
const auto & cparams = lctx.cparams;
|
||||||
|
const auto & hparams = lctx.model.hparams;
|
||||||
|
|
||||||
|
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||||
|
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||||
|
|
||||||
|
if (lctx.kv_self.recurrent) {
|
||||||
|
if (embd_pooled) {
|
||||||
|
// Pooled embeddings cannot be split across ubatches (yet)
|
||||||
|
ubatch = lctx.sbatch.split_seq(cparams.n_ubatch);
|
||||||
|
} else {
|
||||||
|
// recurrent model architectures are easier to implement
|
||||||
|
// with equal-length sequences
|
||||||
|
ubatch = lctx.sbatch.split_equal(cparams.n_ubatch);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ubatch = lctx.sbatch.split_simple(cparams.n_ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
// count the outputs in this u_batch
|
||||||
|
{
|
||||||
|
int32_t n_outputs_new = 0;
|
||||||
|
|
||||||
|
if (n_outputs == n_tokens_all) {
|
||||||
|
n_outputs_new = ubatch.n_tokens;
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(ubatch.output);
|
||||||
|
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||||
|
n_outputs_new += int32_t(ubatch.output[i] != 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// needs to happen before the graph is built
|
||||||
|
lctx.n_outputs = n_outputs_new;
|
||||||
|
}
|
||||||
|
|
||||||
|
// non-causal masks do not use the KV cache
|
||||||
|
if (hparams.causal_attn) {
|
||||||
|
llama_kv_cache_update(&lctx);
|
||||||
|
|
||||||
|
// if we have enough unused cells before the current head ->
|
||||||
|
// better to start searching from the beginning of the cache, hoping to fill it
|
||||||
|
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
|
||||||
|
kv_self.head = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
||||||
|
if (!slot) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
kv_slot_restorer.save(slot);
|
||||||
|
|
||||||
|
if (!kv_self.recurrent) {
|
||||||
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||||
|
// after enough generations, the benefit from this heuristic disappears
|
||||||
|
// if we start defragmenting the cache, the benefit from this will be more important
|
||||||
|
const uint32_t pad = llama_kv_cache_get_padding(cparams);
|
||||||
|
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
|
||||||
|
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
// decode a batch of tokens by evaluating the transformer
|
// decode a batch of tokens by evaluating the transformer
|
||||||
// in case of unsuccessful decoding (error or warning),
|
// in case of unsuccessful decoding (error or warning),
|
||||||
// the kv_cache state will be returned to its original state
|
// the kv_cache state will be returned to its original state
|
||||||
// (for non-recurrent models) or cleaned (for recurrent models)
|
// (for non-recurrent models) or cleaned (for recurrent models)
|
||||||
//
|
//
|
||||||
// - lctx: llama context
|
// - lctx: llama context
|
||||||
// - batch: batch to evaluate
|
// - inp_batch: batch to evaluate
|
||||||
//
|
//
|
||||||
// return 0 on success
|
// return 0 on success
|
||||||
// return positive int on warning
|
// return positive int on warning
|
||||||
@ -8455,37 +8583,18 @@ static int llama_decode_impl(
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// temporary allocate memory for the input batch if needed
|
// temporarily allocate memory for the input batch if needed
|
||||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
|
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
|
||||||
|
|
||||||
const llama_batch & batch = batch_allocr.batch;
|
const llama_batch & batch = batch_allocr.batch;
|
||||||
const uint32_t n_tokens_all = batch.n_tokens;
|
|
||||||
|
|
||||||
const auto & model = lctx.model;
|
const auto & model = lctx.model;
|
||||||
const auto & vocab = model.vocab;
|
const auto & vocab = model.vocab;
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
const auto & cparams = lctx.cparams;
|
const auto & cparams = lctx.cparams;
|
||||||
|
|
||||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
|
||||||
|
|
||||||
if (batch.token) {
|
|
||||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
|
||||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
|
||||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
|
||||||
|
|
||||||
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
|
||||||
|
|
||||||
if (lctx.t_compute_start_us == 0) {
|
if (lctx.t_compute_start_us == 0) {
|
||||||
lctx.t_compute_start_us = ggml_time_us();
|
lctx.t_compute_start_us = ggml_time_us();
|
||||||
}
|
}
|
||||||
lctx.n_queued_tokens += n_tokens_all;
|
|
||||||
|
|
||||||
auto & kv_self = lctx.kv_self;
|
auto & kv_self = lctx.kv_self;
|
||||||
llama_kv_slot_restorer kv_slot_restorer(kv_self);
|
llama_kv_slot_restorer kv_slot_restorer(kv_self);
|
||||||
|
|
||||||
@ -8495,99 +8604,27 @@ static int llama_decode_impl(
|
|||||||
uint32_t n_outputs = 0;
|
uint32_t n_outputs = 0;
|
||||||
uint32_t n_outputs_prev = 0;
|
uint32_t n_outputs_prev = 0;
|
||||||
|
|
||||||
const auto n_ubatch = cparams.n_ubatch;
|
{
|
||||||
|
const int ret = llama_prepare_sbatch(lctx, batch, n_outputs);
|
||||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
if (ret != 0) {
|
||||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
return ret;
|
||||||
|
|
||||||
lctx.embd_seq.clear();
|
|
||||||
|
|
||||||
// count outputs
|
|
||||||
if (batch.logits && !embd_pooled) {
|
|
||||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
|
||||||
n_outputs += batch.logits[i] != 0;
|
|
||||||
}
|
}
|
||||||
} else if (lctx.logits_all || embd_pooled) {
|
|
||||||
n_outputs = n_tokens_all;
|
|
||||||
} else {
|
|
||||||
// keep last output only
|
|
||||||
n_outputs = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
lctx.sbatch.from_batch(batch, n_embd,
|
|
||||||
/* simple_split */ !kv_self.recurrent,
|
|
||||||
/* logits_all */ n_outputs == n_tokens_all);
|
|
||||||
|
|
||||||
// reserve output buffer
|
|
||||||
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
|
|
||||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
|
|
||||||
return -2;
|
|
||||||
};
|
|
||||||
|
|
||||||
while (lctx.sbatch.n_tokens > 0) {
|
while (lctx.sbatch.n_tokens > 0) {
|
||||||
llama_ubatch ubatch;
|
llama_ubatch ubatch;
|
||||||
if (kv_self.recurrent) {
|
|
||||||
if (embd_pooled) {
|
|
||||||
// Pooled embeddings cannot be split across ubatches (yet)
|
|
||||||
ubatch = lctx.sbatch.split_seq(n_ubatch);
|
|
||||||
} else {
|
|
||||||
// recurrent model architectures are easier to implement
|
|
||||||
// with equal-length sequences
|
|
||||||
ubatch = lctx.sbatch.split_equal(n_ubatch);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ubatch = lctx.sbatch.split_simple(n_ubatch);
|
|
||||||
}
|
|
||||||
const uint32_t n_tokens = ubatch.n_tokens;
|
|
||||||
|
|
||||||
// count the outputs in this u_batch
|
|
||||||
{
|
{
|
||||||
int32_t n_outputs_new = 0;
|
const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens);
|
||||||
|
if (ret != 0) {
|
||||||
if (n_outputs == n_tokens_all) {
|
return ret;
|
||||||
n_outputs_new = n_tokens;
|
|
||||||
} else {
|
|
||||||
GGML_ASSERT(ubatch.output);
|
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
||||||
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// needs to happen before the graph is built
|
|
||||||
lctx.n_outputs = n_outputs_new;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
const int n_threads = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||||
ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
|
ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
|
||||||
|
|
||||||
GGML_ASSERT(n_threads > 0);
|
GGML_ASSERT(n_threads > 0);
|
||||||
|
|
||||||
// non-causal masks do not use the KV cache
|
|
||||||
if (hparams.causal_attn) {
|
|
||||||
llama_kv_cache_update(&lctx);
|
|
||||||
|
|
||||||
// if we have enough unused cells before the current head ->
|
|
||||||
// better to start searching from the beginning of the cache, hoping to fill it
|
|
||||||
if (kv_self.head > kv_self.used + 2*n_tokens) {
|
|
||||||
kv_self.head = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
|
||||||
if (!slot) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
kv_slot_restorer.save(slot);
|
|
||||||
|
|
||||||
if (!kv_self.recurrent) {
|
|
||||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
|
||||||
// after enough generations, the benefit from this heuristic disappears
|
|
||||||
// if we start defragmenting the cache, the benefit from this will be more important
|
|
||||||
const uint32_t pad = llama_kv_cache_get_padding(cparams);
|
|
||||||
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
|
|
||||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||||
|
|
||||||
ggml_backend_sched_reset(lctx.sched.get());
|
ggml_backend_sched_reset(lctx.sched.get());
|
||||||
@ -8640,7 +8677,7 @@ static int llama_decode_impl(
|
|||||||
|
|
||||||
// update the kv ring buffer
|
// update the kv ring buffer
|
||||||
{
|
{
|
||||||
kv_self.head += n_tokens;
|
kv_self.head += ubatch.n_tokens;
|
||||||
|
|
||||||
// Ensure kv cache head points to a valid index.
|
// Ensure kv cache head points to a valid index.
|
||||||
if (kv_self.head >= kv_self.size) {
|
if (kv_self.head >= kv_self.size) {
|
||||||
|
Loading…
Reference in New Issue
Block a user