mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 13:28:50 +01:00
KL-divergence (#5076)
* kl-divergence: be able to save all logits to a file * Add ability to compute KL-divergence --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
780e24a22e
commit
6f9939d119
@ -672,6 +672,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||
if (params.logdir.back() != DIRECTORY_SEPARATOR) {
|
||||
params.logdir += DIRECTORY_SEPARATOR;
|
||||
}
|
||||
} else if (arg == "--save-all-logits" || arg == "--kl-divergence-base") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.logits_file = argv[i];
|
||||
} else if (arg == "--perplexity" || arg == "--all-logits") {
|
||||
params.logits_all = true;
|
||||
} else if (arg == "--ppl-stride") {
|
||||
@ -716,6 +722,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||
break;
|
||||
}
|
||||
params.multiple_choice_tasks = std::stoi(argv[i]);
|
||||
} else if (arg == "--kl-divergence") {
|
||||
params.kl_divergence = true;
|
||||
} else if (arg == "--ignore-eos") {
|
||||
params.ignore_eos = true;
|
||||
} else if (arg == "--no-penalize-nl") {
|
||||
@ -967,6 +975,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||
printf(" --winogrande-tasks N number of tasks to use when computing the Winogrande score (default: %zu)\n", params.winogrande_tasks);
|
||||
printf(" --multiple-choice compute multiple choice score over random tasks from datafile supplied with -f\n");
|
||||
printf(" --multiple-choice-tasks N number of tasks to use when computing the multiple choice score (default: %zu)\n", params.winogrande_tasks);
|
||||
printf(" --kl-divergence computes KL-divergence to logits provided via --kl-divergence-base");
|
||||
printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
|
||||
printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft);
|
||||
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
|
||||
|
@ -91,6 +91,7 @@ struct gpt_params {
|
||||
std::string input_suffix = ""; // string to suffix user inputs with
|
||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||
std::string logdir = ""; // directory in which to save YAML log files
|
||||
std::string logits_file = ""; // file for saving *all* logits
|
||||
|
||||
std::vector<llama_model_kv_override> kv_overrides;
|
||||
|
||||
@ -111,6 +112,8 @@ struct gpt_params {
|
||||
bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt
|
||||
size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed
|
||||
|
||||
bool kl_divergence = false; // compute KL-divergence
|
||||
|
||||
bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS
|
||||
bool random_prompt = false; // do not randomize prompt if none provided
|
||||
bool use_color = false; // use color to distinguish generations and inputs
|
||||
|
@ -112,6 +112,43 @@ static results_log_softmax log_softmax(int n_vocab, const float * logits, int to
|
||||
return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
|
||||
}
|
||||
|
||||
static inline int nearest_int(float fval) {
|
||||
//assert(fval <= 4194303.f);
|
||||
float val = fval + 12582912.f;
|
||||
int i; memcpy(&i, &val, sizeof(int));
|
||||
return (i & 0x007fffff) - 0x00400000;
|
||||
}
|
||||
|
||||
static double log_softmax(int n_vocab, const float * logits, uint16_t * log_prob, int tok) {
|
||||
float max_logit = logits[0];
|
||||
float min_logit = logits[0];
|
||||
for (int i = 1; i < n_vocab; ++i) {
|
||||
max_logit = std::max(max_logit, logits[i]);
|
||||
min_logit = std::min(min_logit, logits[i]);
|
||||
}
|
||||
min_logit = std::max(min_logit, max_logit - 16);
|
||||
double sum_exp = 0.0;
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
sum_exp += expf(logits[i] - max_logit);
|
||||
}
|
||||
const float log_sum_exp = log(sum_exp);
|
||||
const float min_log_prob = min_logit - max_logit - log_sum_exp;
|
||||
const float scale = (max_logit - min_logit)/65535.f;
|
||||
float * d = (float *)log_prob;
|
||||
d[0] = scale;
|
||||
d[1] = min_log_prob;
|
||||
log_prob += 4;
|
||||
if (scale) {
|
||||
const float inv_scale = 1/scale;
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
log_prob[i] = logits[i] > min_logit ? nearest_int(inv_scale*(logits[i] - min_logit)) : 0;
|
||||
}
|
||||
} else {
|
||||
std::memset(log_prob, 0, n_vocab*sizeof(uint16_t));
|
||||
}
|
||||
return max_logit + log_sum_exp - logits[tok];
|
||||
}
|
||||
|
||||
static void process_logits(
|
||||
int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
|
||||
double & nll, double & nll2, float * logit_history, float * prob_history
|
||||
@ -147,6 +184,114 @@ static void process_logits(
|
||||
}
|
||||
}
|
||||
|
||||
static void process_logits(std::ostream& out, int n_vocab, const float * logits, const int * tokens, int n_token,
|
||||
std::vector<std::thread> & workers, std::vector<uint16_t> & log_probs, double & nll, double & nll2) {
|
||||
std::mutex mutex;
|
||||
const int nv = 2*((n_vocab + 1)/2) + 4;
|
||||
int counter = 0;
|
||||
auto compute = [&mutex, &counter, &log_probs, &nll, &nll2, n_vocab, logits, tokens, n_token, nv] () {
|
||||
double local_nll = 0;
|
||||
double local_nll2 = 0;
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
int i = counter++;
|
||||
if (i >= n_token) {
|
||||
nll += local_nll; nll2 += local_nll2;
|
||||
break;
|
||||
}
|
||||
lock.unlock();
|
||||
const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
|
||||
local_nll += v;
|
||||
local_nll2 += v*v;
|
||||
}
|
||||
};
|
||||
for (auto & w : workers) {
|
||||
w = std::thread(compute);
|
||||
}
|
||||
compute();
|
||||
for (auto & w : workers) {
|
||||
w.join();
|
||||
}
|
||||
out.write((const char *)log_probs.data(), n_token*nv*sizeof(uint16_t));
|
||||
}
|
||||
|
||||
struct kl_divergence_result {
|
||||
double sum_nll = 0;
|
||||
double sum_nll2 = 0;
|
||||
double sum_kld = 0;
|
||||
double sum_kld2 = 0;
|
||||
double sum_nll_diff = 0;
|
||||
double sum_nll_diff2 = 0;
|
||||
size_t count = 0;
|
||||
};
|
||||
|
||||
static void log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
|
||||
float max_logit = logits[0];
|
||||
for (int i = 1; i < n_vocab; ++i) {
|
||||
max_logit = std::max(max_logit, logits[i]);
|
||||
}
|
||||
double sum_exp = 0.0;
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
sum_exp += expf(logits[i] - max_logit);
|
||||
}
|
||||
const float log_sum_exp = log(sum_exp);
|
||||
const float * d = (const float *)base_log_prob;
|
||||
const float scale = d[0];
|
||||
const float min_log_prob = d[1];
|
||||
base_log_prob += 4;
|
||||
float nll = max_logit + log_sum_exp - logits[tok];
|
||||
kld.sum_nll += nll;
|
||||
kld.sum_nll2 += nll*nll;
|
||||
nll += (scale*base_log_prob[tok] + min_log_prob);
|
||||
kld.sum_nll_diff += nll;
|
||||
kld.sum_nll_diff2 += nll*nll;
|
||||
max_logit += log_sum_exp;
|
||||
double sum = 0;
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
const float p_log_base = scale*base_log_prob[i] + min_log_prob;
|
||||
if (p_log_base > -16.f) {
|
||||
const float p_base = expf(p_log_base);
|
||||
sum += p_base * (p_log_base - logits[i] + max_logit);
|
||||
}
|
||||
}
|
||||
kld.sum_kld += sum;
|
||||
kld.sum_kld2 += sum*sum;
|
||||
++kld.count;
|
||||
}
|
||||
|
||||
static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token,
|
||||
std::vector<std::thread> & workers, const std::vector<uint16_t> & base_log_probs, kl_divergence_result & kld) {
|
||||
std::mutex mutex;
|
||||
const int nv = 2*((n_vocab + 1)/2) + 4;
|
||||
int counter = 0;
|
||||
auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv] () {
|
||||
kl_divergence_result local_kld;
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
int i = counter++;
|
||||
if (i >= n_token) {
|
||||
kld.sum_nll += local_kld.sum_nll;
|
||||
kld.sum_nll2 += local_kld.sum_nll2;
|
||||
kld.sum_kld += local_kld.sum_kld;
|
||||
kld.sum_kld2 += local_kld.sum_kld2;
|
||||
kld.sum_nll_diff += local_kld.sum_nll_diff;
|
||||
kld.sum_nll_diff2 += local_kld.sum_nll_diff2;
|
||||
kld.count += local_kld.count;
|
||||
break;
|
||||
}
|
||||
lock.unlock();
|
||||
log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
|
||||
}
|
||||
};
|
||||
for (auto & w : workers) {
|
||||
w = std::thread(compute);
|
||||
}
|
||||
compute();
|
||||
for (auto & w : workers) {
|
||||
w.join();
|
||||
}
|
||||
}
|
||||
|
||||
static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
||||
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
||||
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
||||
@ -294,6 +439,18 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
std::ofstream logits_stream;
|
||||
if (!params.logits_file.empty()) {
|
||||
logits_stream.open(params.logits_file.c_str());
|
||||
if (!logits_stream.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open %s for writing\n", __func__, params.logits_file.c_str());
|
||||
return {};
|
||||
}
|
||||
fprintf(stderr, "%s: saving all logits to %s\n", __func__, params.logits_file.c_str());
|
||||
logits_stream.write("_logits_", 8);
|
||||
logits_stream.write((const char *)&n_ctx, sizeof(n_ctx));
|
||||
}
|
||||
|
||||
auto tim1 = std::chrono::high_resolution_clock::now();
|
||||
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
|
||||
|
||||
@ -336,6 +493,15 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||
|
||||
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
||||
|
||||
std::vector<uint16_t> log_probs;
|
||||
if (!params.logits_file.empty()) {
|
||||
logits_stream.write((const char *)&n_vocab, sizeof(n_vocab));
|
||||
logits_stream.write((const char *)&n_chunk, sizeof(n_chunk));
|
||||
logits_stream.write((const char *)tokens.data(), n_chunk*n_ctx*sizeof(tokens[0]));
|
||||
const int nv = 2*((n_vocab + 1)/2) + 4;
|
||||
log_probs.resize(n_ctx * nv);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_chunk; ++i) {
|
||||
const int start = i * n_ctx;
|
||||
const int end = start + n_ctx;
|
||||
@ -398,8 +564,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||
// process the entire prompt.
|
||||
const int first = n_ctx/2;
|
||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
|
||||
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
||||
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
|
||||
if (!params.logits_file.empty()) {
|
||||
process_logits(logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
||||
workers, log_probs, nll, nll2);
|
||||
} else {
|
||||
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
||||
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
|
||||
}
|
||||
count += n_ctx - first - 1;
|
||||
|
||||
// perplexity is e^(average negative log-likelihood)
|
||||
@ -1414,6 +1585,148 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
||||
if (params.logits_file.empty()) {
|
||||
fprintf(stderr, "%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__);
|
||||
return;
|
||||
}
|
||||
std::ifstream in(params.logits_file.c_str(), std::ios::binary);
|
||||
if (!in) {
|
||||
fprintf(stderr, "%s: failed to open %s\n", __func__, params.logits_file.c_str());
|
||||
return;
|
||||
}
|
||||
{
|
||||
char check[9]; check[8] = 0;
|
||||
in.read(check, 8);
|
||||
if (in.fail() || strncmp("_logits_", check, 8) != 0) {
|
||||
fprintf(stderr, "%s: %s does not look like a file containing log-probabilities\n", __func__, params.logits_file.c_str());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t n_ctx;
|
||||
in.read((char *)&n_ctx, sizeof(n_ctx));
|
||||
if (n_ctx > llama_n_ctx(ctx)) {
|
||||
fprintf(stderr, "%s: %s has been computed with %d, while the current context is %d. Increase it with -c and retry\n",
|
||||
__func__, params.logits_file.c_str(), n_ctx, params.n_ctx);
|
||||
}
|
||||
|
||||
int n_vocab, n_chunk;
|
||||
in.read((char *)&n_vocab, sizeof(n_vocab));
|
||||
in.read((char *)&n_chunk, sizeof(n_chunk));
|
||||
if (in.fail()) {
|
||||
fprintf(stderr, "%s: failed rwading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
|
||||
return;
|
||||
}
|
||||
if (n_vocab != llama_n_vocab(llama_get_model(ctx))) {
|
||||
fprintf(stderr, "%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx)));
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens(n_ctx * n_chunk);
|
||||
if (in.read((char *)tokens.data(), tokens.size()*sizeof(tokens[0])).fail()) {
|
||||
fprintf(stderr, "%s: failed reading evaluation tokens from %s\n", __func__, params.logits_file.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
const int n_batch = params.n_batch;
|
||||
const int num_batches = (n_ctx + n_batch - 1)/n_batch;
|
||||
const int nv = 2*((n_vocab + 1)/2) + 4;
|
||||
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
||||
|
||||
std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
|
||||
std::vector<float> logits;
|
||||
if (num_batches > 1) {
|
||||
logits.reserve(n_ctx * n_vocab);
|
||||
}
|
||||
|
||||
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
||||
|
||||
auto mean_and_uncertainty = [] (double sum, double sum2, size_t count) {
|
||||
if (count < 1) {
|
||||
return std::make_pair(0., 0.);
|
||||
}
|
||||
double f = sum/count;
|
||||
double df = sum2/count - f*f;
|
||||
df = df > 0 && count > 10 ? sqrt(df/(count-1)) : 0.;
|
||||
return std::make_pair(f, df);
|
||||
};
|
||||
|
||||
kl_divergence_result kld;
|
||||
|
||||
for (int i = 0; i < n_chunk; ++i) {
|
||||
const int start = i * n_ctx;
|
||||
const int end = start + n_ctx;
|
||||
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
|
||||
fprintf(stderr, "%s: failed reading log-probs for chunk %d\n", __func__, i);
|
||||
return;
|
||||
}
|
||||
|
||||
// clear the KV cache
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
const int batch_size = std::min(end - batch_start, n_batch);
|
||||
|
||||
// save original token and restore it after eval
|
||||
const auto token_org = tokens[batch_start];
|
||||
|
||||
// add BOS token for the first batch of each chunk
|
||||
if (add_bos && j == 0) {
|
||||
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
// restore the original token in case it was set to BOS
|
||||
tokens[batch_start] = token_org;
|
||||
|
||||
if (num_batches > 1) {
|
||||
const auto * batch_logits = llama_get_logits(ctx);
|
||||
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
||||
}
|
||||
}
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
if (i == 0) {
|
||||
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
||||
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
||||
int total_seconds = (int)(t_total * n_chunk);
|
||||
if (total_seconds >= 60*60) {
|
||||
fprintf(stderr, "%d hours ", total_seconds / (60*60));
|
||||
total_seconds = total_seconds % (60*60);
|
||||
}
|
||||
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
|
||||
|
||||
printf("\nchunk PPL ln(PPL(Q)/PPL(base)) KL-Divergence\n");
|
||||
}
|
||||
|
||||
const int first = n_ctx/2;
|
||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
|
||||
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
||||
workers, log_probs_uint16, kld);
|
||||
|
||||
auto ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
|
||||
auto log_ppl_ratio = mean_and_uncertainty(kld.sum_nll_diff, kld.sum_nll_diff2, kld.count);
|
||||
auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
|
||||
|
||||
printf("%4d %10.4lf %10.5lf ± %10.5f %10.5f ± %10.5lf\n", i+1, exp(ppl.first),
|
||||
log_ppl_ratio.first, log_ppl_ratio.second, kl_div.first, kl_div.second);
|
||||
|
||||
fflush(stdout);
|
||||
|
||||
logits.clear();
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
@ -1476,6 +1789,8 @@ int main(int argc, char ** argv) {
|
||||
winogrande_score(ctx, params);
|
||||
} else if (params.multiple_choice) {
|
||||
multiple_choice_score(ctx, params);
|
||||
} else if (params.kl_divergence) {
|
||||
kl_divergence(ctx, params);
|
||||
} else {
|
||||
results = perplexity(ctx, params);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user