refactor evaluation logic

This commit is contained in:
Henri Vasserman 2023-07-19 23:45:40 +03:00
parent 4cae9f5673
commit 2cb8469e7f
No known key found for this signature in database
GPG Key ID: 2995FC0F58B1A986

View File

@ -21,6 +21,7 @@
using namespace httplib; using namespace httplib;
using json = nlohmann::json; using json = nlohmann::json;
using ordered_json = nlohmann::ordered_json;
struct server_params { struct server_params {
std::string hostname = "127.0.0.1"; std::string hostname = "127.0.0.1";
@ -82,9 +83,19 @@ static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
return ret; return ret;
} }
#define SERVER_LOG_PRETTY 0
static void server_log(const char * level, const char * function, int line, static void server_log(const char * level, const char * function, int line,
const char * message, const nlohmann::ordered_json & extra) { const char * message, const ordered_json & extra) {
nlohmann::ordered_json log {
#if SERVER_LOG_PRETTY == 1
fprintf(stdout, ANSI_COLOR_MAGENTA ANSI_BOLD "[%s] " ANSI_COLOR_RESET " %s@%d: %s\n", level, function, line, message);
for (auto & it : extra.items()) {
fprintf(stdout, "%s=" ANSI_COLOR_YELLOW ANSI_BOLD "%s " ANSI_COLOR_RESET, it.key().c_str(), it.value().dump().c_str());
}
fprintf(stdout, "\n\n");
#else
ordered_json log {
{ "timestamp", time(nullptr) }, { "timestamp", time(nullptr) },
{ "level", level }, { "level", level },
{ "function", function }, { "function", function },
@ -98,6 +109,8 @@ static void server_log(const char * level, const char * function, int line,
const std::string str = log.dump(-1, ' ', false, json::error_handler_t::replace); const std::string str = log.dump(-1, ' ', false, json::error_handler_t::replace);
fprintf(stdout, "%.*s\n", (int)str.size(), str.data()); fprintf(stdout, "%.*s\n", (int)str.size(), str.data());
#endif
fflush(stdout); fflush(stdout);
} }
@ -152,124 +165,72 @@ static bool server_verbose = false;
#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
#define LOG_INFO(MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_INFO(MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
struct llama_server_context { struct prompt_evaluator {
bool stream = false; llama_context * ctx;
bool has_next_token = false; size_t n_ctx = 0;
std::string generated_text; //std::string prompt;
std::vector<completion_token_output> generated_token_probs;
size_t num_prompt_tokens = 0;
size_t num_tokens_predicted = 0;
size_t n_past = 0;
size_t n_past_guidance = 0;
int n_keep_guidance = 0;
size_t n_remain = 0;
bool cfg_enabled = false;
std::vector<llama_token> embd; std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
std::vector<llama_token> last_n_tokens; std::vector<llama_token> last_n_tokens;
size_t num_prompt_tokens = 0;
llama_model * model = nullptr; //size_t num_tokens_predicted = 0;
llama_context * ctx = nullptr; //size_t n_remain = 0;
llama_context * ctx_guidance = nullptr; size_t n_past = 0;
gpt_params params; size_t n_keep = 0;
bool truncated = false; bool truncated = false;
bool stopped_eos = false;
bool stopped_word = false;
bool stopped_limit = false;
std::string stopping_word;
int32_t multibyte_pending = 0;
std::mutex mutex; void set_context(llama_context * ctx) {
this->ctx = ctx;
std::unique_lock<std::mutex> lock() { this->n_ctx = llama_n_ctx(ctx);
return std::unique_lock<std::mutex>(mutex);
} }
~llama_server_context() { ~prompt_evaluator() {
if (ctx) { if (ctx) {
llama_free(ctx); llama_free(ctx);
ctx = nullptr; ctx = nullptr;
} }
if (ctx_guidance) {
llama_free(ctx_guidance);
ctx_guidance = nullptr;
}
if (model) {
llama_free_model(model);
model = nullptr;
}
} }
void rewind() { void rewind() {
params.antiprompt.clear();
num_prompt_tokens = 0; num_prompt_tokens = 0;
num_tokens_predicted = 0; //num_tokens_predicted = 0;
generated_text = "";
generated_text.reserve(params.n_ctx);
generated_token_probs.clear();
truncated = false; truncated = false;
stopped_eos = false; //n_remain = 0;
stopped_word = false;
stopped_limit = false;
stopping_word = "";
multibyte_pending = 0;
n_remain = 0;
n_past = 0; n_past = 0;
cfg_enabled = false;
n_past_guidance = 0;
} }
bool loadModel(const gpt_params & params_) { void load_prompt(std::string &prompt, int keep, size_t n_last) {
params = params_; prompt.insert(0, 1, ' '); // always add a first space
std::tie(model, ctx) = llama_init_from_gpt_params(params); std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, prompt, true);
if (model == nullptr) {
LOG_ERROR("unable to load model", {{ "model", params_.model }});
return false;
}
struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
ctx_guidance = llama_new_context_with_model(model, lparams);
last_n_tokens.resize(params.n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
return true;
}
void loadPrompt() {
params.prompt.insert(0, 1, ' '); // always add a first space
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
num_prompt_tokens = prompt_tokens.size(); num_prompt_tokens = prompt_tokens.size();
if (params.n_keep < 0) { if (keep < 0) {
params.n_keep = (int)num_prompt_tokens; keep = (int)num_prompt_tokens;
} }
params.n_keep = std::min(params.n_ctx - 4, params.n_keep); n_keep = std::min(n_ctx - 4, (size_t)keep);
// if input prompt is too big, truncate like normal // if input prompt is too big, truncate like normal
if (num_prompt_tokens >= (size_t)params.n_ctx) { if (num_prompt_tokens >= n_ctx) {
const int n_left = (params.n_ctx - params.n_keep) / 2; const size_t n_left = (n_ctx - n_keep) / 2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; const size_t erased_blocks = (num_prompt_tokens - n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
LOG_VERBOSE("input truncated", { LOG_VERBOSE("input truncated", {
{ "n_ctx", params.n_ctx }, { "n_ctx", n_ctx },
{ "n_keep", params.n_keep }, { "n_keep", n_keep },
{ "n_left", n_left }, { "n_left", n_left },
{ "new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()) }, { "new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()) },
}); });
truncated = true; truncated = true;
prompt_tokens = new_tokens; prompt_tokens = new_tokens;
} else { }
const size_t ps = num_prompt_tokens;
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); last_n_tokens.resize(n_last);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); if (n_last > 0) {
const size_t s = std::min(n_last, num_prompt_tokens);
std::fill(last_n_tokens.begin(), last_n_tokens.end() - s, 0);
std::copy(prompt_tokens.end() - s, prompt_tokens.end(), last_n_tokens.begin());
} }
// compare the evaluated prompt with the new prompt // compare the evaluated prompt with the new prompt
@ -285,50 +246,141 @@ struct llama_server_context {
{ "cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past) }, { "cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past) },
{ "to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend()) }, { "to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend()) },
}); });
}
bool evaluate(size_t n_threads, size_t n_batch) {
if (embd.size() >= n_ctx) {
// Reset context
const size_t n_left = (n_ctx - n_keep) / 2;
std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + n_keep);
new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
embd = new_tokens;
n_past = n_keep;
truncated = true;
LOG_VERBOSE("input truncated", {
{ "n_ctx", n_ctx },
{ "n_keep", n_keep },
{ "n_left", n_left },
{ "new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()) },
});
}
while (n_past < embd.size()) {
size_t n_eval = embd.size() - n_past;
if (n_eval > n_batch) {
n_eval = n_batch;
}
//LOG_VERBOSE("eval", {
// { "n_eval", n_eval },
// { "n_past", n_past },
// { "n_threads", n_threads },
// { "embd", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend()) },
//});
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, n_threads)) {
LOG_ERROR("failed to eval", {
{ "n_eval", n_eval },
{ "n_past", n_past },
{ "n_threads", n_threads },
{ "embd", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend()) },
});
return false;
}
n_past += n_eval;
}
return true;
}
void append_token(llama_token id) {
if (last_n_tokens.size() > 0) {
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
}
embd.push_back(id);
}
};
struct llama_server_context {
bool stream = false;
bool has_next_token = false;
std::string generated_text;
std::vector<completion_token_output> generated_token_probs;
size_t num_tokens_predicted = 0;
int n_keep_guidance = 0;
size_t n_remain = 0;
bool cfg_enabled = false;
llama_model * model = nullptr;
llama_context * ctx = nullptr;
gpt_params params;
prompt_evaluator evaluator;
prompt_evaluator evaluator_guidance;
bool stopped_eos = false;
bool stopped_word = false;
bool stopped_limit = false;
std::string stopping_word;
int32_t multibyte_pending = 0;
std::mutex mutex;
std::unique_lock<std::mutex> lock() {
return std::unique_lock<std::mutex>(mutex);
}
~llama_server_context() {
if (model) {
llama_free_model(model);
model = nullptr;
}
}
void rewind() {
params.antiprompt.clear();
num_tokens_predicted = 0;
generated_text = "";
generated_text.reserve(params.n_ctx);
generated_token_probs.clear();
stopped_eos = false;
stopped_word = false;
stopped_limit = false;
stopping_word = "";
multibyte_pending = 0;
n_remain = 0;
cfg_enabled = false;
evaluator.rewind();
evaluator_guidance.rewind();
}
bool loadModel(const gpt_params & params_) {
params = params_;
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == nullptr) {
LOG_ERROR("unable to load model", {{ "model", params_.model }});
return false;
}
evaluator.set_context(ctx);
struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
llama_context * ctx_guidance = llama_new_context_with_model(model, lparams);
evaluator_guidance.set_context(ctx_guidance);
return true;
}
void loadPrompt() {
evaluator.load_prompt(params.prompt, params.n_keep, params.repeat_last_n);
has_next_token = true; has_next_token = true;
} }
void loadGuidancePrompt() { void loadGuidancePrompt() {
params.cfg_negative_prompt.insert(0, 1, ' '); // always add a first space evaluator_guidance.load_prompt(params.cfg_negative_prompt, n_keep_guidance, 0);
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true);
num_prompt_tokens = prompt_tokens.size();
if (n_keep_guidance < 0) {
n_keep_guidance = (int)num_prompt_tokens;
}
n_keep_guidance = std::min(params.n_ctx - 4, n_keep_guidance);
// if input prompt is too big, truncate like normal
if (num_prompt_tokens >= (size_t)params.n_ctx) {
const int n_left = (params.n_ctx - n_keep_guidance) / 2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + n_keep_guidance);
const int erased_blocks = (num_prompt_tokens - n_keep_guidance - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + n_keep_guidance + erased_blocks * n_left, prompt_tokens.end());
LOG_VERBOSE("guidance truncated", {
{ "n_ctx", params.n_ctx },
{ "n_keep", n_keep_guidance },
{ "n_left", n_left },
{ "new_tokens", tokens_to_str(ctx_guidance, new_tokens.cbegin(), new_tokens.cend()) },
});
prompt_tokens = new_tokens;
}
// compare the evaluated prompt with the new prompt
n_past_guidance = common_part(embd_guidance, prompt_tokens);
embd_guidance = prompt_tokens;
if (n_past_guidance == num_prompt_tokens) {
// we have to evaluate at least 1 token to generate logits.
n_past_guidance--;
}
LOG_VERBOSE("guidance prompt ingested", {
{ "n_past", n_past_guidance },
{ "cached", tokens_to_str(ctx_guidance, embd.cbegin(), embd.cbegin() + n_past) },
{ "to_eval", tokens_to_str(ctx_guidance, embd.cbegin() + n_past, embd.cend()) },
});
} }
void beginCompletion() { void beginCompletion() {
@ -341,80 +393,14 @@ struct llama_server_context {
completion_token_output result; completion_token_output result;
result.tok = -1; result.tok = -1;
if (embd.size() >= (size_t)params.n_ctx) { evaluator.evaluate(params.n_threads, params.n_batch);
// Reset context
const int n_left = (params.n_ctx - params.n_keep) / 2;
std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + params.n_keep);
new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
embd = new_tokens;
n_past = params.n_keep;
truncated = true;
LOG_VERBOSE("input truncated", {
{ "n_ctx", params.n_ctx },
{ "n_keep", params.n_keep },
{ "n_left", n_left },
{ "new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()) },
});
}
while (n_past < embd.size()) {
int n_eval = (int)embd.size() - n_past;
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads)) {
LOG_ERROR("failed to eval", {
{ "n_eval", n_eval },
{ "n_past", n_past },
{ "n_threads", params.n_threads },
{ "embd", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend()) },
});
has_next_token = false;
return result;
}
n_past += n_eval;
}
if (cfg_enabled) { if (cfg_enabled) {
if (embd_guidance.size() >= (size_t)params.n_ctx) { evaluator_guidance.evaluate(params.n_threads, params.n_batch);
// Reset context
const int n_left = (params.n_ctx - n_keep_guidance) / 2;
std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + n_keep_guidance);
new_tokens.insert(new_tokens.end(), embd_guidance.end() - n_left, embd_guidance.end());
embd_guidance = new_tokens;
n_past_guidance = n_keep_guidance;
LOG_VERBOSE("guidance truncated", {
{ "n_ctx", params.n_ctx },
{ "n_keep", n_keep_guidance },
{ "n_left", n_left },
{ "new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()) },
});
}
while (n_past_guidance < embd_guidance.size()) {
int n_eval = (int)embd_guidance.size() - n_past_guidance;
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}
if (llama_eval(ctx_guidance, &embd_guidance[n_past_guidance], n_eval, n_past_guidance, params.n_threads)) {
LOG_ERROR("failed to eval", {
{ "n_eval", n_eval },
{ "n_past", n_past_guidance },
{ "n_threads", params.n_threads },
{ "embd", tokens_to_str(ctx_guidance, embd_guidance.cbegin() + n_past_guidance, embd_guidance.cend()) },
});
has_next_token = false;
return result;
}
n_past_guidance += n_eval;
}
} }
if (params.n_predict == 0) { if (params.n_predict == 0) {
has_next_token = false; has_next_token = false;
//result.tok = llama_token_eos();
return result; return result;
} }
@ -424,7 +410,7 @@ struct llama_server_context {
const float top_p = params.top_p; const float top_p = params.top_p;
const float tfs_z = params.tfs_z; const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p; const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n; //const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty; const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.presence_penalty; const float alpha_presence = params.presence_penalty;
const float alpha_frequency = params.frequency_penalty; const float alpha_frequency = params.frequency_penalty;
@ -435,7 +421,7 @@ struct llama_server_context {
const int32_t n_probs = params.n_probs; const int32_t n_probs = params.n_probs;
{ {
auto * logits = llama_get_logits(ctx); auto logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx); auto n_vocab = llama_n_vocab(ctx);
// Apply params.logit_bias map // Apply params.logit_bias map
@ -453,18 +439,18 @@ struct llama_server_context {
if (cfg_enabled) { if (cfg_enabled) {
llama_sample_classifier_free_guidance( llama_sample_classifier_free_guidance(
ctx, &candidates_p, ctx_guidance, params.cfg_scale, params.cfg_smooth_factor); ctx, &candidates_p, evaluator_guidance.ctx, params.cfg_scale, params.cfg_smooth_factor);
} }
// Apply penalties // Apply penalties
float nl_logit = logits[llama_token_nl()]; float nl_logit = logits[llama_token_nl()];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx); //auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
llama_sample_repetition_penalty(ctx, &candidates_p, llama_sample_repetition_penalty(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, evaluator.last_n_tokens.data(), evaluator.last_n_tokens.size(),
last_n_repeat, repeat_penalty); repeat_penalty);
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, evaluator.last_n_tokens.data(), evaluator.last_n_tokens.size(),
last_n_repeat, alpha_frequency, alpha_presence); alpha_frequency, alpha_presence);
if (!penalize_nl) { if (!penalize_nl) {
logits[llama_token_nl()] = nl_logit; logits[llama_token_nl()] = nl_logit;
} }
@ -500,21 +486,19 @@ struct llama_server_context {
for (size_t i = 0; i < std::min(candidates_p.size, (size_t) n_probs); ++i) { for (size_t i = 0; i < std::min(candidates_p.size, (size_t) n_probs); ++i) {
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
} }
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(result.tok);
num_tokens_predicted++; num_tokens_predicted++;
} }
// add it to the context // add it to the context
embd.push_back(result.tok); evaluator.append_token(result.tok);
if (cfg_enabled) { if (cfg_enabled) {
embd_guidance.push_back(result.tok); evaluator_guidance.append_token(result.tok);
} }
// decrement remaining sampling budget // decrement remaining sampling budget
--n_remain; --n_remain;
if (!embd.empty() && embd.back() == llama_token_eos()) { if (result.tok == llama_token_eos()) {
//stopping_word = llama_token_to_str(ctx, embd.back()); stopping_word = "";
has_next_token = false; has_next_token = false;
stopped_eos = true; stopped_eos = true;
LOG_VERBOSE("eos token found", {}); LOG_VERBOSE("eos token found", {});
@ -591,6 +575,7 @@ struct llama_server_context {
LOG_VERBOSE("next token", { LOG_VERBOSE("next token", {
{ "token", token_with_probs.tok }, { "token", token_with_probs.tok },
{ "token_text", tokens_to_output_formatted_string(ctx, token_with_probs.tok) }, { "token_text", tokens_to_output_formatted_string(ctx, token_with_probs.tok) },
{ "n_past", evaluator.n_past },
{ "has_next_token", has_next_token }, { "has_next_token", has_next_token },
{ "n_remain", n_remain }, { "n_remain", n_remain },
{ "num_tokens_predicted", num_tokens_predicted }, { "num_tokens_predicted", num_tokens_predicted },
@ -884,16 +869,16 @@ static json format_final_response(llama_server_context & llama, const std::strin
{ "stop", true }, { "stop", true },
{ "model", llama.params.model_alias }, { "model", llama.params.model_alias },
{ "tokens_predicted", llama.num_tokens_predicted }, { "tokens_predicted", llama.num_tokens_predicted },
{ "tokens_evaluated", llama.num_prompt_tokens }, { "tokens_evaluated", llama.evaluator.num_prompt_tokens },
{ "generation_settings", format_generation_settings(llama) }, { "generation_settings", format_generation_settings(llama) },
{ "prompt", llama.params.prompt }, { "prompt", llama.params.prompt },
{ "cfg_negative_prompt", llama.params.cfg_negative_prompt }, { "cfg_negative_prompt", llama.params.cfg_negative_prompt },
{ "truncated", llama.truncated }, { "truncated", llama.evaluator.truncated },
{ "stopped_eos", llama.stopped_eos }, { "stopped_eos", llama.stopped_eos },
{ "stopped_word", llama.stopped_word }, { "stopped_word", llama.stopped_word },
{ "stopped_limit", llama.stopped_limit }, { "stopped_limit", llama.stopped_limit },
{ "stopping_word", llama.stopping_word }, { "stopping_word", llama.stopping_word },
{ "tokens_cached", llama.n_past }, { "tokens_cached", llama.evaluator.n_past },
{ "timings", format_timings(llama) }, { "timings", format_timings(llama) },
}; };