mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-27 04:23:06 +01:00
server : fix crash when prompt exceeds context size (#3996)
This commit is contained in:
parent
34b0a08207
commit
d96ca7ded7
@ -1557,6 +1557,35 @@ struct llama_server_context
|
|||||||
|
|
||||||
slot.num_prompt_tokens = prompt_tokens.size();
|
slot.num_prompt_tokens = prompt_tokens.size();
|
||||||
|
|
||||||
|
if (slot.params.n_keep < 0)
|
||||||
|
{
|
||||||
|
slot.params.n_keep = slot.num_prompt_tokens;
|
||||||
|
}
|
||||||
|
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||||
|
|
||||||
|
// if input prompt is too big, truncate it
|
||||||
|
if (slot.num_prompt_tokens >= slot.n_ctx)
|
||||||
|
{
|
||||||
|
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||||
|
const int n_block_size = n_left / 2;
|
||||||
|
const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||||
|
|
||||||
|
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
|
||||||
|
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
|
||||||
|
|
||||||
|
LOG_VERBOSE("input truncated", {
|
||||||
|
{"n_ctx", slot.n_ctx},
|
||||||
|
{"n_keep", slot.params.n_keep},
|
||||||
|
{"n_left", n_left},
|
||||||
|
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
||||||
|
});
|
||||||
|
slot.truncated = true;
|
||||||
|
prompt_tokens = new_tokens;
|
||||||
|
|
||||||
|
slot.num_prompt_tokens = prompt_tokens.size();
|
||||||
|
GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
|
||||||
|
}
|
||||||
|
|
||||||
if (!slot.params.cache_prompt)
|
if (!slot.params.cache_prompt)
|
||||||
{
|
{
|
||||||
llama_sampling_reset(slot.ctx_sampling);
|
llama_sampling_reset(slot.ctx_sampling);
|
||||||
@ -1566,35 +1595,6 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
if (slot.params.n_keep < 0)
|
|
||||||
{
|
|
||||||
slot.params.n_keep = slot.num_prompt_tokens;
|
|
||||||
}
|
|
||||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
|
||||||
|
|
||||||
// if input prompt is too big, truncate it
|
|
||||||
if (slot.num_prompt_tokens >= slot.n_ctx)
|
|
||||||
{
|
|
||||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
|
||||||
const int n_block_size = n_left / 2;
|
|
||||||
const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
|
||||||
|
|
||||||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
|
|
||||||
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
|
|
||||||
|
|
||||||
LOG_VERBOSE("input truncated", {
|
|
||||||
{"n_ctx", slot.n_ctx},
|
|
||||||
{"n_keep", slot.params.n_keep},
|
|
||||||
{"n_left", n_left},
|
|
||||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
|
||||||
});
|
|
||||||
slot.truncated = true;
|
|
||||||
prompt_tokens = new_tokens;
|
|
||||||
|
|
||||||
slot.num_prompt_tokens = prompt_tokens.size();
|
|
||||||
GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
// push the prompt into the sampling context (do not apply grammar)
|
// push the prompt into the sampling context (do not apply grammar)
|
||||||
for (auto &token : prompt_tokens)
|
for (auto &token : prompt_tokens)
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user