mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 14:20:31 +01:00
server : fix smart selection of available slot (#10120)
* Fix smart selection of available slot * minor fix * replace vectors of tokens with shorthands
This commit is contained in:
parent
1804adb0cf
commit
d865d1478c
@ -725,12 +725,12 @@ struct server_context {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
server_slot * get_available_slot(const std::string & prompt) {
|
server_slot * get_available_slot(const server_task & task) {
|
||||||
server_slot * ret = nullptr;
|
server_slot * ret = nullptr;
|
||||||
|
|
||||||
// find the slot that has at least n% prompt similarity
|
// find the slot that has at least n% prompt similarity
|
||||||
if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) {
|
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
|
||||||
int max_lcp_len = 0;
|
int max_lcs_len = 0;
|
||||||
float similarity = 0;
|
float similarity = 0;
|
||||||
|
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
@ -740,25 +740,25 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// skip the slot if it does not contains cached tokens
|
// skip the slot if it does not contains cached tokens
|
||||||
if (slot.prompt_tokens.empty()) {
|
if (slot.cache_tokens.empty()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
|
// length of the Longest Common Subsequence between the current slot's prompt and the input prompt
|
||||||
int lcp_len = longest_common_prefix(slot.cache_tokens, slot.prompt_tokens);
|
int lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens);
|
||||||
|
|
||||||
// fraction of the common substring length compared to the current slot's prompt length
|
// fraction of the common subsequence length compared to the current slot's prompt length
|
||||||
similarity = static_cast<float>(lcp_len) / static_cast<int>(slot.prompt_tokens.size());
|
similarity = static_cast<float>(lcs_len) / static_cast<int>(slot.cache_tokens.size());
|
||||||
|
|
||||||
// select the current slot if the criteria match
|
// select the current slot if the criteria match
|
||||||
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
|
if (lcs_len > max_lcs_len && similarity > slot_prompt_similarity) {
|
||||||
max_lcp_len = lcp_len;
|
max_lcs_len = lcs_len;
|
||||||
ret = &slot;
|
ret = &slot;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ret != nullptr) {
|
if (ret != nullptr) {
|
||||||
SLT_DBG(*ret, "selected slot by lcp similarity, max_lcp_len = %d, similarity = %f\n", max_lcp_len, similarity);
|
SLT_DBG(*ret, "selected slot by lcs similarity, max_lcs_len = %d, similarity = %f\n", max_lcs_len, similarity);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1514,18 +1514,7 @@ struct server_context {
|
|||||||
{
|
{
|
||||||
const int id_slot = json_value(task.data, "id_slot", -1);
|
const int id_slot = json_value(task.data, "id_slot", -1);
|
||||||
|
|
||||||
server_slot * slot;
|
server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
|
||||||
|
|
||||||
if (id_slot != -1) {
|
|
||||||
slot = get_slot_by_id(id_slot);
|
|
||||||
} else {
|
|
||||||
std::string prompt;
|
|
||||||
if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
|
|
||||||
prompt = json_value(task.data, "prompt", std::string());
|
|
||||||
}
|
|
||||||
|
|
||||||
slot = get_available_slot(prompt);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (slot == nullptr) {
|
if (slot == nullptr) {
|
||||||
// if no slot is available, we defer this task for processing later
|
// if no slot is available, we defer this task for processing later
|
||||||
|
@ -439,18 +439,60 @@ static std::string gen_chatcmplid() {
|
|||||||
// other common utils
|
// other common utils
|
||||||
//
|
//
|
||||||
|
|
||||||
static size_t longest_common_prefix(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
|
static size_t longest_common_prefix(const llama_tokens & a, const llama_tokens & b) {
|
||||||
size_t i;
|
size_t i;
|
||||||
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
||||||
|
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t longest_common_prefix(const std::string & a, const std::string & b) {
|
static size_t longest_common_subsequence(const llama_tokens & a, const llama_tokens & b) {
|
||||||
size_t i;
|
// check for empty sequences
|
||||||
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
if (a.empty() || b.empty()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
return i;
|
// get the lengths of the input sequences
|
||||||
|
int a_len = a.size();
|
||||||
|
int b_len = b.size();
|
||||||
|
|
||||||
|
// initialize the maximum length of the longest common subsequence (LCS)
|
||||||
|
int max_length = 0;
|
||||||
|
|
||||||
|
// use two rows instead of a 2D matrix to optimize space
|
||||||
|
std::vector<int> prev_row(b_len + 1, 0);
|
||||||
|
std::vector<int> curr_row(b_len + 1, 0);
|
||||||
|
|
||||||
|
// iterate through the elements of a
|
||||||
|
for (int i = 1; i <= a_len; i++) {
|
||||||
|
// iterate through the elements of b
|
||||||
|
for (int j = 1; j <= b_len; j++) {
|
||||||
|
// if elements at the current positions match
|
||||||
|
if (a[i - 1] == b[j - 1]) {
|
||||||
|
// if it's the first element of either sequences, set LCS length to 1
|
||||||
|
if (i == 1 || j == 1) {
|
||||||
|
curr_row[j] = 1;
|
||||||
|
} else {
|
||||||
|
// increment LCS length by 1 compared to the previous element
|
||||||
|
curr_row[j] = prev_row[j - 1] + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// update max_length if necessary
|
||||||
|
if (curr_row[j] > max_length) {
|
||||||
|
max_length = curr_row[j];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// reset LCS length if elements don't match
|
||||||
|
curr_row[j] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// update the previous row for the next iteration
|
||||||
|
prev_row = curr_row;
|
||||||
|
}
|
||||||
|
|
||||||
|
// return the maximum length of the LCS
|
||||||
|
return max_length;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ends_with(const std::string & str, const std::string & suffix) {
|
static bool ends_with(const std::string & str, const std::string & suffix) {
|
||||||
|
Loading…
Reference in New Issue
Block a user