mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-26 03:12:23 +01:00
server : improve "prompt" handling (#7847)
This commit is contained in:
parent
1f0dabda8d
commit
d9da0e4986
@ -147,7 +147,7 @@ struct server_slot {
|
|||||||
int32_t n_prompt_tokens = 0;
|
int32_t n_prompt_tokens = 0;
|
||||||
int32_t n_prompt_tokens_processed = 0;
|
int32_t n_prompt_tokens_processed = 0;
|
||||||
|
|
||||||
json prompt;
|
std::string prompt;
|
||||||
|
|
||||||
// when a task is submitted, we first tokenize the prompt and store it here
|
// when a task is submitted, we first tokenize the prompt and store it here
|
||||||
std::vector<llama_token> prompt_tokens;
|
std::vector<llama_token> prompt_tokens;
|
||||||
@ -822,13 +822,8 @@ struct server_context {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// skip the slot if it does not contains prompt
|
|
||||||
if (!slot.prompt.is_string()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// current slot's prompt
|
// current slot's prompt
|
||||||
std::string slot_prompt = slot.prompt.get<std::string>();
|
std::string slot_prompt = slot.prompt;
|
||||||
|
|
||||||
// length of the current slot's prompt
|
// length of the current slot's prompt
|
||||||
int slot_prompt_len = slot_prompt.size();
|
int slot_prompt_len = slot_prompt.size();
|
||||||
@ -958,13 +953,16 @@ struct server_context {
|
|||||||
if (!task.infill) {
|
if (!task.infill) {
|
||||||
const auto & prompt = data.find("prompt");
|
const auto & prompt = data.find("prompt");
|
||||||
if (prompt == data.end()) {
|
if (prompt == data.end()) {
|
||||||
send_error(task, "Either \"prompt\" or \"messages\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
||||||
return false;
|
return false;
|
||||||
} else {
|
|
||||||
slot.prompt = *prompt;
|
|
||||||
}
|
}
|
||||||
if (slot.prompt.is_array() && slot.prompt.size() == 0) {
|
|
||||||
send_error(task, "\"prompt\" cannot be an empty array", ERROR_TYPE_INVALID_REQUEST);
|
if (prompt->is_string()) {
|
||||||
|
slot.prompt = prompt->get<std::string>();
|
||||||
|
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) {
|
||||||
|
slot.prompt = prompt->at(0).get<std::string>();
|
||||||
|
} else {
|
||||||
|
send_error(task, "\"prompt\" must be a string or an array of strings", ERROR_TYPE_INVALID_REQUEST);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1582,14 +1580,18 @@ struct server_context {
|
|||||||
switch (task.type) {
|
switch (task.type) {
|
||||||
case SERVER_TASK_TYPE_COMPLETION:
|
case SERVER_TASK_TYPE_COMPLETION:
|
||||||
{
|
{
|
||||||
int id_slot = json_value(task.data, "id_slot", -1);
|
const int id_slot = json_value(task.data, "id_slot", -1);
|
||||||
std::string prompt = json_value(task.data, "prompt", std::string());
|
|
||||||
|
|
||||||
server_slot * slot;
|
server_slot * slot;
|
||||||
|
|
||||||
if (id_slot != -1) {
|
if (id_slot != -1) {
|
||||||
slot = get_slot_by_id(id_slot);
|
slot = get_slot_by_id(id_slot);
|
||||||
} else {
|
} else {
|
||||||
|
std::string prompt;
|
||||||
|
if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
|
||||||
|
json_value(task.data, "prompt", std::string());
|
||||||
|
}
|
||||||
|
|
||||||
slot = get_available_slot(prompt);
|
slot = get_available_slot(prompt);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user