From 4cae9f5673f9bba17beae659e0b0cd52cc794dbf Mon Sep 17 00:00:00 2001 From: Henri Vasserman Date: Thu, 13 Jul 2023 22:37:57 +0300 Subject: [PATCH] Port CFG to server. --- examples/server/server.cpp | 120 +++++++++++++++++++++++++++++++++++-- 1 file changed, 116 insertions(+), 4 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 481804d54..57fef514f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -161,13 +161,18 @@ struct llama_server_context { size_t num_prompt_tokens = 0; size_t num_tokens_predicted = 0; size_t n_past = 0; + size_t n_past_guidance = 0; + int n_keep_guidance = 0; size_t n_remain = 0; + bool cfg_enabled = false; std::vector embd; + std::vector embd_guidance; std::vector last_n_tokens; llama_model * model = nullptr; llama_context * ctx = nullptr; + llama_context * ctx_guidance = nullptr; gpt_params params; bool truncated = false; @@ -188,6 +193,10 @@ struct llama_server_context { llama_free(ctx); ctx = nullptr; } + if (ctx_guidance) { + llama_free(ctx_guidance); + ctx_guidance = nullptr; + } if (model) { llama_free_model(model); model = nullptr; @@ -210,6 +219,8 @@ struct llama_server_context { n_remain = 0; n_past = 0; + cfg_enabled = false; + n_past_guidance = 0; } bool loadModel(const gpt_params & params_) { @@ -220,6 +231,9 @@ struct llama_server_context { return false; } + struct llama_context_params lparams = llama_context_params_from_gpt_params(params); + ctx_guidance = llama_new_context_with_model(model, lparams); + last_n_tokens.resize(params.n_ctx); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); return true; @@ -236,7 +250,7 @@ struct llama_server_context { params.n_keep = std::min(params.n_ctx - 4, params.n_keep); // if input prompt is too big, truncate like normal - if (num_prompt_tokens>= (size_t)params.n_ctx) { + if (num_prompt_tokens >= (size_t)params.n_ctx) { const int n_left = (params.n_ctx - params.n_keep) / 2; std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; @@ -275,6 +289,48 @@ struct llama_server_context { has_next_token = true; } + void loadGuidancePrompt() { + params.cfg_negative_prompt.insert(0, 1, ' '); // always add a first space + std::vector prompt_tokens = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true); + num_prompt_tokens = prompt_tokens.size(); + + if (n_keep_guidance < 0) { + n_keep_guidance = (int)num_prompt_tokens; + } + n_keep_guidance = std::min(params.n_ctx - 4, n_keep_guidance); + + // if input prompt is too big, truncate like normal + if (num_prompt_tokens >= (size_t)params.n_ctx) { + const int n_left = (params.n_ctx - n_keep_guidance) / 2; + std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + n_keep_guidance); + const int erased_blocks = (num_prompt_tokens - n_keep_guidance - n_left - 1) / n_left; + new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + n_keep_guidance + erased_blocks * n_left, prompt_tokens.end()); + + LOG_VERBOSE("guidance truncated", { + { "n_ctx", params.n_ctx }, + { "n_keep", n_keep_guidance }, + { "n_left", n_left }, + { "new_tokens", tokens_to_str(ctx_guidance, new_tokens.cbegin(), new_tokens.cend()) }, + }); + + prompt_tokens = new_tokens; + } + + // compare the evaluated prompt with the new prompt + n_past_guidance = common_part(embd_guidance, prompt_tokens); + embd_guidance = prompt_tokens; + if (n_past_guidance == num_prompt_tokens) { + // we have to evaluate at least 1 token to generate logits. + n_past_guidance--; + } + + LOG_VERBOSE("guidance prompt ingested", { + { "n_past", n_past_guidance }, + { "cached", tokens_to_str(ctx_guidance, embd.cbegin(), embd.cbegin() + n_past) }, + { "to_eval", tokens_to_str(ctx_guidance, embd.cbegin() + n_past, embd.cend()) }, + }); + } + void beginCompletion() { // number of tokens to keep when resetting context n_remain = params.n_predict; @@ -320,9 +376,45 @@ struct llama_server_context { n_past += n_eval; } + if (cfg_enabled) { + if (embd_guidance.size() >= (size_t)params.n_ctx) { + // Reset context + const int n_left = (params.n_ctx - n_keep_guidance) / 2; + + std::vector new_tokens(embd.begin(), embd.begin() + n_keep_guidance); + new_tokens.insert(new_tokens.end(), embd_guidance.end() - n_left, embd_guidance.end()); + embd_guidance = new_tokens; + n_past_guidance = n_keep_guidance; + LOG_VERBOSE("guidance truncated", { + { "n_ctx", params.n_ctx }, + { "n_keep", n_keep_guidance }, + { "n_left", n_left }, + { "new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()) }, + }); + } + + while (n_past_guidance < embd_guidance.size()) { + int n_eval = (int)embd_guidance.size() - n_past_guidance; + if (n_eval > params.n_batch) { + n_eval = params.n_batch; + } + if (llama_eval(ctx_guidance, &embd_guidance[n_past_guidance], n_eval, n_past_guidance, params.n_threads)) { + LOG_ERROR("failed to eval", { + { "n_eval", n_eval }, + { "n_past", n_past_guidance }, + { "n_threads", params.n_threads }, + { "embd", tokens_to_str(ctx_guidance, embd_guidance.cbegin() + n_past_guidance, embd_guidance.cend()) }, + }); + has_next_token = false; + return result; + } + n_past_guidance += n_eval; + } + } + if (params.n_predict == 0) { has_next_token = false; - result.tok = llama_token_eos(); + //result.tok = llama_token_eos(); return result; } @@ -359,6 +451,11 @@ struct llama_server_context { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + if (cfg_enabled) { + llama_sample_classifier_free_guidance( + ctx, &candidates_p, ctx_guidance, params.cfg_scale, params.cfg_smooth_factor); + } + // Apply penalties float nl_logit = logits[llama_token_nl()]; auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx); @@ -410,6 +507,9 @@ struct llama_server_context { // add it to the context embd.push_back(result.tok); + if (cfg_enabled) { + embd_guidance.push_back(result.tok); + } // decrement remaining sampling budget --n_remain; @@ -747,6 +847,9 @@ static json format_generation_settings(llama_server_context & llama) { { "stream", llama.stream }, { "logit_bias", llama.params.logit_bias }, { "n_probs", llama.params.n_probs }, + { "cfg_scale", llama.params.cfg_scale }, + { "cfg_smooth_factor", llama.params.cfg_smooth_factor }, + { "cfg_n_keep", llama.n_keep_guidance }, }; } @@ -759,7 +862,7 @@ static json format_embedding_response(llama_server_context & llama) { static json format_timings(llama_server_context & llama) { const auto timings = llama_get_timings(llama.ctx); - assert(timings.n_eval == llama.num_tokens_predicted); + //assert(timings.n_eval == llama.num_tokens_predicted); return json { { "prompt_n", timings.n_eval }, @@ -784,13 +887,13 @@ static json format_final_response(llama_server_context & llama, const std::strin { "tokens_evaluated", llama.num_prompt_tokens }, { "generation_settings", format_generation_settings(llama) }, { "prompt", llama.params.prompt }, + { "cfg_negative_prompt", llama.params.cfg_negative_prompt }, { "truncated", llama.truncated }, { "stopped_eos", llama.stopped_eos }, { "stopped_word", llama.stopped_word }, { "stopped_limit", llama.stopped_limit }, { "stopping_word", llama.stopping_word }, { "tokens_cached", llama.n_past }, - { "tokens_predicted", llama.num_tokens_predicted }, { "timings", format_timings(llama) }, }; @@ -841,6 +944,10 @@ static void parse_options_completion(const json & body, llama_server_context & l llama.params.n_keep = body.value("n_keep", default_params.n_keep); llama.params.seed = body.value("seed", default_params.seed); llama.params.prompt = body.value("prompt", default_params.prompt); + llama.params.cfg_negative_prompt = body.value("cfg_negative_prompt", default_params.cfg_negative_prompt); + llama.params.cfg_scale = body.value("cfg_scale", default_params.cfg_scale); + llama.params.cfg_smooth_factor = body.value("cfg_smooth_factor", default_params.cfg_smooth_factor); + llama.n_keep_guidance = body.value("cfg_n_keep", 0); llama.params.n_probs = body.value("n_probs", default_params.n_probs); llama.params.logit_bias.clear(); @@ -963,6 +1070,11 @@ int main(int argc, char ** argv) { llama.loadPrompt(); llama.beginCompletion(); + if (llama.params.cfg_negative_prompt.size() > 0) { + llama.cfg_enabled = true; + llama.loadGuidancePrompt(); + } + if (!llama.stream) { size_t stop_pos = std::string::npos;