From 6c1ca58f071a53e545eeddeac4cc78e4d34689a3 Mon Sep 17 00:00:00 2001 From: VJHack Date: Sun, 19 Jan 2025 22:40:54 -0600 Subject: [PATCH] changed sigma to float --- common/common.h | 2 +- common/sampling.cpp | 2 +- include/llama.h | 2 +- src/llama-sampling.cpp | 6 +++--- tests/test-sampling.cpp | 6 +++--- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/common/common.h b/common/common.h index 1cce234df..9ea5d4d8b 100644 --- a/common/common.h +++ b/common/common.h @@ -134,7 +134,7 @@ struct common_params_sampling { int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - int32_t top_n_sigma = -1; // -1 = disabled + float top_n_sigma = -1.00f;// -1.0 = disabled float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate bool ignore_eos = false; diff --git a/common/sampling.cpp b/common/sampling.cpp index ddface5ed..0bd682774 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -134,7 +134,7 @@ std::string common_params_sampling::print() const { snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" - "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %d, temp = %.3f\n" + "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", penalty_last_n, penalty_repeat, penalty_freq, penalty_present, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, diff --git a/include/llama.h b/include/llama.h index acc177231..76dff06c4 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1164,7 +1164,7 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641 - LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(int32_t n); + LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n); /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 9826bdd09..876e51d5c 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1651,7 +1651,7 @@ struct llama_sampler * llama_sampler_init_penalties( // top-n-sigma struct llama_sampler_top_n_sigma { - const int32_t n; + const float n; }; static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) { @@ -1681,7 +1681,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t //apply mask for(size_t i = 0; i < cur_p->size; ++i){ - if(cur_p->data[i].logit < max - ctx->n * std) { + if(cur_p->data[i].logit < max - (ctx->n * std)) { cur_p->data[i].logit = -INFINITY; } } @@ -1706,7 +1706,7 @@ static struct llama_sampler_i llama_sampler_top_n_sigma_i = { /* .free = */ llama_sampler_top_n_sigma_free, }; -struct llama_sampler * llama_sampler_init_top_n_sigma(int32_t n) { +struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { return new llama_sampler { /* .iface = */ &llama_sampler_top_n_sigma_i, /* .ctx = */ new llama_sampler_top_n_sigma { diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index d2459f91d..9b4f2341c 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -360,9 +360,9 @@ int main(void) { test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {}); test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {}); - test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1); - test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0); - test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.00f); + test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f); test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f); test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);