diff --git a/common/common.h b/common/common.h index 0618eea74..9252a4b63 100644 --- a/common/common.h +++ b/common/common.h @@ -135,7 +135,7 @@ struct gpt_params { bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed - bool kl_divergence = false; // compute KL-divergence + bool kl_divergence = false; // compute KL divergence bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs diff --git a/examples/perplexity/README.md b/examples/perplexity/README.md index 1a8c0dd64..c5e2bc5de 100644 --- a/examples/perplexity/README.md +++ b/examples/perplexity/README.md @@ -1,8 +1,118 @@ -# perplexity +# Perplexity -TODO +The `perplexity` example can be used to calculate the so-called perplexity value of a language model over a given text corpus. +Perplexity measures how well the model can predict the next token with lower values being better. +Note that perplexity is **not** directly comparable between models, especially if they use different tokenizers. +Also note that finetunes typically result in a higher perplexity value even though the human-rated quality of outputs increases. + +Within llama.cpp the perplexity of base models is used primarily to judge the quality loss from e.g. quantized models vs. FP16. +The convention among contributors is to use the Wikitext-2 test set for testing unless noted otherwise (can be obtained with `scripts/get-wikitext-2.sh`). + +By default only the mean perplexity value and the corresponding uncertainty is calculated. +The uncertainty is determined empirically by assuming a Gaussian distribution of the "correct" logits per and then applying error propagation. + +More statistics can be obtained by recording the logits from the FP16 version of a model. +To do this, supply `perplexity` with `--kl-divergence-base path/to/logit/binary/file.kld`. +The program will then record all logits and save them to the provided path in binary format. +**The logit file will be very large, 11 GiB for LLaMA 2 or 37 GiB for LLaMA 3 when using the Wikitext-2 test set.** +Once you have the file, supply `perplexity` with the quantized model, the logits file via `--kl-divergence-base`, +and finally the `--kl-divergence` argument to indicate that the program should calculate the so-called Kullback-Leibler divergence. +This is a measure of how similar the FP16 and the quantized logit distributions are with a value of 0 indicating that the distribution are the same. +The uncertainty on the mean KL divergence is calculated by assuming the KL divergence per token follows a Gaussian distribution. + +In addition to the KL divergence the following statistics are calculated with `--kl-divergence`: + +* Ratio of mean FP16 PPL and quantized PPL. Uncertainty is estimated on logits, then propagated. The logarithm of this metric is also calculated and printed, it is 0 if the logit distributions are the same. +* Difference of mean FP16 PPL and quantized PPL. Uncertainty is estimated on logits, then propagated. +* Mean change in "correct" token probability. Positive values mean the model gets better at prediction, negative values mean it gets worse. +* Pearson correlation coefficient of the "correct" token probabilites between models. +* Percentiles of change in "correct" token probability. Positive values mean the model gets better at prediction, negative values mean it gets worse. Can be used to judge noise vs. quality loss from quantization. If the percentiles are symmetric then the quantization is essentially just adding noise. If the negative values are significantly larger than the positive values then this indicates that the model is actually becoming worse from the quantization. +* The root mean square of the change in token probabilities. If you were to assume that the quantization simply causes Gaussian noise on the token probabilities then this would be the standard deviation of said noise. The uncertainty on the value is calculated that the change in token probabilities follows a Gaussian distribution. Related discussion: https://github.com/ggerganov/llama.cpp/discussions/2875 . +* Same top p: Percentage of how often the token was assigned the highest probabilites by both models. The uncertainty is calculated from the Gaussian approximation of the binomial distribution. + +## LLaMA 3 8b Scoreboard + +Results are sorted by Kullback-Leibler divergence relative to FP16. +The "WT" importance matrices were created using varying numbers of Wikitext tokens and can be found [here](https://huggingface.co/JohannesGaessler/llama.cpp_importance_matrices/blob/main/imatrix-llama_3-8b-f16-2.7m_tokens.dat). + +| Quantization | imatrix | Model size [GiB] | PPL | ΔPPL | KLD | Mean Δp | RMS Δp | +|--------------|---------|------------------|------------------------|------------------------|-----------------------|-------------------|------------------| +| f16 | None | 14.97 | 6.233160 ± 0.037828 | - | - | - | - | +| q8_0 | None | 7.96 | 6.234284 ± 0.037878 | 0.002650 ± 0.001006 | 0.001355 ± 0.000006 | -0.019 ± 0.003 % | 1.198 ± 0.007 % | +| q6_K | None | 6.14 | 6.253382 ± 0.038078 | 0.021748 ± 0.001852 | 0.005452 ± 0.000035 | -0.007 ± 0.006 % | 2.295 ± 0.019 % | +| q5_K_M | None | 5.33 | 6.288607 ± 0.038338 | 0.056974 ± 0.002598 | 0.010762 ± 0.000079 | -0.114 ± 0.008 % | 3.160 ± 0.031 % | +| q5_K_S | None | 5.21 | 6.336598 ± 0.038755 | 0.104964 ± 0.003331 | 0.016595 ± 0.000122 | -0.223 ± 0.010 % | 3.918 ± 0.036 % | +| q5_1 | None | 5.65 | 6.337857 ± 0.038677 | 0.106223 ± 0.003476 | 0.018045 ± 0.000139 | -0.287 ± 0.011 % | 4.123 ± 0.039 % | +| q5_0 | None | 5.21 | 6.363224 ± 0.038861 | 0.131591 ± 0.003894 | 0.022239 ± 0.000166 | -0.416 ± 0.012 % | 4.634 ± 0.043 % | +| q4_K_M | WT 10m | 4.58 | 6.382937 ± 0.039055 | 0.151303 ± 0.004429 | 0.028152 ± 0.000240 | -0.389 ± 0.014 % | 5.251 ± 0.049 % | +| q4_K_M | None | 4.58 | 6.407115 ± 0.039119 | 0.175482 ± 0.004620 | 0.031273 ± 0.000238 | -0.596 ± 0.014 % | 5.519 ± 0.050 % | +| q4_K_S | WT 10m | 4.37 | 6.409697 ± 0.039189 | 0.178064 ± 0.004744 | 0.031951 ± 0.000259 | -0.531 ± 0.015 % | 5.645 ± 0.051 % | +| iq4_NL | WT 10m | 4.35 | 6.455593 ± 0.039630 | 0.223959 ± 0.005201 | 0.035742 ± 0.000288 | -0.590 ± 0.016 % | 5.998 ± 0.054 % | +| iq4_XS | WT 10m | 4.14 | 6.459705 ± 0.039595 | 0.228071 ± 0.005207 | 0.036334 ± 0.000284 | -0.668 ± 0.016 % | 6.044 ± 0.054 % | +| q4_K_S | None | 4.37 | 6.500529 ± 0.039778 | 0.268895 ± 0.005638 | 0.043136 ± 0.000314 | -0.927 ± 0.017 % | 6.562 ± 0.055 % | +| q4_1 | None | 4.78 | 6.682737 ± 0.041285 | 0.451103 ± 0.008030 | 0.071683 ± 0.000505 | -0.927 ± 0.017 % | 8.512 ± 0.063 % | +| q4_0 | None | 4.34 | 6.700147 ± 0.041226 | 0.468514 ± 0.007951 | 0.071940 ± 0.000491 | -1.588 ± 0.022 % | 8.434 ± 0.061 % | +| q3_K_L | WT 10m | 4.03 | 6.671223 ± 0.041427 | 0.439590 ± 0.008154 | 0.073077 ± 0.000529 | -0.940 ± 0.023 % | 8.662 ± 0.064 % | +| q3_K_M | WT 10m | 3.74 | 6.734255 ± 0.041838 | 0.502622 ± 0.008901 | 0.084358 ± 0.000588 | -1.198 ± 0.024 % | 9.292 ± 0.065 % | +| q3_K_L | None | 4.03 | 6.787876 ± 0.042104 | 0.556242 ± 0.009171 | 0.087176 ± 0.000614 | -1.532 ± 0.025 % | 9.432 ± 0.067 % | +| q3_K_M | None | 3.74 | 6.888498 ± 0.042669 | 0.656864 ± 0.010071 | 0.101913 ± 0.000677 | -1.990 ± 0.026 % | 10.203 ± 0.068 % | +| iq3_M | WT 10m | 3.53 | 6.898327 ± 0.041643 | 0.666694 ± 0.009449 | 0.102534 ± 0.000663 | -3.178 ± 0.026 % | 10.513 ± 0.066 % | +| iq3_S | WT 10m | 3.42 | 6.965501 ± 0.042406 | 0.733867 ± 0.010245 | 0.111278 ± 0.000710 | -3.066 ± 0.027 % | 10.845 ± 0.068 % | +| iq3_XS | WT 10m | 3.28 | 7.163043 ± 0.043772 | 0.931409 ± 0.012084 | 0.138693 ± 0.000857 | -3.667 ± 0.031 % | 12.148 ± 0.070 % | +| iq3_XXS | WT 10m | 3.05 | 7.458436 ± 0.046404 | 1.226803 ± 0.015234 | 0.183625 ± 0.001042 | -3.918 ± 0.035 % | 13.836 ± 0.074 % | +| q3_K_S | WT 10m | 3.41 | 7.602878 ± 0.046848 | 1.371244 ± 0.015688 | 0.199821 ± 0.001008 | -5.046 ± 0.037 % | 14.980 ± 0.070 % | +| q3_K_S | None | 3.41 | 7.863786 ± 0.048885 | 1.632152 ± 0.017733 | 0.228217 ± 0.001079 | -5.604 ± 0.038 % | 15.541 ± 0.070 % | +| iq2_M | WT 10m | 2.74 | 8.600799 ± 0.055124 | 2.369166 ± 0.025244 | 0.325989 ± 0.00160 | -6.463 ± 0.046 % | 18.519 ± 0.080 % | +| q2_K | WT 10k | 2.96 | 8.652290 ± 0.055572 | 2.420657 ± 0.025587 | 0.331393 ± 0.001562 | -6.606 ± 0.046 % | 18.790 ± 0.078 % | +| q2_K | WT 100k | 2.96 | 8.641993 ± 0.055406 | 2.410359 ± 0.025495 | 0.331672 ± 0.001569 | -6.628 ± 0.047 % | 18.856 ± 0.078 % | +| q2_K | WT 10m | 2.96 | 8.647825 ± 0.055610 | 2.416191 ± 0.025683 | 0.332223 ± 0.001572 | -6.500 ± 0.047 % | 18.881 ± 0.078 % | +| q2_K | WT 1m | 2.96 | 8.674365 ± 0.055743 | 2.442732 ± 0.025843 | 0.335308 ± 0.001576 | -6.634 ± 0.047 % | 19.009 ± 0.079 % | +| q2_K | WT 1k | 2.96 | 8.682605 ± 0.055916 | 2.450972 ± 0.026069 | 0.337093 ± 0.001596 | -6.596 ± 0.047 % | 18.977 ± 0.079 % | +| q2_K_S | WT 10m | 2.96 | 9.323778 ± 0.061551 | 3.092145 ± 0.031914 | 0.403360 ± 0.001787 | -7.131 ± 0.049 % | 20.050 ± 0.081 % | +| q2_K_S | WT 1m | 2.96 | 9.329321 ± 0.061378 | 3.097688 ± 0.031816 | 0.403590 ± 0.001797 | -7.289 ± 0.049 % | 20.123 ± 0.081 % | +| q2_K_S | WT 100k | 2.96 | 9.362973 ± 0.061740 | 3.131339 ± 0.032169 | 0.408367 ± 0.001802 | -7.198 ± 0.050 % | 20.132 ± 0.081 % | +| q2_K_S | WT 10k | 2.96 | 9.376479 ± 0.062045 | 3.144846 ± 0.032464 | 0.408662 ± 0.001819 | -7.141 ± 0.050 % | 20.120 ± 0.081 % | +| q2_K_S | WT 1k | 2.96 | 9.415200 ± 0.062475 | 3.183567 ± 0.032993 | 0.415865 ± 0.001846 | -7.153 ± 0.050 % | 20.311 ± 0.082 % | +| iq2_S | WT 10m | 2.56 | 9.650781 ± 0.063209 | 3.419148 ± 0.034017 | 0.439197 ± 0.001976 | -8.319 ± 0.052 % | 21.491 ± 0.083 % | +| q2_K | None | 2.96 | 9.751568 ± 0.063312 | 3.519934 ± 0.033863 | 0.445132 ± 0.001835 | -9.123 ± 0.051 % | 21.421 ± 0.079 % | +| iq2_XS | WT 10m | 2.43 | 10.761424 ± 0.071056 | 4.529791 ± 0.042229 | 0.546290 ± 0.002133 | -10.576 ± 0.056 % | 23.872 ± 0.082 % | +| iq2_XXS | WT 10m | 2.24 | 14.091782 ± 0.098396 | 7.860148 ± 0.070752 | 0.812022 ± 0.002741 | -14.363 ± 0.065 % | 28.576 ± 0.084 % | +| iq1_M | WT 10m | 2.01 | 25.493722 ± 0.177903 | 19.262089 ± 0.152396 | 1.393084 ± 0.003529 | -24.672 ± 0.077 % | 38.287 ± 0.084 % | +| iq1_S | WT 1m | 1.88 | 58.097760 ± 0.438604 | 51.866126 ± 0.416604 | 2.211278 ± 0.004688 | -32.471 ± 0.087 % | 46.418 ± 0.085 % | +| iq1_S | WT 1k | 1.88 | 58.267851 ± 0.446208 | 52.036218 ± 0.424373 | 2.214858 ± 0.004778 | -31.880 ± 0.089 % | 46.330 ± 0.086 % | +| iq1_S | WT 100k | 1.88 | 58.581498 ± 0.453145 | 52.349864 ± 0.431360 | 2.220834 ± 0.004818 | -32.261 ± 0.089 % | 46.002 ± 0.086 % | +| iq1_S | WT 10m | 1.88 | 60.694593 ± 0.471290 | 54.462959 ± 0.449644 | 2.254554 ± 0.004868 | -31.973 ± 0.088 % | 46.271 ± 0.086 % | +| iq1_S | WT 10k | 1.88 | 63.221324 ± 0.493077 | 56.989691 ± 0.471423 | 2.293527 ± 0.004885 | -32.261 ± 0.089 % | 46.562 ± 0.086 % | + +There seems to be no consistent improvement from using more Wikitext tokens for the importance matrix. +K-quants score better on mean Δp than the legacy quants than e.g. KL divergence would suggest. + +## LLaMA 2 vs. LLaMA 3 Quantization comparison + +| Metric | L2 7b q2_K | L3 8b q2_K | L2 7b q4_K_M | L3 8b q4_K_M | L2 7b q6_K | L3 8b q6_K | L2 7b q8_0 | L3 8b q8_0 | +|-----------------|---------------------|---------------------|---------------------|---------------------|---------------------|---------------------|---------------------|---------------------| +| Mean PPL | 5.794552 ± 0.032298 | 9.751568 ± 0.063312 | 5.877078 ± 0.032781 | 6.407115 ± 0.039119 | 5.808494 ± 0.032425 | 6.253382 ± 0.038078 | 5.798542 ± 0.032366 | 6.234284 ± 0.037878 | +| Mean PPL ratio | 1.107955 ± 0.001427 | 1.564849 ± 0.004525 | 1.014242 ± 0.000432 | 1.028160 ± 0.000723 | 1.002406 ± 0.000191 | 1.003490 ± 0.000296 | 1.000689 ± 0.000107 | 1.000425 ± 0.000161 | +| Mean ΔPPL | 0.625552 ± 0.008725 | 3.519934 ± 0.033863 | 0.082526 ± 0.002530 | 0.175482 ± 0.004620 | 0.013941 ± 0.001110 | 0.021748 ± 0.001852 | 0.003990 ± 0.000624 | 0.002650 ± 0.001006 | +| PPL correlation | 97.36% | 89.62% | 99.71% | 99.34% | 99.94% | 99.88% | 99.98% | 99.96% | +| Mean KLD | 0.108903 ± 0.000645 | 0.445132 ± 0.001835 | 0.012686 ± 0.000079 | 0.031273 ± 0.000238 | 0.002098 ± 0.000014 | 0.005452 ± 0.000035 | 0.000369 ± 0.000007 | 0.001355 ± 0.000006 | +| Mean Δp | -2.710 ± 0.023 % | -9.123 ± 0.051 % | -0.416 ± 0.008 % | -0.596 ± 0.014 % | -0.035 ± 0.003 % | -0.007 ± 0.006 % | -0.005 ± 0.002 % | -0.019 ± 0.003 % | +| Maximum Δp | 85.136% | 94.268% | 45.209% | 95.054% | 23.593% | 53.601% | 43.925% | 28.734% | +| 99.9% Δp | 37.184% | 50.003% | 17.461% | 27.084% | 7.798% | 13.613% | 3.387% | 6.402% | +| 99.0% Δp | 18.131% | 25.875% | 7.798% | 12.084% | 3.838% | 6.407% | 1.867% | 3.544% | +| Median Δp | -0.391% | -2.476% | -0.026% | -0.024% | -0.001% | 0.000% | -0.000% | -0.000% | +| 1.0% Δp | -39.762% | -87.173% | -11.433% | -19.567% | -4.222% | -6.767% | -1.862% | -3.698% | +| 0.1% Δp | -79.002% | -98.897% | -26.433% | -56.054% | -9.091% | -16.584% | -3.252% | -6.579% | +| Minimum Δp | -99.915% | -99.965% | -83.383% | -98.699% | -43.142% | -68.487% | -9.343% | -24.301% | +| RMS Δp | 9.762 ± 0.053 % | 21.421 ± 0.079 % | 3.252 ± 0.024 % | 5.519 ± 0.050 % | 1.339 ± 0.010 % | 2.295 ± 0.019 % | 0.618 ± 0.011 % | 1.198 ± 0.007 % | +| Same top p | 85.584 ± 0.086 % | 71.138 ± 0.119 % | 94.665 ± 0.055 % | 91.901 ± 0.072 % | 97.520 ± 0.038 % | 96.031 ± 0.051 % | 98.846 ± 0.026 % | 97.674 ± 0.040 % | + + +## Old Numbers + +
+Llama 2 70B Scoreboard -## Llama 2 70B Scorechart | Quantization | Model size (GiB) | Perplexity | Delta to fp16 | |--------------|------------------|------------|---------------| | Q4_0 | 36.20 | 3.5550 | 3.61% | @@ -18,3 +128,5 @@ TODO | Q5_K_M | 45.41 | 3.4451 | 0.40% | | Q6_K | 52.70 | 3.4367 | 0.16% | | fp16 | 128.5 | 3.4313 | - | + +
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 9a3374fdc..db6e0949d 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -216,17 +216,22 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits, } struct kl_divergence_result { - double sum_nll = 0; - double sum_nll2 = 0; - double sum_kld = 0; - double sum_kld2 = 0; - double sum_nll_diff = 0; - double sum_nll_diff2 = 0; - size_t n_same_top = 0; - size_t count = 0; + double sum_nll = 0.0; + double sum_nll2 = 0.0; + double sum_nll_base = 0.0; + double sum_nll_base2 = 0.0; + double sum_nll_nll_base = 0.0; + double sum_kld = 0.0; + double sum_kld2 = 0.0; + double sum_p_diff = 0.0; + double sum_p_diff2 = 0.0; + double sum_p_diff4 = 0.0; + float max_p_diff = 0.0f; + size_t n_same_top = 0.0; + size_t count = 0.0; }; -static double log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) { +static std::pair log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) { float max_logit = logits[0]; int imax = 0; for (int i = 1; i < n_vocab; ++i) { @@ -244,12 +249,17 @@ static double log_softmax(int n_vocab, const float * logits, const uint16_t * ba const float scale = d[0]; const float min_log_prob = d[1]; base_log_prob += 4; - float nll = max_logit + log_sum_exp - logits[tok]; + + const float nll = max_logit + log_sum_exp - logits[tok]; kld.sum_nll += nll; kld.sum_nll2 += nll*nll; - nll += (scale*base_log_prob[tok] + min_log_prob); - kld.sum_nll_diff += nll; - kld.sum_nll_diff2 += nll*nll; + + const float nll_base = -(scale*base_log_prob[tok] + min_log_prob); + kld.sum_nll_base += nll_base; + kld.sum_nll_base2 += nll_base*nll_base; + + kld.sum_nll_nll_base += nll*nll_base; + max_logit += log_sum_exp; double sum = 0; int imax_base = -1; @@ -269,34 +279,50 @@ static double log_softmax(int n_vocab, const float * logits, const uint16_t * ba kld.sum_kld2 += sum*sum; ++kld.count; if (imax == imax_base) ++kld.n_same_top; - return sum; + + const float p_base = expf(-nll_base); + const float p = expf(-nll); + const float p_diff = p - p_base; + kld.sum_p_diff += p_diff; + const double p_diff2 = p_diff*p_diff; + kld.sum_p_diff2 += p_diff2; + kld.sum_p_diff4 += p_diff2*p_diff2; + kld.max_p_diff = std::max(kld.max_p_diff, std::fabs(p_diff)); + + return std::make_pair(sum, p_diff); } static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token, std::vector & workers, const std::vector & base_log_probs, kl_divergence_result & kld, - float * kld_values) { + float * kld_values, float * p_diff_values) { std::mutex mutex; const int nv = 2*((n_vocab + 1)/2) + 4; int counter = 0; - auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv, kld_values] () { + auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv, kld_values, p_diff_values] () { kl_divergence_result local_kld; while (true) { std::unique_lock lock(mutex); int i = counter++; if (i >= n_token) { - kld.sum_nll += local_kld.sum_nll; - kld.sum_nll2 += local_kld.sum_nll2; - kld.sum_kld += local_kld.sum_kld; - kld.sum_kld2 += local_kld.sum_kld2; - kld.sum_nll_diff += local_kld.sum_nll_diff; - kld.sum_nll_diff2 += local_kld.sum_nll_diff2; - kld.n_same_top += local_kld.n_same_top; - kld.count += local_kld.count; + kld.sum_nll += local_kld.sum_nll; + kld.sum_nll2 += local_kld.sum_nll2; + kld.sum_nll_base += local_kld.sum_nll_base; + kld.sum_nll_base2 += local_kld.sum_nll_base2; + kld.sum_nll_nll_base += local_kld.sum_nll_nll_base; + kld.sum_kld += local_kld.sum_kld; + kld.sum_kld2 += local_kld.sum_kld2; + kld.sum_p_diff += local_kld.sum_p_diff; + kld.sum_p_diff2 += local_kld.sum_p_diff2; + kld.sum_p_diff4 += local_kld.sum_p_diff4; + kld.n_same_top += local_kld.n_same_top; + kld.max_p_diff = std::max(kld.max_p_diff, local_kld.max_p_diff); + kld.count += local_kld.count; break; } lock.unlock(); - double v = log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld); - kld_values[i] = (float)v; + std::pair v = log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld); + kld_values[i] = (float)v.first; + p_diff_values[i] = v.second; } }; for (auto & w : workers) { @@ -1711,7 +1737,8 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1); std::vector log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv); - std::vector kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); + std::vector kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); + std::vector p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); std::vector logits; if (num_batches > 1) { logits.reserve(n_ctx * n_vocab); @@ -1728,9 +1755,18 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { df = df > 0 && count > 10 ? sqrt(df/(count-1)) : 0.; return std::make_pair(f, df); }; + auto covariance = [] (double suma, double sumb, double sumab, size_t count) { + if (count < 10) { + return 0.0; + } + double var = sumab/count - (suma/count)*(sumb/count); + var /= count - 1; + return var; + }; kl_divergence_result kld; - auto kld_ptr = kld_values.data(); + auto kld_ptr = kld_values.data(); + auto p_diff_ptr = p_diff_values.data(); for (int i = 0; i < n_chunk; ++i) { const int start = i * n_ctx; @@ -1785,24 +1821,42 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { } fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0); - printf("\nchunk PPL ln(PPL(Q)/PPL(base)) KL-Divergence Same top\n"); + printf("\nchunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n"); } const int first = n_ctx/2; const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, - workers, log_probs_uint16, kld, kld_ptr); - kld_ptr += n_ctx - 1 - first; + workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr); + p_diff_ptr += n_ctx - 1 - first; + kld_ptr += n_ctx - 1 - first; - auto ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count); - auto log_ppl_ratio = mean_and_uncertainty(kld.sum_nll_diff, kld.sum_nll_diff2, kld.count); - auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count); - auto p_top = 1.*kld.n_same_top/kld.count; - auto d_p_top = sqrt(p_top*(1 - p_top)/(kld.count - 1)); + printf("%4d", i+1); - printf("%4d %10.4lf %10.5lf ± %10.5f %10.5f ± %10.5lf %.5f ± %.5f\n", i+1, exp(ppl.first), - log_ppl_ratio.first, log_ppl_ratio.second, kl_div.first, kl_div.second, - p_top, d_p_top); + auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count); + const double ppl_val = exp(log_ppl.first); + const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 ) + printf(" %9.4lf ± %9.4lf", ppl_val, ppl_unc); + + auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count); + const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count); + const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first; + const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov); + printf(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc); + + auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count); + printf(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second); + + auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count); + const double p_diff_rms_val = sqrt(p_diff_mse.first); + const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second; + printf(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc); + + double p_top_val = 1.*kld.n_same_top/kld.count; + double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1)); + printf(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc); + + printf("\n"); fflush(stdout); @@ -1813,31 +1867,97 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { if (kld.count < 100) return; // we do not wish to do statistics on so few values std::sort(kld_values.begin(), kld_values.end()); + std::sort(p_diff_values.begin(), p_diff_values.end()); - printf("===== KL-divergence statistics\n"); + printf("====== Perplexity statistics ======\n"); + + auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count); + const double ppl_val = exp(log_ppl.first); + const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 ) + printf("Mean PPL(Q) : %10.6lf ± %10.6lf\n", ppl_val, ppl_unc); + + auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count); + const double ppl_base_val = exp(log_ppl_base.first); + const double ppl_base_unc = ppl_base_val * log_ppl_base.second; // ppl_base_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl_base.second ** 2 ) + printf("Mean PPL(base) : %10.6lf ± %10.6lf\n", ppl_base_val, ppl_base_unc); + + const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count); + // printf("Cov(ln(PPL(Q)), ln(PPL(base))): %10.6lf\n", log_ppl_cov); + const double log_ppl_cor = log_ppl_cov / (log_ppl.second*log_ppl_base.second); + printf("Cor(ln(PPL(Q)), ln(PPL(base))): %6.2lf%%\n", 100.0*log_ppl_cor); + + const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first; + const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov); + printf("Mean ln(PPL(Q)/PPL(base)) : %10.6lf ± %10.6lf\n", log_ppl_ratio_val, log_ppl_ratio_unc); + + const double ppl_ratio_val = exp(log_ppl_ratio_val); + const double ppl_ratio_unc = ppl_ratio_val * log_ppl_ratio_unc; // ppl_ratio_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl_ratio.second ** 2 ) + printf("Mean PPL(Q)/PPL(base) : %10.6lf ± %10.6lf\n", ppl_ratio_val, ppl_ratio_unc); + + const double ppl_cov = ppl_val * ppl_base_val * log_ppl_cov; + const double ppl_diff_val = ppl_val - ppl_base_val; + const double ppl_diff_unc = sqrt(ppl_unc*ppl_unc + ppl_base_unc*ppl_base_unc - 2.0*ppl_cov); + printf("Mean PPL(Q)-PPL(base) : %10.6lf ± %10.6lf\n", ppl_diff_val, ppl_diff_unc); + + printf("\n"); + + printf("====== KL divergence statistics ======\n"); auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count); - printf("Average: %10.6f ±%10.6lf\n", kl_div.first, kl_div.second); + printf("Mean KLD: %10.6lf ± %10.6lf\n", kl_div.first, kl_div.second); auto kld_median = kld_values.size()%2 == 0 ? 0.5f*(kld_values[kld_values.size()/2] + kld_values[kld_values.size()/2-1]) : kld_values[kld_values.size()/2]; - printf("Median : %10.6f\n", kld_median); - auto percentile = [&kld_values] (float fraction) { - if (fraction <= 0) return kld_values.front(); - if (fraction >= 1) return kld_values.back(); - float p = fraction*(kld_values.size() - 1); + auto percentile = [] (std::vector values, float fraction) { + if (fraction <= 0) return values.front(); + if (fraction >= 1) return values.back(); + float p = fraction*(values.size() - 1); size_t ip = size_t(p); p -= ip; - return (1 - p)*kld_values[ip] + p*kld_values[std::min(ip+1, kld_values.size()-1)]; + return (1 - p)*values[ip] + p*values[std::min(ip+1, values.size()-1)]; }; - printf("Maximum: %10.6f\n", kld_values.back()); - printf("KLD_99 : %10.6f\n", percentile(0.99f)); - printf("KLD_95 : %10.6f\n", percentile(0.95f)); - printf("KLD_90 : %10.6f\n", percentile(0.90f)); + printf("Maximum KLD: %10.6f\n", kld_values.back()); + printf("99.9%% KLD: %10.6f\n", percentile(kld_values, 0.999f)); + printf("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f)); + printf("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f)); + printf("Median KLD: %10.6f\n", kld_median); + printf("10.0%% KLD: %10.6f\n", percentile(kld_values, 0.100f)); + printf(" 5.0%% KLD: %10.6f\n", percentile(kld_values, 0.050f)); + printf(" 1.0%% KLD: %10.6f\n", percentile(kld_values, 0.010f)); + printf("Minimum KLD: %10.6f\n", kld_values.front()); - printf("Minimum: %10.6f\n", kld_values.front()); - printf("KLD_01 : %10.6f\n", percentile(0.01f)); - printf("KLD_05 : %10.6f\n", percentile(0.05f)); - printf("KLD_10 : %10.6f\n", percentile(0.10f)); + printf("\n"); + + printf("====== Token probability statistics ======\n"); + + auto p_diff = mean_and_uncertainty(kld.sum_p_diff, kld.sum_p_diff2, kld.count); + printf("Mean Δp: %6.3lf ± %5.3lf %%\n", 100.0*p_diff.first, 100.0*p_diff.second); + + auto p_diff_median = p_diff_values.size()%2 == 0 ? 0.5f*(p_diff_values[p_diff_values.size()/2] + p_diff_values[p_diff_values.size()/2-1]) + : p_diff_values[p_diff_values.size()/2]; + + printf("Maximum Δp: %6.3lf%%\n", 100.0*p_diff_values.back()); + printf("99.9%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.999f)); + printf("99.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.990f)); + printf("95.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.950f)); + printf("90.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.900f)); + printf("75.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.750f)); + printf("Median Δp: %6.3lf%%\n", 100.0*p_diff_median); + printf("25.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.250f)); + printf("10.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.100f)); + printf(" 5.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.050f)); + printf(" 1.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.010f)); + printf(" 0.1%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.001f)); + printf("Minimum Δp: %6.3lf%%\n", 100.0*p_diff_values.front()); + + auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count); + // printf("MSE Δp : %10.6lf ± %10.6lf\n", p_diff_mse.first, p_diff_mse.second); + + const double p_diff_rms_val = sqrt(p_diff_mse.first); + const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second; + printf("RMS Δp : %6.3lf ± %5.3lf %%\n", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc); + + const double same_top_p = 1.0*kld.n_same_top/kld.count; + printf("Same top p: %6.3lf ± %5.3lf %%\n", 100.0*same_top_p, 100.0*sqrt(same_top_p*(1.0 - same_top_p)/(kld.count - 1))); }