llama : fix handling of "future" tokens when loading sessions

This commit is contained in:
Georgi Gerganov 2023-10-03 18:29:22 +03:00
parent 0f332a9104
commit 337120cc0d
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
6 changed files with 41 additions and 40 deletions

View File

@ -543,6 +543,9 @@ int main(int argc, char ** argv) {
if (i > 0) {
embd.erase(embd.begin(), embd.begin() + i);
}
// remove any "future" tokens that we might have inherited from the session from the KV cache
llama_kv_cache_tokens_rm(ctx, n_past, -1);
}
// evaluate tokens in batches

View File

@ -332,7 +332,7 @@ int main(int argc, char ** argv) {
}
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, n_ctx);
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1);
const auto t_main_end = ggml_time_us();

View File

@ -448,7 +448,7 @@ struct llama_server_context
n_past = common_part(embd, prompt_tokens);
// since #3228 we now have to manually manage the KV cache
llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx);
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
embd = prompt_tokens;
if (n_past == num_prompt_tokens)

View File

@ -172,7 +172,7 @@ int main(int argc, char ** argv) {
LOG("out of drafted tokens\n");
}
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0));
++n_past_dft;
@ -257,7 +257,7 @@ int main(int argc, char ** argv) {
}
// evaluate the drafted token on the draft model
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, -1);
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0));
++n_past_cur;
@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
}
// evaluate the target model on the drafted tokens
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx);
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, -1);
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0));
++n_past_tgt;

View File

@ -1281,8 +1281,8 @@ static bool llama_kv_cache_init(
// find an empty slot of size "n_tokens" in the cache
// updates the cache head
static bool llama_kv_cache_find_slot(
struct llama_kv_cache & cache,
const struct llama_batch & batch) {
struct llama_kv_cache & cache,
const struct llama_batch & batch) {
const uint32_t n_ctx = cache.size;
const uint32_t n_tokens = batch.n_tokens;
@ -1350,10 +1350,13 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0,
}
static void llama_kv_cache_seq_rm(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].seq_id.erase(seq_id);
@ -1365,11 +1368,14 @@ static void llama_kv_cache_seq_rm(
}
static void llama_kv_cache_seq_cp(
struct llama_kv_cache & cache,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
struct llama_kv_cache & cache,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].seq_id.insert(seq_id_dst);
@ -1387,11 +1393,14 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
}
static void llama_kv_cache_seq_shift(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].pos += delta;
@ -7478,25 +7487,6 @@ void llama_batch_free(struct llama_batch batch) {
int llama_decode(
struct llama_context * ctx,
struct llama_batch batch) {
// TODO: temporary solution to auto clear "future" tokens from the cache
// ref: https://github.com/ggerganov/llama.cpp/pull/3400
if (batch.pos) {
std::map<llama_seq_id, llama_pos> seq_min_pos;
for (int i = 0; i < batch.n_tokens; i++) {
if (seq_min_pos.count(batch.seq_id[i]) == 0) {
seq_min_pos[batch.seq_id[i]] = batch.pos[i];
} else {
seq_min_pos[batch.seq_id[i]] = std::min(seq_min_pos[batch.seq_id[i]], batch.pos[i]);
}
}
for (auto & kv : seq_min_pos) {
llama_kv_cache_seq_rm(ctx->kv_self, kv.first, kv.second, ctx->cparams.n_ctx);
}
} else {
llama_kv_cache_seq_rm(ctx->kv_self, batch.all_seq_id, batch.all_pos_0, ctx->cparams.n_ctx);
}
const int ret = llama_decode_internal(*ctx, batch);
if (ret < 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);

View File

@ -330,12 +330,16 @@ extern "C" {
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
// Remove all tokens data of cells in [c0, c1)
// c0 < -1 : [0, c1]
// c1 < -1 : [c0, inf)
LLAMA_API void llama_kv_cache_tokens_rm(
struct llama_context * ctx,
int32_t c0,
int32_t c1);
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// p0 < -1 : [0, p1]
// p1 < -1 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
@ -344,6 +348,8 @@ extern "C" {
// Copy all tokens that belong to the specified sequence to another sequence
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
// p0 < -1 : [0, p1]
// p1 < -1 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_cp(
struct llama_context * ctx,
llama_seq_id seq_id_src,
@ -358,6 +364,8 @@ extern "C" {
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// If the KV cache is RoPEd, the KV data is updated accordingly
// p0 < -1 : [0, p1]
// p1 < -1 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_shift(
struct llama_context * ctx,
llama_seq_id seq_id,