mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-31 14:13:09 +01:00
cleanup pr and remove explicit floats
This commit is contained in:
parent
f08e6f5bdc
commit
6664d4709f
@ -1657,7 +1657,6 @@ struct llama_sampler_top_n_sigma {
|
|||||||
static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
|
static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
|
||||||
return "top-n-sigma";
|
return "top-n-sigma";
|
||||||
}
|
}
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
|
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
|
||||||
@ -1671,18 +1670,18 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
|
|||||||
}
|
}
|
||||||
logits_sum += cur_p->data[i].logit;
|
logits_sum += cur_p->data[i].logit;
|
||||||
}
|
}
|
||||||
float mean = (float)logits_sum/cur_p->size;
|
float mean = logits_sum/cur_p->size;
|
||||||
|
|
||||||
// calculate standard deviation
|
// calculate standard deviation
|
||||||
float acc = 0;
|
float acc = 0;
|
||||||
for(size_t i = 0; i < cur_p->size; ++i){
|
for(size_t i = 0; i < cur_p->size; ++i){
|
||||||
acc += pow(cur_p->data[i].logit - mean, 2);
|
acc += pow(cur_p->data[i].logit - mean, 2);
|
||||||
}
|
}
|
||||||
float std = sqrt((float)acc/cur_p->size);
|
float std = sqrt(acc/cur_p->size);
|
||||||
|
|
||||||
//apply mask
|
//apply mask
|
||||||
for(size_t i = 0; i < cur_p->size; ++i){
|
for(size_t i = 0; i < cur_p->size; ++i){
|
||||||
if(cur_p->data[i].logit < max - ((float)ctx->n * std)) {
|
if(cur_p->data[i].logit < max - ctx->n * std) {
|
||||||
cur_p->data[i].logit = -INFINITY;
|
cur_p->data[i].logit = -INFINITY;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user