mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 05:17:21 +01:00
Perplexity: Compute scores correlated to HellaSwag (#2312)
* Add parameter --perplexity-lines to perplexity.cpp
This commit is contained in:
parent
24baa54ac1
commit
b5fe67f8c6
@ -387,6 +387,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||||||
params.antiprompt.push_back(argv[i]);
|
params.antiprompt.push_back(argv[i]);
|
||||||
} else if (arg == "--perplexity") {
|
} else if (arg == "--perplexity") {
|
||||||
params.perplexity = true;
|
params.perplexity = true;
|
||||||
|
} else if (arg == "--perplexity-lines") {
|
||||||
|
params.perplexity_lines = true;
|
||||||
} else if (arg == "--ignore-eos") {
|
} else if (arg == "--ignore-eos") {
|
||||||
params.logit_bias[llama_token_eos()] = -INFINITY;
|
params.logit_bias[llama_token_eos()] = -INFINITY;
|
||||||
} else if (arg == "--no-penalize-nl") {
|
} else if (arg == "--no-penalize-nl") {
|
||||||
@ -512,7 +514,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||||||
fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n");
|
fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n");
|
||||||
fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp);
|
fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp);
|
||||||
fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||||
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
|
fprintf(stderr, " --perplexity compute perplexity over each ctx window of the prompt\n");
|
||||||
|
fprintf(stderr, " --perplexity-lines compute perplexity over each line of the prompt\n");
|
||||||
fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
|
fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
|
||||||
fprintf(stderr, " --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
|
fprintf(stderr, " --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
|
||||||
if (llama_mlock_supported()) {
|
if (llama_mlock_supported()) {
|
||||||
|
@ -82,6 +82,7 @@ struct gpt_params {
|
|||||||
bool instruct = false; // instruction mode (used for Alpaca models)
|
bool instruct = false; // instruction mode (used for Alpaca models)
|
||||||
bool penalize_nl = true; // consider newlines as a repeatable token
|
bool penalize_nl = true; // consider newlines as a repeatable token
|
||||||
bool perplexity = false; // compute perplexity over the prompt
|
bool perplexity = false; // compute perplexity over the prompt
|
||||||
|
bool perplexity_lines = false; // compute perplexity over each line of the prompt
|
||||||
bool use_mmap = true; // use mmap for faster loads
|
bool use_mmap = true; // use mmap for faster loads
|
||||||
bool use_mlock = false; // use mlock to keep model in memory
|
bool use_mlock = false; // use mlock to keep model in memory
|
||||||
bool mem_test = false; // compute maximum memory usage
|
bool mem_test = false; // compute maximum memory usage
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <ctime>
|
#include <ctime>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
@ -120,6 +121,77 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void perplexity_lines(llama_context * ctx, const gpt_params & params) {
|
||||||
|
// Calculates perplexity over each line of the prompt
|
||||||
|
|
||||||
|
std::vector<std::string> prompt_lines;
|
||||||
|
std::istringstream strstream(params.prompt);
|
||||||
|
std::string line;
|
||||||
|
|
||||||
|
while (std::getline(strstream,line,'\n')) {
|
||||||
|
prompt_lines.push_back(line);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int n_vocab = llama_n_vocab(ctx);
|
||||||
|
|
||||||
|
int counttotal = 0;
|
||||||
|
size_t n_lines = prompt_lines.size();
|
||||||
|
|
||||||
|
double nll = 0.0;
|
||||||
|
|
||||||
|
fprintf(stderr, "%s: calculating perplexity over %lu lines\n", __func__, n_lines);
|
||||||
|
|
||||||
|
printf("\nLine\tPPL line\tPPL cumulative\n");
|
||||||
|
|
||||||
|
for (size_t i = 0; i < n_lines; ++i) {
|
||||||
|
|
||||||
|
// Tokenize and insert BOS at start
|
||||||
|
std::vector<int> batch_embd = ::llama_tokenize(ctx, prompt_lines[i], true);
|
||||||
|
|
||||||
|
size_t batch_size = batch_embd.size();
|
||||||
|
|
||||||
|
// Stop if line is too long
|
||||||
|
if( batch_size > (size_t)params.n_ctx ) {
|
||||||
|
fprintf(stderr, "%s : tokens in line %lu > n_ctxl\n", __func__, i);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_eval(ctx, batch_embd.data(), batch_size, 0, params.n_threads)) {
|
||||||
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto batch_logits = llama_get_logits(ctx);
|
||||||
|
std::vector<float> logits;
|
||||||
|
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
||||||
|
|
||||||
|
double nllline = 0.0;
|
||||||
|
int countline = 0;
|
||||||
|
|
||||||
|
// Perplexity over second half of the line
|
||||||
|
for (size_t j = batch_size/2; j < batch_size - 1; ++j) {
|
||||||
|
// Calculate probability of next token, given the previous ones.
|
||||||
|
const std::vector<float> tok_logits(
|
||||||
|
logits.begin() + (j + 0) * n_vocab,
|
||||||
|
logits.begin() + (j + 1) * n_vocab);
|
||||||
|
|
||||||
|
const float prob = softmax(tok_logits)[batch_embd[ j + 1]];
|
||||||
|
|
||||||
|
nllline += -std::log(prob);
|
||||||
|
++countline;
|
||||||
|
}
|
||||||
|
|
||||||
|
nll += nllline;
|
||||||
|
counttotal += countline;
|
||||||
|
|
||||||
|
// perplexity is e^(average negative log-likelihood)
|
||||||
|
printf("%lu\t%.8lf\t%.8lf\n", i + 1, std::exp(nllline/countline), std::exp(nll / counttotal) );
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
@ -168,7 +240,11 @@ int main(int argc, char ** argv) {
|
|||||||
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
||||||
}
|
}
|
||||||
|
|
||||||
perplexity(ctx, params);
|
if (params.perplexity_lines) {
|
||||||
|
perplexity_lines(ctx, params);
|
||||||
|
} else {
|
||||||
|
perplexity(ctx, params);
|
||||||
|
}
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user