mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 11:23:56 +01:00
last n tokens done
This commit is contained in:
parent
42591a0acd
commit
dd3cf5760a
@ -165,15 +165,14 @@ static bool server_verbose = false;
|
|||||||
#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
|
#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||||
#define LOG_INFO(MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
|
#define LOG_INFO(MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
|
||||||
|
|
||||||
|
// helper class to manage prompt loading and truncation
|
||||||
struct prompt_evaluator {
|
struct prompt_evaluator {
|
||||||
llama_context * ctx;
|
llama_context * ctx;
|
||||||
size_t n_ctx = 0;
|
size_t n_ctx = 0;
|
||||||
//std::string prompt;
|
|
||||||
std::vector<llama_token> embd;
|
std::vector<llama_token> embd;
|
||||||
std::vector<llama_token> last_n_tokens;
|
std::vector<llama_token> last_n_tokens;
|
||||||
size_t num_prompt_tokens = 0;
|
size_t num_prompt_tokens = 0;
|
||||||
//size_t num_tokens_predicted = 0;
|
size_t repeat_last_n = 0;
|
||||||
//size_t n_remain = 0;
|
|
||||||
size_t n_past = 0;
|
size_t n_past = 0;
|
||||||
size_t n_keep = 0;
|
size_t n_keep = 0;
|
||||||
bool truncated = false;
|
bool truncated = false;
|
||||||
@ -226,11 +225,12 @@ struct prompt_evaluator {
|
|||||||
prompt_tokens = new_tokens;
|
prompt_tokens = new_tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
last_n_tokens.resize(n_last);
|
// fill the last n tokens from the input even if context is truncated
|
||||||
|
repeat_last_n = n_last;
|
||||||
|
last_n_tokens.clear();
|
||||||
if (n_last > 0) {
|
if (n_last > 0) {
|
||||||
const size_t s = std::min(n_last, num_prompt_tokens);
|
last_n_tokens.insert(last_n_tokens.begin(),
|
||||||
std::fill(last_n_tokens.begin(), last_n_tokens.end() - s, 0);
|
std::max(prompt_tokens.begin(), prompt_tokens.end() - n_last), prompt_tokens.end());
|
||||||
std::copy(prompt_tokens.end() - s, prompt_tokens.end(), last_n_tokens.begin());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// compare the evaluated prompt with the new prompt
|
// compare the evaluated prompt with the new prompt
|
||||||
@ -295,8 +295,10 @@ struct prompt_evaluator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void append_token(llama_token id) {
|
void append_token(llama_token id) {
|
||||||
if (last_n_tokens.size() > 0) {
|
if (repeat_last_n > 0) {
|
||||||
last_n_tokens.erase(last_n_tokens.begin());
|
if (last_n_tokens.size() >= repeat_last_n) {
|
||||||
|
last_n_tokens.erase(last_n_tokens.begin());
|
||||||
|
}
|
||||||
last_n_tokens.push_back(id);
|
last_n_tokens.push_back(id);
|
||||||
}
|
}
|
||||||
embd.push_back(id);
|
embd.push_back(id);
|
||||||
@ -410,7 +412,6 @@ struct llama_server_context {
|
|||||||
const float top_p = params.top_p;
|
const float top_p = params.top_p;
|
||||||
const float tfs_z = params.tfs_z;
|
const float tfs_z = params.tfs_z;
|
||||||
const float typical_p = params.typical_p;
|
const float typical_p = params.typical_p;
|
||||||
//const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
|
|
||||||
const float repeat_penalty = params.repeat_penalty;
|
const float repeat_penalty = params.repeat_penalty;
|
||||||
const float alpha_presence = params.presence_penalty;
|
const float alpha_presence = params.presence_penalty;
|
||||||
const float alpha_frequency = params.frequency_penalty;
|
const float alpha_frequency = params.frequency_penalty;
|
||||||
@ -444,7 +445,6 @@ struct llama_server_context {
|
|||||||
|
|
||||||
// Apply penalties
|
// Apply penalties
|
||||||
float nl_logit = logits[llama_token_nl()];
|
float nl_logit = logits[llama_token_nl()];
|
||||||
//auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
|
|
||||||
llama_sample_repetition_penalty(ctx, &candidates_p,
|
llama_sample_repetition_penalty(ctx, &candidates_p,
|
||||||
evaluator.last_n_tokens.data(), evaluator.last_n_tokens.size(),
|
evaluator.last_n_tokens.data(), evaluator.last_n_tokens.size(),
|
||||||
repeat_penalty);
|
repeat_penalty);
|
||||||
|
Loading…
Reference in New Issue
Block a user