mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
server : refactor slot input data, move tokenizer to HTTP thread (#10023)
* server : refactor slot input data, move tokenizer to HTTP thread * move prompt_tokens.empty() check * fix incorrect if branch * fix infinite generation loop * bring back infill validation * add infill test * try fixing format_infill * fix test * remove redundant code * rename completion to inference * update docs * use llama_tokens everywhere
This commit is contained in:
parent
40f2555797
commit
958367bf53
@ -319,6 +319,18 @@ node index.js
|
|||||||
- The prompt is a string or an array with the first element given as a string
|
- The prompt is a string or an array with the first element given as a string
|
||||||
- The model's `tokenizer.ggml.add_bos_token` metadata is `true`
|
- The model's `tokenizer.ggml.add_bos_token` metadata is `true`
|
||||||
|
|
||||||
|
These input shapes and data type are allowed for `prompt`:
|
||||||
|
|
||||||
|
- Single string: `"string"`
|
||||||
|
- Single sequence of tokens: `[12, 34, 56]`
|
||||||
|
- Mixed tokens and strings: `[12, 34, "string", 56, 78]`
|
||||||
|
|
||||||
|
Multiple prompts are also supported. In this case, the completion result will be an array.
|
||||||
|
|
||||||
|
- Only strings: `["string1", "string2"]`
|
||||||
|
- Strings and sequences of tokens: `["string1", [12, 34, 56]]`
|
||||||
|
- Mixed types: `[[12, 34, "string", 56, 78], [12, 34, 56], "string"]`
|
||||||
|
|
||||||
`temperature`: Adjust the randomness of the generated text. Default: `0.8`
|
`temperature`: Adjust the randomness of the generated text. Default: `0.8`
|
||||||
|
|
||||||
`dynatemp_range`: Dynamic temperature range. The final temperature will be in the range of `[temperature - dynatemp_range; temperature + dynatemp_range]` Default: `0.0`, which is disabled.
|
`dynatemp_range`: Dynamic temperature range. The final temperature will be in the range of `[temperature - dynatemp_range; temperature + dynatemp_range]` Default: `0.0`, which is disabled.
|
||||||
|
@ -43,21 +43,6 @@
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
||||||
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
||||||
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
||||||
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
enum stop_type {
|
enum stop_type {
|
||||||
@ -68,6 +53,7 @@ enum stop_type {
|
|||||||
// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
|
// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
|
||||||
enum slot_state {
|
enum slot_state {
|
||||||
SLOT_STATE_IDLE,
|
SLOT_STATE_IDLE,
|
||||||
|
SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
|
||||||
SLOT_STATE_PROCESSING_PROMPT,
|
SLOT_STATE_PROCESSING_PROMPT,
|
||||||
SLOT_STATE_DONE_PROMPT,
|
SLOT_STATE_DONE_PROMPT,
|
||||||
SLOT_STATE_GENERATING,
|
SLOT_STATE_GENERATING,
|
||||||
@ -79,7 +65,7 @@ enum server_state {
|
|||||||
};
|
};
|
||||||
|
|
||||||
enum server_task_type {
|
enum server_task_type {
|
||||||
SERVER_TASK_TYPE_COMPLETION,
|
SERVER_TASK_TYPE_INFERENCE,
|
||||||
SERVER_TASK_TYPE_CANCEL,
|
SERVER_TASK_TYPE_CANCEL,
|
||||||
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
||||||
SERVER_TASK_TYPE_METRICS,
|
SERVER_TASK_TYPE_METRICS,
|
||||||
@ -89,21 +75,22 @@ enum server_task_type {
|
|||||||
SERVER_TASK_TYPE_SET_LORA,
|
SERVER_TASK_TYPE_SET_LORA,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum server_task_cmpl_type {
|
enum server_task_inf_type {
|
||||||
SERVER_TASK_CMPL_TYPE_NORMAL,
|
SERVER_TASK_INF_TYPE_COMPLETION,
|
||||||
SERVER_TASK_CMPL_TYPE_EMBEDDING,
|
SERVER_TASK_INF_TYPE_EMBEDDING,
|
||||||
SERVER_TASK_CMPL_TYPE_RERANK,
|
SERVER_TASK_INF_TYPE_RERANK,
|
||||||
SERVER_TASK_CMPL_TYPE_INFILL,
|
SERVER_TASK_INF_TYPE_INFILL,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task {
|
struct server_task {
|
||||||
int id = -1; // to be filled by server_queue
|
int id = -1; // to be filled by server_queue
|
||||||
int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
|
int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
|
||||||
|
|
||||||
|
llama_tokens prompt_tokens;
|
||||||
server_task_type type;
|
server_task_type type;
|
||||||
json data;
|
json data;
|
||||||
|
|
||||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||||
|
|
||||||
// utility function
|
// utility function
|
||||||
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
||||||
@ -161,26 +148,20 @@ struct server_slot {
|
|||||||
int32_t i_batch = -1;
|
int32_t i_batch = -1;
|
||||||
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
||||||
|
|
||||||
|
// n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
|
||||||
int32_t n_prompt_tokens = 0;
|
int32_t n_prompt_tokens = 0;
|
||||||
int32_t n_prompt_tokens_processed = 0;
|
int32_t n_prompt_tokens_processed = 0;
|
||||||
|
|
||||||
json prompt; // can be either a string, array of strings or array of token ids
|
// input prompt tokens
|
||||||
|
llama_tokens prompt_tokens;
|
||||||
json input_prefix;
|
|
||||||
json input_suffix;
|
|
||||||
json input_extra;
|
|
||||||
|
|
||||||
// when a task is submitted, we first tokenize the prompt and store it here
|
|
||||||
std::vector<llama_token> prompt_tokens;
|
|
||||||
std::vector<llama_token> extra_tokens;
|
|
||||||
|
|
||||||
size_t last_nl_pos = 0;
|
size_t last_nl_pos = 0;
|
||||||
|
|
||||||
std::string generated_text;
|
std::string generated_text;
|
||||||
std::vector<llama_token> cache_tokens;
|
llama_tokens cache_tokens;
|
||||||
std::vector<completion_token_output> generated_token_probs;
|
std::vector<completion_token_output> generated_token_probs;
|
||||||
|
|
||||||
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||||
|
|
||||||
bool has_next_token = true;
|
bool has_next_token = true;
|
||||||
bool has_new_line = false;
|
bool has_new_line = false;
|
||||||
@ -229,7 +210,7 @@ struct server_slot {
|
|||||||
n_past = 0;
|
n_past = 0;
|
||||||
n_sent_text = 0;
|
n_sent_text = 0;
|
||||||
n_sent_token_probs = 0;
|
n_sent_token_probs = 0;
|
||||||
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
||||||
|
|
||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
}
|
}
|
||||||
@ -734,42 +715,6 @@ struct server_context {
|
|||||||
metrics.init();
|
metrics.init();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special, bool parse_special) const {
|
|
||||||
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
|
||||||
// or the first element of the json_prompt array is a string.
|
|
||||||
std::vector<llama_token> prompt_tokens;
|
|
||||||
|
|
||||||
if (json_prompt.is_array()) {
|
|
||||||
bool first = true;
|
|
||||||
for (const auto & p : json_prompt) {
|
|
||||||
if (p.is_string()) {
|
|
||||||
auto s = p.template get<std::string>();
|
|
||||||
|
|
||||||
std::vector<llama_token> p;
|
|
||||||
if (first) {
|
|
||||||
p = common_tokenize(ctx, s, add_special, parse_special);
|
|
||||||
first = false;
|
|
||||||
} else {
|
|
||||||
p = common_tokenize(ctx, s, false, parse_special);
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
|
||||||
} else {
|
|
||||||
if (first) {
|
|
||||||
first = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt_tokens.push_back(p.template get<llama_token>());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
auto s = json_prompt.template get<std::string>();
|
|
||||||
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
|
|
||||||
}
|
|
||||||
|
|
||||||
return prompt_tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
server_slot * get_slot_by_id(int id) {
|
server_slot * get_slot_by_id(int id) {
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
if (slot.id == id) {
|
if (slot.id == id) {
|
||||||
@ -794,22 +739,16 @@ struct server_context {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// skip the slot if it does not contains prompt
|
// skip the slot if it does not contains cached tokens
|
||||||
if (!slot.prompt.is_string()) {
|
if (slot.prompt_tokens.empty()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// current slot's prompt
|
|
||||||
std::string slot_prompt = slot.prompt.get<std::string>();
|
|
||||||
|
|
||||||
// length of the current slot's prompt
|
|
||||||
int slot_prompt_len = slot_prompt.size();
|
|
||||||
|
|
||||||
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
||||||
int lcp_len = longest_common_prefix(slot_prompt, prompt);
|
int lcp_len = longest_common_prefix(slot.cache_tokens, slot.prompt_tokens);
|
||||||
|
|
||||||
// fraction of the common substring length compared to the current slot's prompt length
|
// fraction of the common substring length compared to the current slot's prompt length
|
||||||
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
|
similarity = static_cast<float>(lcp_len) / static_cast<int>(slot.prompt_tokens.size());
|
||||||
|
|
||||||
// select the current slot if the criteria match
|
// select the current slot if the criteria match
|
||||||
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
|
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
|
||||||
@ -914,57 +853,6 @@ struct server_context {
|
|||||||
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
|
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
|
||||||
}
|
}
|
||||||
|
|
||||||
// infill
|
|
||||||
slot.input_prefix = json_value(data, "input_prefix", json());
|
|
||||||
slot.input_suffix = json_value(data, "input_suffix", json());
|
|
||||||
slot.input_extra = json_value(data, "input_extra", json());
|
|
||||||
|
|
||||||
SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.input_extra.size());
|
|
||||||
for (const auto & chunk : slot.input_extra) {
|
|
||||||
// { "text": string, "filename": string }
|
|
||||||
if (!chunk.contains("text") || !chunk["text"].is_string()) {
|
|
||||||
send_error(task, "extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// filename is optional
|
|
||||||
if (chunk.contains("filename") && !chunk["filename"].is_string()) {
|
|
||||||
send_error(task, "extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
SLT_DBG(slot, "extra_context chunk in file '%s':\n%s\n", chunk.value("filename", "").c_str(), chunk.value("text", "").c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// get prompt
|
|
||||||
{
|
|
||||||
const auto & prompt = data.find("prompt");
|
|
||||||
if (prompt == data.end()) {
|
|
||||||
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((prompt->is_string()) ||
|
|
||||||
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
|
|
||||||
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
|
|
||||||
slot.prompt = *prompt;
|
|
||||||
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
|
||||||
slot.prompt = prompt->at(0);
|
|
||||||
} else if (prompt->is_array() && prompt->size() > 1) {
|
|
||||||
// array of strings
|
|
||||||
for (const auto & el : *prompt) {
|
|
||||||
if (!el.is_string()) {
|
|
||||||
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
slot.prompt = *prompt;
|
|
||||||
} else {
|
|
||||||
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
slot.sparams.logit_bias.clear();
|
slot.sparams.logit_bias.clear();
|
||||||
|
|
||||||
@ -1044,8 +932,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
slot.state = SLOT_STATE_STARTED;
|
||||||
slot.prompt_tokens.clear();
|
|
||||||
|
|
||||||
SLT_INF(slot, "%s", "processing task\n");
|
SLT_INF(slot, "%s", "processing task\n");
|
||||||
|
|
||||||
@ -1297,7 +1184,7 @@ struct server_context {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if (slot.sparams.n_probs > 0) {
|
if (slot.sparams.n_probs > 0) {
|
||||||
const std::vector<llama_token> to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
|
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
|
||||||
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
||||||
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
|
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
|
||||||
|
|
||||||
@ -1333,7 +1220,7 @@ struct server_context {
|
|||||||
{"tokens_predicted", slot.n_decoded},
|
{"tokens_predicted", slot.n_decoded},
|
||||||
{"tokens_evaluated", slot.n_prompt_tokens},
|
{"tokens_evaluated", slot.n_prompt_tokens},
|
||||||
{"generation_settings", get_formated_generation(slot)},
|
{"generation_settings", get_formated_generation(slot)},
|
||||||
{"prompt", slot.prompt},
|
{"prompt", common_detokenize(ctx, slot.prompt_tokens)},
|
||||||
{"has_new_line", slot.has_new_line},
|
{"has_new_line", slot.has_new_line},
|
||||||
{"truncated", slot.truncated},
|
{"truncated", slot.truncated},
|
||||||
{"stopped_eos", slot.stopped_eos},
|
{"stopped_eos", slot.stopped_eos},
|
||||||
@ -1348,7 +1235,7 @@ struct server_context {
|
|||||||
if (slot.sparams.n_probs > 0) {
|
if (slot.sparams.n_probs > 0) {
|
||||||
std::vector<completion_token_output> probs;
|
std::vector<completion_token_output> probs;
|
||||||
if (!slot.params.stream && slot.stopped_word) {
|
if (!slot.params.stream && slot.stopped_word) {
|
||||||
const std::vector<llama_token> stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
|
const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
|
||||||
|
|
||||||
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
|
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
|
||||||
probs = std::vector<completion_token_output>(
|
probs = std::vector<completion_token_output>(
|
||||||
@ -1457,19 +1344,17 @@ struct server_context {
|
|||||||
// Functions to create new task(s) and receive result(s)
|
// Functions to create new task(s) and receive result(s)
|
||||||
//
|
//
|
||||||
|
|
||||||
std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
|
// break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
|
||||||
|
std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
|
auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
|
||||||
|
SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
|
||||||
server_task task;
|
server_task task;
|
||||||
task.id = queue_tasks.get_new_id();
|
task.id = queue_tasks.get_new_id();
|
||||||
task.cmpl_type = cmpl_type;
|
task.inf_type = inf_type;
|
||||||
task.type = SERVER_TASK_TYPE_COMPLETION;
|
task.type = SERVER_TASK_TYPE_INFERENCE;
|
||||||
if (replace_prompt) {
|
|
||||||
task.data = task_data;
|
task.data = task_data;
|
||||||
task.data["prompt"] = std::move(prompt);
|
task.prompt_tokens = std::move(prompt_tokens);
|
||||||
} else {
|
|
||||||
task.data = std::move(task_data);
|
|
||||||
}
|
|
||||||
tasks.push_back(std::move(task));
|
tasks.push_back(std::move(task));
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1478,42 +1363,50 @@ struct server_context {
|
|||||||
throw std::runtime_error(error_msg);
|
throw std::runtime_error(error_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
json prompt = data.at("prompt");
|
// because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
|
||||||
|
bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
|
||||||
// if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
|
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
|
||||||
if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
|
switch (inf_type) {
|
||||||
data["index"] = 0;
|
case SERVER_TASK_INF_TYPE_RERANK:
|
||||||
create_task(data, false, nullptr);
|
{
|
||||||
} else if (prompt.is_array()) {
|
|
||||||
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
|
|
||||||
std::vector<json> prompts = prompt;
|
|
||||||
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
|
||||||
// prompts[0] is the question
|
// prompts[0] is the question
|
||||||
// the rest are the answers/documents
|
// the rest are the answers/documents
|
||||||
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
|
GGML_ASSERT(tokenized_prompts.size() > 1);
|
||||||
for (size_t i = 1; i < prompts.size(); i++) {
|
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1);
|
||||||
json qd;
|
for (size_t i = 1; i < tokenized_prompts.size(); i++) {
|
||||||
qd.push_back(prompts[0]);
|
|
||||||
qd.push_back(prompts[i]);
|
|
||||||
data["index"] = i - 1;
|
data["index"] = i - 1;
|
||||||
create_task(data, true, qd);
|
auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]);
|
||||||
|
create_task(data, tokens);
|
||||||
}
|
}
|
||||||
} else {
|
} break;
|
||||||
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
|
case SERVER_TASK_INF_TYPE_INFILL:
|
||||||
for (size_t i = 0; i < prompts.size(); i++) {
|
{
|
||||||
const auto & e = prompts[i];
|
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
||||||
if (e.is_string() || json_is_array_of_numbers(e)) {
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
data["index"] = i;
|
data["index"] = i;
|
||||||
create_task(data, true, e);
|
auto tokens = format_infill(
|
||||||
} else {
|
ctx,
|
||||||
throw std::runtime_error(error_msg);
|
data.at("input_prefix"),
|
||||||
|
data.at("input_suffix"),
|
||||||
|
data.at("input_extra"),
|
||||||
|
params.n_batch,
|
||||||
|
params.n_predict,
|
||||||
|
slots[0].n_ctx, // TODO: there should be a better way
|
||||||
|
params.spm_infill,
|
||||||
|
tokenized_prompts[i]
|
||||||
|
);
|
||||||
|
create_task(data, tokens);
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
||||||
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
|
data["index"] = i;
|
||||||
|
create_task(data, tokenized_prompts[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// invalid case
|
|
||||||
throw std::runtime_error(error_msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
return tasks;
|
return tasks;
|
||||||
}
|
}
|
||||||
@ -1534,7 +1427,7 @@ struct server_context {
|
|||||||
queue_tasks.post(cancel_tasks, true);
|
queue_tasks.post(cancel_tasks, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// receive the results from task(s) created by create_tasks_cmpl
|
// receive the results from task(s) created by create_tasks_inference
|
||||||
void receive_cmpl_results(
|
void receive_cmpl_results(
|
||||||
const std::unordered_set<int> & id_tasks,
|
const std::unordered_set<int> & id_tasks,
|
||||||
const std::function<void(std::vector<server_task_result>&)> & result_handler,
|
const std::function<void(std::vector<server_task_result>&)> & result_handler,
|
||||||
@ -1558,7 +1451,7 @@ struct server_context {
|
|||||||
result_handler(results);
|
result_handler(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
// receive the results from task(s) created by create_tasks_cmpl, in stream mode
|
// receive the results from task(s) created by create_tasks_inference, in stream mode
|
||||||
void receive_cmpl_results_stream(
|
void receive_cmpl_results_stream(
|
||||||
const std::unordered_set<int> & id_tasks, const
|
const std::unordered_set<int> & id_tasks, const
|
||||||
std::function<bool(server_task_result&)> & result_handler, const
|
std::function<bool(server_task_result&)> & result_handler, const
|
||||||
@ -1591,7 +1484,7 @@ struct server_context {
|
|||||||
|
|
||||||
void process_single_task(const server_task & task) {
|
void process_single_task(const server_task & task) {
|
||||||
switch (task.type) {
|
switch (task.type) {
|
||||||
case SERVER_TASK_TYPE_COMPLETION:
|
case SERVER_TASK_TYPE_INFERENCE:
|
||||||
{
|
{
|
||||||
const int id_slot = json_value(task.data, "id_slot", -1);
|
const int id_slot = json_value(task.data, "id_slot", -1);
|
||||||
|
|
||||||
@ -1624,8 +1517,9 @@ struct server_context {
|
|||||||
slot->reset();
|
slot->reset();
|
||||||
|
|
||||||
slot->id_task = task.id;
|
slot->id_task = task.id;
|
||||||
slot->cmpl_type = task.cmpl_type;
|
slot->inf_type = task.inf_type;
|
||||||
slot->index = json_value(task.data, "index", 0);
|
slot->index = json_value(task.data, "index", 0);
|
||||||
|
slot->prompt_tokens = std::move(task.prompt_tokens);
|
||||||
|
|
||||||
if (!launch_slot_with_task(*slot, task)) {
|
if (!launch_slot_with_task(*slot, task)) {
|
||||||
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
|
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
|
||||||
@ -1658,7 +1552,7 @@ struct server_context {
|
|||||||
slot_data["id"] = slot.id;
|
slot_data["id"] = slot.id;
|
||||||
slot_data["id_task"] = slot.id_task;
|
slot_data["id_task"] = slot.id_task;
|
||||||
slot_data["state"] = slot.state;
|
slot_data["state"] = slot.state;
|
||||||
slot_data["prompt"] = slot.prompt;
|
slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
|
||||||
slot_data["next_token"] = {
|
slot_data["next_token"] = {
|
||||||
{"has_next_token", slot.has_next_token},
|
{"has_next_token", slot.has_next_token},
|
||||||
{"has_new_line", slot.has_new_line},
|
{"has_new_line", slot.has_new_line},
|
||||||
@ -1785,9 +1679,6 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
slot->cache_tokens.resize(token_count);
|
slot->cache_tokens.resize(token_count);
|
||||||
|
|
||||||
// TODO: maybe detokenize the slot->cache_tokens instead?
|
|
||||||
slot->prompt = string_format("[restored %d tokens from file]", (int) token_count);
|
|
||||||
|
|
||||||
const int64_t t_end = ggml_time_us();
|
const int64_t t_end = ggml_time_us();
|
||||||
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
||||||
|
|
||||||
@ -1954,142 +1845,18 @@ struct server_context {
|
|||||||
if (params.cont_batching || batch.n_tokens == 0) {
|
if (params.cont_batching || batch.n_tokens == 0) {
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
// this slot still has a prompt to be processed
|
// this slot still has a prompt to be processed
|
||||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
|
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
||||||
auto & prompt_tokens = slot.prompt_tokens;
|
auto & prompt_tokens = slot.prompt_tokens;
|
||||||
|
|
||||||
// we haven't tokenized the prompt yet - do it now:
|
// TODO: maybe move branch to outside of this loop in the future
|
||||||
if (prompt_tokens.empty()) {
|
if (slot.state == SLOT_STATE_STARTED) {
|
||||||
SLT_INF(slot, "tokenizing prompt, len = %d\n", (int) slot.prompt.size());
|
|
||||||
|
|
||||||
slot.t_start_process_prompt = ggml_time_us();
|
slot.t_start_process_prompt = ggml_time_us();
|
||||||
slot.t_start_generation = 0;
|
slot.t_start_generation = 0;
|
||||||
|
|
||||||
switch (slot.cmpl_type) {
|
|
||||||
case SERVER_TASK_CMPL_TYPE_NORMAL:
|
|
||||||
case SERVER_TASK_CMPL_TYPE_EMBEDDING:
|
|
||||||
{
|
|
||||||
prompt_tokens = tokenize(slot.prompt, llama_add_bos_token(model), true);
|
|
||||||
} break;
|
|
||||||
case SERVER_TASK_CMPL_TYPE_RERANK:
|
|
||||||
{
|
|
||||||
// require slot.prompt to be array of 2 strings
|
|
||||||
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
|
|
||||||
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
|
|
||||||
slot.release();
|
|
||||||
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// prompt: [BOS]query[EOS][SEP]doc[EOS]
|
|
||||||
prompt_tokens.clear();
|
|
||||||
prompt_tokens.push_back(llama_token_bos(model));
|
|
||||||
{
|
|
||||||
const auto part = tokenize(slot.prompt[0], false, false);
|
|
||||||
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
|
||||||
}
|
|
||||||
prompt_tokens.push_back(llama_token_eos(model));
|
|
||||||
prompt_tokens.push_back(llama_token_sep(model));
|
|
||||||
{
|
|
||||||
const auto part = tokenize(slot.prompt[1], false, false);
|
|
||||||
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
|
||||||
}
|
|
||||||
prompt_tokens.push_back(llama_token_eos(model));
|
|
||||||
} break;
|
|
||||||
case SERVER_TASK_CMPL_TYPE_INFILL:
|
|
||||||
{
|
|
||||||
// TODO: optimize this block by reducing memory allocations and movement
|
|
||||||
|
|
||||||
// use FIM repo-level pattern:
|
|
||||||
// ref: https://arxiv.org/pdf/2409.12186
|
|
||||||
//
|
|
||||||
// [FIM_REP]myproject
|
|
||||||
// [FIM_SEP]filename0
|
|
||||||
// extra chunk 0
|
|
||||||
// [FIM_SEP]filename1
|
|
||||||
// extra chunk 1
|
|
||||||
// ...
|
|
||||||
// [FIM_SEP]filename
|
|
||||||
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
|
|
||||||
//
|
|
||||||
auto tokens_prefix = tokenize(slot.input_prefix, false, false);
|
|
||||||
auto tokens_suffix = tokenize(slot.input_suffix, false, false);
|
|
||||||
auto tokens_prompt = tokenize(slot.prompt, false, false);
|
|
||||||
|
|
||||||
slot.extra_tokens.clear();
|
|
||||||
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
|
|
||||||
static const auto k_fim_repo = tokenize("myproject\n", false, false);
|
|
||||||
|
|
||||||
slot.extra_tokens.push_back(llama_token_fim_rep(model));
|
|
||||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const auto & chunk : slot.input_extra) {
|
|
||||||
// { "text": string, "filename": string }
|
|
||||||
const std::string text = chunk.value("text", "");
|
|
||||||
const std::string filename = chunk.value("filename", "tmp");
|
|
||||||
|
|
||||||
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
|
||||||
const auto k_fim_file = tokenize(filename + "\n", false, false);
|
|
||||||
|
|
||||||
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
|
|
||||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
|
||||||
} else {
|
|
||||||
// chunk separator in binary form to avoid confusing the AI
|
|
||||||
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
|
|
||||||
static const auto k_chunk_prefix_tokens = tokenize(k_chunk_prefix_str, false, false);
|
|
||||||
|
|
||||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto chunk_tokens = tokenize(text, false, false);
|
|
||||||
slot.extra_tokens.insert(slot.extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
|
||||||
// TODO: current filename
|
|
||||||
static const auto k_fim_file = tokenize("filename\n", false, false);
|
|
||||||
|
|
||||||
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
|
|
||||||
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
|
|
||||||
const int n_suffix_take = std::min<int>(tokens_suffix.size(), (n_batch/4));
|
|
||||||
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);
|
|
||||||
|
|
||||||
// fill the rest of the context with extra chunks
|
|
||||||
const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());
|
|
||||||
|
|
||||||
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
|
|
||||||
tokens_suffix.resize(n_suffix_take);
|
|
||||||
|
|
||||||
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
|
|
||||||
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
|
|
||||||
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
|
|
||||||
|
|
||||||
auto embd_inp = params.spm_infill ? tokens_suffix : tokens_prefix;
|
|
||||||
auto embd_end = params.spm_infill ? tokens_prefix : tokens_suffix;
|
|
||||||
|
|
||||||
if (llama_add_bos_token(model)) {
|
|
||||||
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
|
||||||
}
|
|
||||||
|
|
||||||
SLT_DBG(slot, "extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", slot.n_ctx, n_extra_take, (int) slot.extra_tokens.size());
|
|
||||||
|
|
||||||
// put the extra context before the FIM prefix
|
|
||||||
embd_inp.insert(embd_inp.begin(), slot.extra_tokens.end() - n_extra_take, slot.extra_tokens.end());
|
|
||||||
|
|
||||||
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
|
||||||
embd_inp.push_back(llama_token_fim_mid(model));
|
|
||||||
|
|
||||||
prompt_tokens = std::move(embd_inp);
|
|
||||||
} break;
|
|
||||||
}
|
|
||||||
|
|
||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
slot.n_prompt_tokens = prompt_tokens.size();
|
slot.n_prompt_tokens = prompt_tokens.size();
|
||||||
|
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||||
|
|
||||||
SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
||||||
|
|
||||||
// print prompt tokens (for debugging)
|
// print prompt tokens (for debugging)
|
||||||
if (1) {
|
if (1) {
|
||||||
@ -2114,7 +1881,7 @@ struct server_context {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||||
// this prompt is too large to process - discard it
|
// this prompt is too large to process - discard it
|
||||||
if (slot.n_prompt_tokens > n_ubatch) {
|
if (slot.n_prompt_tokens > n_ubatch) {
|
||||||
slot.release();
|
slot.release();
|
||||||
@ -2144,7 +1911,7 @@ struct server_context {
|
|||||||
const int n_block_size = n_left / 2;
|
const int n_block_size = n_left / 2;
|
||||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||||
|
|
||||||
std::vector<llama_token> new_tokens(
|
llama_tokens new_tokens(
|
||||||
prompt_tokens.begin(),
|
prompt_tokens.begin(),
|
||||||
prompt_tokens.begin() + slot.params.n_keep);
|
prompt_tokens.begin() + slot.params.n_keep);
|
||||||
|
|
||||||
@ -2225,7 +1992,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||||
// cannot fit the prompt in the current batch - will try next iter
|
// cannot fit the prompt in the current batch - will try next iter
|
||||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
||||||
continue;
|
continue;
|
||||||
@ -2234,8 +2001,8 @@ struct server_context {
|
|||||||
|
|
||||||
// check that we are in the right batch_type, if not defer the slot
|
// check that we are in the right batch_type, if not defer the slot
|
||||||
const bool slot_type =
|
const bool slot_type =
|
||||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
|
slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
|
||||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
|
slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
|
||||||
|
|
||||||
if (batch_type == -1) {
|
if (batch_type == -1) {
|
||||||
batch_type = slot_type;
|
batch_type = slot_type;
|
||||||
@ -2353,7 +2120,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
|
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
|
||||||
// prompt evaluated for embedding
|
// prompt evaluated for embedding
|
||||||
send_embedding(slot, batch_view);
|
send_embedding(slot, batch_view);
|
||||||
slot.release();
|
slot.release();
|
||||||
@ -2361,7 +2128,7 @@ struct server_context {
|
|||||||
continue; // continue loop of slots
|
continue; // continue loop of slots
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
||||||
send_rerank(slot, batch_view);
|
send_rerank(slot, batch_view);
|
||||||
slot.release();
|
slot.release();
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
@ -2915,13 +2682,13 @@ int main(int argc, char ** argv) {
|
|||||||
res_ok(res, {{ "success", true }});
|
res_ok(res, {{ "success", true }});
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
|
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
|
||||||
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
||||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
|
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
@ -2967,10 +2734,11 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
|
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
// check model compatibility
|
||||||
std::string err;
|
std::string err;
|
||||||
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
||||||
err += "prefix token is missing. ";
|
err += "prefix token is missing. ";
|
||||||
@ -2981,14 +2749,42 @@ int main(int argc, char ** argv) {
|
|||||||
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
||||||
err += "middle token is missing. ";
|
err += "middle token is missing. ";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!err.empty()) {
|
if (!err.empty()) {
|
||||||
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
|
|
||||||
|
// validate input
|
||||||
|
if (!data.contains("input_prefix")) {
|
||||||
|
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!data.contains("input_suffix")) {
|
||||||
|
res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
|
||||||
|
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
json input_extra = json_value(data, "input_extra", json::array());
|
||||||
|
for (const auto & chunk : input_extra) {
|
||||||
|
// { "text": string, "filename": string }
|
||||||
|
if (!chunk.contains("text") || !chunk.at("text").is_string()) {
|
||||||
|
res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// filename is optional
|
||||||
|
if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
|
||||||
|
res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
||||||
|
|
||||||
|
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: maybe merge this function with "handle_completions_generic"
|
// TODO: maybe merge this function with "handle_completions_generic"
|
||||||
@ -3000,7 +2796,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||||
|
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
|
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
@ -3073,7 +2869,7 @@ int main(int argc, char ** argv) {
|
|||||||
const bool add_special = json_value(body, "add_special", false);
|
const bool add_special = json_value(body, "add_special", false);
|
||||||
const bool with_pieces = json_value(body, "with_pieces", false);
|
const bool with_pieces = json_value(body, "with_pieces", false);
|
||||||
|
|
||||||
std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special, true);
|
llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true);
|
||||||
|
|
||||||
if (with_pieces) {
|
if (with_pieces) {
|
||||||
for (const auto& token : tokens) {
|
for (const auto& token : tokens) {
|
||||||
@ -3110,7 +2906,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
std::string content;
|
std::string content;
|
||||||
if (body.count("tokens") != 0) {
|
if (body.count("tokens") != 0) {
|
||||||
const std::vector<llama_token> tokens = body.at("tokens");
|
const llama_tokens tokens = body.at("tokens");
|
||||||
content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
|
content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3144,7 +2940,7 @@ int main(int argc, char ** argv) {
|
|||||||
json responses = json::array();
|
json responses = json::array();
|
||||||
bool error = false;
|
bool error = false;
|
||||||
{
|
{
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
|
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
@ -3221,7 +3017,7 @@ int main(int argc, char ** argv) {
|
|||||||
json responses = json::array();
|
json responses = json::array();
|
||||||
bool error = false;
|
bool error = false;
|
||||||
{
|
{
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
|
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
|
36
examples/server/tests/features/infill.feature
Normal file
36
examples/server/tests/features/infill.feature
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
@llama.cpp
|
||||||
|
@infill
|
||||||
|
Feature: llama.cpp server
|
||||||
|
|
||||||
|
# The current model is made by adding FIM tokens to the existing stories260K
|
||||||
|
# We may want to use a better model in the future, maybe something like SmolLM 360M
|
||||||
|
|
||||||
|
Background: Server startup
|
||||||
|
Given a server listening on localhost:8080
|
||||||
|
And a model file tinyllamas/stories260K-infill.gguf from HF repo ggml-org/models
|
||||||
|
And a model file test-model-infill.gguf
|
||||||
|
And a model alias tinyllama-infill
|
||||||
|
And 42 as server seed
|
||||||
|
And 1024 as batch size
|
||||||
|
And 1024 as ubatch size
|
||||||
|
And 2048 KV cache size
|
||||||
|
And 64 max tokens to predict
|
||||||
|
And 0.0 temperature
|
||||||
|
Then the server is starting
|
||||||
|
Then the server is healthy
|
||||||
|
|
||||||
|
Scenario: Infill without input_extra
|
||||||
|
Given a prompt "Complete this"
|
||||||
|
And an infill input extra none none
|
||||||
|
And an infill input prefix "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_"
|
||||||
|
And an infill input suffix "}\n"
|
||||||
|
And an infill request with no api error
|
||||||
|
Then 64 tokens are predicted matching One|day|she|saw|big|scary|bird
|
||||||
|
|
||||||
|
Scenario: Infill with input_extra
|
||||||
|
Given a prompt "Complete this"
|
||||||
|
And an infill input extra "llama.h" "LLAMA_API int32_t llama_n_threads();\n"
|
||||||
|
And an infill input prefix "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_"
|
||||||
|
And an infill input suffix "}\n"
|
||||||
|
And an infill request with no api error
|
||||||
|
Then 64 tokens are predicted matching cuts|Jimmy|mom|came|into|the|room"
|
@ -80,6 +80,11 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
|||||||
context.lora_file = None
|
context.lora_file = None
|
||||||
context.disable_ctx_shift = False
|
context.disable_ctx_shift = False
|
||||||
|
|
||||||
|
# infill
|
||||||
|
context.infill_input_extra = None
|
||||||
|
context.infill_input_suffix = ''
|
||||||
|
context.infill_input_prefix = ''
|
||||||
|
|
||||||
context.tasks_result = []
|
context.tasks_result = []
|
||||||
context.concurrent_tasks = []
|
context.concurrent_tasks = []
|
||||||
context.prompts = []
|
context.prompts = []
|
||||||
@ -291,6 +296,28 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
|
|||||||
assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
|
assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}"
|
||||||
|
|
||||||
|
|
||||||
|
@step('an infill request with {api_error} api error')
|
||||||
|
@async_run_until_complete
|
||||||
|
async def step_request_completion(context, api_error: Literal['raised'] | str):
|
||||||
|
if api_error != 'no':
|
||||||
|
raise ValueError(f'api_error={api_error} is not yet implemented')
|
||||||
|
payload = {
|
||||||
|
"prompt": context.prompts[0],
|
||||||
|
"input_suffix": context.infill_input_suffix,
|
||||||
|
"input_prefix": context.infill_input_prefix,
|
||||||
|
"n_predict": context.n_predict,
|
||||||
|
"seed": context.seed,
|
||||||
|
"temperature": context.temperature,
|
||||||
|
}
|
||||||
|
if context.infill_input_extra is not None:
|
||||||
|
payload['input_extra'] = context.infill_input_extra
|
||||||
|
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
|
||||||
|
async with session.post(f'{context.base_url}/infill',
|
||||||
|
json=payload) as response:
|
||||||
|
assert response.status == 200
|
||||||
|
context.tasks_result = [await response.json()]
|
||||||
|
|
||||||
|
|
||||||
@step('{predicted_n:d} tokens are predicted matching {re_content}')
|
@step('{predicted_n:d} tokens are predicted matching {re_content}')
|
||||||
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
|
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
|
||||||
context.completion = context.tasks_result.pop()
|
context.completion = context.tasks_result.pop()
|
||||||
@ -539,6 +566,25 @@ def step_a_prompt_prompt(context, prompt):
|
|||||||
context.n_prompts = len(context.prompts)
|
context.n_prompts = len(context.prompts)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: allow this to be repeated
|
||||||
|
@step('an infill input extra {filename} {text}')
|
||||||
|
def step_infill_input_extra(context, filename, text):
|
||||||
|
if filename == 'none':
|
||||||
|
context.infill_input_extra = None
|
||||||
|
else:
|
||||||
|
context.infill_input_extra = [{'filename': filename, 'text': text}]
|
||||||
|
|
||||||
|
|
||||||
|
@step('an infill input suffix {text}')
|
||||||
|
def step_infill_input_suffix(context, text):
|
||||||
|
context.infill_input_suffix = text
|
||||||
|
|
||||||
|
|
||||||
|
@step('an infill input prefix {text}')
|
||||||
|
def step_infill_input_prefix(context, text):
|
||||||
|
context.infill_input_prefix = text
|
||||||
|
|
||||||
|
|
||||||
@step('{num_prompts:d} prompts {prompt} with seed {seed:d}')
|
@step('{num_prompts:d} prompts {prompt} with seed {seed:d}')
|
||||||
def step_many_prompts(context, num_prompts, prompt, seed):
|
def step_many_prompts(context, num_prompts, prompt, seed):
|
||||||
if context.seed is None:
|
if context.seed is None:
|
||||||
|
@ -24,6 +24,22 @@
|
|||||||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
using llama_tokens = std::vector<llama_token>;
|
||||||
|
|
||||||
|
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||||
|
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||||
|
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||||
|
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||||
|
|
||||||
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
||||||
enum error_type {
|
enum error_type {
|
||||||
@ -52,9 +68,235 @@ static T json_value(const json & body, const std::string & key, const T & defaul
|
|||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// chat template utils
|
// tokenizer and input processing utils
|
||||||
//
|
//
|
||||||
|
|
||||||
|
static bool json_is_array_of_numbers(const json & data) {
|
||||||
|
if (data.is_array()) {
|
||||||
|
for (const auto & e : data) {
|
||||||
|
if (!e.is_number_integer()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// is array having BOTH numbers & strings?
|
||||||
|
static bool json_is_array_of_mixed_numbers_strings(const json & data) {
|
||||||
|
bool seen_string = false;
|
||||||
|
bool seen_number = false;
|
||||||
|
if (data.is_array()) {
|
||||||
|
for (const auto & e : data) {
|
||||||
|
seen_string |= e.is_string();
|
||||||
|
seen_number |= e.is_number_integer();
|
||||||
|
if (seen_number && seen_string) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* this handles 2 cases:
|
||||||
|
* - only string, example: "string"
|
||||||
|
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
|
||||||
|
*/
|
||||||
|
static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
|
||||||
|
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
||||||
|
// or the first element of the json_prompt array is a string.
|
||||||
|
llama_tokens prompt_tokens;
|
||||||
|
|
||||||
|
if (json_prompt.is_array()) {
|
||||||
|
bool first = true;
|
||||||
|
for (const auto & p : json_prompt) {
|
||||||
|
if (p.is_string()) {
|
||||||
|
auto s = p.template get<std::string>();
|
||||||
|
|
||||||
|
llama_tokens p;
|
||||||
|
if (first) {
|
||||||
|
p = common_tokenize(ctx, s, add_special, parse_special);
|
||||||
|
first = false;
|
||||||
|
} else {
|
||||||
|
p = common_tokenize(ctx, s, false, parse_special);
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
||||||
|
} else {
|
||||||
|
if (first) {
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt_tokens.push_back(p.template get<llama_token>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto s = json_prompt.template get<std::string>();
|
||||||
|
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special);
|
||||||
|
}
|
||||||
|
|
||||||
|
return prompt_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* break the input "prompt" object into multiple prompt if needed, then tokenize them
|
||||||
|
* this supports these cases:
|
||||||
|
* - "prompt": "string"
|
||||||
|
* - "prompt": [12, 34, 56]
|
||||||
|
* - "prompt": [12, 34, "string", 56, 78]
|
||||||
|
* and multiple prompts (multi-tasks):
|
||||||
|
* - "prompt": ["string1", "string2"]
|
||||||
|
* - "prompt": ["string1", [12, 34, 56]]
|
||||||
|
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
|
||||||
|
*/
|
||||||
|
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {
|
||||||
|
std::vector<llama_tokens> result;
|
||||||
|
if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
|
||||||
|
// string or mixed
|
||||||
|
result.push_back(tokenize_mixed(ctx, json_prompt, add_special, parse_special));
|
||||||
|
} else if (json_is_array_of_numbers(json_prompt)) {
|
||||||
|
// array of tokens
|
||||||
|
result.push_back(json_prompt.get<llama_tokens>());
|
||||||
|
} else if (json_prompt.is_array()) {
|
||||||
|
// array of prompts
|
||||||
|
result.reserve(json_prompt.size());
|
||||||
|
for (const auto & p : json_prompt) {
|
||||||
|
if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
|
||||||
|
result.push_back(tokenize_mixed(ctx, p, add_special, parse_special));
|
||||||
|
} else if (json_is_array_of_numbers(p)) {
|
||||||
|
// array of tokens
|
||||||
|
result.push_back(p.get<llama_tokens>());
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// template utils
|
||||||
|
//
|
||||||
|
|
||||||
|
// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
|
||||||
|
static llama_tokens format_rerank(const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) {
|
||||||
|
llama_tokens result;
|
||||||
|
result.reserve(doc.size() + query.size() + 4);
|
||||||
|
result.push_back(llama_token_bos(model));
|
||||||
|
result.insert(result.end(), query.begin(), query.end());
|
||||||
|
result.push_back(llama_token_eos(model));
|
||||||
|
result.push_back(llama_token_sep(model));
|
||||||
|
result.insert(result.end(), doc.begin(), doc.end());
|
||||||
|
result.push_back(llama_token_eos(model));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// format infill task
|
||||||
|
static llama_tokens format_infill(
|
||||||
|
const llama_context * ctx,
|
||||||
|
const json & input_prefix,
|
||||||
|
const json & input_suffix,
|
||||||
|
const json & input_extra,
|
||||||
|
const int n_batch,
|
||||||
|
const int n_predict,
|
||||||
|
const int n_ctx,
|
||||||
|
const bool spm_infill,
|
||||||
|
const llama_tokens & tokens_prompt
|
||||||
|
) {
|
||||||
|
// TODO: optimize this block by reducing memory allocations and movement
|
||||||
|
|
||||||
|
// use FIM repo-level pattern:
|
||||||
|
// ref: https://arxiv.org/pdf/2409.12186
|
||||||
|
//
|
||||||
|
// [FIM_REP]myproject
|
||||||
|
// [FIM_SEP]filename0
|
||||||
|
// extra chunk 0
|
||||||
|
// [FIM_SEP]filename1
|
||||||
|
// extra chunk 1
|
||||||
|
// ...
|
||||||
|
// [FIM_SEP]filename
|
||||||
|
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
|
||||||
|
//
|
||||||
|
llama_tokens extra_tokens;
|
||||||
|
extra_tokens.reserve(n_ctx);
|
||||||
|
|
||||||
|
auto model = llama_get_model(ctx);
|
||||||
|
auto tokens_prefix = tokenize_mixed(ctx, input_prefix, false, false);
|
||||||
|
auto tokens_suffix = tokenize_mixed(ctx, input_suffix, false, false);
|
||||||
|
|
||||||
|
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
|
||||||
|
// TODO: make project name an input
|
||||||
|
static const auto k_fim_repo = common_tokenize(ctx, "myproject\n", false, false);
|
||||||
|
|
||||||
|
extra_tokens.push_back(llama_token_fim_rep(model));
|
||||||
|
extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
|
||||||
|
}
|
||||||
|
for (const auto & chunk : input_extra) {
|
||||||
|
// { "text": string, "filename": string }
|
||||||
|
const std::string text = json_value(chunk, "text", std::string());
|
||||||
|
const std::string filename = json_value(chunk, "filename", std::string("tmp"));
|
||||||
|
|
||||||
|
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||||
|
const auto k_fim_file = common_tokenize(ctx, filename + "\n", false, false);
|
||||||
|
|
||||||
|
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
|
||||||
|
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
||||||
|
} else {
|
||||||
|
// chunk separator in binary form to avoid confusing the AI
|
||||||
|
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
|
||||||
|
static const auto k_chunk_prefix_tokens = common_tokenize(ctx, k_chunk_prefix_str, false, false);
|
||||||
|
|
||||||
|
extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto chunk_tokens = common_tokenize(ctx, text, false, false);
|
||||||
|
extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
|
||||||
|
// TODO: current filename
|
||||||
|
static const auto k_fim_file = common_tokenize(ctx, "filename\n", false, false);
|
||||||
|
|
||||||
|
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model));
|
||||||
|
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
|
||||||
|
const int n_suffix_take = std::min<int>(tokens_suffix.size(), (n_batch/4));
|
||||||
|
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4) - 3);
|
||||||
|
|
||||||
|
// fill the rest of the context with extra chunks
|
||||||
|
const int n_extra_take = std::min<int>(std::max<int>(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size());
|
||||||
|
|
||||||
|
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
|
||||||
|
tokens_suffix.resize(n_suffix_take);
|
||||||
|
|
||||||
|
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model));
|
||||||
|
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
|
||||||
|
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model));
|
||||||
|
|
||||||
|
auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
|
||||||
|
auto embd_end = spm_infill ? tokens_prefix : tokens_suffix;
|
||||||
|
|
||||||
|
if (llama_add_bos_token(model)) {
|
||||||
|
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
||||||
|
}
|
||||||
|
|
||||||
|
SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
|
||||||
|
|
||||||
|
// put the extra context before the FIM prefix
|
||||||
|
embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
|
||||||
|
|
||||||
|
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
||||||
|
embd_inp.push_back(llama_token_fim_mid(model));
|
||||||
|
|
||||||
|
return embd_inp;
|
||||||
|
}
|
||||||
|
|
||||||
// Format given chat. If tmpl is empty, we take the template from model metadata
|
// Format given chat. If tmpl is empty, we take the template from model metadata
|
||||||
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
|
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
|
||||||
std::vector<common_chat_msg> chat;
|
std::vector<common_chat_msg> chat;
|
||||||
@ -229,18 +471,6 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin
|
|||||||
return std::string::npos;
|
return std::string::npos;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool json_is_array_of_numbers(const json & data) {
|
|
||||||
if (data.is_array()) {
|
|
||||||
for (const auto & e : data) {
|
|
||||||
if (!e.is_number()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: reuse llama_detokenize
|
// TODO: reuse llama_detokenize
|
||||||
template <class Iter>
|
template <class Iter>
|
||||||
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
||||||
|
Loading…
Reference in New Issue
Block a user