mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 21:37:19 +01:00
server : chunked prefill support
ggml-ci
This commit is contained in:
parent
62e84d9848
commit
a6648b9df7
@ -2418,6 +2418,14 @@ struct server_context {
|
|||||||
int32_t n_batch = llama_n_batch(ctx);
|
int32_t n_batch = llama_n_batch(ctx);
|
||||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||||
|
|
||||||
|
// there are currently slots with ongoing text generation
|
||||||
|
const bool is_tg = batch.n_tokens > 0;
|
||||||
|
|
||||||
|
// limit the batch to avoid blocking the processing
|
||||||
|
if (is_tg) {
|
||||||
|
n_batch = 32; // TODO: configurable
|
||||||
|
}
|
||||||
|
|
||||||
// track if this is an embedding or non-embedding batch
|
// track if this is an embedding or non-embedding batch
|
||||||
// if we've added sampled tokens above, we are in non-embedding mode
|
// if we've added sampled tokens above, we are in non-embedding mode
|
||||||
// -1: none, 0: non-embedding, 1: embedding
|
// -1: none, 0: non-embedding, 1: embedding
|
||||||
@ -2426,6 +2434,18 @@ struct server_context {
|
|||||||
|
|
||||||
// next, batch any pending prompts without exceeding n_batch
|
// next, batch any pending prompts without exceeding n_batch
|
||||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
if (params_base.cont_batching || batch.n_tokens == 0) {
|
||||||
|
// count how many slots are currently processing prompt
|
||||||
|
int n_slots_pp = 0;
|
||||||
|
for (auto & slot : slots) {
|
||||||
|
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
||||||
|
n_slots_pp++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// determine the chunk size of the chunk prefill
|
||||||
|
// a slot cannot submit more than this number of tokens in a single batch if other slots are processing
|
||||||
|
const int32_t n_chunk_pp = std::max(n_slots_pp > 0 ? (n_batch / n_slots_pp) : n_batch, 8);
|
||||||
|
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
// this slot still has a prompt to be processed
|
// this slot still has a prompt to be processed
|
||||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
||||||
@ -2609,8 +2629,10 @@ struct server_context {
|
|||||||
// remove the non-common part from the cache
|
// remove the non-common part from the cache
|
||||||
slot.cache_tokens.resize(slot.n_past);
|
slot.cache_tokens.resize(slot.n_past);
|
||||||
|
|
||||||
|
int n_cur = 0;
|
||||||
|
|
||||||
// add prompt tokens for processing in the current batch
|
// add prompt tokens for processing in the current batch
|
||||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch && n_cur < n_chunk_pp) {
|
||||||
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
|
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
if (slot.params.cache_prompt) {
|
||||||
@ -2619,6 +2641,8 @@ struct server_context {
|
|||||||
|
|
||||||
slot.n_prompt_tokens_processed++;
|
slot.n_prompt_tokens_processed++;
|
||||||
slot.n_past++;
|
slot.n_past++;
|
||||||
|
|
||||||
|
n_cur++;
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user