From 54ef105c85b1220908fa40c2b59f768a0392139b Mon Sep 17 00:00:00 2001 From: VJHack Date: Mon, 13 Jan 2025 17:48:35 -0600 Subject: [PATCH] added tests and fixed nsigma impl --- src/llama-sampling.cpp | 22 +++++++++------------- tests/test-sampling.cpp | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 1eb7df950..e91401d75 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1655,36 +1655,32 @@ struct llama_sampler_top_n_sigma { static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) { return "top-n-sigma"; } +#include static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; - // 1. Find max logit: M - // 2. Find standard deviation of logits: sig - // 3. Create a mask where m[i] = 1 if ith logit >= M - n (sig), else m[i] = 0 - // 4. Apply mask: ith logit itself if m[i]==1, else ith logit = -inf - // 5. p = softmax(l) // find max logit and calculate mean - int32_t max = cur_p->data[0].logit; - int32_t logits_sum = 0; + float max = cur_p->data[0].logit; + float logits_sum = 0; for (size_t i = 0; i < cur_p->size; ++i) { if(cur_p->data[i].logit > max){ max = cur_p->data[i].logit; } logits_sum += cur_p->data[i].logit; } - int32_t mean = logits_sum/cur_p->size; + float mean = (float)logits_sum/cur_p->size; // calculate standard deviation - int32_t acc = 0; + float acc = 0; for(size_t i = 0; i < cur_p->size; ++i){ - acc += (cur_p->data[i].logit - mean) * (cur_p->data[i].logit - mean); + acc += pow(cur_p->data[i].logit - mean, 2); } - int32_t std = sqrt(acc/cur_p->size); - + float std = sqrt((float)acc/cur_p->size); + //apply mask for(size_t i = 0; i < cur_p->size; ++i){ - if(cur_p->data[i].logit < max - (ctx->n * std)) { + if(cur_p->data[i].logit < max - ((float)ctx->n * std)) { cur_p->data[i].logit = -INFINITY; } } diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index c0dcb4848..59bde4d41 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -182,6 +182,17 @@ static void test_dry( tester.check(); } +static void test_top_n_sigma(const std::vector & probs, const std::vector & probs_expected, int n) { + sampler_tester tester(probs, probs_expected); + + DUMP(&tester.cur_p); + tester.apply(llama_sampler_init_top_n_sigma(n)); + tester.apply(llama_sampler_init_dist (0)); + DUMP(&tester.cur_p); + + tester.check(); +} + static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p ) { sampler_tester tester(n_vocab); @@ -349,6 +360,14 @@ int main(void) { test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {}); test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {}); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3); + + // test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3); + // test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4); + // test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0); + test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f); test_sampler_queue(10000, "k", 1, 1.0f, 1.0f); test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);