llama : temp fix for clearing "future" tokens from the KV cache

This commit is contained in:
Georgi Gerganov 2023-10-02 16:42:14 +03:00
parent 6a9fe3dfac
commit 0f332a9104
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -7478,6 +7478,25 @@ void llama_batch_free(struct llama_batch batch) {
int llama_decode( int llama_decode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch) { struct llama_batch batch) {
// TODO: temporary solution to auto clear "future" tokens from the cache
// ref: https://github.com/ggerganov/llama.cpp/pull/3400
if (batch.pos) {
std::map<llama_seq_id, llama_pos> seq_min_pos;
for (int i = 0; i < batch.n_tokens; i++) {
if (seq_min_pos.count(batch.seq_id[i]) == 0) {
seq_min_pos[batch.seq_id[i]] = batch.pos[i];
} else {
seq_min_pos[batch.seq_id[i]] = std::min(seq_min_pos[batch.seq_id[i]], batch.pos[i]);
}
}
for (auto & kv : seq_min_pos) {
llama_kv_cache_seq_rm(ctx->kv_self, kv.first, kv.second, ctx->cparams.n_ctx);
}
} else {
llama_kv_cache_seq_rm(ctx->kv_self, batch.all_seq_id, batch.all_pos_0, ctx->cparams.n_ctx);
}
const int ret = llama_decode_internal(*ctx, batch); const int ret = llama_decode_internal(*ctx, batch);
if (ret < 0) { if (ret < 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);