mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 13:28:50 +01:00
Additional KL-divergence statistics (#5081)
* perplexity: add top-token probability * perplexity: add additional KL-divergence statistics * perplexity: a better organized KL-divergence statistics output --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
9ecdd12e95
commit
44879ee885
@ -222,13 +222,18 @@ struct kl_divergence_result {
|
|||||||
double sum_kld2 = 0;
|
double sum_kld2 = 0;
|
||||||
double sum_nll_diff = 0;
|
double sum_nll_diff = 0;
|
||||||
double sum_nll_diff2 = 0;
|
double sum_nll_diff2 = 0;
|
||||||
|
size_t n_same_top = 0;
|
||||||
size_t count = 0;
|
size_t count = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
static void log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
|
static double log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
|
||||||
float max_logit = logits[0];
|
float max_logit = logits[0];
|
||||||
|
int imax = 0;
|
||||||
for (int i = 1; i < n_vocab; ++i) {
|
for (int i = 1; i < n_vocab; ++i) {
|
||||||
max_logit = std::max(max_logit, logits[i]);
|
if (logits[i] > max_logit) {
|
||||||
|
max_logit = logits[i];
|
||||||
|
imax = i;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
double sum_exp = 0.0;
|
double sum_exp = 0.0;
|
||||||
for (int i = 0; i < n_vocab; ++i) {
|
for (int i = 0; i < n_vocab; ++i) {
|
||||||
@ -247,8 +252,14 @@ static void log_softmax(int n_vocab, const float * logits, const uint16_t * base
|
|||||||
kld.sum_nll_diff2 += nll*nll;
|
kld.sum_nll_diff2 += nll*nll;
|
||||||
max_logit += log_sum_exp;
|
max_logit += log_sum_exp;
|
||||||
double sum = 0;
|
double sum = 0;
|
||||||
|
int imax_base = -1;
|
||||||
|
float p_log_base_max = 0;
|
||||||
for (int i = 0; i < n_vocab; ++i) {
|
for (int i = 0; i < n_vocab; ++i) {
|
||||||
const float p_log_base = scale*base_log_prob[i] + min_log_prob;
|
const float p_log_base = scale*base_log_prob[i] + min_log_prob;
|
||||||
|
if (i == 0 || p_log_base > p_log_base_max) {
|
||||||
|
p_log_base_max = p_log_base;
|
||||||
|
imax_base = i;
|
||||||
|
}
|
||||||
if (p_log_base > -16.f) {
|
if (p_log_base > -16.f) {
|
||||||
const float p_base = expf(p_log_base);
|
const float p_base = expf(p_log_base);
|
||||||
sum += p_base * (p_log_base - logits[i] + max_logit);
|
sum += p_base * (p_log_base - logits[i] + max_logit);
|
||||||
@ -257,14 +268,17 @@ static void log_softmax(int n_vocab, const float * logits, const uint16_t * base
|
|||||||
kld.sum_kld += sum;
|
kld.sum_kld += sum;
|
||||||
kld.sum_kld2 += sum*sum;
|
kld.sum_kld2 += sum*sum;
|
||||||
++kld.count;
|
++kld.count;
|
||||||
|
if (imax == imax_base) ++kld.n_same_top;
|
||||||
|
return sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token,
|
static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token,
|
||||||
std::vector<std::thread> & workers, const std::vector<uint16_t> & base_log_probs, kl_divergence_result & kld) {
|
std::vector<std::thread> & workers, const std::vector<uint16_t> & base_log_probs, kl_divergence_result & kld,
|
||||||
|
float * kld_values) {
|
||||||
std::mutex mutex;
|
std::mutex mutex;
|
||||||
const int nv = 2*((n_vocab + 1)/2) + 4;
|
const int nv = 2*((n_vocab + 1)/2) + 4;
|
||||||
int counter = 0;
|
int counter = 0;
|
||||||
auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv] () {
|
auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv, kld_values] () {
|
||||||
kl_divergence_result local_kld;
|
kl_divergence_result local_kld;
|
||||||
while (true) {
|
while (true) {
|
||||||
std::unique_lock<std::mutex> lock(mutex);
|
std::unique_lock<std::mutex> lock(mutex);
|
||||||
@ -276,11 +290,13 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
|
|||||||
kld.sum_kld2 += local_kld.sum_kld2;
|
kld.sum_kld2 += local_kld.sum_kld2;
|
||||||
kld.sum_nll_diff += local_kld.sum_nll_diff;
|
kld.sum_nll_diff += local_kld.sum_nll_diff;
|
||||||
kld.sum_nll_diff2 += local_kld.sum_nll_diff2;
|
kld.sum_nll_diff2 += local_kld.sum_nll_diff2;
|
||||||
|
kld.n_same_top += local_kld.n_same_top;
|
||||||
kld.count += local_kld.count;
|
kld.count += local_kld.count;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
|
double v = log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
|
||||||
|
kld_values[i] = (float)v;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
for (auto & w : workers) {
|
for (auto & w : workers) {
|
||||||
@ -1615,7 +1631,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
|||||||
in.read((char *)&n_vocab, sizeof(n_vocab));
|
in.read((char *)&n_vocab, sizeof(n_vocab));
|
||||||
in.read((char *)&n_chunk, sizeof(n_chunk));
|
in.read((char *)&n_chunk, sizeof(n_chunk));
|
||||||
if (in.fail()) {
|
if (in.fail()) {
|
||||||
fprintf(stderr, "%s: failed rwading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
|
fprintf(stderr, "%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (n_vocab != llama_n_vocab(llama_get_model(ctx))) {
|
if (n_vocab != llama_n_vocab(llama_get_model(ctx))) {
|
||||||
@ -1634,6 +1650,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
|||||||
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
||||||
|
|
||||||
std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
|
std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
|
||||||
|
std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
if (num_batches > 1) {
|
if (num_batches > 1) {
|
||||||
logits.reserve(n_ctx * n_vocab);
|
logits.reserve(n_ctx * n_vocab);
|
||||||
@ -1652,6 +1669,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
kl_divergence_result kld;
|
kl_divergence_result kld;
|
||||||
|
auto kld_ptr = kld_values.data();
|
||||||
|
|
||||||
for (int i = 0; i < n_chunk; ++i) {
|
for (int i = 0; i < n_chunk; ++i) {
|
||||||
const int start = i * n_ctx;
|
const int start = i * n_ctx;
|
||||||
@ -1705,20 +1723,24 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
|||||||
}
|
}
|
||||||
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
|
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
|
||||||
|
|
||||||
printf("\nchunk PPL ln(PPL(Q)/PPL(base)) KL-Divergence\n");
|
printf("\nchunk PPL ln(PPL(Q)/PPL(base)) KL-Divergence Same top\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
const int first = n_ctx/2;
|
const int first = n_ctx/2;
|
||||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
|
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
|
||||||
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
||||||
workers, log_probs_uint16, kld);
|
workers, log_probs_uint16, kld, kld_ptr);
|
||||||
|
kld_ptr += n_ctx - 1 - first;
|
||||||
|
|
||||||
auto ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
|
auto ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
|
||||||
auto log_ppl_ratio = mean_and_uncertainty(kld.sum_nll_diff, kld.sum_nll_diff2, kld.count);
|
auto log_ppl_ratio = mean_and_uncertainty(kld.sum_nll_diff, kld.sum_nll_diff2, kld.count);
|
||||||
auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
|
auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
|
||||||
|
auto p_top = 1.*kld.n_same_top/kld.count;
|
||||||
|
auto d_p_top = sqrt(p_top*(1 - p_top)/(kld.count - 1));
|
||||||
|
|
||||||
printf("%4d %10.4lf %10.5lf ± %10.5f %10.5f ± %10.5lf\n", i+1, exp(ppl.first),
|
printf("%4d %10.4lf %10.5lf ± %10.5f %10.5f ± %10.5lf %.5f ± %.5f\n", i+1, exp(ppl.first),
|
||||||
log_ppl_ratio.first, log_ppl_ratio.second, kl_div.first, kl_div.second);
|
log_ppl_ratio.first, log_ppl_ratio.second, kl_div.first, kl_div.second,
|
||||||
|
p_top, d_p_top);
|
||||||
|
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
@ -1726,6 +1748,35 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
|||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
|
if (kld.count < 100) return; // we do not wish to do statistics on so few values
|
||||||
|
|
||||||
|
std::sort(kld_values.begin(), kld_values.end());
|
||||||
|
|
||||||
|
printf("===== KL-divergence statistics\n");
|
||||||
|
auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
|
||||||
|
printf("Average: %10.6f ±%10.6lf\n", kl_div.first, kl_div.second);
|
||||||
|
auto kld_median = kld_values.size()%2 == 0 ? 0.5f*(kld_values[kld_values.size()/2] + kld_values[kld_values.size()/2-1])
|
||||||
|
: kld_values[kld_values.size()/2];
|
||||||
|
printf("Median : %10.6f\n", kld_median);
|
||||||
|
|
||||||
|
auto percentile = [&kld_values] (float fraction) {
|
||||||
|
if (fraction <= 0) return kld_values.front();
|
||||||
|
if (fraction >= 1) return kld_values.back();
|
||||||
|
float p = fraction*(kld_values.size() - 1);
|
||||||
|
size_t ip = size_t(p); p -= ip;
|
||||||
|
return (1 - p)*kld_values[ip] + p*kld_values[std::min(ip+1, kld_values.size()-1)];
|
||||||
|
};
|
||||||
|
|
||||||
|
printf("Maximum: %10.6f\n", kld_values.back());
|
||||||
|
printf("KLD_99 : %10.6f\n", percentile(0.99f));
|
||||||
|
printf("KLD_95 : %10.6f\n", percentile(0.95f));
|
||||||
|
printf("KLD_90 : %10.6f\n", percentile(0.90f));
|
||||||
|
|
||||||
|
printf("Minimum: %10.6f\n", kld_values.front());
|
||||||
|
printf("KLD_01 : %10.6f\n", percentile(0.01f));
|
||||||
|
printf("KLD_05 : %10.6f\n", percentile(0.05f));
|
||||||
|
printf("KLD_10 : %10.6f\n", percentile(0.10f));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
|
Loading…
Reference in New Issue
Block a user