mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-22 09:39:08 +01:00
last n tokens done
This commit is contained in:
parent
42591a0acd
commit
dd3cf5760a
@ -165,15 +165,14 @@ static bool server_verbose = false;
|
||||
#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||
#define LOG_INFO(MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||
|
||||
// helper class to manage prompt loading and truncation
|
||||
struct prompt_evaluator {
|
||||
llama_context * ctx;
|
||||
size_t n_ctx = 0;
|
||||
//std::string prompt;
|
||||
std::vector<llama_token> embd;
|
||||
std::vector<llama_token> last_n_tokens;
|
||||
size_t num_prompt_tokens = 0;
|
||||
//size_t num_tokens_predicted = 0;
|
||||
//size_t n_remain = 0;
|
||||
size_t repeat_last_n = 0;
|
||||
size_t n_past = 0;
|
||||
size_t n_keep = 0;
|
||||
bool truncated = false;
|
||||
@ -226,11 +225,12 @@ struct prompt_evaluator {
|
||||
prompt_tokens = new_tokens;
|
||||
}
|
||||
|
||||
last_n_tokens.resize(n_last);
|
||||
// fill the last n tokens from the input even if context is truncated
|
||||
repeat_last_n = n_last;
|
||||
last_n_tokens.clear();
|
||||
if (n_last > 0) {
|
||||
const size_t s = std::min(n_last, num_prompt_tokens);
|
||||
std::fill(last_n_tokens.begin(), last_n_tokens.end() - s, 0);
|
||||
std::copy(prompt_tokens.end() - s, prompt_tokens.end(), last_n_tokens.begin());
|
||||
last_n_tokens.insert(last_n_tokens.begin(),
|
||||
std::max(prompt_tokens.begin(), prompt_tokens.end() - n_last), prompt_tokens.end());
|
||||
}
|
||||
|
||||
// compare the evaluated prompt with the new prompt
|
||||
@ -295,8 +295,10 @@ struct prompt_evaluator {
|
||||
}
|
||||
|
||||
void append_token(llama_token id) {
|
||||
if (last_n_tokens.size() > 0) {
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
if (repeat_last_n > 0) {
|
||||
if (last_n_tokens.size() >= repeat_last_n) {
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
}
|
||||
last_n_tokens.push_back(id);
|
||||
}
|
||||
embd.push_back(id);
|
||||
@ -410,7 +412,6 @@ struct llama_server_context {
|
||||
const float top_p = params.top_p;
|
||||
const float tfs_z = params.tfs_z;
|
||||
const float typical_p = params.typical_p;
|
||||
//const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
|
||||
const float repeat_penalty = params.repeat_penalty;
|
||||
const float alpha_presence = params.presence_penalty;
|
||||
const float alpha_frequency = params.frequency_penalty;
|
||||
@ -444,7 +445,6 @@ struct llama_server_context {
|
||||
|
||||
// Apply penalties
|
||||
float nl_logit = logits[llama_token_nl()];
|
||||
//auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
|
||||
llama_sample_repetition_penalty(ctx, &candidates_p,
|
||||
evaluator.last_n_tokens.data(), evaluator.last_n_tokens.size(),
|
||||
repeat_penalty);
|
||||
|
Loading…
Reference in New Issue
Block a user