mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 13:27:21 +01:00
server : add self-extend support (#5104)
* Ported self extension to server example * Update server.cpp * Fixed prompt caching without self extend * Update server.cpp * Added description to server readme. * Update server.cpp * Update server.cpp * Update server.cpp * Update server.cpp * Update README.md * Changed descriptions * server : formatting * Update examples/server/server.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update examples/server/server.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update server.cpp * Update server.cpp --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
a1d6df129b
commit
ec903c0341
@ -30,7 +30,8 @@ Command line options:
|
|||||||
- `-cb`, `--cont-batching`: enable continuous batching (a.k.a dynamic batching) (default: disabled)
|
- `-cb`, `--cont-batching`: enable continuous batching (a.k.a dynamic batching) (default: disabled)
|
||||||
- `-spf FNAME`, `--system-prompt-file FNAME` Set a file to load "a system prompt (initial prompt of all slots), this is useful for chat applications. [See more](#change-system-prompt-on-runtime)
|
- `-spf FNAME`, `--system-prompt-file FNAME` Set a file to load "a system prompt (initial prompt of all slots), this is useful for chat applications. [See more](#change-system-prompt-on-runtime)
|
||||||
- `--mmproj MMPROJ_FILE`: Path to a multimodal projector file for LLaVA.
|
- `--mmproj MMPROJ_FILE`: Path to a multimodal projector file for LLaVA.
|
||||||
|
- `--grp-attn-n`: Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`
|
||||||
|
- `--grp-attn-w`: Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`
|
||||||
## Build
|
## Build
|
||||||
|
|
||||||
server is build alongside everything else from the root of the project
|
server is build alongside everything else from the root of the project
|
||||||
|
@ -184,6 +184,12 @@ struct llama_client_slot
|
|||||||
struct llama_sampling_params sparams;
|
struct llama_sampling_params sparams;
|
||||||
llama_sampling_context *ctx_sampling = nullptr;
|
llama_sampling_context *ctx_sampling = nullptr;
|
||||||
|
|
||||||
|
int32_t ga_i = 0; // group-attention state
|
||||||
|
int32_t ga_n = 1;// group-attention factor
|
||||||
|
int32_t ga_w = 512; // group-attention width
|
||||||
|
|
||||||
|
int32_t n_past_se = 0; // self-extend
|
||||||
|
|
||||||
// multimodal
|
// multimodal
|
||||||
std::vector<slot_image> images;
|
std::vector<slot_image> images;
|
||||||
|
|
||||||
@ -212,7 +218,8 @@ struct llama_client_slot
|
|||||||
sent_count = 0;
|
sent_count = 0;
|
||||||
sent_token_probs_index = 0;
|
sent_token_probs_index = 0;
|
||||||
infill = false;
|
infill = false;
|
||||||
|
ga_i = 0;
|
||||||
|
n_past_se = 0;
|
||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
|
|
||||||
for (slot_image & img : images)
|
for (slot_image & img : images)
|
||||||
@ -399,9 +406,26 @@ struct llama_server_context
|
|||||||
|
|
||||||
slot.id = i;
|
slot.id = i;
|
||||||
slot.n_ctx = n_ctx_slot;
|
slot.n_ctx = n_ctx_slot;
|
||||||
slot.reset();
|
|
||||||
|
|
||||||
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot);
|
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot);
|
||||||
|
|
||||||
|
const int ga_n = params.grp_attn_n;
|
||||||
|
const int ga_w = params.grp_attn_w;
|
||||||
|
|
||||||
|
if (ga_n != 1) {
|
||||||
|
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
|
||||||
|
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
|
||||||
|
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
|
||||||
|
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
|
||||||
|
LOG_TEE(" -> Slot %i - self-extend: ga_n = %d, ga_w = %d\n", slot.id, ga_n, ga_w);
|
||||||
|
}
|
||||||
|
|
||||||
|
slot.ga_i = 0;
|
||||||
|
slot.ga_n = ga_n;
|
||||||
|
slot.ga_w = ga_w;
|
||||||
|
|
||||||
|
slot.reset();
|
||||||
|
|
||||||
slots.push_back(slot);
|
slots.push_back(slot);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1348,6 +1372,8 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (llama_client_slot &slot : slots)
|
for (llama_client_slot &slot : slots)
|
||||||
|
{
|
||||||
|
if (slot.ga_n == 1)
|
||||||
{
|
{
|
||||||
if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx)
|
if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx)
|
||||||
{
|
{
|
||||||
@ -1371,12 +1397,13 @@ struct llama_server_context
|
|||||||
slot.truncated = true;
|
slot.truncated = true;
|
||||||
|
|
||||||
LOG_VERBOSE("context shift", {
|
LOG_VERBOSE("context shift", {
|
||||||
{"n_ctx", n_ctx},
|
{ "n_ctx", n_ctx },
|
||||||
{"n_keep", params.n_keep},
|
{ "n_keep", params.n_keep },
|
||||||
{"n_left", n_left},
|
{ "n_left", n_left },
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// decode any currently ongoing sequences
|
// decode any currently ongoing sequences
|
||||||
for (auto & slot : slots)
|
for (auto & slot : slots)
|
||||||
@ -1401,7 +1428,8 @@ struct llama_server_context
|
|||||||
|
|
||||||
slot.i_batch = batch.n_tokens;
|
slot.i_batch = batch.n_tokens;
|
||||||
|
|
||||||
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true);
|
const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
||||||
|
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
|
||||||
|
|
||||||
slot.n_past += 1;
|
slot.n_past += 1;
|
||||||
}
|
}
|
||||||
@ -1499,6 +1527,8 @@ struct llama_server_context
|
|||||||
llama_sampling_reset(slot.ctx_sampling);
|
llama_sampling_reset(slot.ctx_sampling);
|
||||||
|
|
||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
|
slot.n_past_se = 0;
|
||||||
|
slot.ga_i = 0;
|
||||||
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
|
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@ -1512,6 +1542,25 @@ struct llama_server_context
|
|||||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
||||||
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
|
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
|
||||||
|
|
||||||
|
if (slot.ga_n != 1)
|
||||||
|
{
|
||||||
|
int ga_i = 0;
|
||||||
|
int32_t ga_n = slot.ga_n;
|
||||||
|
int32_t ga_w = slot.ga_w;
|
||||||
|
int32_t slot_npast = 0;
|
||||||
|
for (int k = 0; k < slot.n_past; ++k)
|
||||||
|
{
|
||||||
|
while (slot_npast >= ga_i + ga_w) {
|
||||||
|
const int bd = (ga_w/ga_n)*(ga_n - 1);
|
||||||
|
slot_npast -= bd;
|
||||||
|
ga_i += ga_w/ga_n;
|
||||||
|
}
|
||||||
|
slot_npast++;
|
||||||
|
}
|
||||||
|
slot.n_past_se = slot_npast;
|
||||||
|
slot.ga_i = ga_i;
|
||||||
|
}
|
||||||
|
|
||||||
LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
|
LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1526,6 +1575,10 @@ struct llama_server_context
|
|||||||
// we have to evaluate at least 1 token to generate logits.
|
// we have to evaluate at least 1 token to generate logits.
|
||||||
LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id);
|
LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id);
|
||||||
slot.n_past--;
|
slot.n_past--;
|
||||||
|
if (slot.ga_i > 0)
|
||||||
|
{
|
||||||
|
slot.n_past_se--;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_VERBOSE("prompt ingested", {
|
LOG_VERBOSE("prompt ingested", {
|
||||||
@ -1538,9 +1591,22 @@ struct llama_server_context
|
|||||||
|
|
||||||
// process the prefix of first image
|
// process the prefix of first image
|
||||||
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
|
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
|
||||||
|
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
||||||
|
int ga_i = slot.ga_i;
|
||||||
|
int32_t ga_n = slot.ga_n;
|
||||||
|
int32_t ga_w = slot.ga_w;
|
||||||
for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past)
|
for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past)
|
||||||
{
|
{
|
||||||
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, { slot.id }, false);
|
if (slot.ga_n != 1)
|
||||||
|
{
|
||||||
|
while (slot_npast >= ga_i + ga_w) {
|
||||||
|
const int bd = (ga_w/ga_n)*(ga_n - 1);
|
||||||
|
slot_npast -= bd;
|
||||||
|
ga_i += ga_w/ga_n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
|
||||||
|
slot_npast += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (has_images && !ingest_images(slot, n_batch))
|
if (has_images && !ingest_images(slot, n_batch))
|
||||||
@ -1570,6 +1636,36 @@ struct llama_server_context
|
|||||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
|
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
|
||||||
{
|
{
|
||||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||||
|
|
||||||
|
for (auto & slot : slots)
|
||||||
|
{
|
||||||
|
if (slot.ga_n != 1)
|
||||||
|
{
|
||||||
|
// context extension via Self-Extend
|
||||||
|
while (slot.n_past_se >= slot.ga_i + slot.ga_w)
|
||||||
|
{
|
||||||
|
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
|
||||||
|
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
|
||||||
|
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
|
||||||
|
|
||||||
|
LOG_TEE("\n");
|
||||||
|
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
|
||||||
|
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
||||||
|
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
||||||
|
|
||||||
|
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
|
||||||
|
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
|
||||||
|
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
|
||||||
|
|
||||||
|
slot.n_past_se -= bd;
|
||||||
|
|
||||||
|
slot.ga_i += slot.ga_w / slot.ga_n;
|
||||||
|
|
||||||
|
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
|
||||||
|
}
|
||||||
|
slot.n_past_se += n_tokens;
|
||||||
|
}
|
||||||
|
}
|
||||||
llama_batch batch_view =
|
llama_batch batch_view =
|
||||||
{
|
{
|
||||||
n_tokens,
|
n_tokens,
|
||||||
@ -1583,6 +1679,7 @@ struct llama_server_context
|
|||||||
};
|
};
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch_view);
|
const int ret = llama_decode(ctx, batch_view);
|
||||||
|
|
||||||
if (ret != 0)
|
if (ret != 0)
|
||||||
{
|
{
|
||||||
if (n_batch == 1 || ret < 0)
|
if (n_batch == 1 || ret < 0)
|
||||||
@ -1728,6 +1825,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
|||||||
printf(" --override-kv KEY=TYPE:VALUE\n");
|
printf(" --override-kv KEY=TYPE:VALUE\n");
|
||||||
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
|
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
|
||||||
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
|
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
|
||||||
|
printf(" -gan N, --grp-attn-n N Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
|
||||||
|
printf(" -gaw N, --grp-attn-w N Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1913,6 +2012,25 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||||||
}
|
}
|
||||||
params.n_threads = std::stoi(argv[i]);
|
params.n_threads = std::stoi(argv[i]);
|
||||||
}
|
}
|
||||||
|
else if (arg == "--grp-attn-n" || arg == "-gan")
|
||||||
|
{
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
params.grp_attn_n = std::stoi(argv[i]);
|
||||||
|
}
|
||||||
|
else if (arg == "--grp-attn-w" || arg == "-gaw")
|
||||||
|
{
|
||||||
|
if (++i >= argc)
|
||||||
|
{
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
params.grp_attn_w = std::stoi(argv[i]);
|
||||||
|
}
|
||||||
else if (arg == "--threads-batch" || arg == "-tb")
|
else if (arg == "--threads-batch" || arg == "-tb")
|
||||||
{
|
{
|
||||||
if (++i >= argc)
|
if (++i >= argc)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user