From ddc3c2208acf0fb5a05f28205f1291486f922822 Mon Sep 17 00:00:00 2001 From: VJHack Date: Thu, 9 Jan 2025 23:04:28 -0600 Subject: [PATCH] initial sampling changes: --- common/common.h | 3 +++ common/sampling.cpp | 27 +++++++++++++++++++-------- src/llama-sampling.cpp | 42 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 8 deletions(-) diff --git a/common/common.h b/common/common.h index 0d452cf0f..3671751e9 100644 --- a/common/common.h +++ b/common/common.h @@ -95,6 +95,7 @@ enum common_sampler_type { COMMON_SAMPLER_TYPE_XTC = 8, COMMON_SAMPLER_TYPE_INFILL = 9, COMMON_SAMPLER_TYPE_PENALTIES = 10, + COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11 }; // dimensionality reduction methods, used by cvector-generator @@ -128,6 +129,7 @@ struct common_params_sampling { int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + int32_t top_n_sigma = 2; float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate bool ignore_eos = false; @@ -146,6 +148,7 @@ struct common_params_sampling { COMMON_SAMPLER_TYPE_MIN_P, COMMON_SAMPLER_TYPE_XTC, COMMON_SAMPLER_TYPE_TEMPERATURE, + COMMON_SAMPLER_TYPE_TOP_N_SIGMA, }; std::string grammar; // optional BNF-like grammar to constrain sampling diff --git a/common/sampling.cpp b/common/sampling.cpp index e83a971c7..15b08fe70 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -176,28 +176,32 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } break; case COMMON_SAMPLER_TYPE_TOP_K: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); break; case COMMON_SAMPLER_TYPE_TOP_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); + llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_MIN_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); + llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_XTC: - llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); break; case COMMON_SAMPLER_TYPE_TYPICAL_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); + llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_TEMPERATURE: - llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); break; case COMMON_SAMPLER_TYPE_INFILL: - llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model)); + llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model)); break; case COMMON_SAMPLER_TYPE_PENALTIES: - llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + break; + case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: + // llama_sampler_chain_add(result->chain, ) + llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma)) break; default: GGML_ASSERT(false && "unknown sampler type"); @@ -407,6 +411,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_XTC: return 'x'; case COMMON_SAMPLER_TYPE_INFILL: return 'i'; case COMMON_SAMPLER_TYPE_PENALTIES: return 'e'; + case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's'; default : return '?'; } } @@ -422,6 +427,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_XTC: return "xtc"; case COMMON_SAMPLER_TYPE_INFILL: return "infill"; case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties"; + case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma"; default : return ""; } } @@ -437,6 +443,7 @@ std::vector common_sampler_types_from_names(const std::vect { "xtc", COMMON_SAMPLER_TYPE_XTC }, { "infill", COMMON_SAMPLER_TYPE_INFILL }, { "penalties", COMMON_SAMPLER_TYPE_PENALTIES }, + { "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, }; // since samplers names are written multiple ways @@ -451,6 +458,9 @@ std::vector common_sampler_types_from_names(const std::vect { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "min-p", COMMON_SAMPLER_TYPE_MIN_P }, { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE }, + { "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, + { "top-nsigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, + { "top_nsigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, }; std::vector samplers; @@ -484,6 +494,7 @@ std::vector common_sampler_types_from_chars(const std::stri { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA} }; std::vector samplers; diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index ef5a576cc..d4e5e9be7 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1645,6 +1645,48 @@ struct llama_sampler * llama_sampler_init_penalties( }; } +// top-n-sigma + +struct llama_sampler_top_n_sigma { + const int32_t n; +}; + +static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) { + return "top-n-sigma"; +} + +static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; + llama_sampler_top_n_sigma_impl(cur_p, ctx->n); +} + +// static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) { +// const auto * ctx = (const llama_sampler_top_k *) smpl->ctx; +// return llama_sampler_init_top_k(ctx->k); +// } + +// static void llama_sampler_top_k_free(struct llama_sampler * smpl) { +// delete (llama_sampler_top_k *) smpl->ctx; +// } + +// static struct llama_sampler_i llama_sampler_top_k_i = { +// /* .name = */ llama_sampler_top_k_name, +// /* .accept = */ nullptr, +// /* .apply = */ llama_sampler_top_k_apply, +// /* .reset = */ nullptr, +// /* .clone = */ llama_sampler_top_k_clone, +// /* .free = */ llama_sampler_top_k_free, +// }; + +// struct llama_sampler * llama_sampler_init_top_k(int32_t k) { +// return new llama_sampler { +// /* .iface = */ &llama_sampler_top_k_i, +// /* .ctx = */ new llama_sampler_top_k { +// /* .k = */ k, +// }, +// }; +// } + // DRY struct llama_sampler_dry {