mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-06 02:48:57 +01:00
llama : fix cell_max logic + rename functions
This commit is contained in:
parent
36714e16d0
commit
ddad227782
31
llama.cpp
31
llama.cpp
@ -1319,7 +1319,7 @@ static bool llama_kv_cache_find_slot(
|
|||||||
|
|
||||||
// find how many cells are currently in use
|
// find how many cells are currently in use
|
||||||
static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
|
static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
|
||||||
for (uint32_t i = cache.size - 2; i > 0; --i) {
|
for (uint32_t i = cache.size - 1; i > 0; --i) {
|
||||||
if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
|
if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
|
||||||
return i + 1;
|
return i + 1;
|
||||||
}
|
}
|
||||||
@ -2606,7 +2606,7 @@ static struct ggml_cgraph * llm_build_llama(
|
|||||||
const int n_gpu_layers = model.n_gpu_layers;
|
const int n_gpu_layers = model.n_gpu_layers;
|
||||||
|
|
||||||
const int32_t n_tokens = batch.n_tokens;
|
const int32_t n_tokens = batch.n_tokens;
|
||||||
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max + n_tokens;
|
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max;
|
||||||
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
|
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
|
||||||
|
|
||||||
const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift;
|
const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift;
|
||||||
@ -2994,7 +2994,7 @@ static struct ggml_cgraph * llm_build_baichaun(
|
|||||||
const int n_gpu_layers = model.n_gpu_layers;
|
const int n_gpu_layers = model.n_gpu_layers;
|
||||||
|
|
||||||
const int32_t n_tokens = batch.n_tokens;
|
const int32_t n_tokens = batch.n_tokens;
|
||||||
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max + n_tokens;
|
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max;
|
||||||
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
|
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
|
||||||
|
|
||||||
const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift;
|
const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift;
|
||||||
@ -3397,7 +3397,7 @@ static struct ggml_cgraph * llm_build_falcon(
|
|||||||
const int n_gpu_layers = model.n_gpu_layers;
|
const int n_gpu_layers = model.n_gpu_layers;
|
||||||
|
|
||||||
const int32_t n_tokens = batch.n_tokens;
|
const int32_t n_tokens = batch.n_tokens;
|
||||||
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max + n_tokens;
|
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max;
|
||||||
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
|
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
|
||||||
|
|
||||||
const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift;
|
const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift;
|
||||||
@ -3758,7 +3758,7 @@ static struct ggml_cgraph * llm_build_starcoder(
|
|||||||
const float norm_eps = hparams.f_norm_eps;
|
const float norm_eps = hparams.f_norm_eps;
|
||||||
|
|
||||||
const int32_t n_tokens = batch.n_tokens;
|
const int32_t n_tokens = batch.n_tokens;
|
||||||
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max + n_tokens;
|
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max;
|
||||||
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
|
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
|
||||||
|
|
||||||
auto & buf_compute = lctx.buf_compute;
|
auto & buf_compute = lctx.buf_compute;
|
||||||
@ -4013,13 +4013,13 @@ static struct ggml_cgraph * llama_build_graph(
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// evaluate the transformer
|
// decode a batch of tokens by evaluating the transformer
|
||||||
//
|
//
|
||||||
// - lctx: llama context
|
// - lctx: llama context
|
||||||
// - batch: batch to evaluate
|
// - batch: batch to evaluate
|
||||||
// - n_threads: number of threads to use
|
// - n_threads: number of threads to use
|
||||||
//
|
//
|
||||||
static bool llama_eval_internal(
|
static bool llama_decode_internal(
|
||||||
llama_context & lctx,
|
llama_context & lctx,
|
||||||
llama_batch batch,
|
llama_batch batch,
|
||||||
int n_threads) {
|
int n_threads) {
|
||||||
@ -4051,6 +4051,8 @@ static bool llama_eval_internal(
|
|||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd;
|
||||||
const int64_t n_vocab = hparams.n_vocab;
|
const int64_t n_vocab = hparams.n_vocab;
|
||||||
|
|
||||||
|
// helpers for smoother batch API transistion
|
||||||
|
// after deprecating the llama_eval calls, these will be removed
|
||||||
std::vector<llama_pos> pos;
|
std::vector<llama_pos> pos;
|
||||||
std::vector<llama_seq_id> seq_id;
|
std::vector<llama_seq_id> seq_id;
|
||||||
|
|
||||||
@ -4076,14 +4078,15 @@ static bool llama_eval_internal(
|
|||||||
// TODO: better strategies can be implemented
|
// TODO: better strategies can be implemented
|
||||||
kv_self.head = 0;
|
kv_self.head = 0;
|
||||||
|
|
||||||
|
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||||
// after enough generations, the benefit from this heuristic disappears
|
// after enough generations, the benefit from this heuristic disappears
|
||||||
// if we start defragmenting the cache, the benefit from this will be more important
|
// if we start defragmenting the cache, the benefit from this will be more important
|
||||||
kv_self.cell_max = llama_kv_cache_cell_max(kv_self);
|
kv_self.cell_max = llama_kv_cache_cell_max(kv_self);
|
||||||
|
//printf("kv_self.cell_max = %d\n", kv_self.cell_max);
|
||||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_allocr_reset(lctx.alloc);
|
ggml_allocr_reset(lctx.alloc);
|
||||||
|
|
||||||
@ -7329,7 +7332,7 @@ int llama_eval(
|
|||||||
int n_threads) {
|
int n_threads) {
|
||||||
llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1);
|
llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1);
|
||||||
|
|
||||||
if (!llama_eval_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) {
|
if (!llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -7354,7 +7357,7 @@ int llama_eval_embd(
|
|||||||
|
|
||||||
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, };
|
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, };
|
||||||
|
|
||||||
if (!llama_eval_internal(*ctx, batch, n_threads)) {
|
if (!llama_decode_internal(*ctx, batch, n_threads)) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -7391,7 +7394,7 @@ int llama_decode(
|
|||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch,
|
struct llama_batch batch,
|
||||||
int n_threads) {
|
int n_threads) {
|
||||||
if (!llama_eval_internal(*ctx, batch, n_threads)) {
|
if (!llama_decode_internal(*ctx, batch, n_threads)) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user