mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-24 02:19:18 +01:00
sampling : refactor + optimize penalties sampler (#10803)
* sampling : refactor + optimize penalties sampler ggml-ci * common : apply ignore_eos as logit bias ggml-ci * batched : remove penalties sampler * params : allow penalty_last_n == -1 to be equal to context size ggml-ci * common : by default, move the penalties at the end of the sampling chain ggml-ci * common : ignore all EOG tokens Co-authored-by: Diego Devesa <slarengh@gmail.com> * common : move back the penalties at the front of the sampling chain ggml-ci * readme : restore hint about --ignore-eos flag [no ci] * llama : minor ggml-ci * webui : update --------- Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
parent
4ddd199f6f
commit
644fd71b44
@ -855,13 +855,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
params.sampling.ignore_eos = true;
|
params.sampling.ignore_eos = true;
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
|
||||||
{"--penalize-nl"},
|
|
||||||
string_format("penalize newline tokens (default: %s)", params.sampling.penalize_nl ? "true" : "false"),
|
|
||||||
[](common_params & params) {
|
|
||||||
params.sampling.penalize_nl = true;
|
|
||||||
}
|
|
||||||
).set_sparam());
|
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--temp"}, "N",
|
{"--temp"}, "N",
|
||||||
string_format("temperature (default: %.1f)", (double)params.sampling.temp),
|
string_format("temperature (default: %.1f)", (double)params.sampling.temp),
|
||||||
@ -916,6 +909,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
{"--repeat-last-n"}, "N",
|
{"--repeat-last-n"}, "N",
|
||||||
string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sampling.penalty_last_n),
|
string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sampling.penalty_last_n),
|
||||||
[](common_params & params, int value) {
|
[](common_params & params, int value) {
|
||||||
|
if (value < -1) {
|
||||||
|
throw std::runtime_error(string_format("error: invalid repeat-last-n = %d\n", value));
|
||||||
|
}
|
||||||
params.sampling.penalty_last_n = value;
|
params.sampling.penalty_last_n = value;
|
||||||
params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
|
params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
|
||||||
}
|
}
|
||||||
@ -970,6 +966,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
{"--dry-penalty-last-n"}, "N",
|
{"--dry-penalty-last-n"}, "N",
|
||||||
string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sampling.dry_penalty_last_n),
|
string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sampling.dry_penalty_last_n),
|
||||||
[](common_params & params, int value) {
|
[](common_params & params, int value) {
|
||||||
|
if (value < -1) {
|
||||||
|
throw std::runtime_error(string_format("error: invalid dry-penalty-last-n = %d\n", value));
|
||||||
|
}
|
||||||
params.sampling.dry_penalty_last_n = value;
|
params.sampling.dry_penalty_last_n = value;
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
|
@ -940,6 +940,25 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
params.sampling.ignore_eos = false;
|
params.sampling.ignore_eos = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (params.sampling.ignore_eos) {
|
||||||
|
for (llama_token i = 0; i < llama_n_vocab(model); i++) {
|
||||||
|
if (llama_token_is_eog(model, i)) {
|
||||||
|
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
||||||
|
params.sampling.logit_bias.push_back({i, -INFINITY});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.sampling.penalty_last_n == -1) {
|
||||||
|
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||||
|
params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.sampling.dry_penalty_last_n == -1) {
|
||||||
|
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||||
|
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
|
||||||
|
}
|
||||||
|
|
||||||
if (params.warmup) {
|
if (params.warmup) {
|
||||||
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
||||||
|
|
||||||
|
@ -95,6 +95,7 @@ enum common_sampler_type {
|
|||||||
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
|
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
|
||||||
COMMON_SAMPLER_TYPE_XTC = 8,
|
COMMON_SAMPLER_TYPE_XTC = 8,
|
||||||
COMMON_SAMPLER_TYPE_INFILL = 9,
|
COMMON_SAMPLER_TYPE_INFILL = 9,
|
||||||
|
COMMON_SAMPLER_TYPE_PENALTIES = 10,
|
||||||
};
|
};
|
||||||
|
|
||||||
// dimensionality reduction methods, used by cvector-generator
|
// dimensionality reduction methods, used by cvector-generator
|
||||||
@ -130,7 +131,6 @@ struct common_params_sampling {
|
|||||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
float mirostat_tau = 5.00f; // target entropy
|
float mirostat_tau = 5.00f; // target entropy
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
|
||||||
bool ignore_eos = false;
|
bool ignore_eos = false;
|
||||||
bool no_perf = false; // disable performance metrics
|
bool no_perf = false; // disable performance metrics
|
||||||
bool timing_per_token = false;
|
bool timing_per_token = false;
|
||||||
@ -139,6 +139,7 @@ struct common_params_sampling {
|
|||||||
|
|
||||||
|
|
||||||
std::vector<enum common_sampler_type> samplers = {
|
std::vector<enum common_sampler_type> samplers = {
|
||||||
|
COMMON_SAMPLER_TYPE_PENALTIES,
|
||||||
COMMON_SAMPLER_TYPE_DRY,
|
COMMON_SAMPLER_TYPE_DRY,
|
||||||
COMMON_SAMPLER_TYPE_TOP_K,
|
COMMON_SAMPLER_TYPE_TOP_K,
|
||||||
COMMON_SAMPLER_TYPE_TYPICAL_P,
|
COMMON_SAMPLER_TYPE_TYPICAL_P,
|
||||||
@ -194,9 +195,11 @@ struct common_params {
|
|||||||
|
|
||||||
// offload params
|
// offload params
|
||||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||||
|
|
||||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||||
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
||||||
|
|
||||||
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
||||||
|
|
||||||
struct cpu_params cpuparams;
|
struct cpu_params cpuparams;
|
||||||
|
@ -161,26 +161,14 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
params.logit_bias.size(),
|
params.logit_bias.size(),
|
||||||
params.logit_bias.data()));
|
params.logit_bias.data()));
|
||||||
|
|
||||||
llama_sampler_chain_add(result->chain,
|
|
||||||
llama_sampler_init_penalties(
|
|
||||||
llama_n_vocab (model),
|
|
||||||
llama_token_eos(model),
|
|
||||||
llama_token_nl (model),
|
|
||||||
params.penalty_last_n,
|
|
||||||
params.penalty_repeat,
|
|
||||||
params.penalty_freq,
|
|
||||||
params.penalty_present,
|
|
||||||
params.penalize_nl,
|
|
||||||
params.ignore_eos));
|
|
||||||
|
|
||||||
if (params.mirostat == 0) {
|
if (params.mirostat == 0) {
|
||||||
for (const auto & cnstr : params.samplers) {
|
for (const auto & cnstr : params.samplers) {
|
||||||
switch (cnstr) {
|
switch (cnstr) {
|
||||||
case COMMON_SAMPLER_TYPE_DRY:
|
case COMMON_SAMPLER_TYPE_DRY:
|
||||||
{
|
{
|
||||||
std::vector<const char*> c_breakers;
|
std::vector<const char *> c_breakers;
|
||||||
c_breakers.reserve(params.dry_sequence_breakers.size());
|
c_breakers.reserve(params.dry_sequence_breakers.size());
|
||||||
for (const auto& str : params.dry_sequence_breakers) {
|
for (const auto & str : params.dry_sequence_breakers) {
|
||||||
c_breakers.push_back(str.c_str());
|
c_breakers.push_back(str.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -208,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
case COMMON_SAMPLER_TYPE_INFILL:
|
case COMMON_SAMPLER_TYPE_INFILL:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
|
||||||
break;
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "unknown sampler type");
|
GGML_ASSERT(false && "unknown sampler type");
|
||||||
}
|
}
|
||||||
@ -415,6 +406,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
|||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
||||||
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
||||||
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
||||||
|
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
|
||||||
default : return '?';
|
default : return '?';
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -429,6 +421,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
|||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
||||||
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
||||||
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
||||||
|
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
|
||||||
default : return "";
|
default : return "";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -443,6 +436,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|||||||
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||||
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
||||||
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
||||||
|
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
|
||||||
};
|
};
|
||||||
|
|
||||||
// since samplers names are written multiple ways
|
// since samplers names are written multiple ways
|
||||||
@ -489,6 +483,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
|||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
||||||
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
||||||
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<common_sampler_type> samplers;
|
std::vector<common_sampler_type> samplers;
|
||||||
|
@ -65,6 +65,7 @@ int main(int argc, char ** argv) {
|
|||||||
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||||
|
|
||||||
auto sparams = llama_sampler_chain_default_params();
|
auto sparams = llama_sampler_chain_default_params();
|
||||||
|
sparams.no_perf = false;
|
||||||
|
|
||||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||||
|
|
||||||
|
@ -177,16 +177,11 @@ Example usage: `--temp 0`
|
|||||||
|
|
||||||
- `--repeat-penalty N`: Control the repetition of token sequences in the generated text default: 1.0, 1.0 = disabled).
|
- `--repeat-penalty N`: Control the repetition of token sequences in the generated text default: 1.0, 1.0 = disabled).
|
||||||
- `--repeat-last-n N`: Last n tokens to consider for penalizing repetition (default: 64, 0 = disabled, -1 = ctx-size).
|
- `--repeat-last-n N`: Last n tokens to consider for penalizing repetition (default: 64, 0 = disabled, -1 = ctx-size).
|
||||||
- `--no-penalize-nl`: Disable penalization for newline tokens when applying the repeat penalty.
|
|
||||||
|
|
||||||
The `repeat-penalty` option helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. The default value is 1.
|
The `repeat-penalty` option helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. The default value is 1.
|
||||||
|
|
||||||
The `repeat-last-n` option controls the number of tokens in the history to consider for penalizing repetition. A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only consider recent tokens. A value of 0 disables the penalty, and a value of -1 sets the number of tokens considered equal to the context size (`ctx-size`).
|
The `repeat-last-n` option controls the number of tokens in the history to consider for penalizing repetition. A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only consider recent tokens. A value of 0 disables the penalty, and a value of -1 sets the number of tokens considered equal to the context size (`ctx-size`).
|
||||||
|
|
||||||
Use the `--no-penalize-nl` option to disable newline penalization when applying the repeat penalty. This option is particularly useful for generating chat conversations, dialogues, code, poetry, or any text where newline tokens play a significant role in structure and formatting. Disabling newline penalization helps maintain the natural flow and intended formatting in these specific use cases.
|
|
||||||
|
|
||||||
Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --no-penalize-nl`
|
|
||||||
|
|
||||||
### DRY Repetition Penalty
|
### DRY Repetition Penalty
|
||||||
|
|
||||||
DRY (Don't Repeat Yourself) sampling is an effective technique for reducing repetition in generated text even across long contexts by penalizing tokens based on their recent usage patterns (original [PR link](https://github.com/oobabooga/text-generation-webui/pull/5677)).
|
DRY (Don't Repeat Yourself) sampling is an effective technique for reducing repetition in generated text even across long contexts by penalizing tokens based on their recent usage patterns (original [PR link](https://github.com/oobabooga/text-generation-webui/pull/5677)).
|
||||||
|
@ -104,7 +104,6 @@ The project is under active development, and we are [looking for feedback and co
|
|||||||
| `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) |
|
| `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) |
|
||||||
| `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: dkypmxt) |
|
| `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: dkypmxt) |
|
||||||
| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
|
| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
|
||||||
| `--penalize-nl` | penalize newline tokens (default: false) |
|
|
||||||
| `--temp N` | temperature (default: 0.8) |
|
| `--temp N` | temperature (default: 0.8) |
|
||||||
| `--top-k N` | top-k sampling (default: 40, 0 = disabled) |
|
| `--top-k N` | top-k sampling (default: 40, 0 = disabled) |
|
||||||
| `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) |
|
| `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) |
|
||||||
@ -393,8 +392,6 @@ These words will not be included in the completion, so make sure to add them to
|
|||||||
|
|
||||||
`repeat_last_n`: Last n tokens to consider for penalizing repetition. Default: `64`, where `0` is disabled and `-1` is ctx-size.
|
`repeat_last_n`: Last n tokens to consider for penalizing repetition. Default: `64`, where `0` is disabled and `-1` is ctx-size.
|
||||||
|
|
||||||
`penalize_nl`: Penalize newline tokens when applying the repeat penalty. Default: `true`
|
|
||||||
|
|
||||||
`presence_penalty`: Repeat alpha presence penalty. Default: `0.0`, which is disabled.
|
`presence_penalty`: Repeat alpha presence penalty. Default: `0.0`, which is disabled.
|
||||||
|
|
||||||
`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
|
`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
|
||||||
@ -655,7 +652,6 @@ This endpoint is public (no API key check). By default, it is read-only. To make
|
|||||||
"mirostat": 0,
|
"mirostat": 0,
|
||||||
"mirostat_tau": 5.0,
|
"mirostat_tau": 5.0,
|
||||||
"mirostat_eta": 0.10000000149011612,
|
"mirostat_eta": 0.10000000149011612,
|
||||||
"penalize_nl": false,
|
|
||||||
"stop": [],
|
"stop": [],
|
||||||
"max_tokens": -1,
|
"max_tokens": -1,
|
||||||
"n_keep": 0,
|
"n_keep": 0,
|
||||||
@ -845,7 +841,6 @@ Example:
|
|||||||
"mirostat": 0,
|
"mirostat": 0,
|
||||||
"mirostat_tau": 5.0,
|
"mirostat_tau": 5.0,
|
||||||
"mirostat_eta": 0.10000000149011612,
|
"mirostat_eta": 0.10000000149011612,
|
||||||
"penalize_nl": false,
|
|
||||||
"stop": [],
|
"stop": [],
|
||||||
"max_tokens": -1,
|
"max_tokens": -1,
|
||||||
"n_keep": 0,
|
"n_keep": 0,
|
||||||
|
Binary file not shown.
@ -39,7 +39,6 @@
|
|||||||
temperature: 0.8, // adapt all following parameters to optimized min-p requierements. If for non-english, set to 0.6 or lower
|
temperature: 0.8, // adapt all following parameters to optimized min-p requierements. If for non-english, set to 0.6 or lower
|
||||||
repeat_last_n: 0, // 0 = disable penalty, -1 = context size
|
repeat_last_n: 0, // 0 = disable penalty, -1 = context size
|
||||||
repeat_penalty: 1.0, // 1.0 = disabled
|
repeat_penalty: 1.0, // 1.0 = disabled
|
||||||
penalize_nl: false, // true only useful for infinite completion
|
|
||||||
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
|
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
|
||||||
dry_base: 1.75, // 0.0 = disabled
|
dry_base: 1.75, // 0.0 = disabled
|
||||||
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
|
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
|
||||||
|
@ -303,7 +303,6 @@
|
|||||||
temperature: 0.7,
|
temperature: 0.7,
|
||||||
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
|
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
|
||||||
repeat_penalty: 1.18, // 1.0 = disabled
|
repeat_penalty: 1.18, // 1.0 = disabled
|
||||||
penalize_nl: false,
|
|
||||||
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
|
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
|
||||||
dry_base: 1.75, // 0.0 = disabled
|
dry_base: 1.75, // 0.0 = disabled
|
||||||
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
|
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
|
||||||
@ -1006,7 +1005,6 @@
|
|||||||
${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
|
${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
|
||||||
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
|
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
|
||||||
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
|
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
|
||||||
${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
|
|
||||||
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
|
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
|
||||||
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
|
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
|
||||||
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}
|
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}
|
||||||
|
@ -135,7 +135,6 @@ struct slot_params {
|
|||||||
{"mirostat", sampling.mirostat},
|
{"mirostat", sampling.mirostat},
|
||||||
{"mirostat_tau", sampling.mirostat_tau},
|
{"mirostat_tau", sampling.mirostat_tau},
|
||||||
{"mirostat_eta", sampling.mirostat_eta},
|
{"mirostat_eta", sampling.mirostat_eta},
|
||||||
{"penalize_nl", sampling.penalize_nl},
|
|
||||||
{"stop", antiprompt},
|
{"stop", antiprompt},
|
||||||
{"max_tokens", n_predict}, // User configured n_predict
|
{"max_tokens", n_predict}, // User configured n_predict
|
||||||
{"n_keep", n_keep},
|
{"n_keep", n_keep},
|
||||||
@ -184,6 +183,7 @@ struct server_task {
|
|||||||
|
|
||||||
static slot_params params_from_json_cmpl(
|
static slot_params params_from_json_cmpl(
|
||||||
const llama_model * model,
|
const llama_model * model,
|
||||||
|
const llama_context * ctx,
|
||||||
const common_params & params_base,
|
const common_params & params_base,
|
||||||
const json & data) {
|
const json & data) {
|
||||||
slot_params params;
|
slot_params params;
|
||||||
@ -226,7 +226,6 @@ struct server_task {
|
|||||||
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
|
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
|
||||||
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
|
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
|
||||||
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
|
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
|
||||||
params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl);
|
|
||||||
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
||||||
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
|
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
|
||||||
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
||||||
@ -239,6 +238,25 @@ struct server_task {
|
|||||||
params.speculative.n_min = std::max(params.speculative.n_min, 2);
|
params.speculative.n_min = std::max(params.speculative.n_min, 2);
|
||||||
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
||||||
|
|
||||||
|
// TODO: add more sanity checks for the input parameters
|
||||||
|
|
||||||
|
if (params.sampling.penalty_last_n < -1) {
|
||||||
|
throw std::runtime_error("Error: repeat_last_n must be >= -1");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.sampling.dry_penalty_last_n < -1) {
|
||||||
|
throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.sampling.penalty_last_n == -1) {
|
||||||
|
// note: should be the slot's context and not the full context, but it's ok
|
||||||
|
params.sampling.penalty_last_n = llama_n_ctx(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.sampling.dry_penalty_last_n == -1) {
|
||||||
|
params.sampling.dry_penalty_last_n = llama_n_ctx(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
if (params.sampling.dry_base < 1.0f) {
|
if (params.sampling.dry_base < 1.0f) {
|
||||||
params.sampling.dry_base = defaults.sampling.dry_base;
|
params.sampling.dry_base = defaults.sampling.dry_base;
|
||||||
}
|
}
|
||||||
@ -1469,7 +1487,7 @@ struct server_context {
|
|||||||
n_ctx = llama_n_ctx(ctx);
|
n_ctx = llama_n_ctx(ctx);
|
||||||
|
|
||||||
add_bos_token = llama_add_bos_token(model);
|
add_bos_token = llama_add_bos_token(model);
|
||||||
has_eos_token = !llama_add_eos_token(model);
|
has_eos_token = llama_token_eos(model) != LLAMA_TOKEN_NULL;
|
||||||
|
|
||||||
if (!params_base.speculative.model.empty()) {
|
if (!params_base.speculative.model.empty()) {
|
||||||
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
|
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
|
||||||
@ -3381,7 +3399,7 @@ int main(int argc, char ** argv) {
|
|||||||
task.index = i;
|
task.index = i;
|
||||||
|
|
||||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
||||||
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.params_base, data);
|
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
|
||||||
task.id_selected_slot = json_value(data, "id_slot", -1);
|
task.id_selected_slot = json_value(data, "id_slot", -1);
|
||||||
|
|
||||||
// OAI-compat
|
// OAI-compat
|
||||||
|
@ -222,7 +222,6 @@
|
|||||||
temperature: 0.7,
|
temperature: 0.7,
|
||||||
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
|
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
|
||||||
repeat_penalty: 1.18, // 1.0 = disabled
|
repeat_penalty: 1.18, // 1.0 = disabled
|
||||||
penalize_nl: false,
|
|
||||||
top_k: 40, // <= 0 to use vocab size
|
top_k: 40, // <= 0 to use vocab size
|
||||||
top_p: 0.95, // 1.0 = disabled
|
top_p: 0.95, // 1.0 = disabled
|
||||||
min_p: 0.05, // 0 = disabled
|
min_p: 0.05, // 0 = disabled
|
||||||
@ -779,7 +778,6 @@
|
|||||||
${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
|
${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
|
||||||
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
|
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
|
||||||
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
|
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
|
||||||
${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
|
|
||||||
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
|
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
|
||||||
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
|
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
|
||||||
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}
|
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}
|
||||||
|
@ -225,7 +225,6 @@
|
|||||||
temperature: 0.7,
|
temperature: 0.7,
|
||||||
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
|
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
|
||||||
repeat_penalty: 1.18, // 1.0 = disabled
|
repeat_penalty: 1.18, // 1.0 = disabled
|
||||||
penalize_nl: false,
|
|
||||||
top_k: 40, // <= 0 to use vocab size
|
top_k: 40, // <= 0 to use vocab size
|
||||||
top_p: 0.95, // 1.0 = disabled
|
top_p: 0.95, // 1.0 = disabled
|
||||||
min_p: 0.05, // 0 = disabled
|
min_p: 0.05, // 0 = disabled
|
||||||
@ -782,7 +781,6 @@
|
|||||||
${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
|
${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
|
||||||
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
|
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
|
||||||
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
|
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
|
||||||
${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
|
|
||||||
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
|
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
|
||||||
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
|
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
|
||||||
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}
|
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}
|
||||||
|
@ -33,7 +33,7 @@ const CONFIG_DEFAULT = {
|
|||||||
systemMessage: 'You are a helpful assistant.',
|
systemMessage: 'You are a helpful assistant.',
|
||||||
showTokensPerSecond: false,
|
showTokensPerSecond: false,
|
||||||
// make sure these default values are in sync with `common.h`
|
// make sure these default values are in sync with `common.h`
|
||||||
samplers: 'dkypmxt',
|
samplers: 'edkypmxt',
|
||||||
temperature: 0.8,
|
temperature: 0.8,
|
||||||
dynatemp_range: 0.0,
|
dynatemp_range: 0.0,
|
||||||
dynatemp_exponent: 1.0,
|
dynatemp_exponent: 1.0,
|
||||||
|
@ -1139,16 +1139,12 @@ extern "C" {
|
|||||||
const char * grammar_str,
|
const char * grammar_str,
|
||||||
const char * grammar_root);
|
const char * grammar_root);
|
||||||
|
|
||||||
|
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
|
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
|
||||||
int32_t n_vocab, // llama_n_vocab()
|
|
||||||
llama_token special_eos_id, // llama_token_eos()
|
|
||||||
llama_token linefeed_id, // llama_token_nl()
|
|
||||||
int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||||
float penalty_repeat, // 1.0 = disabled
|
float penalty_repeat, // 1.0 = disabled
|
||||||
float penalty_freq, // 0.0 = disabled
|
float penalty_freq, // 0.0 = disabled
|
||||||
float penalty_present, // 0.0 = disabled
|
float penalty_present); // 0.0 = disabled
|
||||||
bool penalize_nl, // consider newlines as a repeatable token
|
|
||||||
bool ignore_eos); // ignore the end-of-sequence token
|
|
||||||
|
|
||||||
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
|
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_dry(
|
LLAMA_API struct llama_sampler * llama_sampler_init_dry(
|
||||||
|
@ -1396,19 +1396,15 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab
|
|||||||
// penalties
|
// penalties
|
||||||
|
|
||||||
struct llama_sampler_penalties {
|
struct llama_sampler_penalties {
|
||||||
const int32_t n_vocab;
|
|
||||||
const llama_token special_eos_id;
|
|
||||||
const llama_token linefeed_id;
|
|
||||||
|
|
||||||
const int32_t penalty_last_n;
|
const int32_t penalty_last_n;
|
||||||
const float penalty_repeat;
|
const float penalty_repeat;
|
||||||
const float penalty_freq;
|
const float penalty_freq;
|
||||||
const float penalty_present;
|
const float penalty_present;
|
||||||
|
|
||||||
const bool penalize_nl;
|
|
||||||
const bool ignore_eos;
|
|
||||||
|
|
||||||
ring_buffer<llama_token> prev;
|
ring_buffer<llama_token> prev;
|
||||||
|
|
||||||
|
// a frequency map to count token occurrences
|
||||||
|
std::unordered_map<llama_token, int> token_count;
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
|
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
|
||||||
@ -1421,76 +1417,50 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx->token_count[token]++;
|
||||||
|
|
||||||
|
// if the ring buffer is full, remove the oldest token
|
||||||
|
if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
|
||||||
|
const auto old = ctx->prev.front();
|
||||||
|
|
||||||
|
ctx->token_count[old]--;
|
||||||
|
if (ctx->token_count[old] == 0) {
|
||||||
|
ctx->token_count.erase(old);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ctx->prev.push_back(token);
|
ctx->prev.push_back(token);
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
// sanity check
|
||||||
|
std::unordered_map<llama_token, int> tmp;
|
||||||
|
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
|
||||||
|
tmp[ctx->prev.rat(i)]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(ctx->token_count == tmp);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
||||||
|
|
||||||
if (ctx->ignore_eos) {
|
|
||||||
assert(ctx->special_eos_id >= 0);
|
|
||||||
|
|
||||||
// optimistically check if the candidates are not yet sorted/shuffled/truncated
|
|
||||||
if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
|
|
||||||
cur_p->data[ctx->special_eos_id].logit = -INFINITY;
|
|
||||||
} else {
|
|
||||||
// else, search for the special EOS token
|
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
||||||
if (cur_p->data[i].id == ctx->special_eos_id) {
|
|
||||||
cur_p->data[i].logit = -INFINITY;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((ctx->penalty_last_n == 0) ||
|
if ((ctx->penalty_last_n == 0) ||
|
||||||
(ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
|
(ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool nl_found = false;
|
|
||||||
size_t nl_idx = 0;
|
|
||||||
float nl_logit = -INFINITY;
|
|
||||||
if (!ctx->penalize_nl) {
|
|
||||||
assert(ctx->linefeed_id >= 0);
|
|
||||||
|
|
||||||
// optimistically check if the candidates are not yet sorted/shuffled/truncated
|
|
||||||
if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
|
|
||||||
nl_found = true;
|
|
||||||
nl_idx = ctx->linefeed_id;
|
|
||||||
nl_logit = cur_p->data[ctx->linefeed_id].logit;
|
|
||||||
} else {
|
|
||||||
// else, search for the linefeed token
|
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
|
||||||
if (cur_p->data[i].id == ctx->linefeed_id) {
|
|
||||||
nl_found = true;
|
|
||||||
nl_idx = i;
|
|
||||||
nl_logit = cur_p->data[i].logit;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a frequency map to count occurrences of each token in last_tokens
|
|
||||||
// TODO: optimize this by maintaining the token count in the sampler context
|
|
||||||
using llama_token_cnt = std::unordered_map<llama_token, int>;
|
|
||||||
llama_token_cnt token_count;
|
|
||||||
|
|
||||||
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
|
|
||||||
token_count[ctx->prev.rat(i)]++;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply frequency and presence penalties to the cur_p
|
// Apply frequency and presence penalties to the cur_p
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
const auto token_iter = token_count.find(cur_p->data[i].id);
|
const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
|
||||||
if (token_iter == token_count.end()) {
|
if (token_iter == ctx->token_count.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int count = token_iter->second;
|
const int count = token_iter->second;
|
||||||
|
|
||||||
|
assert(count > 0 && count <= ctx->penalty_last_n);
|
||||||
|
|
||||||
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
||||||
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
||||||
if (cur_p->data[i].logit <= 0) {
|
if (cur_p->data[i].logit <= 0) {
|
||||||
@ -1503,30 +1473,21 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
|
|||||||
}
|
}
|
||||||
|
|
||||||
cur_p->sorted = false;
|
cur_p->sorted = false;
|
||||||
|
|
||||||
if (!ctx->penalize_nl && nl_found) {
|
|
||||||
// restore the logit of the newline token if it was penalized
|
|
||||||
cur_p->data[nl_idx].logit = nl_logit;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
|
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
|
||||||
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
|
||||||
ctx->prev.clear();
|
ctx->prev.clear();
|
||||||
|
ctx->token_count.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
|
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
|
||||||
const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
|
const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
|
||||||
auto * result = llama_sampler_init_penalties(
|
auto * result = llama_sampler_init_penalties(
|
||||||
ctx->n_vocab,
|
|
||||||
ctx->special_eos_id,
|
|
||||||
ctx->linefeed_id,
|
|
||||||
ctx->penalty_last_n,
|
ctx->penalty_last_n,
|
||||||
ctx->penalty_repeat,
|
ctx->penalty_repeat,
|
||||||
ctx->penalty_freq,
|
ctx->penalty_freq,
|
||||||
ctx->penalty_present,
|
ctx->penalty_present);
|
||||||
ctx->penalize_nl,
|
|
||||||
ctx->ignore_eos);
|
|
||||||
|
|
||||||
// copy the state
|
// copy the state
|
||||||
{
|
{
|
||||||
@ -1552,38 +1513,21 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_penalties(
|
struct llama_sampler * llama_sampler_init_penalties(
|
||||||
int32_t n_vocab,
|
|
||||||
llama_token special_eos_id,
|
|
||||||
llama_token linefeed_id,
|
|
||||||
int32_t penalty_last_n,
|
int32_t penalty_last_n,
|
||||||
float penalty_repeat,
|
float penalty_repeat,
|
||||||
float penalty_freq,
|
float penalty_freq,
|
||||||
float penalty_present,
|
float penalty_present) {
|
||||||
bool penalize_nl,
|
|
||||||
bool ignore_eos) {
|
|
||||||
if (linefeed_id == LLAMA_TOKEN_NULL) {
|
|
||||||
penalize_nl = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (special_eos_id == LLAMA_TOKEN_NULL) {
|
|
||||||
ignore_eos = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
penalty_last_n = std::max(penalty_last_n, 0);
|
penalty_last_n = std::max(penalty_last_n, 0);
|
||||||
|
|
||||||
return new llama_sampler {
|
return new llama_sampler {
|
||||||
/* .iface = */ &llama_sampler_penalties_i,
|
/* .iface = */ &llama_sampler_penalties_i,
|
||||||
/* .ctx = */ new llama_sampler_penalties {
|
/* .ctx = */ new llama_sampler_penalties {
|
||||||
/* .n_vocab = */ n_vocab,
|
|
||||||
/* .special_eos_id = */ special_eos_id,
|
|
||||||
/* .linefeed_id = */ linefeed_id,
|
|
||||||
/* .penalty_last_n = */ penalty_last_n,
|
/* .penalty_last_n = */ penalty_last_n,
|
||||||
/* .penalty_repeat = */ penalty_repeat,
|
/* .penalty_repeat = */ penalty_repeat,
|
||||||
/* .penalty_freq = */ penalty_freq,
|
/* .penalty_freq = */ penalty_freq,
|
||||||
/* .penalty_present = */ penalty_present,
|
/* .penalty_present = */ penalty_present,
|
||||||
/* .penalize_nl = */ penalize_nl,
|
|
||||||
/* .ignore_eos = */ ignore_eos,
|
|
||||||
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
||||||
|
/* .token_count = */ {},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -1611,7 +1555,8 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
|
|||||||
if (word.find(str) != std::string::npos) {
|
if (word.find(str) != std::string::npos) {
|
||||||
token_sequences.emplace(token_id, std::vector<llama_token>());
|
token_sequences.emplace(token_id, std::vector<llama_token>());
|
||||||
} else {
|
} else {
|
||||||
size_t word_len = word.size(), str_len = str.size();
|
size_t word_len = word.size();
|
||||||
|
size_t str_len = str.size();
|
||||||
size_t pos = -1;
|
size_t pos = -1;
|
||||||
while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
|
while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
|
||||||
bool match = true;
|
bool match = true;
|
||||||
|
@ -145,7 +145,7 @@ static void test_penalties(
|
|||||||
sampler_tester tester(probs, probs_expected);
|
sampler_tester tester(probs, probs_expected);
|
||||||
|
|
||||||
const size_t n_vocab = probs.size();
|
const size_t n_vocab = probs.size();
|
||||||
auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
|
auto * sampler = llama_sampler_init_penalties(last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
|
||||||
|
|
||||||
for (size_t i = 0; i < last_tokens.size(); i++) {
|
for (size_t i = 0; i < last_tokens.size(); i++) {
|
||||||
llama_sampler_accept(sampler, last_tokens[i]);
|
llama_sampler_accept(sampler, last_tokens[i]);
|
||||||
|
Loading…
Reference in New Issue
Block a user