diff --git a/common/arg.cpp b/common/arg.cpp index 27886b84e..6c6be6ef7 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -899,6 +899,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.min_p = std::stof(value); } ).set_sparam()); + add_opt(common_arg( + {"--top-nsigma"}, "N", + string_format("top-n-sigma sampling (default: %d, -1 = disabled)", params.sampling.top_n_sigma), + [](common_params & params, const std::string & value) { + params.sampling.top_n_sigma = std::stof(value); + } + ).set_sparam()); add_opt(common_arg( {"--xtc-probability"}, "N", string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), diff --git a/common/common.h b/common/common.h index 3671751e9..50f1def67 100644 --- a/common/common.h +++ b/common/common.h @@ -95,7 +95,6 @@ enum common_sampler_type { COMMON_SAMPLER_TYPE_XTC = 8, COMMON_SAMPLER_TYPE_INFILL = 9, COMMON_SAMPLER_TYPE_PENALTIES = 10, - COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11 }; // dimensionality reduction methods, used by cvector-generator @@ -129,7 +128,7 @@ struct common_params_sampling { int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - int32_t top_n_sigma = 2; + int32_t top_n_sigma = -1; // -1 = disabled float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate bool ignore_eos = false; @@ -148,7 +147,6 @@ struct common_params_sampling { COMMON_SAMPLER_TYPE_MIN_P, COMMON_SAMPLER_TYPE_XTC, COMMON_SAMPLER_TYPE_TEMPERATURE, - COMMON_SAMPLER_TYPE_TOP_N_SIGMA, }; std::string grammar; // optional BNF-like grammar to constrain sampling diff --git a/common/sampling.cpp b/common/sampling.cpp index 15b08fe70..9d58c1680 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -131,11 +131,11 @@ std::string common_params_sampling::print() const { snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" - "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n" - "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", + "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %d, temp = %.3f\n" + "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f,", penalty_last_n, penalty_repeat, penalty_freq, penalty_present, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, - top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, + top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp, mirostat, mirostat_eta, mirostat_tau); return std::string(result); @@ -162,49 +162,50 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co params.logit_bias.data())); if (params.mirostat == 0) { - for (const auto & cnstr : params.samplers) { - switch (cnstr) { - case COMMON_SAMPLER_TYPE_DRY: - { - std::vector c_breakers; - c_breakers.reserve(params.dry_sequence_breakers.size()); - for (const auto & str : params.dry_sequence_breakers) { - c_breakers.push_back(str.c_str()); - } + if(params.top_n_sigma >= 0) { + llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); + llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma)); + } else { + for (const auto & cnstr : params.samplers) { + switch (cnstr) { + case COMMON_SAMPLER_TYPE_DRY: + { + std::vector c_breakers; + c_breakers.reserve(params.dry_sequence_breakers.size()); + for (const auto & str : params.dry_sequence_breakers) { + c_breakers.push_back(str.c_str()); + } - llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); - } - break; - case COMMON_SAMPLER_TYPE_TOP_K: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); - break; - case COMMON_SAMPLER_TYPE_TOP_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_MIN_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_XTC: - llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); - break; - case COMMON_SAMPLER_TYPE_TYPICAL_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_TEMPERATURE: - llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); - break; - case COMMON_SAMPLER_TYPE_INFILL: - llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model)); - break; - case COMMON_SAMPLER_TYPE_PENALTIES: - llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); - break; - case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: - // llama_sampler_chain_add(result->chain, ) - llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma)) - break; - default: - GGML_ASSERT(false && "unknown sampler type"); + llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + } + break; + case COMMON_SAMPLER_TYPE_TOP_K: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + break; + case COMMON_SAMPLER_TYPE_TOP_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_MIN_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_XTC: + llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + break; + case COMMON_SAMPLER_TYPE_TYPICAL_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_TEMPERATURE: + llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + break; + case COMMON_SAMPLER_TYPE_INFILL: + llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model)); + break; + case COMMON_SAMPLER_TYPE_PENALTIES: + llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + break; + default: + GGML_ASSERT(false && "unknown sampler type"); + } } } llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); @@ -411,7 +412,6 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_XTC: return 'x'; case COMMON_SAMPLER_TYPE_INFILL: return 'i'; case COMMON_SAMPLER_TYPE_PENALTIES: return 'e'; - case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's'; default : return '?'; } } @@ -427,7 +427,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_XTC: return "xtc"; case COMMON_SAMPLER_TYPE_INFILL: return "infill"; case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties"; - case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma"; default : return ""; } } @@ -443,7 +442,6 @@ std::vector common_sampler_types_from_names(const std::vect { "xtc", COMMON_SAMPLER_TYPE_XTC }, { "infill", COMMON_SAMPLER_TYPE_INFILL }, { "penalties", COMMON_SAMPLER_TYPE_PENALTIES }, - { "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, }; // since samplers names are written multiple ways @@ -458,9 +456,6 @@ std::vector common_sampler_types_from_names(const std::vect { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "min-p", COMMON_SAMPLER_TYPE_MIN_P }, { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE }, - { "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, - { "top-nsigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, - { "top_nsigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, }; std::vector samplers; @@ -494,7 +489,6 @@ std::vector common_sampler_types_from_chars(const std::stri { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES }, - { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA} }; std::vector samplers; diff --git a/include/llama.h b/include/llama.h index 0295a51fb..7100d1ab0 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1133,6 +1133,9 @@ extern "C" { /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); + /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641 + LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(int32_t n); + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index d4e5e9be7..1eb7df950 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -301,6 +301,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) cur_p->size = k; } + static uint32_t get_rng_seed(uint32_t seed) { if (seed == LLAMA_DEFAULT_SEED) { // use system clock if std::random_device is not a true RNG @@ -1657,35 +1658,65 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; - llama_sampler_top_n_sigma_impl(cur_p, ctx->n); + // 1. Find max logit: M + // 2. Find standard deviation of logits: sig + // 3. Create a mask where m[i] = 1 if ith logit >= M - n (sig), else m[i] = 0 + // 4. Apply mask: ith logit itself if m[i]==1, else ith logit = -inf + // 5. p = softmax(l) + + // find max logit and calculate mean + int32_t max = cur_p->data[0].logit; + int32_t logits_sum = 0; + for (size_t i = 0; i < cur_p->size; ++i) { + if(cur_p->data[i].logit > max){ + max = cur_p->data[i].logit; + } + logits_sum += cur_p->data[i].logit; + } + int32_t mean = logits_sum/cur_p->size; + + // calculate standard deviation + int32_t acc = 0; + for(size_t i = 0; i < cur_p->size; ++i){ + acc += (cur_p->data[i].logit - mean) * (cur_p->data[i].logit - mean); + } + int32_t std = sqrt(acc/cur_p->size); + + //apply mask + for(size_t i = 0; i < cur_p->size; ++i){ + if(cur_p->data[i].logit < max - (ctx->n * std)) { + cur_p->data[i].logit = -INFINITY; + } + } + llama_sampler_softmax_impl(cur_p); } -// static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) { -// const auto * ctx = (const llama_sampler_top_k *) smpl->ctx; -// return llama_sampler_init_top_k(ctx->k); -// } +static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl){ + const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx; + return llama_sampler_init_top_n_sigma(ctx->n); +} -// static void llama_sampler_top_k_free(struct llama_sampler * smpl) { -// delete (llama_sampler_top_k *) smpl->ctx; -// } +static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) { + delete (llama_sampler_top_n_sigma *) smpl->ctx; +} -// static struct llama_sampler_i llama_sampler_top_k_i = { -// /* .name = */ llama_sampler_top_k_name, -// /* .accept = */ nullptr, -// /* .apply = */ llama_sampler_top_k_apply, -// /* .reset = */ nullptr, -// /* .clone = */ llama_sampler_top_k_clone, -// /* .free = */ llama_sampler_top_k_free, -// }; +static struct llama_sampler_i llama_sampler_top_n_sigma_i = { + /* .name = */ llama_sampler_top_n_sigma_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_n_sigma_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_n_sigma_clone, + /* .free = */ llama_sampler_top_n_sigma_free, +}; -// struct llama_sampler * llama_sampler_init_top_k(int32_t k) { -// return new llama_sampler { -// /* .iface = */ &llama_sampler_top_k_i, -// /* .ctx = */ new llama_sampler_top_k { -// /* .k = */ k, -// }, -// }; -// } +struct llama_sampler * llama_sampler_init_top_n_sigma(int32_t n) { + return new llama_sampler { + /* .iface = */ &llama_sampler_top_n_sigma_i, + /* .ctx = */ new llama_sampler_top_n_sigma { + /* .n = */ n, + }, + }; +} // DRY