mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 13:27:21 +01:00
server : fix segfault on long system prompt
This commit is contained in:
parent
6e02327e8b
commit
7eda5583fa
@ -1136,28 +1136,19 @@ struct server_context {
|
||||
if (!system_prompt.empty()) {
|
||||
system_tokens = ::llama_tokenize(ctx, system_prompt, true);
|
||||
|
||||
llama_batch_clear(batch);
|
||||
|
||||
for (int i = 0; i < (int)system_tokens.size(); ++i) {
|
||||
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
|
||||
}
|
||||
|
||||
const int32_t n_batch = llama_n_batch(ctx);
|
||||
const int32_t n_tokens_prompt = system_tokens.size();
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i);
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);
|
||||
|
||||
if (llama_decode(ctx, batch_view) != 0) {
|
||||
llama_batch_clear(batch);
|
||||
|
||||
for (int32_t j = 0; j < n_tokens; ++j) {
|
||||
llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
LOG_ERROR("llama_decode() failed", {});
|
||||
return;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user