llama : fix kv shift bug (#3835)

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-10-29 18:32:51 +02:00 committed by GitHub
parent d69d777c02
commit 71a09da301
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1552,14 +1552,14 @@ static void llama_kv_cache_seq_shift(
for (uint32_t i = 0; i < cache.size; ++i) { for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].pos += delta; cache.has_shift = true;
cache.cells[i].pos += delta;
cache.cells[i].delta += delta;
if (cache.cells[i].pos < 0) { if (cache.cells[i].pos < 0) {
cache.cells[i].pos = -1; cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear(); cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i; if (new_head == cache.size) new_head = i;
} else {
cache.has_shift = true;
cache.cells[i].delta = delta;
} }
} }
} }
@ -6073,11 +6073,20 @@ static int llama_decode_internal(
#endif #endif
// update the kv ring buffer // update the kv ring buffer
lctx.kv_self.has_shift = false; {
lctx.kv_self.head += n_tokens; if (kv_self.has_shift) {
// Ensure kv cache head points to a valid index. kv_self.has_shift = false;
if (lctx.kv_self.head >= lctx.kv_self.size) { for (uint32_t i = 0; i < kv_self.size; ++i) {
lctx.kv_self.head = 0; kv_self.cells[i].delta = 0;
}
}
kv_self.head += n_tokens;
// Ensure kv cache head points to a valid index.
if (kv_self.head >= kv_self.size) {
kv_self.head = 0;
}
} }
#ifdef GGML_PERF #ifdef GGML_PERF