2023-10-11 21:35:46 +02:00
|
|
|
#include "sampling.h"
|
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
#include "common.h"
|
2023-10-18 15:21:57 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
std::string gpt_sampling_params::print_all() const {
|
|
|
|
char result[1024];
|
2023-10-18 15:21:57 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
snprintf(result, sizeof(result),
|
|
|
|
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
|
|
|
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
|
|
|
|
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
|
|
|
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
|
|
|
top_k, tfs_z, top_p, min_p, typ_p, temp,
|
|
|
|
mirostat, mirostat_eta, mirostat_tau);
|
2024-03-13 19:10:40 +01:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
return std::string(result);
|
|
|
|
}
|
2023-10-18 15:21:57 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
std::string gpt_sampling_params::print_samplers() const {
|
|
|
|
std::string result = "CFG -> Penalties ";
|
|
|
|
if (mirostat == 0) {
|
|
|
|
for (const auto & sampler : samplers) {
|
|
|
|
const auto name = llama_sampling_type_to_str(sampler);
|
|
|
|
if (!name.empty()) {
|
|
|
|
result += "-> " + name + " ";
|
|
|
|
}
|
2024-06-25 21:07:28 +02:00
|
|
|
}
|
2024-08-05 09:08:25 +02:00
|
|
|
} else {
|
|
|
|
result += "-> mirostat ";
|
2023-10-11 21:35:46 +02:00
|
|
|
}
|
2023-10-18 15:21:57 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
return result;
|
|
|
|
}
|
2024-05-07 23:07:58 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params) {
|
|
|
|
llama_sampling_params lparams = llama_sampling_default_params();
|
|
|
|
|
|
|
|
lparams.seed = params.seed;
|
|
|
|
lparams.n_prev = params.n_prev;
|
|
|
|
lparams.n_probs = params.n_probs;
|
|
|
|
lparams.min_keep = params.min_keep;
|
|
|
|
lparams.top_k = params.top_k;
|
|
|
|
lparams.top_p = params.top_p;
|
|
|
|
lparams.min_p = params.min_p;
|
|
|
|
lparams.tfs_z = params.tfs_z;
|
|
|
|
lparams.typ_p = params.typ_p;
|
|
|
|
lparams.temp = params.temp;
|
|
|
|
lparams.dynatemp_range = params.dynatemp_range;
|
|
|
|
lparams.dynatemp_exponent = params.dynatemp_exponent;
|
|
|
|
lparams.penalty_last_n = params.penalty_last_n;
|
|
|
|
lparams.penalty_repeat = params.penalty_repeat;
|
|
|
|
lparams.penalty_freq = params.penalty_freq;
|
|
|
|
lparams.penalty_present = params.penalty_present;
|
|
|
|
lparams.mirostat = params.mirostat;
|
|
|
|
lparams.mirostat_tau = params.mirostat_tau;
|
|
|
|
lparams.mirostat_eta = params.mirostat_eta;
|
|
|
|
lparams.penalize_nl = params.penalize_nl;
|
|
|
|
lparams.ignore_eos = params.ignore_eos;
|
|
|
|
|
|
|
|
lparams.n_samplers = params.samplers.size();
|
|
|
|
for (int i = 0; i < lparams.n_samplers; i++) {
|
|
|
|
lparams.samplers[i] = params.samplers[i];
|
|
|
|
}
|
|
|
|
|
|
|
|
struct llama_sampling * result = llama_sampling_init(model, lparams);
|
|
|
|
|
|
|
|
llama_sampling_set_grammar (result, params.grammar.c_str(), "root");
|
|
|
|
llama_sampling_set_logit_bias(result, params.logit_bias.size(), params.logit_bias.data());
|
2024-04-24 11:08:36 +02:00
|
|
|
|
2023-10-18 15:21:57 +02:00
|
|
|
return result;
|
2023-10-11 21:35:46 +02:00
|
|
|
}
|
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
void llama_sampling_cp(llama_sampling * src, llama_sampling *& dst) {
|
|
|
|
if (dst) {
|
|
|
|
llama_sampling_free(dst);
|
2023-10-18 15:21:57 +02:00
|
|
|
}
|
2023-10-11 21:35:46 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
dst = llama_sampling_cp(src);
|
2023-10-11 21:35:46 +02:00
|
|
|
}
|
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
llama_token llama_sampling_sample(
|
|
|
|
struct llama_sampling * smpl,
|
|
|
|
struct llama_context * ctx,
|
|
|
|
int idx) {
|
|
|
|
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx));
|
2023-10-18 15:21:57 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
// first, sample the token without any grammar constraints
|
|
|
|
const llama_token id = llama_sampling_sample(smpl, nullptr);
|
2023-10-18 15:21:57 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
// create an array with a single token data element for the sampled id
|
|
|
|
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
|
|
|
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
|
2023-10-11 21:35:46 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
llama_sampling_grammar(smpl, &single_token_data_array);
|
2024-04-24 11:08:36 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
// check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
|
|
|
|
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
|
|
|
if (is_valid) {
|
|
|
|
return id;
|
2023-10-11 21:35:46 +02:00
|
|
|
}
|
2023-10-18 15:21:57 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
// if the token is not valid, sample again, after applying the grammar constraints
|
|
|
|
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx));
|
2023-10-18 15:21:57 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
llama_sampling_grammar(smpl, nullptr);
|
2023-10-11 21:35:46 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
return llama_sampling_sample(smpl, nullptr);
|
2023-10-20 20:07:23 +02:00
|
|
|
}
|
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx_main, int n) {
|
|
|
|
n = std::min(n, llama_sampling_n_prev(smpl));
|
2023-10-20 20:07:23 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
if (n <= 0) {
|
|
|
|
return "";
|
|
|
|
}
|
2023-10-20 20:07:23 +02:00
|
|
|
|
|
|
|
std::string result;
|
2024-08-05 09:08:25 +02:00
|
|
|
result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab
|
2023-10-20 20:07:23 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
for (int i = n - 1; i >= 0; i--) {
|
|
|
|
const llama_token id = llama_sampling_prev(smpl, i);
|
2023-10-20 20:07:23 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");
|
2023-10-20 20:07:23 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
result += llama_token_to_piece(ctx_main, id);
|
|
|
|
}
|
2023-10-20 20:07:23 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
return result;
|
2023-10-20 20:07:23 +02:00
|
|
|
}
|
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
char llama_sampling_type_to_chr(llama_sampler_type sampler) {
|
|
|
|
switch (sampler) {
|
|
|
|
case LLAMA_SAMPLER_TYPE_TOP_K: return 'k';
|
|
|
|
case LLAMA_SAMPLER_TYPE_TFS_Z: return 'f';
|
|
|
|
case LLAMA_SAMPLER_TYPE_TYPICAL_P: return 'y';
|
|
|
|
case LLAMA_SAMPLER_TYPE_TOP_P: return 'p';
|
|
|
|
case LLAMA_SAMPLER_TYPE_MIN_P: return 'm';
|
|
|
|
case LLAMA_SAMPLER_TYPE_TEMPERATURE: return 't';
|
|
|
|
default : return '?';
|
2023-12-06 09:41:03 +01:00
|
|
|
}
|
2023-12-05 11:05:51 +01:00
|
|
|
}
|
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
std::string llama_sampling_type_to_str(llama_sampler_type sampler) {
|
|
|
|
switch (sampler) {
|
|
|
|
case LLAMA_SAMPLER_TYPE_TOP_K: return "top_k";
|
|
|
|
case LLAMA_SAMPLER_TYPE_TFS_Z: return "tfs_z";
|
|
|
|
case LLAMA_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
|
|
|
|
case LLAMA_SAMPLER_TYPE_TOP_P: return "top_p";
|
|
|
|
case LLAMA_SAMPLER_TYPE_MIN_P: return "min_p";
|
|
|
|
case LLAMA_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
2024-05-22 19:04:20 +02:00
|
|
|
default : return "";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
|
|
|
std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
|
2024-08-05 09:08:25 +02:00
|
|
|
{ "top_k", LLAMA_SAMPLER_TYPE_TOP_K },
|
|
|
|
{ "top_p", LLAMA_SAMPLER_TYPE_TOP_P },
|
|
|
|
{ "typ_p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
|
|
|
|
{ "min_p", LLAMA_SAMPLER_TYPE_MIN_P },
|
|
|
|
{ "tfs_z", LLAMA_SAMPLER_TYPE_TFS_Z },
|
|
|
|
{ "temperature", LLAMA_SAMPLER_TYPE_TEMPERATURE },
|
2024-05-22 19:04:20 +02:00
|
|
|
};
|
|
|
|
|
|
|
|
// since samplers names are written multiple ways
|
|
|
|
// make it ready for both system names and input names
|
|
|
|
std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
|
2024-08-05 09:08:25 +02:00
|
|
|
{ "top-k", LLAMA_SAMPLER_TYPE_TOP_K },
|
|
|
|
{ "top-p", LLAMA_SAMPLER_TYPE_TOP_P },
|
|
|
|
{ "nucleus", LLAMA_SAMPLER_TYPE_TOP_P },
|
|
|
|
{ "typical-p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
|
|
|
|
{ "typical", LLAMA_SAMPLER_TYPE_TYPICAL_P },
|
|
|
|
{ "typ-p", LLAMA_SAMPLER_TYPE_TYPICAL_P },
|
|
|
|
{ "typ", LLAMA_SAMPLER_TYPE_TYPICAL_P },
|
|
|
|
{ "min-p", LLAMA_SAMPLER_TYPE_MIN_P },
|
|
|
|
{ "tfs-z", LLAMA_SAMPLER_TYPE_TFS_Z },
|
|
|
|
{ "tfs", LLAMA_SAMPLER_TYPE_TFS_Z },
|
|
|
|
{ "temp", LLAMA_SAMPLER_TYPE_TEMPERATURE },
|
2024-05-22 19:04:20 +02:00
|
|
|
};
|
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
std::vector<llama_sampler_type> samplers;
|
|
|
|
samplers.reserve(names.size());
|
|
|
|
|
|
|
|
for (const auto & name : names) {
|
|
|
|
auto sampler = sampler_canonical_name_map.find(name);
|
|
|
|
if (sampler != sampler_canonical_name_map.end()) {
|
|
|
|
samplers.push_back(sampler->second);
|
|
|
|
} else {
|
|
|
|
if (allow_alt_names) {
|
|
|
|
sampler = sampler_alt_name_map.find(name);
|
|
|
|
if (sampler != sampler_alt_name_map.end()) {
|
|
|
|
samplers.push_back(sampler->second);
|
2024-05-22 19:04:20 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2024-08-05 09:08:25 +02:00
|
|
|
|
|
|
|
return samplers;
|
2024-05-22 19:04:20 +02:00
|
|
|
}
|
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & chars) {
|
2024-05-22 19:04:20 +02:00
|
|
|
std::unordered_map<char, llama_sampler_type> sampler_name_map {
|
2024-08-05 09:08:25 +02:00
|
|
|
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TOP_K), LLAMA_SAMPLER_TYPE_TOP_K },
|
|
|
|
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TFS_Z), LLAMA_SAMPLER_TYPE_TFS_Z },
|
|
|
|
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TYPICAL_P), LLAMA_SAMPLER_TYPE_TYPICAL_P },
|
|
|
|
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TOP_P), LLAMA_SAMPLER_TYPE_TOP_P },
|
|
|
|
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_MIN_P), LLAMA_SAMPLER_TYPE_MIN_P },
|
|
|
|
{ llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TEMPERATURE), LLAMA_SAMPLER_TYPE_TEMPERATURE }
|
2024-05-22 19:04:20 +02:00
|
|
|
};
|
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
std::vector<llama_sampler_type> samplers;
|
|
|
|
samplers.reserve(chars.size());
|
2023-10-18 15:21:57 +02:00
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
for (const auto & c : chars) {
|
|
|
|
const auto sampler = sampler_name_map.find(c);
|
|
|
|
if (sampler != sampler_name_map.end()) {
|
|
|
|
samplers.push_back(sampler->second);
|
2023-10-11 21:35:46 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-08-05 09:08:25 +02:00
|
|
|
return samplers;
|
2023-10-18 15:21:57 +02:00
|
|
|
}
|