mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-07 08:53:16 +01:00
completed top nsigma sampler implementation
This commit is contained in:
parent
ddc3c2208a
commit
da038d8715
@ -899,6 +899,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
params.sampling.min_p = std::stof(value);
|
params.sampling.min_p = std::stof(value);
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--top-nsigma"}, "N",
|
||||||
|
string_format("top-n-sigma sampling (default: %d, -1 = disabled)", params.sampling.top_n_sigma),
|
||||||
|
[](common_params & params, const std::string & value) {
|
||||||
|
params.sampling.top_n_sigma = std::stof(value);
|
||||||
|
}
|
||||||
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--xtc-probability"}, "N",
|
{"--xtc-probability"}, "N",
|
||||||
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
|
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
|
||||||
|
@ -95,7 +95,6 @@ enum common_sampler_type {
|
|||||||
COMMON_SAMPLER_TYPE_XTC = 8,
|
COMMON_SAMPLER_TYPE_XTC = 8,
|
||||||
COMMON_SAMPLER_TYPE_INFILL = 9,
|
COMMON_SAMPLER_TYPE_INFILL = 9,
|
||||||
COMMON_SAMPLER_TYPE_PENALTIES = 10,
|
COMMON_SAMPLER_TYPE_PENALTIES = 10,
|
||||||
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// dimensionality reduction methods, used by cvector-generator
|
// dimensionality reduction methods, used by cvector-generator
|
||||||
@ -129,7 +128,7 @@ struct common_params_sampling {
|
|||||||
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
||||||
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
int32_t top_n_sigma = 2;
|
int32_t top_n_sigma = -1; // -1 = disabled
|
||||||
float mirostat_tau = 5.00f; // target entropy
|
float mirostat_tau = 5.00f; // target entropy
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
bool ignore_eos = false;
|
bool ignore_eos = false;
|
||||||
@ -148,7 +147,6 @@ struct common_params_sampling {
|
|||||||
COMMON_SAMPLER_TYPE_MIN_P,
|
COMMON_SAMPLER_TYPE_MIN_P,
|
||||||
COMMON_SAMPLER_TYPE_XTC,
|
COMMON_SAMPLER_TYPE_XTC,
|
||||||
COMMON_SAMPLER_TYPE_TEMPERATURE,
|
COMMON_SAMPLER_TYPE_TEMPERATURE,
|
||||||
COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||||
|
@ -131,11 +131,11 @@ std::string common_params_sampling::print() const {
|
|||||||
snprintf(result, sizeof(result),
|
snprintf(result, sizeof(result),
|
||||||
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||||
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
||||||
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
|
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %d, temp = %.3f\n"
|
||||||
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f,",
|
||||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
||||||
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
||||||
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
|
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
|
||||||
mirostat, mirostat_eta, mirostat_tau);
|
mirostat, mirostat_eta, mirostat_tau);
|
||||||
|
|
||||||
return std::string(result);
|
return std::string(result);
|
||||||
@ -162,6 +162,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
params.logit_bias.data()));
|
params.logit_bias.data()));
|
||||||
|
|
||||||
if (params.mirostat == 0) {
|
if (params.mirostat == 0) {
|
||||||
|
if(params.top_n_sigma >= 0) {
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma));
|
||||||
|
} else {
|
||||||
for (const auto & cnstr : params.samplers) {
|
for (const auto & cnstr : params.samplers) {
|
||||||
switch (cnstr) {
|
switch (cnstr) {
|
||||||
case COMMON_SAMPLER_TYPE_DRY:
|
case COMMON_SAMPLER_TYPE_DRY:
|
||||||
@ -199,14 +203,11 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
|
||||||
// llama_sampler_chain_add(result->chain, )
|
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma))
|
|
||||||
break;
|
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "unknown sampler type");
|
GGML_ASSERT(false && "unknown sampler type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||||
} else if (params.mirostat == 1) {
|
} else if (params.mirostat == 1) {
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||||
@ -411,7 +412,6 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
|||||||
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
||||||
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
||||||
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
|
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
|
||||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
|
|
||||||
default : return '?';
|
default : return '?';
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -427,7 +427,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
|||||||
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
||||||
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
||||||
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
|
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
|
||||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
|
|
||||||
default : return "";
|
default : return "";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -443,7 +442,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|||||||
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
||||||
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
||||||
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
|
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
|
||||||
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// since samplers names are written multiple ways
|
// since samplers names are written multiple ways
|
||||||
@ -458,9 +456,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|||||||
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||||
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
|
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
|
||||||
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||||
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
|
||||||
{ "top-nsigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
|
||||||
{ "top_nsigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<common_sampler_type> samplers;
|
std::vector<common_sampler_type> samplers;
|
||||||
@ -494,7 +489,6 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
|||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<common_sampler_type> samplers;
|
std::vector<common_sampler_type> samplers;
|
||||||
|
@ -1133,6 +1133,9 @@ extern "C" {
|
|||||||
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
|
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
|
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
|
||||||
|
|
||||||
|
/// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641
|
||||||
|
LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(int32_t n);
|
||||||
|
|
||||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||||
|
@ -301,6 +301,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
|||||||
cur_p->size = k;
|
cur_p->size = k;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static uint32_t get_rng_seed(uint32_t seed) {
|
static uint32_t get_rng_seed(uint32_t seed) {
|
||||||
if (seed == LLAMA_DEFAULT_SEED) {
|
if (seed == LLAMA_DEFAULT_SEED) {
|
||||||
// use system clock if std::random_device is not a true RNG
|
// use system clock if std::random_device is not a true RNG
|
||||||
@ -1657,35 +1658,65 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler *
|
|||||||
|
|
||||||
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
|
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
|
||||||
llama_sampler_top_n_sigma_impl(cur_p, ctx->n);
|
// 1. Find max logit: M
|
||||||
|
// 2. Find standard deviation of logits: sig
|
||||||
|
// 3. Create a mask where m[i] = 1 if ith logit >= M - n (sig), else m[i] = 0
|
||||||
|
// 4. Apply mask: ith logit itself if m[i]==1, else ith logit = -inf
|
||||||
|
// 5. p = softmax(l)
|
||||||
|
|
||||||
|
// find max logit and calculate mean
|
||||||
|
int32_t max = cur_p->data[0].logit;
|
||||||
|
int32_t logits_sum = 0;
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
if(cur_p->data[i].logit > max){
|
||||||
|
max = cur_p->data[i].logit;
|
||||||
|
}
|
||||||
|
logits_sum += cur_p->data[i].logit;
|
||||||
|
}
|
||||||
|
int32_t mean = logits_sum/cur_p->size;
|
||||||
|
|
||||||
|
// calculate standard deviation
|
||||||
|
int32_t acc = 0;
|
||||||
|
for(size_t i = 0; i < cur_p->size; ++i){
|
||||||
|
acc += (cur_p->data[i].logit - mean) * (cur_p->data[i].logit - mean);
|
||||||
|
}
|
||||||
|
int32_t std = sqrt(acc/cur_p->size);
|
||||||
|
|
||||||
|
//apply mask
|
||||||
|
for(size_t i = 0; i < cur_p->size; ++i){
|
||||||
|
if(cur_p->data[i].logit < max - (ctx->n * std)) {
|
||||||
|
cur_p->data[i].logit = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
llama_sampler_softmax_impl(cur_p);
|
||||||
}
|
}
|
||||||
|
|
||||||
// static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
|
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl){
|
||||||
// const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
|
const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
|
||||||
// return llama_sampler_init_top_k(ctx->k);
|
return llama_sampler_init_top_n_sigma(ctx->n);
|
||||||
// }
|
}
|
||||||
|
|
||||||
// static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
|
static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
|
||||||
// delete (llama_sampler_top_k *) smpl->ctx;
|
delete (llama_sampler_top_n_sigma *) smpl->ctx;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// static struct llama_sampler_i llama_sampler_top_k_i = {
|
static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
|
||||||
// /* .name = */ llama_sampler_top_k_name,
|
/* .name = */ llama_sampler_top_n_sigma_name,
|
||||||
// /* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
// /* .apply = */ llama_sampler_top_k_apply,
|
/* .apply = */ llama_sampler_top_n_sigma_apply,
|
||||||
// /* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
// /* .clone = */ llama_sampler_top_k_clone,
|
/* .clone = */ llama_sampler_top_n_sigma_clone,
|
||||||
// /* .free = */ llama_sampler_top_k_free,
|
/* .free = */ llama_sampler_top_n_sigma_free,
|
||||||
// };
|
};
|
||||||
|
|
||||||
// struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
struct llama_sampler * llama_sampler_init_top_n_sigma(int32_t n) {
|
||||||
// return new llama_sampler {
|
return new llama_sampler {
|
||||||
// /* .iface = */ &llama_sampler_top_k_i,
|
/* .iface = */ &llama_sampler_top_n_sigma_i,
|
||||||
// /* .ctx = */ new llama_sampler_top_k {
|
/* .ctx = */ new llama_sampler_top_n_sigma {
|
||||||
// /* .k = */ k,
|
/* .n = */ n,
|
||||||
// },
|
},
|
||||||
// };
|
};
|
||||||
// }
|
}
|
||||||
|
|
||||||
// DRY
|
// DRY
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user