mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-02 15:02:47 +01:00
llama : extend llama_kv_cache API
This commit is contained in:
parent
6952a460b9
commit
4d76d762ef
@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){
|
|||||||
if (n_eval > n_batch) {
|
if (n_eval > n_batch) {
|
||||||
n_eval = n_batch;
|
n_eval = n_batch;
|
||||||
}
|
}
|
||||||
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, false };
|
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, };
|
||||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
|
@ -79,7 +79,9 @@ static void write_logfile(
|
|||||||
static std::vector<float> softmax(const std::vector<float>& logits) {
|
static std::vector<float> softmax(const std::vector<float>& logits) {
|
||||||
std::vector<float> probs(logits.size());
|
std::vector<float> probs(logits.size());
|
||||||
float max_logit = logits[0];
|
float max_logit = logits[0];
|
||||||
for (float v : logits) max_logit = std::max(max_logit, v);
|
for (float v : logits) {
|
||||||
|
max_logit = std::max(max_logit, v);
|
||||||
|
}
|
||||||
double sum_exp = 0.0;
|
double sum_exp = 0.0;
|
||||||
for (size_t i = 0; i < logits.size(); i++) {
|
for (size_t i = 0; i < logits.size(); i++) {
|
||||||
// Subtract the maximum logit value from the current logit value for numerical stability
|
// Subtract the maximum logit value from the current logit value for numerical stability
|
||||||
@ -88,15 +90,21 @@ static std::vector<float> softmax(const std::vector<float>& logits) {
|
|||||||
sum_exp += exp_logit;
|
sum_exp += exp_logit;
|
||||||
probs[i] = exp_logit;
|
probs[i] = exp_logit;
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp;
|
for (size_t i = 0; i < probs.size(); i++) {
|
||||||
|
probs[i] /= sum_exp;
|
||||||
|
}
|
||||||
return probs;
|
return probs;
|
||||||
}
|
}
|
||||||
|
|
||||||
static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) {
|
static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) {
|
||||||
float max_logit = logits[0];
|
float max_logit = logits[0];
|
||||||
for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]);
|
for (int i = 1; i < n_vocab; ++i) {
|
||||||
|
max_logit = std::max(max_logit, logits[i]);
|
||||||
|
}
|
||||||
double sum_exp = 0.0;
|
double sum_exp = 0.0;
|
||||||
for (int i = 0; i < n_vocab; ++i) sum_exp += expf(logits[i] - max_logit);
|
for (int i = 0; i < n_vocab; ++i) {
|
||||||
|
sum_exp += expf(logits[i] - max_logit);
|
||||||
|
}
|
||||||
return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
|
return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -107,7 +115,8 @@ static void process_logits(
|
|||||||
std::mutex mutex;
|
std::mutex mutex;
|
||||||
int counter = 0;
|
int counter = 0;
|
||||||
auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
|
auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
|
||||||
double local_nll = 0, local_nll2 = 0;
|
double local_nll = 0;
|
||||||
|
double local_nll2 = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
std::unique_lock<std::mutex> lock(mutex);
|
std::unique_lock<std::mutex> lock(mutex);
|
||||||
int i = counter++;
|
int i = counter++;
|
||||||
@ -125,10 +134,13 @@ static void process_logits(
|
|||||||
prob_history[i] = results.prob;
|
prob_history[i] = results.prob;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
for (auto & w : workers) w = std::thread(compute);
|
for (auto & w : workers) {
|
||||||
|
w = std::thread(compute);
|
||||||
|
}
|
||||||
compute();
|
compute();
|
||||||
for (auto & w : workers) w.join();
|
for (auto & w : workers) {
|
||||||
|
w.join();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
||||||
@ -194,6 +206,9 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|||||||
|
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
|
// clear the KV cache
|
||||||
|
llama_kv_cache_keep_seq(ctx, -1);
|
||||||
|
|
||||||
for (int j = 0; j < num_batches; ++j) {
|
for (int j = 0; j < num_batches; ++j) {
|
||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
const int batch_size = std::min(end - batch_start, n_batch);
|
const int batch_size = std::min(end - batch_start, n_batch);
|
||||||
@ -319,6 +334,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
|
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
|
// clear the KV cache
|
||||||
|
llama_kv_cache_keep_seq(ctx, -1);
|
||||||
|
|
||||||
for (int j = 0; j < num_batches; ++j) {
|
for (int j = 0; j < num_batches; ++j) {
|
||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
const int batch_size = std::min(end - batch_start, n_batch);
|
const int batch_size = std::min(end - batch_start, n_batch);
|
||||||
@ -549,6 +567,9 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
query_embd.resize(32);
|
query_embd.resize(32);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clear the KV cache
|
||||||
|
llama_kv_cache_keep_seq(ctx, -1);
|
||||||
|
|
||||||
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
|
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
|
||||||
if (logits.empty()) {
|
if (logits.empty()) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
|
62
llama.cpp
62
llama.cpp
@ -1316,7 +1316,8 @@ static bool llama_kv_cache_find_slot(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_update_cell_max(struct llama_kv_cache & cache) {
|
void llama_kv_cache_update(struct llama_kv_cache & cache) {
|
||||||
|
// compute new cell_max
|
||||||
cache.cell_max = 0;
|
cache.cell_max = 0;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cache.size; i++) {
|
for (uint32_t i = 0; i < cache.size; i++) {
|
||||||
@ -1326,18 +1327,40 @@ void llama_kv_cache_update_cell_max(struct llama_kv_cache & cache) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_clear(struct llama_kv_cache & cache, int32_t p0, int32_t p1) {
|
void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
|
||||||
cache.head = p0;
|
if (c0 < 0) c0 = 0;
|
||||||
|
if (c1 < 0) c1 = cache.size;
|
||||||
|
|
||||||
if (p0 < 0) p0 = 0;
|
for (int32_t i = c0; i < c1; ++i) {
|
||||||
if (p1 < 0) p1 = cache.size;
|
|
||||||
|
|
||||||
for (int32_t i = p0; i < p1; ++i) {
|
|
||||||
cache.cells[i].pos = -1;
|
cache.cells[i].pos = -1;
|
||||||
cache.cells[i].seq_id.clear();
|
cache.cells[i].seq_id.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_update_cell_max(cache);
|
llama_kv_cache_update(cache);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_rm_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
|
||||||
|
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||||
|
if (cache.cells[i].has_seq_id(seq_id)) {
|
||||||
|
cache.cells[i].seq_id.erase(seq_id);
|
||||||
|
if (cache.cells[i].seq_id.empty()) {
|
||||||
|
cache.cells[i].pos = -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_kv_cache_update(cache);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
|
||||||
|
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||||
|
if (!cache.cells[i].has_seq_id(seq_id)) {
|
||||||
|
cache.cells[i].pos = -1;
|
||||||
|
cache.cells[i].seq_id.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_kv_cache_update(cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -3968,10 +3991,6 @@ static bool llama_eval_internal(
|
|||||||
batch.seq_id = seq_id.data();
|
batch.seq_id = seq_id.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (batch.clear_kv) {
|
|
||||||
llama_kv_cache_clear(kv_self, 0, -1);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -6803,8 +6822,16 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
|
|||||||
return ctx->kv_self.head;
|
return ctx->kv_self.head;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1) {
|
void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1) {
|
||||||
llama_kv_cache_clear(ctx->kv_self, p0, p1);
|
llama_kv_cache_rm_tokens(ctx->kv_self, c0, c1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||||
|
llama_kv_cache_rm_seq(ctx->kv_self, seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||||
|
llama_kv_cache_keep_seq(ctx->kv_self, seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the *maximum* size of the state
|
// Returns the *maximum* size of the state
|
||||||
@ -7203,7 +7230,7 @@ int llama_eval(
|
|||||||
uint32_t n_tokens,
|
uint32_t n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_threads) {
|
int n_threads) {
|
||||||
llama_kv_cache_clear(ctx->kv_self, n_past, -1);
|
llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1);
|
||||||
|
|
||||||
if (!llama_eval_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) {
|
if (!llama_eval_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
||||||
@ -7226,9 +7253,9 @@ int llama_eval_embd(
|
|||||||
uint32_t n_tokens,
|
uint32_t n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_threads) {
|
int n_threads) {
|
||||||
llama_kv_cache_clear(ctx->kv_self, n_past, -1);
|
llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1);
|
||||||
|
|
||||||
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, n_past, 1, 0, n_past == 0, };
|
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, n_past, 1, 0, };
|
||||||
|
|
||||||
if (!llama_eval_internal(*ctx, batch, n_threads)) {
|
if (!llama_eval_internal(*ctx, batch, n_threads)) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
||||||
@ -7259,7 +7286,6 @@ struct llama_batch llama_batch_get_one(
|
|||||||
/*all_pos_0 =*/ pos_0,
|
/*all_pos_0 =*/ pos_0,
|
||||||
/*all_pos_1 =*/ 1,
|
/*all_pos_1 =*/ 1,
|
||||||
/*all_seq_id =*/ seq_id,
|
/*all_seq_id =*/ seq_id,
|
||||||
/*clear_kv =*/ pos_0 == 0,
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
11
llama.h
11
llama.h
@ -84,8 +84,6 @@ extern "C" {
|
|||||||
llama_pos all_pos_0; // used if pos == NULL
|
llama_pos all_pos_0; // used if pos == NULL
|
||||||
llama_pos all_pos_1; // used if pos == NULL
|
llama_pos all_pos_1; // used if pos == NULL
|
||||||
llama_seq_id all_seq_id; // used if seq_id == NULL
|
llama_seq_id all_seq_id; // used if seq_id == NULL
|
||||||
|
|
||||||
bool clear_kv; // if true, clear the entire KV cache. common usage for perplexity calculations
|
|
||||||
} llama_seq;
|
} llama_seq;
|
||||||
|
|
||||||
enum llama_log_level {
|
enum llama_log_level {
|
||||||
@ -323,7 +321,14 @@ extern "C" {
|
|||||||
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
|
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
|
||||||
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
|
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
|
||||||
|
|
||||||
LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1);
|
// Remove all tokens between cells [c0, c1)
|
||||||
|
LLAMA_API void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1);
|
||||||
|
|
||||||
|
// Removes all tokens that belong to the specified sequence
|
||||||
|
LLAMA_API void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||||
|
|
||||||
|
// Removes all tokens that do not belong to the specified sequence
|
||||||
|
LLAMA_API void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||||
|
|
||||||
//
|
//
|
||||||
// State / sessions
|
// State / sessions
|
||||||
|
Loading…
Reference in New Issue
Block a user