mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 05:17:21 +01:00
wip: add cencellable request
This commit is contained in:
parent
ba8a1f9c5b
commit
15fbcb5df7
@ -1581,12 +1581,20 @@ struct server_response {
|
||||
}
|
||||
|
||||
// This function blocks the thread until there is a response for one of the id_tasks
|
||||
server_task_result_ptr recv(const std::unordered_set<int> & id_tasks) {
|
||||
server_task_result_ptr recv(const std::unordered_set<int> & id_tasks, int timeout = -1) {
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
condition_results.wait(lock, [&]{
|
||||
return !queue_results.empty();
|
||||
});
|
||||
if (timeout == -1) {
|
||||
condition_results.wait(lock, [&]{
|
||||
return !queue_results.empty();
|
||||
});
|
||||
} else {
|
||||
if (!condition_results.wait_for(lock, std::chrono::seconds(timeout), [&]{
|
||||
return !queue_results.empty();
|
||||
})) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int) queue_results.size(); i++) {
|
||||
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
|
||||
@ -2315,10 +2323,17 @@ struct server_context {
|
||||
void receive_multi_results(
|
||||
const std::unordered_set<int> & id_tasks,
|
||||
const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
|
||||
const std::function<bool()> & check_alive_handler,
|
||||
const std::function<void(json)> & error_handler) {
|
||||
static constexpr int UNBLOCK_SECONDS = 1;
|
||||
std::vector<server_task_result_ptr> results(id_tasks.size());
|
||||
for (size_t i = 0; i < id_tasks.size(); i++) {
|
||||
server_task_result_ptr result = queue_results.recv(id_tasks);
|
||||
server_task_result_ptr result = queue_results.recv(id_tasks, UNBLOCK_SECONDS);
|
||||
|
||||
if (!check_alive_handler()) {
|
||||
cancel_tasks(id_tasks);
|
||||
return;
|
||||
}
|
||||
|
||||
if (result->is_error()) {
|
||||
error_handler(result->to_json());
|
||||
@ -2342,10 +2357,18 @@ struct server_context {
|
||||
void receive_cmpl_results_stream(
|
||||
const std::unordered_set<int> & id_tasks,
|
||||
const std::function<bool(server_task_result_ptr&)> & result_handler,
|
||||
const std::function<bool()> & check_alive_handler,
|
||||
const std::function<void(json)> & error_handler) {
|
||||
size_t n_finished = 0;
|
||||
static constexpr int UNBLOCK_SECONDS = 1;
|
||||
while (true) {
|
||||
server_task_result_ptr result = queue_results.recv(id_tasks);
|
||||
server_task_result_ptr result = queue_results.recv(id_tasks, UNBLOCK_SECONDS);
|
||||
|
||||
// mostly useful to cancel the task during prompt processing
|
||||
if (!check_alive_handler()) {
|
||||
cancel_tasks(id_tasks);
|
||||
return;
|
||||
}
|
||||
|
||||
if (result->is_error()) {
|
||||
error_handler(result->to_json());
|
||||
@ -3686,6 +3709,8 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
res_ok(res, arr);
|
||||
}
|
||||
}, [&]() {
|
||||
return true;
|
||||
}, [&](const json & error_data) {
|
||||
res_error(res, error_data);
|
||||
});
|
||||
@ -3705,6 +3730,8 @@ int main(int argc, char ** argv) {
|
||||
} else {
|
||||
return server_sent_event(sink, "data", res_json);
|
||||
}
|
||||
}, [&]() {
|
||||
return true;
|
||||
}, [&](const json & error_data) {
|
||||
server_sent_event(sink, "error", error_data);
|
||||
});
|
||||
|
Loading…
x
Reference in New Issue
Block a user