llama : extend llama_kv_cache API

This commit is contained in:
Georgi Gerganov 2023-09-18 15:53:03 +03:00
parent 6952a460b9
commit 4d76d762ef
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 84 additions and 32 deletions

View File

@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){
if (n_eval > n_batch) { if (n_eval > n_batch) {
n_eval = n_batch; n_eval = n_batch;
} }
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, false }; llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, };
if (llama_decode(ctx, batch, params.n_threads)) { if (llama_decode(ctx, batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return false; return false;

View File

@ -79,7 +79,9 @@ static void write_logfile(
static std::vector<float> softmax(const std::vector<float>& logits) { static std::vector<float> softmax(const std::vector<float>& logits) {
std::vector<float> probs(logits.size()); std::vector<float> probs(logits.size());
float max_logit = logits[0]; float max_logit = logits[0];
for (float v : logits) max_logit = std::max(max_logit, v); for (float v : logits) {
max_logit = std::max(max_logit, v);
}
double sum_exp = 0.0; double sum_exp = 0.0;
for (size_t i = 0; i < logits.size(); i++) { for (size_t i = 0; i < logits.size(); i++) {
// Subtract the maximum logit value from the current logit value for numerical stability // Subtract the maximum logit value from the current logit value for numerical stability
@ -88,15 +90,21 @@ static std::vector<float> softmax(const std::vector<float>& logits) {
sum_exp += exp_logit; sum_exp += exp_logit;
probs[i] = exp_logit; probs[i] = exp_logit;
} }
for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp; for (size_t i = 0; i < probs.size(); i++) {
probs[i] /= sum_exp;
}
return probs; return probs;
} }
static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) { static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) {
float max_logit = logits[0]; float max_logit = logits[0];
for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]); for (int i = 1; i < n_vocab; ++i) {
max_logit = std::max(max_logit, logits[i]);
}
double sum_exp = 0.0; double sum_exp = 0.0;
for (int i = 0; i < n_vocab; ++i) sum_exp += expf(logits[i] - max_logit); for (int i = 0; i < n_vocab; ++i) {
sum_exp += expf(logits[i] - max_logit);
}
return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp}; return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
} }
@ -107,7 +115,8 @@ static void process_logits(
std::mutex mutex; std::mutex mutex;
int counter = 0; int counter = 0;
auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () { auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
double local_nll = 0, local_nll2 = 0; double local_nll = 0;
double local_nll2 = 0;
while (true) { while (true) {
std::unique_lock<std::mutex> lock(mutex); std::unique_lock<std::mutex> lock(mutex);
int i = counter++; int i = counter++;
@ -125,10 +134,13 @@ static void process_logits(
prob_history[i] = results.prob; prob_history[i] = results.prob;
} }
}; };
for (auto & w : workers) w = std::thread(compute); for (auto & w : workers) {
w = std::thread(compute);
}
compute(); compute();
for (auto & w : workers) w.join(); for (auto & w : workers) {
w.join();
}
} }
static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) { static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
@ -194,6 +206,9 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_keep_seq(ctx, -1);
for (int j = 0; j < num_batches; ++j) { for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch; const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch); const int batch_size = std::min(end - batch_start, n_batch);
@ -319,6 +334,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_keep_seq(ctx, -1);
for (int j = 0; j < num_batches; ++j) { for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch; const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch); const int batch_size = std::min(end - batch_start, n_batch);
@ -549,6 +567,9 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
query_embd.resize(32); query_embd.resize(32);
} }
// clear the KV cache
llama_kv_cache_keep_seq(ctx, -1);
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads); auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
if (logits.empty()) { if (logits.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);

View File

@ -1316,7 +1316,8 @@ static bool llama_kv_cache_find_slot(
return true; return true;
} }
void llama_kv_cache_update_cell_max(struct llama_kv_cache & cache) { void llama_kv_cache_update(struct llama_kv_cache & cache) {
// compute new cell_max
cache.cell_max = 0; cache.cell_max = 0;
for (uint32_t i = 0; i < cache.size; i++) { for (uint32_t i = 0; i < cache.size; i++) {
@ -1326,18 +1327,40 @@ void llama_kv_cache_update_cell_max(struct llama_kv_cache & cache) {
} }
} }
void llama_kv_cache_clear(struct llama_kv_cache & cache, int32_t p0, int32_t p1) { void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
cache.head = p0; if (c0 < 0) c0 = 0;
if (c1 < 0) c1 = cache.size;
if (p0 < 0) p0 = 0; for (int32_t i = c0; i < c1; ++i) {
if (p1 < 0) p1 = cache.size;
for (int32_t i = p0; i < p1; ++i) {
cache.cells[i].pos = -1; cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear(); cache.cells[i].seq_id.clear();
} }
llama_kv_cache_update_cell_max(cache); llama_kv_cache_update(cache);
}
void llama_kv_cache_rm_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id)) {
cache.cells[i].seq_id.erase(seq_id);
if (cache.cells[i].seq_id.empty()) {
cache.cells[i].pos = -1;
}
}
}
llama_kv_cache_update(cache);
}
void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
for (uint32_t i = 0; i < cache.size; ++i) {
if (!cache.cells[i].has_seq_id(seq_id)) {
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
}
}
llama_kv_cache_update(cache);
} }
// //
@ -3968,10 +3991,6 @@ static bool llama_eval_internal(
batch.seq_id = seq_id.data(); batch.seq_id = seq_id.data();
} }
if (batch.clear_kv) {
llama_kv_cache_clear(kv_self, 0, -1);
}
if (!llama_kv_cache_find_slot(kv_self, batch)) { if (!llama_kv_cache_find_slot(kv_self, batch)) {
return false; return false;
} }
@ -6803,8 +6822,16 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
return ctx->kv_self.head; return ctx->kv_self.head;
} }
void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1) { void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1) {
llama_kv_cache_clear(ctx->kv_self, p0, p1); llama_kv_cache_rm_tokens(ctx->kv_self, c0, c1);
}
void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id) {
llama_kv_cache_rm_seq(ctx->kv_self, seq_id);
}
void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id) {
llama_kv_cache_keep_seq(ctx->kv_self, seq_id);
} }
// Returns the *maximum* size of the state // Returns the *maximum* size of the state
@ -7203,7 +7230,7 @@ int llama_eval(
uint32_t n_tokens, uint32_t n_tokens,
int n_past, int n_past,
int n_threads) { int n_threads) {
llama_kv_cache_clear(ctx->kv_self, n_past, -1); llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1);
if (!llama_eval_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) { if (!llama_eval_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) {
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
@ -7226,9 +7253,9 @@ int llama_eval_embd(
uint32_t n_tokens, uint32_t n_tokens,
int n_past, int n_past,
int n_threads) { int n_threads) {
llama_kv_cache_clear(ctx->kv_self, n_past, -1); llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1);
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, n_past, 1, 0, n_past == 0, }; llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, n_past, 1, 0, };
if (!llama_eval_internal(*ctx, batch, n_threads)) { if (!llama_eval_internal(*ctx, batch, n_threads)) {
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
@ -7259,7 +7286,6 @@ struct llama_batch llama_batch_get_one(
/*all_pos_0 =*/ pos_0, /*all_pos_0 =*/ pos_0,
/*all_pos_1 =*/ 1, /*all_pos_1 =*/ 1,
/*all_seq_id =*/ seq_id, /*all_seq_id =*/ seq_id,
/*clear_kv =*/ pos_0 == 0,
}; };
} }

11
llama.h
View File

@ -84,8 +84,6 @@ extern "C" {
llama_pos all_pos_0; // used if pos == NULL llama_pos all_pos_0; // used if pos == NULL
llama_pos all_pos_1; // used if pos == NULL llama_pos all_pos_1; // used if pos == NULL
llama_seq_id all_seq_id; // used if seq_id == NULL llama_seq_id all_seq_id; // used if seq_id == NULL
bool clear_kv; // if true, clear the entire KV cache. common usage for perplexity calculations
} llama_seq; } llama_seq;
enum llama_log_level { enum llama_log_level {
@ -323,7 +321,14 @@ extern "C" {
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx), LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
"avoid using this, it will be removed in the future, instead - count the tokens in user code"); "avoid using this, it will be removed in the future, instead - count the tokens in user code");
LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1); // Remove all tokens between cells [c0, c1)
LLAMA_API void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1);
// Removes all tokens that belong to the specified sequence
LLAMA_API void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id);
// Removes all tokens that do not belong to the specified sequence
LLAMA_API void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id);
// //
// State / sessions // State / sessions