Port CFG to server.

This commit is contained in:
Henri Vasserman 2023-07-13 22:37:57 +03:00
parent 3a13d1e829
commit 4cae9f5673
No known key found for this signature in database
GPG Key ID: 2995FC0F58B1A986

View File

@ -161,13 +161,18 @@ struct llama_server_context {
size_t num_prompt_tokens = 0;
size_t num_tokens_predicted = 0;
size_t n_past = 0;
size_t n_past_guidance = 0;
int n_keep_guidance = 0;
size_t n_remain = 0;
bool cfg_enabled = false;
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
std::vector<llama_token> last_n_tokens;
llama_model * model = nullptr;
llama_context * ctx = nullptr;
llama_context * ctx_guidance = nullptr;
gpt_params params;
bool truncated = false;
@ -188,6 +193,10 @@ struct llama_server_context {
llama_free(ctx);
ctx = nullptr;
}
if (ctx_guidance) {
llama_free(ctx_guidance);
ctx_guidance = nullptr;
}
if (model) {
llama_free_model(model);
model = nullptr;
@ -210,6 +219,8 @@ struct llama_server_context {
n_remain = 0;
n_past = 0;
cfg_enabled = false;
n_past_guidance = 0;
}
bool loadModel(const gpt_params & params_) {
@ -220,6 +231,9 @@ struct llama_server_context {
return false;
}
struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
ctx_guidance = llama_new_context_with_model(model, lparams);
last_n_tokens.resize(params.n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
return true;
@ -275,6 +289,48 @@ struct llama_server_context {
has_next_token = true;
}
void loadGuidancePrompt() {
params.cfg_negative_prompt.insert(0, 1, ' '); // always add a first space
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true);
num_prompt_tokens = prompt_tokens.size();
if (n_keep_guidance < 0) {
n_keep_guidance = (int)num_prompt_tokens;
}
n_keep_guidance = std::min(params.n_ctx - 4, n_keep_guidance);
// if input prompt is too big, truncate like normal
if (num_prompt_tokens >= (size_t)params.n_ctx) {
const int n_left = (params.n_ctx - n_keep_guidance) / 2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + n_keep_guidance);
const int erased_blocks = (num_prompt_tokens - n_keep_guidance - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + n_keep_guidance + erased_blocks * n_left, prompt_tokens.end());
LOG_VERBOSE("guidance truncated", {
{ "n_ctx", params.n_ctx },
{ "n_keep", n_keep_guidance },
{ "n_left", n_left },
{ "new_tokens", tokens_to_str(ctx_guidance, new_tokens.cbegin(), new_tokens.cend()) },
});
prompt_tokens = new_tokens;
}
// compare the evaluated prompt with the new prompt
n_past_guidance = common_part(embd_guidance, prompt_tokens);
embd_guidance = prompt_tokens;
if (n_past_guidance == num_prompt_tokens) {
// we have to evaluate at least 1 token to generate logits.
n_past_guidance--;
}
LOG_VERBOSE("guidance prompt ingested", {
{ "n_past", n_past_guidance },
{ "cached", tokens_to_str(ctx_guidance, embd.cbegin(), embd.cbegin() + n_past) },
{ "to_eval", tokens_to_str(ctx_guidance, embd.cbegin() + n_past, embd.cend()) },
});
}
void beginCompletion() {
// number of tokens to keep when resetting context
n_remain = params.n_predict;
@ -320,9 +376,45 @@ struct llama_server_context {
n_past += n_eval;
}
if (cfg_enabled) {
if (embd_guidance.size() >= (size_t)params.n_ctx) {
// Reset context
const int n_left = (params.n_ctx - n_keep_guidance) / 2;
std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + n_keep_guidance);
new_tokens.insert(new_tokens.end(), embd_guidance.end() - n_left, embd_guidance.end());
embd_guidance = new_tokens;
n_past_guidance = n_keep_guidance;
LOG_VERBOSE("guidance truncated", {
{ "n_ctx", params.n_ctx },
{ "n_keep", n_keep_guidance },
{ "n_left", n_left },
{ "new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()) },
});
}
while (n_past_guidance < embd_guidance.size()) {
int n_eval = (int)embd_guidance.size() - n_past_guidance;
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}
if (llama_eval(ctx_guidance, &embd_guidance[n_past_guidance], n_eval, n_past_guidance, params.n_threads)) {
LOG_ERROR("failed to eval", {
{ "n_eval", n_eval },
{ "n_past", n_past_guidance },
{ "n_threads", params.n_threads },
{ "embd", tokens_to_str(ctx_guidance, embd_guidance.cbegin() + n_past_guidance, embd_guidance.cend()) },
});
has_next_token = false;
return result;
}
n_past_guidance += n_eval;
}
}
if (params.n_predict == 0) {
has_next_token = false;
result.tok = llama_token_eos();
//result.tok = llama_token_eos();
return result;
}
@ -359,6 +451,11 @@ struct llama_server_context {
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
if (cfg_enabled) {
llama_sample_classifier_free_guidance(
ctx, &candidates_p, ctx_guidance, params.cfg_scale, params.cfg_smooth_factor);
}
// Apply penalties
float nl_logit = logits[llama_token_nl()];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
@ -410,6 +507,9 @@ struct llama_server_context {
// add it to the context
embd.push_back(result.tok);
if (cfg_enabled) {
embd_guidance.push_back(result.tok);
}
// decrement remaining sampling budget
--n_remain;
@ -747,6 +847,9 @@ static json format_generation_settings(llama_server_context & llama) {
{ "stream", llama.stream },
{ "logit_bias", llama.params.logit_bias },
{ "n_probs", llama.params.n_probs },
{ "cfg_scale", llama.params.cfg_scale },
{ "cfg_smooth_factor", llama.params.cfg_smooth_factor },
{ "cfg_n_keep", llama.n_keep_guidance },
};
}
@ -759,7 +862,7 @@ static json format_embedding_response(llama_server_context & llama) {
static json format_timings(llama_server_context & llama) {
const auto timings = llama_get_timings(llama.ctx);
assert(timings.n_eval == llama.num_tokens_predicted);
//assert(timings.n_eval == llama.num_tokens_predicted);
return json {
{ "prompt_n", timings.n_eval },
@ -784,13 +887,13 @@ static json format_final_response(llama_server_context & llama, const std::strin
{ "tokens_evaluated", llama.num_prompt_tokens },
{ "generation_settings", format_generation_settings(llama) },
{ "prompt", llama.params.prompt },
{ "cfg_negative_prompt", llama.params.cfg_negative_prompt },
{ "truncated", llama.truncated },
{ "stopped_eos", llama.stopped_eos },
{ "stopped_word", llama.stopped_word },
{ "stopped_limit", llama.stopped_limit },
{ "stopping_word", llama.stopping_word },
{ "tokens_cached", llama.n_past },
{ "tokens_predicted", llama.num_tokens_predicted },
{ "timings", format_timings(llama) },
};
@ -841,6 +944,10 @@ static void parse_options_completion(const json & body, llama_server_context & l
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
llama.params.seed = body.value("seed", default_params.seed);
llama.params.prompt = body.value("prompt", default_params.prompt);
llama.params.cfg_negative_prompt = body.value("cfg_negative_prompt", default_params.cfg_negative_prompt);
llama.params.cfg_scale = body.value("cfg_scale", default_params.cfg_scale);
llama.params.cfg_smooth_factor = body.value("cfg_smooth_factor", default_params.cfg_smooth_factor);
llama.n_keep_guidance = body.value("cfg_n_keep", 0);
llama.params.n_probs = body.value("n_probs", default_params.n_probs);
llama.params.logit_bias.clear();
@ -963,6 +1070,11 @@ int main(int argc, char ** argv) {
llama.loadPrompt();
llama.beginCompletion();
if (llama.params.cfg_negative_prompt.size() > 0) {
llama.cfg_enabled = true;
llama.loadGuidancePrompt();
}
if (!llama.stream) {
size_t stop_pos = std::string::npos;