From 6664d4709fe66910ac82b07879ef67189fd31bbb Mon Sep 17 00:00:00 2001 From: VJHack Date: Mon, 13 Jan 2025 22:21:38 -0600 Subject: [PATCH] cleanup pr and remove explicit floats --- src/llama-sampling.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 9938f0f3d..9826bdd09 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1657,7 +1657,6 @@ struct llama_sampler_top_n_sigma { static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) { return "top-n-sigma"; } -#include static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; @@ -1671,18 +1670,18 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t } logits_sum += cur_p->data[i].logit; } - float mean = (float)logits_sum/cur_p->size; + float mean = logits_sum/cur_p->size; // calculate standard deviation float acc = 0; for(size_t i = 0; i < cur_p->size; ++i){ acc += pow(cur_p->data[i].logit - mean, 2); } - float std = sqrt((float)acc/cur_p->size); + float std = sqrt(acc/cur_p->size); //apply mask for(size_t i = 0; i < cur_p->size; ++i){ - if(cur_p->data[i].logit < max - ((float)ctx->n * std)) { + if(cur_p->data[i].logit < max - ctx->n * std) { cur_p->data[i].logit = -INFINITY; } }