mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 05:48:47 +01:00
server : chunked prefill support
ggml-ci
This commit is contained in:
parent
62e84d9848
commit
a6648b9df7
@ -2418,6 +2418,14 @@ struct server_context {
|
||||
int32_t n_batch = llama_n_batch(ctx);
|
||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||
|
||||
// there are currently slots with ongoing text generation
|
||||
const bool is_tg = batch.n_tokens > 0;
|
||||
|
||||
// limit the batch to avoid blocking the processing
|
||||
if (is_tg) {
|
||||
n_batch = 32; // TODO: configurable
|
||||
}
|
||||
|
||||
// track if this is an embedding or non-embedding batch
|
||||
// if we've added sampled tokens above, we are in non-embedding mode
|
||||
// -1: none, 0: non-embedding, 1: embedding
|
||||
@ -2426,6 +2434,18 @@ struct server_context {
|
||||
|
||||
// next, batch any pending prompts without exceeding n_batch
|
||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
||||
// count how many slots are currently processing prompt
|
||||
int n_slots_pp = 0;
|
||||
for (auto & slot : slots) {
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
||||
n_slots_pp++;
|
||||
}
|
||||
}
|
||||
|
||||
// determine the chunk size of the chunk prefill
|
||||
// a slot cannot submit more than this number of tokens in a single batch if other slots are processing
|
||||
const int32_t n_chunk_pp = std::max(n_slots_pp > 0 ? (n_batch / n_slots_pp) : n_batch, 8);
|
||||
|
||||
for (auto & slot : slots) {
|
||||
// this slot still has a prompt to be processed
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
||||
@ -2609,8 +2629,10 @@ struct server_context {
|
||||
// remove the non-common part from the cache
|
||||
slot.cache_tokens.resize(slot.n_past);
|
||||
|
||||
int n_cur = 0;
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch && n_cur < n_chunk_pp) {
|
||||
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
@ -2619,6 +2641,8 @@ struct server_context {
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
slot.n_past++;
|
||||
|
||||
n_cur++;
|
||||
}
|
||||
|
||||
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
||||
|
Loading…
Reference in New Issue
Block a user