diff --git a/examples/server/server.cpp b/examples/server/server.cpp index cd21d8224..cb1b3a163 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -165,15 +165,14 @@ static bool server_verbose = false; #define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_INFO(MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) +// helper class to manage prompt loading and truncation struct prompt_evaluator { llama_context * ctx; size_t n_ctx = 0; - //std::string prompt; std::vector embd; std::vector last_n_tokens; size_t num_prompt_tokens = 0; - //size_t num_tokens_predicted = 0; - //size_t n_remain = 0; + size_t repeat_last_n = 0; size_t n_past = 0; size_t n_keep = 0; bool truncated = false; @@ -226,11 +225,12 @@ struct prompt_evaluator { prompt_tokens = new_tokens; } - last_n_tokens.resize(n_last); + // fill the last n tokens from the input even if context is truncated + repeat_last_n = n_last; + last_n_tokens.clear(); if (n_last > 0) { - const size_t s = std::min(n_last, num_prompt_tokens); - std::fill(last_n_tokens.begin(), last_n_tokens.end() - s, 0); - std::copy(prompt_tokens.end() - s, prompt_tokens.end(), last_n_tokens.begin()); + last_n_tokens.insert(last_n_tokens.begin(), + std::max(prompt_tokens.begin(), prompt_tokens.end() - n_last), prompt_tokens.end()); } // compare the evaluated prompt with the new prompt @@ -295,8 +295,10 @@ struct prompt_evaluator { } void append_token(llama_token id) { - if (last_n_tokens.size() > 0) { - last_n_tokens.erase(last_n_tokens.begin()); + if (repeat_last_n > 0) { + if (last_n_tokens.size() >= repeat_last_n) { + last_n_tokens.erase(last_n_tokens.begin()); + } last_n_tokens.push_back(id); } embd.push_back(id); @@ -410,7 +412,6 @@ struct llama_server_context { const float top_p = params.top_p; const float tfs_z = params.tfs_z; const float typical_p = params.typical_p; - //const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n; const float repeat_penalty = params.repeat_penalty; const float alpha_presence = params.presence_penalty; const float alpha_frequency = params.frequency_penalty; @@ -444,7 +445,6 @@ struct llama_server_context { // Apply penalties float nl_logit = logits[llama_token_nl()]; - //auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx); llama_sample_repetition_penalty(ctx, &candidates_p, evaluator.last_n_tokens.data(), evaluator.last_n_tokens.size(), repeat_penalty);