cleanup pr and remove explicit floats

This commit is contained in:
VJHack 2025-01-13 22:21:38 -06:00
parent f08e6f5bdc
commit 6664d4709f

View File

@ -1657,7 +1657,6 @@ struct llama_sampler_top_n_sigma {
static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) { static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
return "top-n-sigma"; return "top-n-sigma";
} }
#include <iostream>
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
@ -1671,18 +1670,18 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
} }
logits_sum += cur_p->data[i].logit; logits_sum += cur_p->data[i].logit;
} }
float mean = (float)logits_sum/cur_p->size; float mean = logits_sum/cur_p->size;
// calculate standard deviation // calculate standard deviation
float acc = 0; float acc = 0;
for(size_t i = 0; i < cur_p->size; ++i){ for(size_t i = 0; i < cur_p->size; ++i){
acc += pow(cur_p->data[i].logit - mean, 2); acc += pow(cur_p->data[i].logit - mean, 2);
} }
float std = sqrt((float)acc/cur_p->size); float std = sqrt(acc/cur_p->size);
//apply mask //apply mask
for(size_t i = 0; i < cur_p->size; ++i){ for(size_t i = 0; i < cur_p->size; ++i){
if(cur_p->data[i].logit < max - ((float)ctx->n * std)) { if(cur_p->data[i].logit < max - ctx->n * std) {
cur_p->data[i].logit = -INFINITY; cur_p->data[i].logit = -INFINITY;
} }
} }