llama : no longer perform uninitialized access to the KV cache

This commit is contained in:
Georgi Gerganov 2023-10-08 11:49:38 +03:00
parent acead654d2
commit ee268b5446
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -1329,7 +1329,9 @@ static bool llama_kv_cache_init(
// cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*ggml_tensor_overhead()); // cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*ggml_tensor_overhead());
// change it and test that it works // change it and test that it works
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
memset(cache.buf.data, 0, cache.buf.size);
// this is not necessary, since we should not be accessing cache data that has not been initialized yet
//memset(cache.buf.data, 0, cache.buf.size);
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = cache.buf.size; params.mem_size = cache.buf.size;
@ -1430,7 +1432,7 @@ static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
} }
} }
return 0; return 1;
} }
static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, int32_t c1) { static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
@ -5020,8 +5022,7 @@ static int llama_decode_internal(
// a heuristic, to avoid attending the full cache if it is not yet utilized // a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears // after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important // if we start defragmenting the cache, the benefit from this will be more important
//kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA? kv_self.n = llama_kv_cache_cell_max(kv_self);
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self)));
//printf("kv_self.n = %d\n", kv_self.n); //printf("kv_self.n = %d\n", kv_self.n);