llama : extend llama_kv_cache API

This commit is contained in:
Georgi Gerganov 2023-09-18 15:53:03 +03:00
parent 6952a460b9
commit 4d76d762ef
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 84 additions and 32 deletions

View File

@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, false };
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, };
if (llama_decode(ctx, batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;

View File

@ -79,7 +79,9 @@ static void write_logfile(
static std::vector<float> softmax(const std::vector<float>& logits) {
std::vector<float> probs(logits.size());
float max_logit = logits[0];
for (float v : logits) max_logit = std::max(max_logit, v);
for (float v : logits) {
max_logit = std::max(max_logit, v);
}
double sum_exp = 0.0;
for (size_t i = 0; i < logits.size(); i++) {
// Subtract the maximum logit value from the current logit value for numerical stability
@ -88,15 +90,21 @@ static std::vector<float> softmax(const std::vector<float>& logits) {
sum_exp += exp_logit;
probs[i] = exp_logit;
}
for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp;
for (size_t i = 0; i < probs.size(); i++) {
probs[i] /= sum_exp;
}
return probs;
}
static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) {
float max_logit = logits[0];
for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]);
for (int i = 1; i < n_vocab; ++i) {
max_logit = std::max(max_logit, logits[i]);
}
double sum_exp = 0.0;
for (int i = 0; i < n_vocab; ++i) sum_exp += expf(logits[i] - max_logit);
for (int i = 0; i < n_vocab; ++i) {
sum_exp += expf(logits[i] - max_logit);
}
return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
}
@ -107,7 +115,8 @@ static void process_logits(
std::mutex mutex;
int counter = 0;
auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
double local_nll = 0, local_nll2 = 0;
double local_nll = 0;
double local_nll2 = 0;
while (true) {
std::unique_lock<std::mutex> lock(mutex);
int i = counter++;
@ -125,10 +134,13 @@ static void process_logits(
prob_history[i] = results.prob;
}
};
for (auto & w : workers) w = std::thread(compute);
for (auto & w : workers) {
w = std::thread(compute);
}
compute();
for (auto & w : workers) w.join();
for (auto & w : workers) {
w.join();
}
}
static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
@ -194,6 +206,9 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_keep_seq(ctx, -1);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
@ -319,6 +334,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_keep_seq(ctx, -1);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
@ -549,6 +567,9 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
query_embd.resize(32);
}
// clear the KV cache
llama_kv_cache_keep_seq(ctx, -1);
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
if (logits.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);

View File

@ -1316,7 +1316,8 @@ static bool llama_kv_cache_find_slot(
return true;
}
void llama_kv_cache_update_cell_max(struct llama_kv_cache & cache) {
void llama_kv_cache_update(struct llama_kv_cache & cache) {
// compute new cell_max
cache.cell_max = 0;
for (uint32_t i = 0; i < cache.size; i++) {
@ -1326,18 +1327,40 @@ void llama_kv_cache_update_cell_max(struct llama_kv_cache & cache) {
}
}
void llama_kv_cache_clear(struct llama_kv_cache & cache, int32_t p0, int32_t p1) {
cache.head = p0;
void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
if (c0 < 0) c0 = 0;
if (c1 < 0) c1 = cache.size;
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = cache.size;
for (int32_t i = p0; i < p1; ++i) {
for (int32_t i = c0; i < c1; ++i) {
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
}
llama_kv_cache_update_cell_max(cache);
llama_kv_cache_update(cache);
}
void llama_kv_cache_rm_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id)) {
cache.cells[i].seq_id.erase(seq_id);
if (cache.cells[i].seq_id.empty()) {
cache.cells[i].pos = -1;
}
}
}
llama_kv_cache_update(cache);
}
void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
for (uint32_t i = 0; i < cache.size; ++i) {
if (!cache.cells[i].has_seq_id(seq_id)) {
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
}
}
llama_kv_cache_update(cache);
}
//
@ -3968,10 +3991,6 @@ static bool llama_eval_internal(
batch.seq_id = seq_id.data();
}
if (batch.clear_kv) {
llama_kv_cache_clear(kv_self, 0, -1);
}
if (!llama_kv_cache_find_slot(kv_self, batch)) {
return false;
}
@ -6803,8 +6822,16 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
return ctx->kv_self.head;
}
void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1) {
llama_kv_cache_clear(ctx->kv_self, p0, p1);
void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1) {
llama_kv_cache_rm_tokens(ctx->kv_self, c0, c1);
}
void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id) {
llama_kv_cache_rm_seq(ctx->kv_self, seq_id);
}
void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id) {
llama_kv_cache_keep_seq(ctx->kv_self, seq_id);
}
// Returns the *maximum* size of the state
@ -7203,7 +7230,7 @@ int llama_eval(
uint32_t n_tokens,
int n_past,
int n_threads) {
llama_kv_cache_clear(ctx->kv_self, n_past, -1);
llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1);
if (!llama_eval_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) {
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
@ -7226,9 +7253,9 @@ int llama_eval_embd(
uint32_t n_tokens,
int n_past,
int n_threads) {
llama_kv_cache_clear(ctx->kv_self, n_past, -1);
llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1);
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, n_past, 1, 0, n_past == 0, };
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, n_past, 1, 0, };
if (!llama_eval_internal(*ctx, batch, n_threads)) {
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
@ -7259,7 +7286,6 @@ struct llama_batch llama_batch_get_one(
/*all_pos_0 =*/ pos_0,
/*all_pos_1 =*/ 1,
/*all_seq_id =*/ seq_id,
/*clear_kv =*/ pos_0 == 0,
};
}

11
llama.h
View File

@ -84,8 +84,6 @@ extern "C" {
llama_pos all_pos_0; // used if pos == NULL
llama_pos all_pos_1; // used if pos == NULL
llama_seq_id all_seq_id; // used if seq_id == NULL
bool clear_kv; // if true, clear the entire KV cache. common usage for perplexity calculations
} llama_seq;
enum llama_log_level {
@ -323,7 +321,14 @@ extern "C" {
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1);
// Remove all tokens between cells [c0, c1)
LLAMA_API void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1);
// Removes all tokens that belong to the specified sequence
LLAMA_API void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id);
// Removes all tokens that do not belong to the specified sequence
LLAMA_API void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id);
//
// State / sessions