mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 11:23:56 +01:00
Port CFG to server.
This commit is contained in:
parent
3a13d1e829
commit
4cae9f5673
@ -161,13 +161,18 @@ struct llama_server_context {
|
|||||||
size_t num_prompt_tokens = 0;
|
size_t num_prompt_tokens = 0;
|
||||||
size_t num_tokens_predicted = 0;
|
size_t num_tokens_predicted = 0;
|
||||||
size_t n_past = 0;
|
size_t n_past = 0;
|
||||||
|
size_t n_past_guidance = 0;
|
||||||
|
int n_keep_guidance = 0;
|
||||||
size_t n_remain = 0;
|
size_t n_remain = 0;
|
||||||
|
bool cfg_enabled = false;
|
||||||
|
|
||||||
std::vector<llama_token> embd;
|
std::vector<llama_token> embd;
|
||||||
|
std::vector<llama_token> embd_guidance;
|
||||||
std::vector<llama_token> last_n_tokens;
|
std::vector<llama_token> last_n_tokens;
|
||||||
|
|
||||||
llama_model * model = nullptr;
|
llama_model * model = nullptr;
|
||||||
llama_context * ctx = nullptr;
|
llama_context * ctx = nullptr;
|
||||||
|
llama_context * ctx_guidance = nullptr;
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
bool truncated = false;
|
bool truncated = false;
|
||||||
@ -188,6 +193,10 @@ struct llama_server_context {
|
|||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
ctx = nullptr;
|
ctx = nullptr;
|
||||||
}
|
}
|
||||||
|
if (ctx_guidance) {
|
||||||
|
llama_free(ctx_guidance);
|
||||||
|
ctx_guidance = nullptr;
|
||||||
|
}
|
||||||
if (model) {
|
if (model) {
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
model = nullptr;
|
model = nullptr;
|
||||||
@ -210,6 +219,8 @@ struct llama_server_context {
|
|||||||
|
|
||||||
n_remain = 0;
|
n_remain = 0;
|
||||||
n_past = 0;
|
n_past = 0;
|
||||||
|
cfg_enabled = false;
|
||||||
|
n_past_guidance = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool loadModel(const gpt_params & params_) {
|
bool loadModel(const gpt_params & params_) {
|
||||||
@ -220,6 +231,9 @@ struct llama_server_context {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
|
||||||
|
ctx_guidance = llama_new_context_with_model(model, lparams);
|
||||||
|
|
||||||
last_n_tokens.resize(params.n_ctx);
|
last_n_tokens.resize(params.n_ctx);
|
||||||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||||
return true;
|
return true;
|
||||||
@ -236,7 +250,7 @@ struct llama_server_context {
|
|||||||
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
|
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
|
||||||
|
|
||||||
// if input prompt is too big, truncate like normal
|
// if input prompt is too big, truncate like normal
|
||||||
if (num_prompt_tokens>= (size_t)params.n_ctx) {
|
if (num_prompt_tokens >= (size_t)params.n_ctx) {
|
||||||
const int n_left = (params.n_ctx - params.n_keep) / 2;
|
const int n_left = (params.n_ctx - params.n_keep) / 2;
|
||||||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
|
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
|
||||||
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
|
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
|
||||||
@ -275,6 +289,48 @@ struct llama_server_context {
|
|||||||
has_next_token = true;
|
has_next_token = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void loadGuidancePrompt() {
|
||||||
|
params.cfg_negative_prompt.insert(0, 1, ' '); // always add a first space
|
||||||
|
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true);
|
||||||
|
num_prompt_tokens = prompt_tokens.size();
|
||||||
|
|
||||||
|
if (n_keep_guidance < 0) {
|
||||||
|
n_keep_guidance = (int)num_prompt_tokens;
|
||||||
|
}
|
||||||
|
n_keep_guidance = std::min(params.n_ctx - 4, n_keep_guidance);
|
||||||
|
|
||||||
|
// if input prompt is too big, truncate like normal
|
||||||
|
if (num_prompt_tokens >= (size_t)params.n_ctx) {
|
||||||
|
const int n_left = (params.n_ctx - n_keep_guidance) / 2;
|
||||||
|
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + n_keep_guidance);
|
||||||
|
const int erased_blocks = (num_prompt_tokens - n_keep_guidance - n_left - 1) / n_left;
|
||||||
|
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + n_keep_guidance + erased_blocks * n_left, prompt_tokens.end());
|
||||||
|
|
||||||
|
LOG_VERBOSE("guidance truncated", {
|
||||||
|
{ "n_ctx", params.n_ctx },
|
||||||
|
{ "n_keep", n_keep_guidance },
|
||||||
|
{ "n_left", n_left },
|
||||||
|
{ "new_tokens", tokens_to_str(ctx_guidance, new_tokens.cbegin(), new_tokens.cend()) },
|
||||||
|
});
|
||||||
|
|
||||||
|
prompt_tokens = new_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
// compare the evaluated prompt with the new prompt
|
||||||
|
n_past_guidance = common_part(embd_guidance, prompt_tokens);
|
||||||
|
embd_guidance = prompt_tokens;
|
||||||
|
if (n_past_guidance == num_prompt_tokens) {
|
||||||
|
// we have to evaluate at least 1 token to generate logits.
|
||||||
|
n_past_guidance--;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_VERBOSE("guidance prompt ingested", {
|
||||||
|
{ "n_past", n_past_guidance },
|
||||||
|
{ "cached", tokens_to_str(ctx_guidance, embd.cbegin(), embd.cbegin() + n_past) },
|
||||||
|
{ "to_eval", tokens_to_str(ctx_guidance, embd.cbegin() + n_past, embd.cend()) },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void beginCompletion() {
|
void beginCompletion() {
|
||||||
// number of tokens to keep when resetting context
|
// number of tokens to keep when resetting context
|
||||||
n_remain = params.n_predict;
|
n_remain = params.n_predict;
|
||||||
@ -320,9 +376,45 @@ struct llama_server_context {
|
|||||||
n_past += n_eval;
|
n_past += n_eval;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (cfg_enabled) {
|
||||||
|
if (embd_guidance.size() >= (size_t)params.n_ctx) {
|
||||||
|
// Reset context
|
||||||
|
const int n_left = (params.n_ctx - n_keep_guidance) / 2;
|
||||||
|
|
||||||
|
std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + n_keep_guidance);
|
||||||
|
new_tokens.insert(new_tokens.end(), embd_guidance.end() - n_left, embd_guidance.end());
|
||||||
|
embd_guidance = new_tokens;
|
||||||
|
n_past_guidance = n_keep_guidance;
|
||||||
|
LOG_VERBOSE("guidance truncated", {
|
||||||
|
{ "n_ctx", params.n_ctx },
|
||||||
|
{ "n_keep", n_keep_guidance },
|
||||||
|
{ "n_left", n_left },
|
||||||
|
{ "new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()) },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
while (n_past_guidance < embd_guidance.size()) {
|
||||||
|
int n_eval = (int)embd_guidance.size() - n_past_guidance;
|
||||||
|
if (n_eval > params.n_batch) {
|
||||||
|
n_eval = params.n_batch;
|
||||||
|
}
|
||||||
|
if (llama_eval(ctx_guidance, &embd_guidance[n_past_guidance], n_eval, n_past_guidance, params.n_threads)) {
|
||||||
|
LOG_ERROR("failed to eval", {
|
||||||
|
{ "n_eval", n_eval },
|
||||||
|
{ "n_past", n_past_guidance },
|
||||||
|
{ "n_threads", params.n_threads },
|
||||||
|
{ "embd", tokens_to_str(ctx_guidance, embd_guidance.cbegin() + n_past_guidance, embd_guidance.cend()) },
|
||||||
|
});
|
||||||
|
has_next_token = false;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
n_past_guidance += n_eval;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (params.n_predict == 0) {
|
if (params.n_predict == 0) {
|
||||||
has_next_token = false;
|
has_next_token = false;
|
||||||
result.tok = llama_token_eos();
|
//result.tok = llama_token_eos();
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -359,6 +451,11 @@ struct llama_server_context {
|
|||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
|
|
||||||
|
if (cfg_enabled) {
|
||||||
|
llama_sample_classifier_free_guidance(
|
||||||
|
ctx, &candidates_p, ctx_guidance, params.cfg_scale, params.cfg_smooth_factor);
|
||||||
|
}
|
||||||
|
|
||||||
// Apply penalties
|
// Apply penalties
|
||||||
float nl_logit = logits[llama_token_nl()];
|
float nl_logit = logits[llama_token_nl()];
|
||||||
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
|
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
|
||||||
@ -410,6 +507,9 @@ struct llama_server_context {
|
|||||||
|
|
||||||
// add it to the context
|
// add it to the context
|
||||||
embd.push_back(result.tok);
|
embd.push_back(result.tok);
|
||||||
|
if (cfg_enabled) {
|
||||||
|
embd_guidance.push_back(result.tok);
|
||||||
|
}
|
||||||
// decrement remaining sampling budget
|
// decrement remaining sampling budget
|
||||||
--n_remain;
|
--n_remain;
|
||||||
|
|
||||||
@ -747,6 +847,9 @@ static json format_generation_settings(llama_server_context & llama) {
|
|||||||
{ "stream", llama.stream },
|
{ "stream", llama.stream },
|
||||||
{ "logit_bias", llama.params.logit_bias },
|
{ "logit_bias", llama.params.logit_bias },
|
||||||
{ "n_probs", llama.params.n_probs },
|
{ "n_probs", llama.params.n_probs },
|
||||||
|
{ "cfg_scale", llama.params.cfg_scale },
|
||||||
|
{ "cfg_smooth_factor", llama.params.cfg_smooth_factor },
|
||||||
|
{ "cfg_n_keep", llama.n_keep_guidance },
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -759,7 +862,7 @@ static json format_embedding_response(llama_server_context & llama) {
|
|||||||
static json format_timings(llama_server_context & llama) {
|
static json format_timings(llama_server_context & llama) {
|
||||||
const auto timings = llama_get_timings(llama.ctx);
|
const auto timings = llama_get_timings(llama.ctx);
|
||||||
|
|
||||||
assert(timings.n_eval == llama.num_tokens_predicted);
|
//assert(timings.n_eval == llama.num_tokens_predicted);
|
||||||
|
|
||||||
return json {
|
return json {
|
||||||
{ "prompt_n", timings.n_eval },
|
{ "prompt_n", timings.n_eval },
|
||||||
@ -784,13 +887,13 @@ static json format_final_response(llama_server_context & llama, const std::strin
|
|||||||
{ "tokens_evaluated", llama.num_prompt_tokens },
|
{ "tokens_evaluated", llama.num_prompt_tokens },
|
||||||
{ "generation_settings", format_generation_settings(llama) },
|
{ "generation_settings", format_generation_settings(llama) },
|
||||||
{ "prompt", llama.params.prompt },
|
{ "prompt", llama.params.prompt },
|
||||||
|
{ "cfg_negative_prompt", llama.params.cfg_negative_prompt },
|
||||||
{ "truncated", llama.truncated },
|
{ "truncated", llama.truncated },
|
||||||
{ "stopped_eos", llama.stopped_eos },
|
{ "stopped_eos", llama.stopped_eos },
|
||||||
{ "stopped_word", llama.stopped_word },
|
{ "stopped_word", llama.stopped_word },
|
||||||
{ "stopped_limit", llama.stopped_limit },
|
{ "stopped_limit", llama.stopped_limit },
|
||||||
{ "stopping_word", llama.stopping_word },
|
{ "stopping_word", llama.stopping_word },
|
||||||
{ "tokens_cached", llama.n_past },
|
{ "tokens_cached", llama.n_past },
|
||||||
{ "tokens_predicted", llama.num_tokens_predicted },
|
|
||||||
{ "timings", format_timings(llama) },
|
{ "timings", format_timings(llama) },
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -841,6 +944,10 @@ static void parse_options_completion(const json & body, llama_server_context & l
|
|||||||
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
|
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
|
||||||
llama.params.seed = body.value("seed", default_params.seed);
|
llama.params.seed = body.value("seed", default_params.seed);
|
||||||
llama.params.prompt = body.value("prompt", default_params.prompt);
|
llama.params.prompt = body.value("prompt", default_params.prompt);
|
||||||
|
llama.params.cfg_negative_prompt = body.value("cfg_negative_prompt", default_params.cfg_negative_prompt);
|
||||||
|
llama.params.cfg_scale = body.value("cfg_scale", default_params.cfg_scale);
|
||||||
|
llama.params.cfg_smooth_factor = body.value("cfg_smooth_factor", default_params.cfg_smooth_factor);
|
||||||
|
llama.n_keep_guidance = body.value("cfg_n_keep", 0);
|
||||||
llama.params.n_probs = body.value("n_probs", default_params.n_probs);
|
llama.params.n_probs = body.value("n_probs", default_params.n_probs);
|
||||||
|
|
||||||
llama.params.logit_bias.clear();
|
llama.params.logit_bias.clear();
|
||||||
@ -963,6 +1070,11 @@ int main(int argc, char ** argv) {
|
|||||||
llama.loadPrompt();
|
llama.loadPrompt();
|
||||||
llama.beginCompletion();
|
llama.beginCompletion();
|
||||||
|
|
||||||
|
if (llama.params.cfg_negative_prompt.size() > 0) {
|
||||||
|
llama.cfg_enabled = true;
|
||||||
|
llama.loadGuidancePrompt();
|
||||||
|
}
|
||||||
|
|
||||||
if (!llama.stream) {
|
if (!llama.stream) {
|
||||||
size_t stop_pos = std::string::npos;
|
size_t stop_pos = std::string::npos;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user