server : refactor slot input data, move tokenizer to HTTP thread (#10023)

* server : refactor slot input data, move tokenizer to HTTP thread

* move prompt_tokens.empty() check

* fix incorrect if branch

* fix infinite generation loop

* bring back infill validation

* add infill test

* try fixing format_infill

* fix test

* remove redundant code

* rename completion to inference

* update docs

* use llama_tokens everywhere
This commit is contained in:
Xuan Son Nguyen 2024-10-24 21:51:22 +02:00 committed by GitHub
parent 40f2555797
commit 958367bf53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 468 additions and 348 deletions

View File

@ -319,6 +319,18 @@ node index.js
- The prompt is a string or an array with the first element given as a string - The prompt is a string or an array with the first element given as a string
- The model's `tokenizer.ggml.add_bos_token` metadata is `true` - The model's `tokenizer.ggml.add_bos_token` metadata is `true`
These input shapes and data type are allowed for `prompt`:
- Single string: `"string"`
- Single sequence of tokens: `[12, 34, 56]`
- Mixed tokens and strings: `[12, 34, "string", 56, 78]`
Multiple prompts are also supported. In this case, the completion result will be an array.
- Only strings: `["string1", "string2"]`
- Strings and sequences of tokens: `["string1", [12, 34, 56]]`
- Mixed types: `[[12, 34, "string", 56, 78], [12, 34, 56], "string"]`
`temperature`: Adjust the randomness of the generated text. Default: `0.8` `temperature`: Adjust the randomness of the generated text. Default: `0.8`
`dynatemp_range`: Dynamic temperature range. The final temperature will be in the range of `[temperature - dynatemp_range; temperature + dynatemp_range]` Default: `0.0`, which is disabled. `dynatemp_range`: Dynamic temperature range. The final temperature will be in the range of `[temperature - dynatemp_range; temperature + dynatemp_range]` Default: `0.0`, which is disabled.

View File

@ -43,21 +43,6 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
enum stop_type { enum stop_type {
@ -68,6 +53,7 @@ enum stop_type {
// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283 // state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
enum slot_state { enum slot_state {
SLOT_STATE_IDLE, SLOT_STATE_IDLE,
SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
SLOT_STATE_PROCESSING_PROMPT, SLOT_STATE_PROCESSING_PROMPT,
SLOT_STATE_DONE_PROMPT, SLOT_STATE_DONE_PROMPT,
SLOT_STATE_GENERATING, SLOT_STATE_GENERATING,
@ -79,7 +65,7 @@ enum server_state {
}; };
enum server_task_type { enum server_task_type {
SERVER_TASK_TYPE_COMPLETION, SERVER_TASK_TYPE_INFERENCE,
SERVER_TASK_TYPE_CANCEL, SERVER_TASK_TYPE_CANCEL,
SERVER_TASK_TYPE_NEXT_RESPONSE, SERVER_TASK_TYPE_NEXT_RESPONSE,
SERVER_TASK_TYPE_METRICS, SERVER_TASK_TYPE_METRICS,
@ -89,21 +75,22 @@ enum server_task_type {
SERVER_TASK_TYPE_SET_LORA, SERVER_TASK_TYPE_SET_LORA,
}; };
enum server_task_cmpl_type { enum server_task_inf_type {
SERVER_TASK_CMPL_TYPE_NORMAL, SERVER_TASK_INF_TYPE_COMPLETION,
SERVER_TASK_CMPL_TYPE_EMBEDDING, SERVER_TASK_INF_TYPE_EMBEDDING,
SERVER_TASK_CMPL_TYPE_RERANK, SERVER_TASK_INF_TYPE_RERANK,
SERVER_TASK_CMPL_TYPE_INFILL, SERVER_TASK_INF_TYPE_INFILL,
}; };
struct server_task { struct server_task {
int id = -1; // to be filled by server_queue int id = -1; // to be filled by server_queue
int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
llama_tokens prompt_tokens;
server_task_type type; server_task_type type;
json data; json data;
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
// utility function // utility function
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) { static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
@ -161,26 +148,20 @@ struct server_slot {
int32_t i_batch = -1; int32_t i_batch = -1;
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
// n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens = 0;
int32_t n_prompt_tokens_processed = 0; int32_t n_prompt_tokens_processed = 0;
json prompt; // can be either a string, array of strings or array of token ids // input prompt tokens
llama_tokens prompt_tokens;
json input_prefix;
json input_suffix;
json input_extra;
// when a task is submitted, we first tokenize the prompt and store it here
std::vector<llama_token> prompt_tokens;
std::vector<llama_token> extra_tokens;
size_t last_nl_pos = 0; size_t last_nl_pos = 0;
std::string generated_text; std::string generated_text;
std::vector<llama_token> cache_tokens; llama_tokens cache_tokens;
std::vector<completion_token_output> generated_token_probs; std::vector<completion_token_output> generated_token_probs;
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
bool has_next_token = true; bool has_next_token = true;
bool has_new_line = false; bool has_new_line = false;
@ -229,7 +210,7 @@ struct server_slot {
n_past = 0; n_past = 0;
n_sent_text = 0; n_sent_text = 0;
n_sent_token_probs = 0; n_sent_token_probs = 0;
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
generated_token_probs.clear(); generated_token_probs.clear();
} }
@ -734,42 +715,6 @@ struct server_context {
metrics.init(); metrics.init();
} }
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special, bool parse_special) const {
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string.
std::vector<llama_token> prompt_tokens;
if (json_prompt.is_array()) {
bool first = true;
for (const auto & p : json_prompt) {
if (p.is_string()) {
auto s = p.template get<std::string>();
std::vector<llama_token> p;
if (first) {
p = common_tokenize(ctx, s, add_special, parse_special);
first = false;
} else {
p = common_tokenize(ctx, s, false, parse_special);
}
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
} else {
if (first) {
first = false;
}
prompt_tokens.push_back(p.template get<llama_token>());
}
}
} else {
auto s = json_prompt.template get<std::string>();
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
}
return prompt_tokens;
}
server_slot * get_slot_by_id(int id) { server_slot * get_slot_by_id(int id) {
for (server_slot & slot : slots) { for (server_slot & slot : slots) {
if (slot.id == id) { if (slot.id == id) {
@ -794,22 +739,16 @@ struct server_context {
continue; continue;
} }
// skip the slot if it does not contains prompt // skip the slot if it does not contains cached tokens
if (!slot.prompt.is_string()) { if (slot.prompt_tokens.empty()) {
continue; continue;
} }
// current slot's prompt
std::string slot_prompt = slot.prompt.get<std::string>();
// length of the current slot's prompt
int slot_prompt_len = slot_prompt.size();
// length of the Longest Common Prefix between the current slot's prompt and the input prompt // length of the Longest Common Prefix between the current slot's prompt and the input prompt
int lcp_len = longest_common_prefix(slot_prompt, prompt); int lcp_len = longest_common_prefix(slot.cache_tokens, slot.prompt_tokens);
// fraction of the common substring length compared to the current slot's prompt length // fraction of the common substring length compared to the current slot's prompt length
similarity = static_cast<float>(lcp_len) / slot_prompt_len; similarity = static_cast<float>(lcp_len) / static_cast<int>(slot.prompt_tokens.size());
// select the current slot if the criteria match // select the current slot if the criteria match
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) { if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
@ -914,57 +853,6 @@ struct server_context {
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict); SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
} }
// infill
slot.input_prefix = json_value(data, "input_prefix", json());
slot.input_suffix = json_value(data, "input_suffix", json());
slot.input_extra = json_value(data, "input_extra", json());
SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.input_extra.size());
for (const auto & chunk : slot.input_extra) {
// { "text": string, "filename": string }
if (!chunk.contains("text") || !chunk["text"].is_string()) {
send_error(task, "extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST);
return false;
}
// filename is optional
if (chunk.contains("filename") && !chunk["filename"].is_string()) {
send_error(task, "extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST);
return false;
}
SLT_DBG(slot, "extra_context chunk in file '%s':\n%s\n", chunk.value("filename", "").c_str(), chunk.value("text", "").c_str());
}
// get prompt
{
const auto & prompt = data.find("prompt");
if (prompt == data.end()) {
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
return false;
}
if ((prompt->is_string()) ||
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
slot.prompt = *prompt;
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
slot.prompt = prompt->at(0);
} else if (prompt->is_array() && prompt->size() > 1) {
// array of strings
for (const auto & el : *prompt) {
if (!el.is_string()) {
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
slot.prompt = *prompt;
} else {
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
{ {
slot.sparams.logit_bias.clear(); slot.sparams.logit_bias.clear();
@ -1044,8 +932,7 @@ struct server_context {
} }
} }
slot.state = SLOT_STATE_PROCESSING_PROMPT; slot.state = SLOT_STATE_STARTED;
slot.prompt_tokens.clear();
SLT_INF(slot, "%s", "processing task\n"); SLT_INF(slot, "%s", "processing task\n");
@ -1297,7 +1184,7 @@ struct server_context {
}; };
if (slot.sparams.n_probs > 0) { if (slot.sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = common_tokenize(ctx, tkn.text_to_send, false); const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size()); const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size()); const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
@ -1333,7 +1220,7 @@ struct server_context {
{"tokens_predicted", slot.n_decoded}, {"tokens_predicted", slot.n_decoded},
{"tokens_evaluated", slot.n_prompt_tokens}, {"tokens_evaluated", slot.n_prompt_tokens},
{"generation_settings", get_formated_generation(slot)}, {"generation_settings", get_formated_generation(slot)},
{"prompt", slot.prompt}, {"prompt", common_detokenize(ctx, slot.prompt_tokens)},
{"has_new_line", slot.has_new_line}, {"has_new_line", slot.has_new_line},
{"truncated", slot.truncated}, {"truncated", slot.truncated},
{"stopped_eos", slot.stopped_eos}, {"stopped_eos", slot.stopped_eos},
@ -1348,7 +1235,7 @@ struct server_context {
if (slot.sparams.n_probs > 0) { if (slot.sparams.n_probs > 0) {
std::vector<completion_token_output> probs; std::vector<completion_token_output> probs;
if (!slot.params.stream && slot.stopped_word) { if (!slot.params.stream && slot.stopped_word) {
const std::vector<llama_token> stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
probs = std::vector<completion_token_output>( probs = std::vector<completion_token_output>(
@ -1457,19 +1344,17 @@ struct server_context {
// Functions to create new task(s) and receive result(s) // Functions to create new task(s) and receive result(s)
// //
std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) { // break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
std::vector<server_task> tasks; std::vector<server_task> tasks;
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) { auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
server_task task; server_task task;
task.id = queue_tasks.get_new_id(); task.id = queue_tasks.get_new_id();
task.cmpl_type = cmpl_type; task.inf_type = inf_type;
task.type = SERVER_TASK_TYPE_COMPLETION; task.type = SERVER_TASK_TYPE_INFERENCE;
if (replace_prompt) {
task.data = task_data; task.data = task_data;
task.data["prompt"] = std::move(prompt); task.prompt_tokens = std::move(prompt_tokens);
} else {
task.data = std::move(task_data);
}
tasks.push_back(std::move(task)); tasks.push_back(std::move(task));
}; };
@ -1478,42 +1363,50 @@ struct server_context {
throw std::runtime_error(error_msg); throw std::runtime_error(error_msg);
} }
json prompt = data.at("prompt"); // because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
// if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
if (prompt.is_string() || json_is_array_of_numbers(prompt)) { switch (inf_type) {
data["index"] = 0; case SERVER_TASK_INF_TYPE_RERANK:
create_task(data, false, nullptr); {
} else if (prompt.is_array()) {
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
std::vector<json> prompts = prompt;
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
// prompts[0] is the question // prompts[0] is the question
// the rest are the answers/documents // the rest are the answers/documents
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1); GGML_ASSERT(tokenized_prompts.size() > 1);
for (size_t i = 1; i < prompts.size(); i++) { SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1);
json qd; for (size_t i = 1; i < tokenized_prompts.size(); i++) {
qd.push_back(prompts[0]);
qd.push_back(prompts[i]);
data["index"] = i - 1; data["index"] = i - 1;
create_task(data, true, qd); auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]);
create_task(data, tokens);
} }
} else { } break;
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size()); case SERVER_TASK_INF_TYPE_INFILL:
for (size_t i = 0; i < prompts.size(); i++) { {
const auto & e = prompts[i]; SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
if (e.is_string() || json_is_array_of_numbers(e)) { for (size_t i = 0; i < tokenized_prompts.size(); i++) {
data["index"] = i; data["index"] = i;
create_task(data, true, e); auto tokens = format_infill(
} else { ctx,
throw std::runtime_error(error_msg); data.at("input_prefix"),
data.at("input_suffix"),
data.at("input_extra"),
params.n_batch,
params.n_predict,
slots[0].n_ctx, // TODO: there should be a better way
params.spm_infill,
tokenized_prompts[i]
);
create_task(data, tokens);
}
} break;
default:
{
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
data["index"] = i;
create_task(data, tokenized_prompts[i]);
} }
} }
} }
} else {
// invalid case
throw std::runtime_error(error_msg);
}
return tasks; return tasks;
} }
@ -1534,7 +1427,7 @@ struct server_context {
queue_tasks.post(cancel_tasks, true); queue_tasks.post(cancel_tasks, true);
} }
// receive the results from task(s) created by create_tasks_cmpl // receive the results from task(s) created by create_tasks_inference
void receive_cmpl_results( void receive_cmpl_results(
const std::unordered_set<int> & id_tasks, const std::unordered_set<int> & id_tasks,
const std::function<void(std::vector<server_task_result>&)> & result_handler, const std::function<void(std::vector<server_task_result>&)> & result_handler,
@ -1558,7 +1451,7 @@ struct server_context {
result_handler(results); result_handler(results);
} }
// receive the results from task(s) created by create_tasks_cmpl, in stream mode // receive the results from task(s) created by create_tasks_inference, in stream mode
void receive_cmpl_results_stream( void receive_cmpl_results_stream(
const std::unordered_set<int> & id_tasks, const const std::unordered_set<int> & id_tasks, const
std::function<bool(server_task_result&)> & result_handler, const std::function<bool(server_task_result&)> & result_handler, const
@ -1591,7 +1484,7 @@ struct server_context {
void process_single_task(const server_task & task) { void process_single_task(const server_task & task) {
switch (task.type) { switch (task.type) {
case SERVER_TASK_TYPE_COMPLETION: case SERVER_TASK_TYPE_INFERENCE:
{ {
const int id_slot = json_value(task.data, "id_slot", -1); const int id_slot = json_value(task.data, "id_slot", -1);
@ -1624,8 +1517,9 @@ struct server_context {
slot->reset(); slot->reset();
slot->id_task = task.id; slot->id_task = task.id;
slot->cmpl_type = task.cmpl_type; slot->inf_type = task.inf_type;
slot->index = json_value(task.data, "index", 0); slot->index = json_value(task.data, "index", 0);
slot->prompt_tokens = std::move(task.prompt_tokens);
if (!launch_slot_with_task(*slot, task)) { if (!launch_slot_with_task(*slot, task)) {
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
@ -1658,7 +1552,7 @@ struct server_context {
slot_data["id"] = slot.id; slot_data["id"] = slot.id;
slot_data["id_task"] = slot.id_task; slot_data["id_task"] = slot.id_task;
slot_data["state"] = slot.state; slot_data["state"] = slot.state;
slot_data["prompt"] = slot.prompt; slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
slot_data["next_token"] = { slot_data["next_token"] = {
{"has_next_token", slot.has_next_token}, {"has_next_token", slot.has_next_token},
{"has_new_line", slot.has_new_line}, {"has_new_line", slot.has_new_line},
@ -1785,9 +1679,6 @@ struct server_context {
} }
slot->cache_tokens.resize(token_count); slot->cache_tokens.resize(token_count);
// TODO: maybe detokenize the slot->cache_tokens instead?
slot->prompt = string_format("[restored %d tokens from file]", (int) token_count);
const int64_t t_end = ggml_time_us(); const int64_t t_end = ggml_time_us();
const double t_restore_ms = (t_end - t_start) / 1000.0; const double t_restore_ms = (t_end - t_start) / 1000.0;
@ -1954,142 +1845,18 @@ struct server_context {
if (params.cont_batching || batch.n_tokens == 0) { if (params.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) { for (auto & slot : slots) {
// this slot still has a prompt to be processed // this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT) { if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
auto & prompt_tokens = slot.prompt_tokens; auto & prompt_tokens = slot.prompt_tokens;
// we haven't tokenized the prompt yet - do it now: // TODO: maybe move branch to outside of this loop in the future
if (prompt_tokens.empty()) { if (slot.state == SLOT_STATE_STARTED) {
SLT_INF(slot, "tokenizing prompt, len = %d\n", (int) slot.prompt.size());
slot.t_start_process_prompt = ggml_time_us(); slot.t_start_process_prompt = ggml_time_us();
slot.t_start_generation = 0; slot.t_start_generation = 0;
switch (slot.cmpl_type) {
case SERVER_TASK_CMPL_TYPE_NORMAL:
case SERVER_TASK_CMPL_TYPE_EMBEDDING:
{
prompt_tokens = tokenize(slot.prompt, llama_add_bos_token(model), true);
} break;
case SERVER_TASK_CMPL_TYPE_RERANK:
{
// require slot.prompt to be array of 2 strings
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
slot.release();
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
continue;
}
// prompt: [BOS]query[EOS][SEP]doc[EOS]
prompt_tokens.clear();
prompt_tokens.push_back(llama_token_bos(model));
{
const auto part = tokenize(slot.prompt[0], false, false);
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
}
prompt_tokens.push_back(llama_token_eos(model));
prompt_tokens.push_back(llama_token_sep(model));
{
const auto part = tokenize(slot.prompt[1], false, false);
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
}
prompt_tokens.push_back(llama_token_eos(model));
} break;
case SERVER_TASK_CMPL_TYPE_INFILL:
{
// TODO: optimize this block by reducing memory allocations and movement
// use FIM repo-level pattern:
// ref: https://arxiv.org/pdf/2409.12186
//
// [FIM_REP]myproject
// [FIM_SEP]filename0
// extra chunk 0
// [FIM_SEP]filename1
// extra chunk 1
// ...
// [FIM_SEP]filename
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
//
auto tokens_prefix = tokenize(slot.input_prefix, false, false);
auto tokens_suffix = tokenize(slot.input_suffix, false, false);
auto tokens_prompt = tokenize(slot.prompt, false, false);
slot.extra_tokens.clear();
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
static const auto k_fim_repo = tokenize("myproject\n", false, false);
slot.extra_tokens.push_back(llama_token_fim_rep(model));
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
}
for (const auto & chunk : slot.input_extra) {
// { "text": string, "filename": string }
const std::string text = chunk.value("text", "");
const std::string filename = chunk.value("filename", "tmp");
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
const auto k_fim_file = tokenize(filename + "\n", false, false);
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
} else {
// chunk separator in binary form to avoid confusing the AI
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
static const auto k_chunk_prefix_tokens = tokenize(k_chunk_prefix_str, false, false);
slot.extra_tokens.insert(slot.extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
}
const auto chunk_tokens = tokenize(text, false, false);
slot.extra_tokens.insert(slot.extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
}
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
// TODO: current filename
static const auto k_fim_file = tokenize("filename\n", false, false);
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
}
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
const int n_suffix_take = std::min<int>(tokens_suffix.size(), (n_batch/4));
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);
// fill the rest of the context with extra chunks
const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
tokens_suffix.resize(n_suffix_take);
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
auto embd_inp = params.spm_infill ? tokens_suffix : tokens_prefix;
auto embd_end = params.spm_infill ? tokens_prefix : tokens_suffix;
if (llama_add_bos_token(model)) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
}
SLT_DBG(slot, "extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", slot.n_ctx, n_extra_take, (int) slot.extra_tokens.size());
// put the extra context before the FIM prefix
embd_inp.insert(embd_inp.begin(), slot.extra_tokens.end() - n_extra_take, slot.extra_tokens.end());
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
embd_inp.push_back(llama_token_fim_mid(model));
prompt_tokens = std::move(embd_inp);
} break;
}
slot.n_past = 0; slot.n_past = 0;
slot.n_prompt_tokens = prompt_tokens.size(); slot.n_prompt_tokens = prompt_tokens.size();
slot.state = SLOT_STATE_PROCESSING_PROMPT;
SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
// print prompt tokens (for debugging) // print prompt tokens (for debugging)
if (1) { if (1) {
@ -2114,7 +1881,7 @@ struct server_context {
continue; continue;
} }
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
// this prompt is too large to process - discard it // this prompt is too large to process - discard it
if (slot.n_prompt_tokens > n_ubatch) { if (slot.n_prompt_tokens > n_ubatch) {
slot.release(); slot.release();
@ -2144,7 +1911,7 @@ struct server_context {
const int n_block_size = n_left / 2; const int n_block_size = n_left / 2;
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
std::vector<llama_token> new_tokens( llama_tokens new_tokens(
prompt_tokens.begin(), prompt_tokens.begin(),
prompt_tokens.begin() + slot.params.n_keep); prompt_tokens.begin() + slot.params.n_keep);
@ -2225,7 +1992,7 @@ struct server_context {
} }
// non-causal tasks require to fit the entire prompt in the physical batch // non-causal tasks require to fit the entire prompt in the physical batch
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
// cannot fit the prompt in the current batch - will try next iter // cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
continue; continue;
@ -2234,8 +2001,8 @@ struct server_context {
// check that we are in the right batch_type, if not defer the slot // check that we are in the right batch_type, if not defer the slot
const bool slot_type = const bool slot_type =
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0; slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
if (batch_type == -1) { if (batch_type == -1) {
batch_type = slot_type; batch_type = slot_type;
@ -2353,7 +2120,7 @@ struct server_context {
} }
if (slot.state == SLOT_STATE_DONE_PROMPT) { if (slot.state == SLOT_STATE_DONE_PROMPT) {
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
// prompt evaluated for embedding // prompt evaluated for embedding
send_embedding(slot, batch_view); send_embedding(slot, batch_view);
slot.release(); slot.release();
@ -2361,7 +2128,7 @@ struct server_context {
continue; // continue loop of slots continue; // continue loop of slots
} }
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
send_rerank(slot, batch_view); send_rerank(slot, batch_view);
slot.release(); slot.release();
slot.i_batch = -1; slot.i_batch = -1;
@ -2915,13 +2682,13 @@ int main(int argc, char ** argv) {
res_ok(res, {{ "success", true }}); res_ok(res, {{ "success", true }});
}; };
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
if (ctx_server.params.embedding || ctx_server.params.reranking) { if (ctx_server.params.embedding || ctx_server.params.reranking) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type); std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
@ -2967,10 +2734,11 @@ int main(int argc, char ** argv) {
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body); json data = json::parse(req.body);
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res); return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
}; };
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
// check model compatibility
std::string err; std::string err;
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) { if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
err += "prefix token is missing. "; err += "prefix token is missing. ";
@ -2981,14 +2749,42 @@ int main(int argc, char ** argv) {
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) { if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
err += "middle token is missing. "; err += "middle token is missing. ";
} }
if (!err.empty()) { if (!err.empty()) {
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
json data = json::parse(req.body); json data = json::parse(req.body);
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
// validate input
if (!data.contains("input_prefix")) {
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
}
if (!data.contains("input_suffix")) {
res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
}
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
return;
}
json input_extra = json_value(data, "input_extra", json::array());
for (const auto & chunk : input_extra) {
// { "text": string, "filename": string }
if (!chunk.contains("text") || !chunk.at("text").is_string()) {
res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
return;
}
// filename is optional
if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
return;
}
}
data["input_extra"] = input_extra; // default to empty array if it's not exist
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
}; };
// TODO: maybe merge this function with "handle_completions_generic" // TODO: maybe merge this function with "handle_completions_generic"
@ -3000,7 +2796,7 @@ int main(int argc, char ** argv) {
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL); std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
@ -3073,7 +2869,7 @@ int main(int argc, char ** argv) {
const bool add_special = json_value(body, "add_special", false); const bool add_special = json_value(body, "add_special", false);
const bool with_pieces = json_value(body, "with_pieces", false); const bool with_pieces = json_value(body, "with_pieces", false);
std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special, true); llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true);
if (with_pieces) { if (with_pieces) {
for (const auto& token : tokens) { for (const auto& token : tokens) {
@ -3110,7 +2906,7 @@ int main(int argc, char ** argv) {
std::string content; std::string content;
if (body.count("tokens") != 0) { if (body.count("tokens") != 0) {
const std::vector<llama_token> tokens = body.at("tokens"); const llama_tokens tokens = body.at("tokens");
content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend()); content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
} }
@ -3144,7 +2940,7 @@ int main(int argc, char ** argv) {
json responses = json::array(); json responses = json::array();
bool error = false; bool error = false;
{ {
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING); std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);
@ -3221,7 +3017,7 @@ int main(int argc, char ** argv) {
json responses = json::array(); json responses = json::array();
bool error = false; bool error = false;
{ {
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK); std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks); ctx_server.queue_tasks.post(tasks);

View File

@ -0,0 +1,36 @@
@llama.cpp
@infill
Feature: llama.cpp server
# The current model is made by adding FIM tokens to the existing stories260K
# We may want to use a better model in the future, maybe something like SmolLM 360M
Background: Server startup
Given a server listening on localhost:8080
And a model file tinyllamas/stories260K-infill.gguf from HF repo ggml-org/models
And a model file test-model-infill.gguf
And a model alias tinyllama-infill
And 42 as server seed
And 1024 as batch size
And 1024 as ubatch size
And 2048 KV cache size
And 64 max tokens to predict
And 0.0 temperature
Then the server is starting
Then the server is healthy
Scenario: Infill without input_extra
Given a prompt "Complete this"
And an infill input extra none none
And an infill input prefix "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_"
And an infill input suffix "}\n"
And an infill request with no api error
Then 64 tokens are predicted matching One|day|she|saw|big|scary|bird
Scenario: Infill with input_extra
Given a prompt "Complete this"
And an infill input extra "llama.h" "LLAMA_API int32_t llama_n_threads();\n"
And an infill input prefix "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_"
And an infill input suffix "}\n"
And an infill request with no api error
Then 64 tokens are predicted matching cuts|Jimmy|mom|came|into|the|room"

View File

@ -80,6 +80,11 @@ def step_server_config(context, server_fqdn: str, server_port: str):
context.lora_file = None context.lora_file = None
context.disable_ctx_shift = False context.disable_ctx_shift = False
# infill
context.infill_input_extra = None
context.infill_input_suffix = ''
context.infill_input_prefix = ''
context.tasks_result = [] context.tasks_result = []
context.concurrent_tasks = [] context.concurrent_tasks = []
context.prompts = [] context.prompts = []
@ -291,6 +296,28 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}" assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
@step('an infill request with {api_error} api error')
@async_run_until_complete
async def step_request_completion(context, api_error: Literal['raised'] | str):
if api_error != 'no':
raise ValueError(f'api_error={api_error} is not yet implemented')
payload = {
"prompt": context.prompts[0],
"input_suffix": context.infill_input_suffix,
"input_prefix": context.infill_input_prefix,
"n_predict": context.n_predict,
"seed": context.seed,
"temperature": context.temperature,
}
if context.infill_input_extra is not None:
payload['input_extra'] = context.infill_input_extra
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
async with session.post(f'{context.base_url}/infill',
json=payload) as response:
assert response.status == 200
context.tasks_result = [await response.json()]
@step('{predicted_n:d} tokens are predicted matching {re_content}') @step('{predicted_n:d} tokens are predicted matching {re_content}')
def step_n_tokens_predicted_with_content(context, predicted_n, re_content): def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
context.completion = context.tasks_result.pop() context.completion = context.tasks_result.pop()
@ -539,6 +566,25 @@ def step_a_prompt_prompt(context, prompt):
context.n_prompts = len(context.prompts) context.n_prompts = len(context.prompts)
# TODO: allow this to be repeated
@step('an infill input extra {filename} {text}')
def step_infill_input_extra(context, filename, text):
if filename == 'none':
context.infill_input_extra = None
else:
context.infill_input_extra = [{'filename': filename, 'text': text}]
@step('an infill input suffix {text}')
def step_infill_input_suffix(context, text):
context.infill_input_suffix = text
@step('an infill input prefix {text}')
def step_infill_input_prefix(context, text):
context.infill_input_prefix = text
@step('{num_prompts:d} prompts {prompt} with seed {seed:d}') @step('{num_prompts:d} prompts {prompt} with seed {seed:d}')
def step_many_prompts(context, num_prompts, prompt, seed): def step_many_prompts(context, num_prompts, prompt, seed):
if context.seed is None: if context.seed is None:

View File

@ -24,6 +24,22 @@
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
using llama_tokens = std::vector<llama_token>;
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
enum error_type { enum error_type {
@ -52,9 +68,235 @@ static T json_value(const json & body, const std::string & key, const T & defaul
} }
// //
// chat template utils // tokenizer and input processing utils
// //
static bool json_is_array_of_numbers(const json & data) {
if (data.is_array()) {
for (const auto & e : data) {
if (!e.is_number_integer()) {
return false;
}
}
return true;
}
return false;
}
// is array having BOTH numbers & strings?
static bool json_is_array_of_mixed_numbers_strings(const json & data) {
bool seen_string = false;
bool seen_number = false;
if (data.is_array()) {
for (const auto & e : data) {
seen_string |= e.is_string();
seen_number |= e.is_number_integer();
if (seen_number && seen_string) {
return true;
}
}
}
return false;
}
/**
* this handles 2 cases:
* - only string, example: "string"
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
*/
static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string.
llama_tokens prompt_tokens;
if (json_prompt.is_array()) {
bool first = true;
for (const auto & p : json_prompt) {
if (p.is_string()) {
auto s = p.template get<std::string>();
llama_tokens p;
if (first) {
p = common_tokenize(ctx, s, add_special, parse_special);
first = false;
} else {
p = common_tokenize(ctx, s, false, parse_special);
}
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
} else {
if (first) {
first = false;
}
prompt_tokens.push_back(p.template get<llama_token>());
}
}
} else {
auto s = json_prompt.template get<std::string>();
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
}
return prompt_tokens;
}
/**
* break the input "prompt" object into multiple prompt if needed, then tokenize them
* this supports these cases:
* - "prompt": "string"
* - "prompt": [12, 34, 56]
* - "prompt": [12, 34, "string", 56, 78]
* and multiple prompts (multi-tasks):
* - "prompt": ["string1", "string2"]
* - "prompt": ["string1", [12, 34, 56]]
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
*/
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
std::vector<llama_tokens> result;
if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
// string or mixed
result.push_back(tokenize_mixed(ctx, json_prompt, add_special, parse_special));
} else if (json_is_array_of_numbers(json_prompt)) {
// array of tokens
result.push_back(json_prompt.get<llama_tokens>());
} else if (json_prompt.is_array()) {
// array of prompts
result.reserve(json_prompt.size());
for (const auto & p : json_prompt) {
if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
result.push_back(tokenize_mixed(ctx, p, add_special, parse_special));
} else if (json_is_array_of_numbers(p)) {
// array of tokens
result.push_back(p.get<llama_tokens>());
} else {
throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
}
}
} else {
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
}
return result;
}
//
// template utils
//
// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
static llama_tokens format_rerank(const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) {
llama_tokens result;
result.reserve(doc.size() + query.size() + 4);
result.push_back(llama_token_bos(model));
result.insert(result.end(), query.begin(), query.end());
result.push_back(llama_token_eos(model));
result.push_back(llama_token_sep(model));
result.insert(result.end(), doc.begin(), doc.end());
result.push_back(llama_token_eos(model));
return result;
}
// format infill task
static llama_tokens format_infill(
const llama_context * ctx,
const json & input_prefix,
const json & input_suffix,
const json & input_extra,
const int n_batch,
const int n_predict,
const int n_ctx,
const bool spm_infill,
const llama_tokens & tokens_prompt
) {
// TODO: optimize this block by reducing memory allocations and movement
// use FIM repo-level pattern:
// ref: https://arxiv.org/pdf/2409.12186
//
// [FIM_REP]myproject
// [FIM_SEP]filename0
// extra chunk 0
// [FIM_SEP]filename1
// extra chunk 1
// ...
// [FIM_SEP]filename
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
//
llama_tokens extra_tokens;
extra_tokens.reserve(n_ctx);
auto model = llama_get_model(ctx);
auto tokens_prefix = tokenize_mixed(ctx, input_prefix, false, false);
auto tokens_suffix = tokenize_mixed(ctx, input_suffix, false, false);
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
// TODO: make project name an input
static const auto k_fim_repo = common_tokenize(ctx, "myproject\n", false, false);
extra_tokens.push_back(llama_token_fim_rep(model));
extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
}
for (const auto & chunk : input_extra) {
// { "text": string, "filename": string }
const std::string text = json_value(chunk, "text", std::string());
const std::string filename = json_value(chunk, "filename", std::string("tmp"));
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
const auto k_fim_file = common_tokenize(ctx, filename + "\n", false, false);
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
} else {
// chunk separator in binary form to avoid confusing the AI
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
static const auto k_chunk_prefix_tokens = common_tokenize(ctx, k_chunk_prefix_str, false, false);
extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
}
const auto chunk_tokens = common_tokenize(ctx, text, false, false);
extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
}
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
// TODO: current filename
static const auto k_fim_file = common_tokenize(ctx, "filename\n", false, false);
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
}
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
const int n_suffix_take = std::min<int>(tokens_suffix.size(), (n_batch/4));
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);
// fill the rest of the context with extra chunks
const int n_extra_take = std::min<int>(std::max<int>(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size());
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
tokens_suffix.resize(n_suffix_take);
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
auto embd_end = spm_infill ? tokens_prefix : tokens_suffix;
if (llama_add_bos_token(model)) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
}
SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
// put the extra context before the FIM prefix
embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
embd_inp.push_back(llama_token_fim_mid(model));
return embd_inp;
}
// Format given chat. If tmpl is empty, we take the template from model metadata // Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) { inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
std::vector<common_chat_msg> chat; std::vector<common_chat_msg> chat;
@ -229,18 +471,6 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin
return std::string::npos; return std::string::npos;
} }
static bool json_is_array_of_numbers(const json & data) {
if (data.is_array()) {
for (const auto & e : data) {
if (!e.is_number()) {
return false;
}
}
return true;
}
return false;
}
// TODO: reuse llama_detokenize // TODO: reuse llama_detokenize
template <class Iter> template <class Iter>
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {