wip: add cencellable request

This commit is contained in:
Xuan Son Nguyen 2025-01-10 15:23:13 +01:00
parent ba8a1f9c5b
commit 15fbcb5df7

View File

@ -1581,12 +1581,20 @@ struct server_response {
}
// This function blocks the thread until there is a response for one of the id_tasks
server_task_result_ptr recv(const std::unordered_set<int> & id_tasks) {
server_task_result_ptr recv(const std::unordered_set<int> & id_tasks, int timeout = -1) {
while (true) {
std::unique_lock<std::mutex> lock(mutex_results);
condition_results.wait(lock, [&]{
return !queue_results.empty();
});
if (timeout == -1) {
condition_results.wait(lock, [&]{
return !queue_results.empty();
});
} else {
if (!condition_results.wait_for(lock, std::chrono::seconds(timeout), [&]{
return !queue_results.empty();
})) {
return nullptr;
}
}
for (int i = 0; i < (int) queue_results.size(); i++) {
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
@ -2315,10 +2323,17 @@ struct server_context {
void receive_multi_results(
const std::unordered_set<int> & id_tasks,
const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
const std::function<bool()> & check_alive_handler,
const std::function<void(json)> & error_handler) {
static constexpr int UNBLOCK_SECONDS = 1;
std::vector<server_task_result_ptr> results(id_tasks.size());
for (size_t i = 0; i < id_tasks.size(); i++) {
server_task_result_ptr result = queue_results.recv(id_tasks);
server_task_result_ptr result = queue_results.recv(id_tasks, UNBLOCK_SECONDS);
if (!check_alive_handler()) {
cancel_tasks(id_tasks);
return;
}
if (result->is_error()) {
error_handler(result->to_json());
@ -2342,10 +2357,18 @@ struct server_context {
void receive_cmpl_results_stream(
const std::unordered_set<int> & id_tasks,
const std::function<bool(server_task_result_ptr&)> & result_handler,
const std::function<bool()> & check_alive_handler,
const std::function<void(json)> & error_handler) {
size_t n_finished = 0;
static constexpr int UNBLOCK_SECONDS = 1;
while (true) {
server_task_result_ptr result = queue_results.recv(id_tasks);
server_task_result_ptr result = queue_results.recv(id_tasks, UNBLOCK_SECONDS);
// mostly useful to cancel the task during prompt processing
if (!check_alive_handler()) {
cancel_tasks(id_tasks);
return;
}
if (result->is_error()) {
error_handler(result->to_json());
@ -3686,6 +3709,8 @@ int main(int argc, char ** argv) {
}
res_ok(res, arr);
}
}, [&]() {
return true;
}, [&](const json & error_data) {
res_error(res, error_data);
});
@ -3705,6 +3730,8 @@ int main(int argc, char ** argv) {
} else {
return server_sent_event(sink, "data", res_json);
}
}, [&]() {
return true;
}, [&](const json & error_data) {
server_sent_event(sink, "error", error_data);
});