server : accept extra_context for the infill endpoint (#9874)

* server : accept extra_context for the infill endpoint

ggml-ci

* server : update readme [no ci]

* server : use repo-level FIM pattern if possible

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-10-13 21:31:35 +03:00 committed by GitHub
parent c7181bd294
commit d4c19c0f5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 153 additions and 26 deletions

View File

@ -524,9 +524,30 @@ Takes a prefix and a suffix and returns the predicted completion as stream.
- `input_prefix`: Set the prefix of the code to infill. - `input_prefix`: Set the prefix of the code to infill.
- `input_suffix`: Set the suffix of the code to infill. - `input_suffix`: Set the suffix of the code to infill.
- `prompt`: Added after the `FIM_MID` token
- `extra_context`: Additional context inserted before the FIM prefix. See https://github.com/ggerganov/llama.cpp/pull/9874
It also accepts all the options of `/completion`. It also accepts all the options of `/completion`.
If the model has `FIM_REPO` and `FIM_FILE_SEP` tokens, the [repo-level pattern](https://arxiv.org/pdf/2409.12186) is used:
```txt
<FIM_REP>myproject
<FIM_SEP>{chunk 0 filename}
{chunk 0 text}
<FIM_SEP>{chunk 1 filename}
{chunk 1 text}
...
<FIM_SEP>filename
<FIM_PRE>[input_prefix]<FIM_SUF>[input_suffix]<FIM_MID>[prompt]
```
If the tokens are missing, then the extra context is simply prefixed at the start:
```txt
[extra_context]<FIM_PRE>[input_prefix]<FIM_SUF>[input_suffix]<FIM_MID>[prompt]
```
### **GET** `/props`: Get server global properties. ### **GET** `/props`: Get server global properties.
This endpoint is public (no API key check). By default, it is read-only. To make POST request to change global properties, you need to start server with `--props` This endpoint is public (no API key check). By default, it is read-only. To make POST request to change global properties, you need to start server with `--props`

View File

@ -139,6 +139,7 @@ struct slot_params {
json input_prefix; json input_prefix;
json input_suffix; json input_suffix;
json extra_context;
}; };
struct server_slot { struct server_slot {
@ -170,6 +171,7 @@ struct server_slot {
// when a task is submitted, we first tokenize the prompt and store it here // when a task is submitted, we first tokenize the prompt and store it here
std::vector<llama_token> prompt_tokens; std::vector<llama_token> prompt_tokens;
std::vector<llama_token> extra_tokens;
std::string generated_text; std::string generated_text;
std::vector<llama_token> cache_tokens; std::vector<llama_token> cache_tokens;
@ -906,8 +908,26 @@ struct server_context {
} }
// infill // infill
slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix); slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix); slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
slot.params.extra_context = json_value(data, "extra_context", default_params.extra_context);
SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.params.extra_context.size());
for (const auto & chunk : slot.params.extra_context) {
// { "text": string, "filename": string }
if (!chunk.contains("text") || !chunk["text"].is_string()) {
send_error(task, "extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST);
return false;
}
// filename is optional
if (chunk.contains("filename") && !chunk["filename"].is_string()) {
send_error(task, "extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST);
return false;
}
SLT_DBG(slot, "extra_context chunk in file '%s':\n%s\n", chunk.value("filename", "").c_str(), chunk.value("text", "").c_str());
}
// get prompt // get prompt
if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) { if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
@ -1934,13 +1954,66 @@ struct server_context {
} break; } break;
case SERVER_TASK_CMPL_TYPE_INFILL: case SERVER_TASK_CMPL_TYPE_INFILL:
{ {
// use FIM repo-level pattern:
// ref: https://arxiv.org/pdf/2409.12186
//
// [FIM_REP]myproject
// [FIM_SEP]filename0
// extra chunk 0
// [FIM_SEP]filename1
// extra chunk 1
// ...
// [FIM_SEP]filename
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]
//
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false); auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false); auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
// for now pick context to fit in a single batch (ratio prefix:suffix = 3:1, TODO: configurable?) slot.extra_tokens.clear();
const int n_suffix_take = std::min<int>(suffix_tokens.size(), n_batch/4); if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
static const auto k_fim_repo = tokenize("myproject\n", false, false);
slot.extra_tokens.push_back(llama_token_fim_rep(model));
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
}
for (const auto & chunk : slot.params.extra_context) {
// { "text": string, "filename": string }
const std::string text = chunk.value("text", "");
const std::string filename = chunk.value("filename", "tmp");
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
const auto k_fim_file = tokenize(filename + "\n", false, false);
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
} else {
// chunk separator in binary form to avoid confusing the AI
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
static const auto k_chunk_prefix_tokens = tokenize(k_chunk_prefix_str, false, false);
slot.extra_tokens.insert(slot.extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
}
const auto chunk_tokens = tokenize(text, false, false);
slot.extra_tokens.insert(slot.extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
}
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
// TODO: current filename
static const auto k_fim_file = tokenize("filename\n", false, false);
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
}
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
const int n_suffix_take = std::min<int>(suffix_tokens.size(), (n_batch)/4);
const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch - 3) - n_suffix_take); const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch - 3) - n_suffix_take);
// fill the rest of the context with extra chunks
const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());
prefix_tokens.erase(prefix_tokens.begin(), prefix_tokens.begin() + prefix_tokens.size() - n_prefix_take); prefix_tokens.erase(prefix_tokens.begin(), prefix_tokens.begin() + prefix_tokens.size() - n_prefix_take);
suffix_tokens.resize(n_suffix_take); suffix_tokens.resize(n_suffix_take);
@ -1954,6 +2027,11 @@ struct server_context {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
} }
SLT_DBG(slot, "extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", slot.n_ctx, n_extra_take, (int) slot.extra_tokens.size());
// put the extra context before the FIM prefix
embd_inp.insert(embd_inp.begin(), slot.extra_tokens.end() - n_extra_take, slot.extra_tokens.end());
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
embd_inp.push_back(llama_token_fim_mid(model)); embd_inp.push_back(llama_token_fim_mid(model));
@ -2058,11 +2136,15 @@ struct server_context {
while (head_c < slot.cache_tokens.size() && while (head_c < slot.cache_tokens.size() &&
head_p < prompt_tokens.size()) { head_p < prompt_tokens.size()) {
if (llama_token_is_control(model, slot.cache_tokens[head_c])) { if (llama_token_is_control(model, slot.cache_tokens[head_c]) &&
slot.cache_tokens[head_c] != llama_token_fim_rep(model) &&
slot.cache_tokens[head_c] != llama_token_fim_sep(model)) {
break; break;
} }
if (llama_token_is_control(model, prompt_tokens[head_p])) { if (llama_token_is_control(model, prompt_tokens[head_p]) &&
prompt_tokens[head_p] != llama_token_fim_rep(model) &&
prompt_tokens[head_p] != llama_token_fim_sep(model)) {
break; break;
} }
@ -2071,11 +2153,15 @@ struct server_context {
while (head_c + n_match < slot.cache_tokens.size() && while (head_c + n_match < slot.cache_tokens.size() &&
head_p + n_match < prompt_tokens.size() && head_p + n_match < prompt_tokens.size() &&
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
if (llama_token_is_control(model, slot.cache_tokens[head_c + n_match])) { if (llama_token_is_control(model, slot.cache_tokens[head_c + n_match]) &&
slot.cache_tokens[head_c + n_match] != llama_token_fim_rep(model) &&
slot.cache_tokens[head_c + n_match] != llama_token_fim_sep(model)) {
break; break;
} }
if (llama_token_is_control(model, prompt_tokens[head_p + n_match])) { if (llama_token_is_control(model, prompt_tokens[head_p + n_match]) &&
prompt_tokens[head_p + n_match] != llama_token_fim_rep(model) &&
prompt_tokens[head_p + n_match] != llama_token_fim_sep(model)) {
break; break;
} }

View File

@ -6596,8 +6596,8 @@ static void llm_load_vocab(
) { ) {
vocab.special_eot_id = t.second; vocab.special_eot_id = t.second;
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.first.c_str()); __func__, t.second, t.first.c_str());
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
} }
} }
@ -6610,8 +6610,8 @@ static void llm_load_vocab(
) { ) {
vocab.special_eom_id = t.second; vocab.special_eom_id = t.second;
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.first.c_str()); __func__, t.second, t.first.c_str());
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
} }
} }
@ -6627,8 +6627,8 @@ static void llm_load_vocab(
) { ) {
vocab.special_fim_pre_id = t.second; vocab.special_fim_pre_id = t.second;
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.first.c_str()); __func__, t.second, t.first.c_str());
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
} }
} }
@ -6644,8 +6644,8 @@ static void llm_load_vocab(
) { ) {
vocab.special_fim_suf_id = t.second; vocab.special_fim_suf_id = t.second;
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.first.c_str()); __func__, t.second, t.first.c_str());
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
} }
} }
@ -6661,8 +6661,8 @@ static void llm_load_vocab(
) { ) {
vocab.special_fim_mid_id = t.second; vocab.special_fim_mid_id = t.second;
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.first.c_str()); __func__, t.second, t.first.c_str());
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
} }
} }
@ -6677,8 +6677,8 @@ static void llm_load_vocab(
) { ) {
vocab.special_fim_pad_id = t.second; vocab.special_fim_pad_id = t.second;
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.first.c_str()); __func__, t.second, t.first.c_str());
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
} }
} }
@ -6694,8 +6694,8 @@ static void llm_load_vocab(
) { ) {
vocab.special_fim_rep_id = t.second; vocab.special_fim_rep_id = t.second;
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.first.c_str()); __func__, t.second, t.first.c_str());
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
} }
} }
@ -6708,8 +6708,8 @@ static void llm_load_vocab(
) { ) {
vocab.special_fim_sep_id = t.second; vocab.special_fim_sep_id = t.second;
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.first.c_str()); __func__, t.second, t.first.c_str());
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
} }
} }
@ -6720,6 +6720,19 @@ static void llm_load_vocab(
// this is currently determined based on the token text, which is obviously not ideal // this is currently determined based on the token text, which is obviously not ideal
// ref: https://github.com/ggerganov/llama.cpp/issues/9606 // ref: https://github.com/ggerganov/llama.cpp/issues/9606
vocab.special_eog_ids.clear(); vocab.special_eog_ids.clear();
if (vocab.special_fim_pad_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_pad_id) == 0) {
vocab.special_eog_ids.insert(vocab.special_fim_pad_id);
}
if (vocab.special_fim_rep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_rep_id) == 0) {
vocab.special_eog_ids.insert(vocab.special_fim_rep_id);
}
if (vocab.special_fim_sep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_sep_id) == 0) {
vocab.special_eog_ids.insert(vocab.special_fim_sep_id);
}
for (const auto & t : vocab.token_to_id) { for (const auto & t : vocab.token_to_id) {
if (false if (false
|| t.first == "<|eot_id|>" || t.first == "<|eot_id|>"
@ -6732,13 +6745,20 @@ static void llm_load_vocab(
) { ) {
vocab.special_eog_ids.insert(t.second); vocab.special_eog_ids.insert(t.second);
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
__func__, t.first.c_str()); __func__, t.second, t.first.c_str());
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
} }
} else {
// token is control, but not marked as EOG -> print a warning
if (vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && vocab.special_eog_ids.count(t.second) == 0) {
LLAMA_LOG_WARN("%s: control token: %6d '%s' is not marked as EOG\n",
__func__, t.second, t.first.c_str());
}
} }
} }
// sanity checks
if (vocab.special_eos_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) { if (vocab.special_eos_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) {
vocab.special_eog_ids.insert(vocab.special_eos_id); vocab.special_eog_ids.insert(vocab.special_eos_id);
LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__); LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);