mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-26 03:12:23 +01:00
server : allow to specify custom prompt for penalty calculation (#3727)
This commit is contained in:
parent
b9ec82d262
commit
6123979952
@ -203,12 +203,14 @@ static llama_token llama_sampling_sample_impl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// apply penalties
|
// apply penalties
|
||||||
if (!prev.empty()) {
|
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
|
||||||
|
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
|
||||||
|
if (penalty_tokens_used_size) {
|
||||||
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
|
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
|
||||||
|
|
||||||
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
||||||
prev.data() + prev.size() - penalty_last_n,
|
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
||||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
|
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
|
||||||
|
|
||||||
if (!penalize_nl) {
|
if (!penalize_nl) {
|
||||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||||
|
@ -36,6 +36,9 @@ typedef struct llama_sampling_params {
|
|||||||
float cfg_scale = 1.f; // how strong is guidance
|
float cfg_scale = 1.f; // how strong is guidance
|
||||||
|
|
||||||
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
||||||
|
|
||||||
|
std::vector<llama_token> penalty_prompt_tokens;
|
||||||
|
bool use_penalty_prompt_tokens = false;
|
||||||
} llama_sampling_params;
|
} llama_sampling_params;
|
||||||
|
|
||||||
// general sampler context
|
// general sampler context
|
||||||
|
@ -148,6 +148,8 @@ node index.js
|
|||||||
|
|
||||||
`frequency_penalty`: Repeat alpha frequency penalty (default: 0.0, 0.0 = disabled);
|
`frequency_penalty`: Repeat alpha frequency penalty (default: 0.0, 0.0 = disabled);
|
||||||
|
|
||||||
|
`penalty_prompt`: This will replace the `prompt` for the purpose of the penalty evaluation. Can be either `null`, a string or an array of numbers representing tokens (default: `null` = use the original `prompt`).
|
||||||
|
|
||||||
`mirostat`: Enable Mirostat sampling, controlling perplexity during text generation (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0).
|
`mirostat`: Enable Mirostat sampling, controlling perplexity during text generation (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0).
|
||||||
|
|
||||||
`mirostat_tau`: Set the Mirostat target entropy, parameter tau (default: 5.0).
|
`mirostat_tau`: Set the Mirostat target entropy, parameter tau (default: 5.0).
|
||||||
|
@ -761,6 +761,42 @@ struct llama_server_context
|
|||||||
slot->prompt = "";
|
slot->prompt = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slot->sparams.penalty_prompt_tokens.clear();
|
||||||
|
slot->sparams.use_penalty_prompt_tokens = false;
|
||||||
|
const auto &penalty_prompt = data.find("penalty_prompt");
|
||||||
|
if (penalty_prompt != data.end())
|
||||||
|
{
|
||||||
|
if (penalty_prompt->is_string())
|
||||||
|
{
|
||||||
|
const auto penalty_prompt_string = penalty_prompt->get<std::string>();
|
||||||
|
auto penalty_tokens = llama_tokenize(model, penalty_prompt_string, false);
|
||||||
|
slot->sparams.penalty_prompt_tokens.swap(penalty_tokens);
|
||||||
|
if (slot->params.n_predict > 0)
|
||||||
|
{
|
||||||
|
slot->sparams.penalty_prompt_tokens.reserve(slot->sparams.penalty_prompt_tokens.size() + slot->params.n_predict);
|
||||||
|
}
|
||||||
|
slot->sparams.use_penalty_prompt_tokens = true;
|
||||||
|
}
|
||||||
|
else if (penalty_prompt->is_array())
|
||||||
|
{
|
||||||
|
const auto n_tokens = penalty_prompt->size();
|
||||||
|
slot->sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot->params.n_predict));
|
||||||
|
const int n_vocab = llama_n_vocab(model);
|
||||||
|
for (const auto &penalty_token : *penalty_prompt)
|
||||||
|
{
|
||||||
|
if (penalty_token.is_number_integer())
|
||||||
|
{
|
||||||
|
const auto tok = penalty_token.get<llama_token>();
|
||||||
|
if (tok >= 0 && tok < n_vocab)
|
||||||
|
{
|
||||||
|
slot->sparams.penalty_prompt_tokens.push_back(tok);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slot->sparams.use_penalty_prompt_tokens = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
slot->sparams.logit_bias.clear();
|
slot->sparams.logit_bias.clear();
|
||||||
|
|
||||||
if (json_value(data, "ignore_eos", false))
|
if (json_value(data, "ignore_eos", false))
|
||||||
@ -992,6 +1028,12 @@ struct llama_server_context
|
|||||||
slot.generated_text += token_str;
|
slot.generated_text += token_str;
|
||||||
slot.has_next_token = true;
|
slot.has_next_token = true;
|
||||||
|
|
||||||
|
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1)
|
||||||
|
{
|
||||||
|
// we can change penalty_prompt_tokens because it is always created from scratch each request
|
||||||
|
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
|
||||||
|
}
|
||||||
|
|
||||||
// check if there is incomplete UTF-8 character at the end
|
// check if there is incomplete UTF-8 character at the end
|
||||||
bool incomplete = false;
|
bool incomplete = false;
|
||||||
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i)
|
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i)
|
||||||
@ -1183,6 +1225,8 @@ struct llama_server_context
|
|||||||
{"repeat_penalty", slot.sparams.penalty_repeat},
|
{"repeat_penalty", slot.sparams.penalty_repeat},
|
||||||
{"presence_penalty", slot.sparams.penalty_present},
|
{"presence_penalty", slot.sparams.penalty_present},
|
||||||
{"frequency_penalty", slot.sparams.penalty_freq},
|
{"frequency_penalty", slot.sparams.penalty_freq},
|
||||||
|
{"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
|
||||||
|
{"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
|
||||||
{"mirostat", slot.sparams.mirostat},
|
{"mirostat", slot.sparams.mirostat},
|
||||||
{"mirostat_tau", slot.sparams.mirostat_tau},
|
{"mirostat_tau", slot.sparams.mirostat_tau},
|
||||||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||||
|
Loading…
Reference in New Issue
Block a user