mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-31 06:03:11 +01:00
added tests and fixed nsigma impl
This commit is contained in:
parent
8fb681bf9a
commit
54ef105c85
@ -1655,36 +1655,32 @@ struct llama_sampler_top_n_sigma {
|
|||||||
static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
|
static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
|
||||||
return "top-n-sigma";
|
return "top-n-sigma";
|
||||||
}
|
}
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
|
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
|
||||||
// 1. Find max logit: M
|
|
||||||
// 2. Find standard deviation of logits: sig
|
|
||||||
// 3. Create a mask where m[i] = 1 if ith logit >= M - n (sig), else m[i] = 0
|
|
||||||
// 4. Apply mask: ith logit itself if m[i]==1, else ith logit = -inf
|
|
||||||
// 5. p = softmax(l)
|
|
||||||
|
|
||||||
// find max logit and calculate mean
|
// find max logit and calculate mean
|
||||||
int32_t max = cur_p->data[0].logit;
|
float max = cur_p->data[0].logit;
|
||||||
int32_t logits_sum = 0;
|
float logits_sum = 0;
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
if(cur_p->data[i].logit > max){
|
if(cur_p->data[i].logit > max){
|
||||||
max = cur_p->data[i].logit;
|
max = cur_p->data[i].logit;
|
||||||
}
|
}
|
||||||
logits_sum += cur_p->data[i].logit;
|
logits_sum += cur_p->data[i].logit;
|
||||||
}
|
}
|
||||||
int32_t mean = logits_sum/cur_p->size;
|
float mean = (float)logits_sum/cur_p->size;
|
||||||
|
|
||||||
// calculate standard deviation
|
// calculate standard deviation
|
||||||
int32_t acc = 0;
|
float acc = 0;
|
||||||
for(size_t i = 0; i < cur_p->size; ++i){
|
for(size_t i = 0; i < cur_p->size; ++i){
|
||||||
acc += (cur_p->data[i].logit - mean) * (cur_p->data[i].logit - mean);
|
acc += pow(cur_p->data[i].logit - mean, 2);
|
||||||
}
|
}
|
||||||
int32_t std = sqrt(acc/cur_p->size);
|
float std = sqrt((float)acc/cur_p->size);
|
||||||
|
|
||||||
//apply mask
|
//apply mask
|
||||||
for(size_t i = 0; i < cur_p->size; ++i){
|
for(size_t i = 0; i < cur_p->size; ++i){
|
||||||
if(cur_p->data[i].logit < max - (ctx->n * std)) {
|
if(cur_p->data[i].logit < max - ((float)ctx->n * std)) {
|
||||||
cur_p->data[i].logit = -INFINITY;
|
cur_p->data[i].logit = -INFINITY;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -182,6 +182,17 @@ static void test_dry(
|
|||||||
tester.check();
|
tester.check();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void test_top_n_sigma(const std::vector<float> & probs, const std::vector<float> & probs_expected, int n) {
|
||||||
|
sampler_tester tester(probs, probs_expected);
|
||||||
|
|
||||||
|
DUMP(&tester.cur_p);
|
||||||
|
tester.apply(llama_sampler_init_top_n_sigma(n));
|
||||||
|
tester.apply(llama_sampler_init_dist (0));
|
||||||
|
DUMP(&tester.cur_p);
|
||||||
|
|
||||||
|
tester.check();
|
||||||
|
}
|
||||||
|
|
||||||
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
|
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
|
||||||
) {
|
) {
|
||||||
sampler_tester tester(n_vocab);
|
sampler_tester tester(n_vocab);
|
||||||
@ -349,6 +360,14 @@ int main(void) {
|
|||||||
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {});
|
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {});
|
||||||
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
|
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
|
||||||
|
|
||||||
|
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1);
|
||||||
|
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0);
|
||||||
|
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3);
|
||||||
|
|
||||||
|
// test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
|
||||||
|
// test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
|
||||||
|
// test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
|
||||||
|
|
||||||
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
|
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
|
||||||
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
|
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
|
||||||
test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);
|
test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);
|
||||||
|
Loading…
Reference in New Issue
Block a user