From 4ffc7a17d4e80c5f3f905139cb570ed9b6934fcb Mon Sep 17 00:00:00 2001 From: Niall Coates <1349685+Niall-@users.noreply.github.com> Date: Tue, 6 Feb 2024 08:16:23 +0000 Subject: [PATCH] server : various fixes for the prompt field in /completion (#5300) server : fix deadlock when prompt array contains strings and numbers server : removed an unnecessary generation when generating multi-prompts server : removed an unnecessary assert --- examples/server/server.cpp | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8000fee5c..fc7e723a1 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1163,13 +1163,30 @@ struct llama_server_context task.multitask_id = multitask_id; // when a completion task's prompt array is not a singleton, we split it into multiple requests - if (task.data.count("prompt") && task.data.at("prompt").size() > 1) - { - split_multiprompt_task(task_id, task); - } - // otherwise, it's a single-prompt task, we actually queue it - queue_tasks.post(task); + // if there's numbers in the prompt array it will be treated as an array of tokens + if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) { + bool numbers = false; + for (const auto& e : task.data.at("prompt")) { + if (e.is_number()) { + numbers = true; + break; + } + } + + // NOTE: split_multiprompt_task() does not handle a mix of strings and numbers, + // it will completely stall the server. I don't know where the bug for this is. + // + // if there are numbers, it needs to be treated like a single prompt, + // queue_tasks handles a mix of strings and numbers just fine. + if (numbers) { + queue_tasks.post(task); + } else { + split_multiprompt_task(task_id, task); + } + } else { + queue_tasks.post(task); + } } // for multiple images processing @@ -1251,7 +1268,10 @@ struct llama_server_context void split_multiprompt_task(int multitask_id, task_server& multiprompt_task) { int prompt_count = multiprompt_task.data.at("prompt").size(); - assert(prompt_count > 1); + if (prompt_count <= 1) { + send_error(multiprompt_task, "error while handling multiple prompts"); + return; + } // generate all the ID for subtask std::vector subtask_ids(prompt_count);