initial sampling changes:

This commit is contained in:
VJHack 2025-01-09 23:04:28 -06:00
parent f7cd13301c
commit ddc3c2208a
3 changed files with 64 additions and 8 deletions

View File

@ -95,6 +95,7 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_XTC = 8, COMMON_SAMPLER_TYPE_XTC = 8,
COMMON_SAMPLER_TYPE_INFILL = 9, COMMON_SAMPLER_TYPE_INFILL = 9,
COMMON_SAMPLER_TYPE_PENALTIES = 10, COMMON_SAMPLER_TYPE_PENALTIES = 10,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11
}; };
// dimensionality reduction methods, used by cvector-generator // dimensionality reduction methods, used by cvector-generator
@ -128,6 +129,7 @@ struct common_params_sampling {
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
int32_t top_n_sigma = 2;
float mirostat_tau = 5.00f; // target entropy float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate float mirostat_eta = 0.10f; // learning rate
bool ignore_eos = false; bool ignore_eos = false;
@ -146,6 +148,7 @@ struct common_params_sampling {
COMMON_SAMPLER_TYPE_MIN_P, COMMON_SAMPLER_TYPE_MIN_P,
COMMON_SAMPLER_TYPE_XTC, COMMON_SAMPLER_TYPE_XTC,
COMMON_SAMPLER_TYPE_TEMPERATURE, COMMON_SAMPLER_TYPE_TEMPERATURE,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
}; };
std::string grammar; // optional BNF-like grammar to constrain sampling std::string grammar; // optional BNF-like grammar to constrain sampling

View File

@ -199,6 +199,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
case COMMON_SAMPLER_TYPE_PENALTIES: case COMMON_SAMPLER_TYPE_PENALTIES:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break; break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
// llama_sampler_chain_add(result->chain, )
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma))
break;
default: default:
GGML_ASSERT(false && "unknown sampler type"); GGML_ASSERT(false && "unknown sampler type");
} }
@ -407,6 +411,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_XTC: return 'x'; case COMMON_SAMPLER_TYPE_XTC: return 'x';
case COMMON_SAMPLER_TYPE_INFILL: return 'i'; case COMMON_SAMPLER_TYPE_INFILL: return 'i';
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e'; case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
default : return '?'; default : return '?';
} }
} }
@ -422,6 +427,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_XTC: return "xtc"; case COMMON_SAMPLER_TYPE_XTC: return "xtc";
case COMMON_SAMPLER_TYPE_INFILL: return "infill"; case COMMON_SAMPLER_TYPE_INFILL: return "infill";
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties"; case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
default : return ""; default : return "";
} }
} }
@ -437,6 +443,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "xtc", COMMON_SAMPLER_TYPE_XTC }, { "xtc", COMMON_SAMPLER_TYPE_XTC },
{ "infill", COMMON_SAMPLER_TYPE_INFILL }, { "infill", COMMON_SAMPLER_TYPE_INFILL },
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES }, { "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
}; };
// since samplers names are written multiple ways // since samplers names are written multiple ways
@ -451,6 +458,9 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P }, { "min-p", COMMON_SAMPLER_TYPE_MIN_P },
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE }, { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "top-nsigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "top_nsigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
}; };
std::vector<common_sampler_type> samplers; std::vector<common_sampler_type> samplers;
@ -484,6 +494,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA}
}; };
std::vector<common_sampler_type> samplers; std::vector<common_sampler_type> samplers;

View File

@ -1645,6 +1645,48 @@ struct llama_sampler * llama_sampler_init_penalties(
}; };
} }
// top-n-sigma
struct llama_sampler_top_n_sigma {
const int32_t n;
};
static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
return "top-n-sigma";
}
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
llama_sampler_top_n_sigma_impl(cur_p, ctx->n);
}
// static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
// const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
// return llama_sampler_init_top_k(ctx->k);
// }
// static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
// delete (llama_sampler_top_k *) smpl->ctx;
// }
// static struct llama_sampler_i llama_sampler_top_k_i = {
// /* .name = */ llama_sampler_top_k_name,
// /* .accept = */ nullptr,
// /* .apply = */ llama_sampler_top_k_apply,
// /* .reset = */ nullptr,
// /* .clone = */ llama_sampler_top_k_clone,
// /* .free = */ llama_sampler_top_k_free,
// };
// struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
// return new llama_sampler {
// /* .iface = */ &llama_sampler_top_k_i,
// /* .ctx = */ new llama_sampler_top_k {
// /* .k = */ k,
// },
// };
// }
// DRY // DRY
struct llama_sampler_dry { struct llama_sampler_dry {