diff --git a/.github/workflows/server.yml b/.github/workflows/server.yml index 04e3fc0c1..f9aeefaa8 100644 --- a/.github/workflows/server.yml +++ b/.github/workflows/server.yml @@ -58,7 +58,8 @@ jobs: cmake \ python3-pip \ wget \ - psmisc + psmisc \ + language-pack-en - name: Build id: cmake_build diff --git a/Makefile b/Makefile index 4f26c0463..efce10bb8 100644 --- a/Makefile +++ b/Makefile @@ -724,10 +724,9 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -server: examples/server/server.cpp examples/server/oai.hpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h examples/llava/llava.h examples/llava/llava.cpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS) +server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) - $(CXX) $(CXXFLAGS) -c examples/llava/clip.cpp -o $(call GET_OBJ_FILE, examples/llava/clip.cpp) -Wno-cast-qual - $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h %.hpp $< examples/llava/clip.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/llava/clip.cpp) -o $@ $(LDFLAGS) $(LWINSOCK2) + $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2) gguf: examples/gguf/gguf.cpp ggml.o $(OBJS) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) diff --git a/examples/server-embd.py b/examples/server-embd.py index c5c4ea87b..118e04271 100644 --- a/examples/server-embd.py +++ b/examples/server-embd.py @@ -13,7 +13,7 @@ async def main(): model_url = "http://127.0.0.1:6900" responses: list[requests.Response] = await asyncio.gather(*[requests_post_async( url= f"{model_url}/embedding", - json= {"content": str(i)*1024} + json= {"content": str(0)*1024} ) for i in range(n)]) for response in responses: diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index cc13b2d63..c21eba634 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -1,12 +1,12 @@ set(TARGET server) option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) -add_executable(${TARGET} server.cpp oai.hpp utils.hpp json.hpp httplib.h) +add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h) install(TARGETS ${TARGET} RUNTIME) target_compile_definitions(${TARGET} PRIVATE SERVER_VERBOSE=$ ) -target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT}) if (WIN32) TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32) endif() diff --git a/examples/server/README.md b/examples/server/README.md index 21da7a0a0..591f748f8 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -436,7 +436,7 @@ Notice that each `probs` is an array of length `n_probs`. "next_token": { "has_next_token": true, "n_remain": -1, - "num_tokens_predicted": 0, + "n_decoded": 0, "stopped_eos": false, "stopped_limit": false, "stopped_word": false, diff --git a/examples/server/oai.hpp b/examples/server/oai.hpp deleted file mode 100644 index ff4ad6994..000000000 --- a/examples/server/oai.hpp +++ /dev/null @@ -1,225 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "json.hpp" -#include "utils.hpp" - -#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" - -using json = nlohmann::json; - -inline static json oaicompat_completion_params_parse( - const struct llama_model * model, - const json &body, /* openai api json semantics */ - const std::string &chat_template) -{ - json llama_params; - - llama_params["__oaicompat"] = true; - - // Map OpenAI parameters to llama.cpp parameters - // - // For parameters that are defined by the OpenAI documentation (e.g. - // temperature), we explicitly specify OpenAI's intended default; we - // need to do that because sometimes OpenAI disagrees with llama.cpp - // - // https://platform.openai.com/docs/api-reference/chat/create - llama_sampling_params default_sparams; - llama_params["model"] = json_value(body, "model", std::string("unknown")); - llama_params["prompt"] = format_chat(model, chat_template, body["messages"]); - llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); - llama_params["temperature"] = json_value(body, "temperature", 0.0); - llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k); - llama_params["top_p"] = json_value(body, "top_p", 1.0); - llama_params["n_predict"] = json_value(body, "max_tokens", -1); - llama_params["logit_bias"] = json_value(body, "logit_bias",json::object()); - llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0); - llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0); - llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED); - llama_params["stream"] = json_value(body, "stream", false); - llama_params["mirostat"] = json_value(body, "mirostat", default_sparams.mirostat); - llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); - llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); - llama_params["penalize_nl"] = json_value(body, "penalize_nl", default_sparams.penalize_nl); - llama_params["typical_p"] = json_value(body, "typical_p", default_sparams.typical_p); - llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n); - llama_params["ignore_eos"] = json_value(body, "ignore_eos", false); - llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z); - - if (body.count("grammar") != 0) { - llama_params["grammar"] = json_value(body, "grammar", json::object()); - } - - // Handle 'stop' field - if (body.contains("stop") && body["stop"].is_string()) { - llama_params["stop"] = json::array({body["stop"].get()}); - } else { - llama_params["stop"] = json_value(body, "stop", json::array()); - } - - // Ensure there is ChatML-specific end sequence among stop words - llama_params["stop"].push_back("<|im_end|>"); - - return llama_params; -} - -inline static json format_final_response_oaicompat(const json &request, const task_result &response, bool streaming = false) -{ - json result = response.result_json; - - bool stopped_word = result.count("stopped_word") != 0; - bool stopped_eos = json_value(result, "stopped_eos", false); - int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); - std::string content = json_value(result, "content", std::string("")); - - std::string finish_reason = "length"; - if (stopped_word || stopped_eos) { - finish_reason = "stop"; - } - - json choices = - streaming ? json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}}}) - : json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"message", json{{"content", content}, - {"role", "assistant"}}}}}); - - std::time_t t = std::time(0); - - json res = - json{{"choices", choices}, - {"created", t}, - {"model", - json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, - {"usage", - json{{"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens}}}, - {"id", gen_chatcmplid()}}; - - if (server_verbose) { - res["__verbose"] = result; - } - - if (result.contains("completion_probabilities")) { - res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); - } - - return res; -} - -// return value is vector as there is one case where we might need to generate two responses -inline static std::vector format_partial_response_oaicompat(const task_result &response) { - json result = response.result_json; - - if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { - return std::vector({response.result_json}); - } - - bool first = json_value(result, "oaicompat_token_ctr", 0) == 0; - std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); - - bool stopped_word = json_value(result, "stopped_word", false); - bool stopped_eos = json_value(result, "stopped_eos", false); - bool stopped_limit = json_value(result, "stopped_limit", false); - std::string content = json_value(result, "content", std::string("")); - - std::string finish_reason; - if (stopped_word || stopped_eos) { - finish_reason = "stop"; - } - if (stopped_limit) { - finish_reason = "length"; - } - - std::time_t t = std::time(0); - - json choices; - - if (!finish_reason.empty()) { - choices = json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}}}); - } else { - if (first) { - if (content.empty()) { - choices = json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{{"role", "assistant"}}}}}); - } else { - // We have to send this as two updates to conform to openai behavior - json initial_ret = json{{"choices", json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"role", "assistant"} - }}}})}, - {"created", t}, - {"id", gen_chatcmplid()}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - json second_ret = json{ - {"choices", json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"content", content}}} - }})}, - {"created", t}, - {"id", gen_chatcmplid()}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - return std::vector({initial_ret, second_ret}); - } - } else { - // Some idiosyncrasy in task processing logic makes several trailing calls - // with empty content, we ignore these at the calee site. - if (content.empty()) { - return std::vector({json::object()}); - } - - choices = json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", - json{ - {"content", content}, - }}, - }}); - } - } - - json ret = json{{"choices", choices}, - {"created", t}, - {"id", gen_chatcmplid()}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - return std::vector({ret}); -} - -inline static json format_embeddings_response_oaicompat(const json &request, const json &embeddings) -{ - json res = - json{ - {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", "list"}, - {"usage", - json{{"prompt_tokens", 0}, - {"total_tokens", 0}}}, - {"data", embeddings} - }; - return res; -} - diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8fe5e0b19..3bdbde954 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1,13 +1,8 @@ +#include "utils.hpp" + #include "common.h" #include "llama.h" #include "grammar-parser.h" -#include "utils.hpp" -#include "oai.hpp" - -#include "../llava/clip.h" -#include "../llava/llava.h" - -#include "stb_image.h" #ifndef NDEBUG // crash the server in debug mode, otherwise send an http 500 error @@ -24,46 +19,76 @@ #include "completion.js.hpp" #include "json-schema-to-grammar.mjs.hpp" -#include -#include +#include #include #include -#include +#include +#include +#include +#include #include using json = nlohmann::json; -struct server_params { - std::string hostname = "127.0.0.1"; - std::vector api_keys; - std::string public_path = "examples/server/public"; - std::string chat_template = ""; - int32_t port = 8080; - int32_t read_timeout = 600; - int32_t write_timeout = 600; - bool slots_endpoint = true; - bool metrics_endpoint = false; - int n_threads_http = -1; -}; - bool server_verbose = false; bool server_log_json = true; enum stop_type { - STOP_FULL, - STOP_PARTIAL, + STOP_TYPE_FULL, + STOP_TYPE_PARTIAL, }; -// TODO: can become bool if we can't find use of more states enum slot_state { - IDLE, - PROCESSING, + SLOT_STATE_IDLE, + SLOT_STATE_PROCESSING, }; enum slot_command { - NONE, - LOAD_PROMPT, - RELEASE, + SLOT_COMMAND_NONE, + SLOT_COMMAND_LOAD_PROMPT, + SLOT_COMMAND_RELEASE, +}; + +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded + SERVER_STATE_ERROR // An error occurred, load_model failed +}; + +enum server_task_type { + SERVER_TASK_TYPE_COMPLETION, + SERVER_TASK_TYPE_CANCEL, + SERVER_TASK_TYPE_NEXT_RESPONSE, + SERVER_TASK_TYPE_METRICS +}; + +struct server_task { + int id = -1; // to be filled by server_queue + int id_multi = -1; + int id_target = -1; + + server_task_type type; + json data; + + bool infill = false; + bool embedding = false; +}; + +struct server_task_result { + int id = -1; + int id_multi = -1; + + json data; + + bool stop; + bool error; +}; + +struct server_task_multi { + int id = -1; + + std::set subtasks_remaining; + std::vector results; }; struct slot_params { @@ -80,26 +105,32 @@ struct slot_params { json input_suffix; }; -struct slot_image { - int32_t id; +struct server_params { + int32_t port = 8080; + int32_t read_timeout = 600; + int32_t write_timeout = 600; + int32_t n_threads_http = -1; - bool request_encode_image = false; - float * image_embedding = nullptr; - int32_t image_tokens = 0; + std::string hostname = "127.0.0.1"; + std::string public_path = "examples/server/public"; + std::string chat_template = ""; + std::string system_prompt = ""; - clip_image_u8 * img_data; + std::vector api_keys; - std::string prefix_prompt; // before of this image + bool slots_endpoint = true; + bool metrics_endpoint = false; }; struct server_slot { int id; - int task_id = -1; + int id_task = -1; + int id_multi = -1; struct slot_params params; - slot_state state = IDLE; - slot_command command = NONE; + slot_state state = SLOT_STATE_IDLE; + slot_command command = SLOT_COMMAND_NONE; // used to determine the slot that has been used the longest int64_t t_last_used = -1; @@ -116,27 +147,31 @@ struct server_slot { int32_t n_prompt_tokens_processed = 0; json prompt; + + // when a task is submitted, we first tokenize the prompt and store it here + std::vector prompt_tokens; + std::string generated_text; - llama_token sampled; std::vector cache_tokens; std::vector generated_token_probs; - bool infill = false; - bool embedding = false; + bool infill = false; + bool embedding = false; bool has_next_token = true; - bool truncated = false; - bool stopped_eos = false; - bool stopped_word = false; - bool stopped_limit = false; + bool truncated = false; + bool stopped_eos = false; + bool stopped_word = false; + bool stopped_limit = false; bool oaicompat = false; - std::string oaicompat_model; + std::string oaicompat_model; std::string stopping_word; // sampling + llama_token sampled; struct llama_sampling_params sparams; - llama_sampling_context *ctx_sampling = nullptr; + llama_sampling_context * ctx_sampling = nullptr; int32_t ga_i = 0; // group-attention state int32_t ga_n = 1; // group-attention factor @@ -144,48 +179,32 @@ struct server_slot { int32_t n_past_se = 0; // self-extend - // multimodal - std::vector images; - // stats size_t n_sent_text = 0; // number of sent text character size_t n_sent_token_probs = 0; int64_t t_start_process_prompt; - int64_t t_start_genereration; + int64_t t_start_generation; double t_prompt_processing; // ms double t_token_generation; // ms - // multitasks - int multitask_id = -1; - void reset() { - n_prompt_tokens = 0; - generated_text = ""; - truncated = false; - stopped_eos = false; - stopped_word = false; - stopped_limit = false; - stopping_word = ""; - n_past = 0; - n_sent_text = 0; - n_sent_token_probs = 0; - infill = false; - ga_i = 0; - n_past_se = 0; + n_prompt_tokens = 0; + generated_text = ""; + truncated = false; + stopped_eos = false; + stopped_word = false; + stopped_limit = false; + stopping_word = ""; + n_past = 0; + n_sent_text = 0; + n_sent_token_probs = 0; + infill = false; + ga_i = 0; + n_past_se = 0; generated_token_probs.clear(); - - for (slot_image & img : images) { - free(img.image_embedding); - if (img.img_data) { - clip_image_u8_free(img.img_data); - } - img.prefix_prompt = ""; - } - - images.clear(); } bool has_budget(gpt_params &global_params) { @@ -205,32 +224,29 @@ struct server_slot { } bool available() const { - return state == IDLE && command == NONE; + return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE; } bool is_processing() const { - return (state == IDLE && command == LOAD_PROMPT) || state == PROCESSING; + return (state == SLOT_STATE_IDLE && command == SLOT_COMMAND_LOAD_PROMPT) || state == SLOT_STATE_PROCESSING; } - void add_token_string(const completion_token_output &token) { - if (command == RELEASE) { + void add_token_string(const completion_token_output & token) { + if (command == SLOT_COMMAND_RELEASE) { return; } - cache_tokens.push_back(token.tok); generated_token_probs.push_back(token); } void release() { - if (state == PROCESSING) - { - t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3; - command = RELEASE; + if (state == SLOT_STATE_PROCESSING) { + t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; + command = SLOT_COMMAND_RELEASE; } } - json get_formated_timings() { - return json - { + json get_formated_timings() const { + return json { {"prompt_n", n_prompt_tokens_processed}, {"prompt_ms", t_prompt_processing}, {"prompt_per_token_ms", t_prompt_processing / n_prompt_tokens_processed}, @@ -243,16 +259,47 @@ struct server_slot { }; } + size_t find_stopping_strings(const std::string & text, const size_t last_token_size, const stop_type type) { + size_t stop_pos = std::string::npos; + + for (const std::string & word : params.antiprompt) { + size_t pos; + + if (type == STOP_TYPE_FULL) { + const size_t tmp = word.size() + last_token_size; + const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; + + pos = text.find(word, from_pos); + } else { + pos = find_partial_stop_string(word, text); + } + + if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { + if (type == STOP_TYPE_FULL) { + stopped_word = true; + stopping_word = word; + has_next_token = false; + } + stop_pos = pos; + } + } + + return stop_pos; + } + void print_timings() const { - char buffer[512]; + char buffer[512]; + double t_token = t_prompt_processing / n_prompt_tokens_processed; double n_tokens_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - sprintf(buffer, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)", + + snprintf(buffer, 512, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)", t_prompt_processing, n_prompt_tokens_processed, t_token, n_tokens_second); + LOG_INFO(buffer, { - {"slot_id", id}, - {"task_id", task_id}, + {"id_slot", id}, + {"id_task", id_task}, {"t_prompt_processing", t_prompt_processing}, {"n_prompt_tokens_processed", n_prompt_tokens_processed}, {"t_token", t_token}, @@ -261,22 +308,25 @@ struct server_slot { t_token = t_token_generation / n_decoded; n_tokens_second = 1e3 / t_token_generation * n_decoded; - sprintf(buffer, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)", + + snprintf(buffer, 512, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)", t_token_generation, n_decoded, t_token, n_tokens_second); + LOG_INFO(buffer, { - {"slot_id", id}, - {"task_id", task_id}, + {"id_slot", id}, + {"id_task", id_task}, {"t_token_generation", t_token_generation}, {"n_decoded", n_decoded}, {"t_token", t_token}, {"n_tokens_second", n_tokens_second}, }); - sprintf(buffer, " total time = %10.2f ms", t_prompt_processing + t_token_generation); + snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation); + LOG_INFO(buffer, { - {"slot_id", id}, - {"task_id", task_id}, + {"id_slot", id}, + {"id_task", id_task}, {"t_prompt_processing", t_prompt_processing}, {"t_token_generation", t_token_generation}, {"t_total", t_prompt_processing + t_token_generation}, @@ -291,9 +341,8 @@ struct server_metrics { uint64_t n_prompt_tokens_processed = 0; uint64_t t_prompt_processing = 0; - uint64_t n_tokens_predicted = 0; - uint64_t t_tokens_generation = 0; - + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; void on_prompt_eval(const server_slot &slot) { n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; @@ -315,23 +364,261 @@ struct server_metrics { } }; -struct llama_server_context -{ - llama_model *model = nullptr; - llama_context *ctx = nullptr; +struct server_queue { + int id = 0; + bool running; - clip_ctx *clp_ctx = nullptr; + // queues + std::vector queue_tasks; + std::vector queue_tasks_deferred; + + std::vector queue_multitasks; + + std::mutex mutex_tasks; + std::condition_variable condition_tasks; + + // callback functions + std::function callback_new_task; + std::function callback_finish_multitask; + std::function callback_run_slots; + + // Add a new task to the end of the queue + int post(server_task task) { + std::unique_lock lock(mutex_tasks); + if (task.id == -1) { + task.id = id++; + LOG_VERBOSE("new task id", {{"new_id", task.id}}); + } + queue_tasks.push_back(std::move(task)); + condition_tasks.notify_one(); + return task.id; + } + + // Add a new task, but defer until one slot is available + void defer(server_task task) { + std::unique_lock lock(mutex_tasks); + queue_tasks_deferred.push_back(std::move(task)); + } + + // Get the next id for creating anew task + int get_new_id() { + std::unique_lock lock(mutex_tasks); + int new_id = id++; + LOG_VERBOSE("new task id", {{"new_id", new_id}}); + return new_id; + } + + // Register function to process a new task + void on_new_task(std::function callback) { + callback_new_task = std::move(callback); + } + + // Register function to process a multitask when it is finished + void on_finish_multitask(std::function callback) { + callback_finish_multitask = std::move(callback); + } + + // Register the function to be called when all slots data is ready to be processed + void on_run_slots(std::function callback) { + callback_run_slots = std::move(callback); + } + + // Call when the state of one slot is changed + void notify_slot_changed() { + // move deferred tasks back to main loop + std::unique_lock lock(mutex_tasks); + for (auto & task : queue_tasks_deferred) { + queue_tasks.push_back(std::move(task)); + } + queue_tasks_deferred.clear(); + } + + // end the start_loop routine + void terminate() { + std::unique_lock lock(mutex_tasks); + running = false; + condition_tasks.notify_all(); + } + + /** + * Main loop consists of these steps: + * - Wait until a new task arrives + * - Process the task (i.e. maybe copy data into slot) + * - Check if multitask is finished + * - Run all slots + */ + void start_loop() { + running = true; + + while (true) { + LOG_VERBOSE("new task may arrive", {}); + + while (true) { + std::unique_lock lock(mutex_tasks); + if (queue_tasks.empty()) { + lock.unlock(); + break; + } + server_task task = queue_tasks.front(); + queue_tasks.erase(queue_tasks.begin()); + lock.unlock(); + LOG_VERBOSE("callback_new_task", {{"id_task", task.id}}); + callback_new_task(task); + } + + LOG_VERBOSE("update_multitasks", {}); + + // check if we have any finished multitasks + auto queue_iterator = queue_multitasks.begin(); + while (queue_iterator != queue_multitasks.end()) { + if (queue_iterator->subtasks_remaining.empty()) { + // all subtasks done == multitask is done + server_task_multi current_multitask = *queue_iterator; + callback_finish_multitask(current_multitask); + // remove this multitask + queue_iterator = queue_multitasks.erase(queue_iterator); + } else { + ++queue_iterator; + } + } + + // all tasks in the current loop is processed, slots data is now ready + LOG_VERBOSE("callback_run_slots", {}); + + callback_run_slots(); + + LOG_VERBOSE("wait for new task", {}); + { + std::unique_lock lock(mutex_tasks); + if (queue_tasks.empty()) { + if (!running) { + LOG_VERBOSE("ending start_loop", {}); + return; + } + condition_tasks.wait(lock, [&]{ + return (!queue_tasks.empty() || !running); + }); + } + } + } + } + + // + // functions to manage multitasks + // + + // add a multitask by specifying the id of all subtask (subtask is a server_task) + void add_multitask(int id_multi, std::vector & sub_ids) { + std::lock_guard lock(mutex_tasks); + server_task_multi multi; + multi.id = id_multi; + std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); + queue_multitasks.push_back(multi); + } + + // updatethe remaining subtasks, while appending results to multitask + void update_multitask(int id_multi, int id_sub, server_task_result & result) { + std::lock_guard lock(mutex_tasks); + for (auto & multitask : queue_multitasks) { + if (multitask.id == id_multi) { + multitask.subtasks_remaining.erase(id_sub); + multitask.results.push_back(result); + } + } + } +}; + +struct server_response { + typedef std::function callback_multitask_t; + callback_multitask_t callback_update_multitask; + + // for keeping track of all tasks waiting for the result + std::set waiting_task_ids; + + // the main result queue + std::vector queue_results; + + std::mutex mutex_results; + std::condition_variable condition_results; + + // add the id_task to the list of tasks waiting for response + void add_waiting_task_id(int id_task) { + LOG_VERBOSE("waiting for task id", {{"id_task", id_task}}); + + std::unique_lock lock(mutex_results); + waiting_task_ids.insert(id_task); + } + + // when the request is finished, we can remove task associated with it + void remove_waiting_task_id(int id_task) { + LOG_VERBOSE("remove waiting for task id", {{"id_task", id_task}}); + + std::unique_lock lock(mutex_results); + waiting_task_ids.erase(id_task); + } + + // This function blocks the thread until there is a response for this id_task + server_task_result recv(int id_task) { + while (true) { + std::unique_lock lock(mutex_results); + condition_results.wait(lock, [&]{ + return !queue_results.empty(); + }); + + for (int i = 0; i < (int) queue_results.size(); i++) { + if (queue_results[i].id == id_task) { + assert(queue_results[i].id_multi == -1); + server_task_result res = queue_results[i]; + queue_results.erase(queue_results.begin() + i); + return res; + } + } + } + + // should never reach here + } + + // Register the function to update multitask + void on_multitask_update(callback_multitask_t callback) { + callback_update_multitask = std::move(callback); + } + + // Send a new result to a waiting id_task + void send(server_task_result result) { + LOG_VERBOSE("send new result", {{"id_task", result.id}}); + + std::unique_lock lock(mutex_results); + for (const auto & id_task : waiting_task_ids) { + // LOG_TEE("waiting task id %i \n", id_task); + // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result + if (result.id_multi == id_task) { + LOG_VERBOSE("callback_update_multitask", {{"id_task", id_task}}); + callback_update_multitask(id_task, result.id, result); + continue; + } + + if (result.id == id_task) { + LOG_VERBOSE("queue_results.push_back", {{"id_task", id_task}}); + queue_results.push_back(result); + condition_results.notify_all(); + return; + } + } + } +}; + +struct server_context { + llama_model * model = nullptr; + llama_context * ctx = nullptr; gpt_params params; llama_batch batch; - bool multimodal = false; - bool clean_kv_cache = true; - bool all_slots_are_idle = false; - bool add_bos_token = true; + bool clean_kv_cache = true; + bool add_bos_token = true; - int32_t n_ctx; // total context for all clients / slots + int32_t n_ctx; // total context for all clients / slots // system prompt bool system_need_update = false; @@ -346,60 +633,32 @@ struct llama_server_context std::vector slots; json default_generation_settings_for_props; - llama_server_queue queue_tasks; - llama_server_response queue_results; + server_queue queue_tasks; + server_response queue_results; server_metrics metrics; - ~llama_server_context() - { - if (ctx) - { + ~server_context() { + if (ctx) { llama_free(ctx); ctx = nullptr; } - if (model) - { + + if (model) { llama_free_model(model); model = nullptr; } } - bool load_model(const gpt_params ¶ms_) - { + bool load_model(const gpt_params & params_) { params = params_; - if (!params.mmproj.empty()) { - multimodal = true; - LOG_INFO("Multi Modal Mode Enabled", {}); - clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1); - if(clp_ctx == nullptr) { - LOG_ERROR("unable to load clip model", {{"model", params.mmproj}}); - return false; - } - - if (params.n_ctx < 2048) { // request larger context for the image embedding - params.n_ctx = 2048; - } - } std::tie(model, ctx) = llama_init_from_gpt_params(params); - if (model == nullptr) - { + if (model == nullptr) { LOG_ERROR("unable to load model", {{"model", params.model}}); return false; } - if (multimodal) { - const int n_embd_clip = clip_n_mmproj_embd(clp_ctx); - const int n_embd_llm = llama_n_embd(model); - if (n_embd_clip != n_embd_llm) { - LOG_TEE("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_embd_clip, n_embd_llm); - llama_free(ctx); - llama_free_model(model); - return false; - } - } - n_ctx = llama_n_ctx(ctx); add_bos_token = llama_should_add_bos_token(model); @@ -407,25 +666,19 @@ struct llama_server_context return true; } - void validate_model_chat_template(server_params & sparams) { + bool validate_model_chat_template() const { llama_chat_message chat[] = {{"user", "test"}}; - std::vector buf(1); - int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size()); - if (res < 0) { - LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {}); - sparams.chat_template = "chatml"; - } + + const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0); + + return res > 0; } void initialize() { - // create slots - all_slots_are_idle = true; - const int32_t n_ctx_slot = n_ctx / params.n_parallel; LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}}); - for (int i = 0; i < params.n_parallel; i++) - { + for (int i = 0; i < params.n_parallel; i++) { server_slot slot; slot.id = i; @@ -433,7 +686,7 @@ struct llama_server_context slot.n_predict = params.n_predict; LOG_INFO("new slot", { - {"slot_id", slot.id}, + {"id_slot", slot.id}, {"n_ctx_slot", slot.n_ctx} }); @@ -447,9 +700,9 @@ struct llama_server_context //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT LOG_INFO("slot self-extend", { - {"slot_id", slot.id}, - {"ga_n", ga_n}, - {"ga_w", ga_w} + {"id_slot", slot.id}, + {"ga_n", ga_n}, + {"ga_w", ga_w} }); } @@ -468,8 +721,7 @@ struct llama_server_context batch = llama_batch_init(n_ctx, 0, params.n_parallel); } - std::vector tokenize(const json & json_prompt, bool add_bos) const - { + std::vector tokenize(const json & json_prompt, bool add_bos) const { // TODO: currently, we tokenize using special tokens by default // this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216) // but it's better compared to completely ignoring ChatML and other chat templates @@ -479,38 +731,30 @@ struct llama_server_context // or the first element of the json_prompt array is a string. std::vector prompt_tokens; - if (json_prompt.is_array()) - { + if (json_prompt.is_array()) { bool first = true; - for (const auto& p : json_prompt) - { - if (p.is_string()) - { + for (const auto & p : json_prompt) { + if (p.is_string()) { auto s = p.template get(); + std::vector p; - if (first) - { + if (first) { p = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL); first = false; - } - else - { + } else { p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL); } + prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); - } - else - { - if (first) - { + } else { + if (first) { first = false; } + prompt_tokens.push_back(p.template get()); } } - } - else - { + } else { auto s = json_prompt.template get(); prompt_tokens = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL); } @@ -518,19 +762,18 @@ struct llama_server_context return prompt_tokens; } - server_slot* get_slot(int id) { + server_slot * get_slot(int id) { int64_t t_last = ggml_time_us(); - server_slot *last_used = nullptr; - for (server_slot & slot : slots) - { - if (slot.id == id && slot.available()) - { + server_slot * last_used = nullptr; + + for (server_slot & slot : slots) { + if (slot.id == id && slot.available()) { return &slot; } - if (slot.available() && slot.t_last_used < t_last) - { + // among all available slots, find the one that has been least recently used + if (slot.available() && slot.t_last_used < t_last) { last_used = &slot; t_last = slot.t_last_used; } @@ -539,295 +782,204 @@ struct llama_server_context return last_used; } - bool launch_slot_with_data(server_slot* &slot, json data) { + bool launch_slot_with_data(server_slot & slot, json data) const { slot_params default_params; llama_sampling_params default_sparams; if (data.count("__oaicompat") != 0) { - slot->oaicompat = true; - slot->oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); + slot.oaicompat = true; + slot.oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); } else { - slot->oaicompat = false; - slot->oaicompat_model = ""; + slot.oaicompat = false; + slot.oaicompat_model = ""; } - slot->params.stream = json_value(data, "stream", false); - slot->params.cache_prompt = json_value(data, "cache_prompt", false); - slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict); - slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k); - slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p); - slot->sparams.min_p = json_value(data, "min_p", default_sparams.min_p); - slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); - slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); - slot->sparams.temp = json_value(data, "temperature", default_sparams.temp); - slot->sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); - slot->sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); - slot->sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); - slot->sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); - slot->sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); - slot->sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); - slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); - slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); - slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); - slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); - slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep); - slot->params.seed = json_value(data, "seed", default_params.seed); - slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar); - slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); - slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); + slot.params.stream = json_value(data, "stream", false); + slot.params.cache_prompt = json_value(data, "cache_prompt", false); + slot.params.n_predict = json_value(data, "n_predict", default_params.n_predict); + slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); + slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); + slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); + slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); + slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); + slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); + slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); + slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); + slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); + slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); + slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); + slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); + slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); + slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); + slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); + slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); + slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); + slot.params.seed = json_value(data, "seed", default_params.seed); + slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar); + slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); + slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); - if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) { + if (slot.params.cache_prompt && slot.ga_n != 1) { + LOG_WARNING("cache_prompt is not supported with group-attention", {}); + slot.params.cache_prompt = false; + } + + if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { // Might be better to reject the request with a 400 ? LOG_WARNING("Max tokens to predict exceeds server configuration", { - {"params.n_predict", slot->params.n_predict}, - {"slot.n_predict", slot->n_predict}, + {"params.n_predict", slot.params.n_predict}, + {"slot.n_predict", slot.n_predict}, }); - slot->params.n_predict = slot->n_predict; + slot.params.n_predict = slot.n_predict; } // infill - if (data.count("input_prefix") != 0) - { - slot->params.input_prefix = data["input_prefix"]; - } - else - { - slot->params.input_prefix = ""; - } + slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix); + slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix); + slot.prompt = json_value(data, "prompt", std::string("")); - if (data.count("input_suffix") != 0) + // penalize user-provided tokens { - slot->params.input_suffix = data["input_suffix"]; - } - else - { - slot->params.input_suffix = ""; - } + slot.sparams.penalty_prompt_tokens.clear(); + slot.sparams.use_penalty_prompt_tokens = false; - if (data.count("prompt") != 0) - { - slot->prompt = data["prompt"]; - } - else - { - slot->prompt = ""; - } + const auto & penalty_prompt = data.find("penalty_prompt"); - slot->sparams.penalty_prompt_tokens.clear(); - slot->sparams.use_penalty_prompt_tokens = false; - const auto &penalty_prompt = data.find("penalty_prompt"); - if (penalty_prompt != data.end()) - { - if (penalty_prompt->is_string()) - { - const auto penalty_prompt_string = penalty_prompt->get(); - auto penalty_tokens = llama_tokenize(model, penalty_prompt_string, false); - slot->sparams.penalty_prompt_tokens.swap(penalty_tokens); - if (slot->params.n_predict > 0) - { - slot->sparams.penalty_prompt_tokens.reserve(slot->sparams.penalty_prompt_tokens.size() + slot->params.n_predict); - } - slot->sparams.use_penalty_prompt_tokens = true; - } - else if (penalty_prompt->is_array()) - { - const auto n_tokens = penalty_prompt->size(); - slot->sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot->params.n_predict)); - const int n_vocab = llama_n_vocab(model); - for (const auto &penalty_token : *penalty_prompt) - { - if (penalty_token.is_number_integer()) - { - const auto tok = penalty_token.get(); - if (tok >= 0 && tok < n_vocab) - { - slot->sparams.penalty_prompt_tokens.push_back(tok); - } + if (penalty_prompt != data.end()) { + if (penalty_prompt->is_string()) { + const auto penalty_prompt_string = penalty_prompt->get(); + slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false); + + if (slot.params.n_predict > 0) { + slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict); } - } - slot->sparams.use_penalty_prompt_tokens = true; - } - } + slot.sparams.use_penalty_prompt_tokens = true; - slot->sparams.logit_bias.clear(); - - if (json_value(data, "ignore_eos", false)) - { - slot->sparams.logit_bias[llama_token_eos(model)] = -INFINITY; - } - - const auto &logit_bias = data.find("logit_bias"); - if (logit_bias != data.end() && logit_bias->is_array()) - { - const int n_vocab = llama_n_vocab(model); - for (const auto &el : *logit_bias) - { - if (el.is_array() && el.size() == 2) - { - float bias; - if (el[1].is_number()) - { - bias = el[1].get(); - } - else if (el[1].is_boolean() && !el[1].get()) - { - bias = -INFINITY; - } - else - { - continue; - } - - if (el[0].is_number_integer()) - { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) - { - slot->sparams.logit_bias[tok] = bias; - } - } - else if (el[0].is_string()) - { - auto toks = llama_tokenize(model, el[0].get(), false); - for (auto tok : toks) - { - slot->sparams.logit_bias[tok] = bias; - } - } - } - } - } - - slot->params.antiprompt.clear(); - - const auto &stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) - { - for (const auto &word : *stop) - { - if (!word.empty()) - { - slot->params.antiprompt.push_back(word); - } - } - } - - const auto &samplers_sequence = data.find("samplers"); - if (samplers_sequence != data.end() && samplers_sequence->is_array()) - { - std::vector sampler_names; - for (const auto &sampler_name : *samplers_sequence) - { - if (sampler_name.is_string()) - { - sampler_names.emplace_back(sampler_name); - } - } - slot->sparams.samplers_sequence = sampler_types_from_names(sampler_names, false); - } - else - { - slot->sparams.samplers_sequence = default_sparams.samplers_sequence; - } - - if (multimodal) - { - const auto &images_data = data.find("image_data"); - if (images_data != data.end() && images_data->is_array()) - { - for (const auto &img : *images_data) - { - const std::vector image_buffer = base64_decode(img["data"].get()); - - slot_image img_sl; - img_sl.id = img.count("id") != 0 ? img["id"].get() : slot->images.size(); - img_sl.img_data = clip_image_u8_init(); - if (!clip_image_load_from_bytes(image_buffer.data(), image_buffer.size(), img_sl.img_data)) - { - LOG_ERROR("failed to load image", { - {"slot_id", slot->id}, - {"img_sl_id", img_sl.id} - }); - return false; - } - LOG_VERBOSE("image loaded", { - {"slot_id", slot->id}, - {"img_sl_id", img_sl.id} + LOG_VERBOSE("penalty_prompt_tokens", { + {"id_slot", slot.id}, + {"tokens", slot.sparams.penalty_prompt_tokens}, }); - img_sl.request_encode_image = true; - slot->images.push_back(img_sl); } - // process prompt - // example: system prompt [img-102] user [img-103] describe [img-134] -> [{id: 102, prefix: 'system prompt '}, {id: 103, prefix: ' user '}, {id: 134, prefix: ' describe '}]} - if (slot->images.size() > 0 && !slot->prompt.is_array()) - { - std::string prompt = slot->prompt.get(); - size_t pos = 0, begin_prefix = 0; - std::string pattern = "[img-"; - while ((pos = prompt.find(pattern, pos)) != std::string::npos) { - size_t end_prefix = pos; - pos += pattern.length(); - size_t end_pos = prompt.find(']', pos); - if (end_pos != std::string::npos) - { - std::string image_id = prompt.substr(pos, end_pos - pos); - try - { - int img_id = std::stoi(image_id); - bool found = false; - for (slot_image &img : slot->images) - { - if (img.id == img_id) { - found = true; - img.prefix_prompt = prompt.substr(begin_prefix, end_prefix - begin_prefix); - begin_prefix = end_pos + 1; - break; - } - } - if (!found) { - LOG_TEE("ERROR: Image with id: %i, not found.\n", img_id); - slot->images.clear(); - return false; - } - } catch (const std::invalid_argument& e) { - LOG_TEE("Invalid image number id in prompt\n"); - slot->images.clear(); - return false; + else if (penalty_prompt->is_array()) { + const auto n_tokens = penalty_prompt->size(); + slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict)); + + const int n_vocab = llama_n_vocab(model); + for (const auto & penalty_token : *penalty_prompt) { + if (penalty_token.is_number_integer()) { + const auto tok = penalty_token.get(); + if (tok >= 0 && tok < n_vocab) { + slot.sparams.penalty_prompt_tokens.push_back(tok); } } } - slot->prompt = ""; - slot->params.input_suffix = prompt.substr(begin_prefix); - slot->params.cache_prompt = false; // multimodal doesn't support cache prompt + slot.sparams.use_penalty_prompt_tokens = true; + + LOG_VERBOSE("penalty_prompt_tokens", { + {"id_slot", slot.id}, + {"tokens", slot.sparams.penalty_prompt_tokens}, + }); } } } - if (slot->ctx_sampling != nullptr) { - llama_sampling_free(slot->ctx_sampling); - } - slot->ctx_sampling = llama_sampling_init(slot->sparams); - llama_set_rng_seed(ctx, slot->params.seed); - slot->command = LOAD_PROMPT; + slot.sparams.logit_bias.clear(); - all_slots_are_idle = false; + if (json_value(data, "ignore_eos", false)) { + slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + } + + const auto & logit_bias = data.find("logit_bias"); + if (logit_bias != data.end() && logit_bias->is_array()) { + const int n_vocab = llama_n_vocab(model); + for (const auto & el : *logit_bias) { + if (el.is_array() && el.size() == 2) { + float bias; + if (el[1].is_number()) { + bias = el[1].get(); + } else if (el[1].is_boolean() && !el[1].get()) { + bias = -INFINITY; + } else { + continue; + } + + if (el[0].is_number_integer()) { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) { + slot.sparams.logit_bias[tok] = bias; + } + } else if (el[0].is_string()) { + auto toks = llama_tokenize(model, el[0].get(), false); + for (auto tok : toks) { + slot.sparams.logit_bias[tok] = bias; + } + } + } + } + } + } + + { + slot.params.antiprompt.clear(); + + const auto & stop = data.find("stop"); + if (stop != data.end() && stop->is_array()) { + for (const auto & word : *stop) { + if (!word.empty()) { + slot.params.antiprompt.push_back(word); + } + } + } + } + + { + const auto & samplers_sequence = data.find("samplers"); + if (samplers_sequence != data.end() && samplers_sequence->is_array()) { + std::vector sampler_names; + for (const auto & sampler_name : *samplers_sequence) { + if (sampler_name.is_string()) { + sampler_names.emplace_back(sampler_name); + } + } + slot.sparams.samplers_sequence = sampler_types_from_names(sampler_names, false); + } else { + slot.sparams.samplers_sequence = default_sparams.samplers_sequence; + } + } + + { + if (slot.ctx_sampling != nullptr) { + llama_sampling_free(slot.ctx_sampling); + } + slot.ctx_sampling = llama_sampling_init(slot.sparams); + llama_set_rng_seed(ctx, slot.params.seed); + } + + slot.command = SLOT_COMMAND_LOAD_PROMPT; + slot.prompt_tokens.clear(); LOG_INFO("slot is processing task", { - {"slot_id", slot->id}, - {"task_id", slot->task_id}, + {"id_slot", slot.id}, + {"id_task", slot.id_task}, }); return true; } void kv_cache_clear() { + LOG_VERBOSE("clearing KV cache", {}); + // clear the entire KV cache llama_kv_cache_clear(ctx); clean_kv_cache = false; } void system_prompt_update() { + LOG_VERBOSE("system prompt update", { + {"system_prompt", system_prompt}, + }); + kv_cache_clear(); system_tokens.clear(); @@ -836,13 +988,11 @@ struct llama_server_context llama_batch_clear(batch); - for (int i = 0; i < (int)system_tokens.size(); ++i) - { + for (int i = 0; i < (int)system_tokens.size(); ++i) { llama_batch_add(batch, system_tokens[i], i, { 0 }, false); } - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch) - { + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch) { const int32_t n_tokens = std::min(params.n_batch, (int32_t) (batch.n_tokens - i)); llama_batch batch_view = { n_tokens, @@ -854,78 +1004,42 @@ struct llama_server_context batch.logits + i, 0, 0, 0, // unused }; - if (llama_decode(ctx, batch_view) != 0) - { + + if (llama_decode(ctx, batch_view) != 0) { LOG_TEE("%s: llama_decode() failed\n", __func__); return; } } // assign the system KV cache to all parallel sequences - for (int32_t i = 1; i < params.n_parallel; ++i) - { + for (int32_t i = 1; i < params.n_parallel; ++i) { llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size()); } } - LOG_TEE("system prompt updated\n"); system_need_update = false; } - void system_prompt_notify() { + void system_prompt_set(const json & sys_props) { + system_prompt = sys_props.value("prompt", ""); + name_user = sys_props.value("anti_prompt", ""); + name_assistant = sys_props.value("assistant_name", ""); + + LOG_VERBOSE("system prompt process", { + {"system_prompt", system_prompt}, + {"name_user", name_user}, + {"name_assistant", name_assistant}, + }); + // release all slots - for (server_slot &slot : slots) - { + for (server_slot & slot : slots) { slot.release(); } system_need_update = true; } - void system_prompt_process(const json &sys_props) { - system_prompt = sys_props.value("prompt", ""); - name_user = sys_props.value("anti_prompt", ""); - name_assistant = sys_props.value("assistant_name", ""); - - - system_prompt_notify(); - } - - static size_t find_stopping_strings(const std::string &text, const size_t last_token_size, - const stop_type type, server_slot &slot) - { - size_t stop_pos = std::string::npos; - - for (const std::string &word : slot.params.antiprompt) - { - size_t pos; - if (type == STOP_FULL) - { - const size_t tmp = word.size() + last_token_size; - const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; - pos = text.find(word, from_pos); - } - else - { - pos = find_partial_stop_string(word, text); - } - if (pos != std::string::npos && - (stop_pos == std::string::npos || pos < stop_pos)) - { - if (type == STOP_FULL) - { - slot.stopped_word = true; - slot.stopping_word = word; - slot.has_next_token = false; - } - stop_pos = pos; - } - } - - return stop_pos; - } - - bool process_token(completion_token_output &result, server_slot &slot) { + bool process_token(completion_token_output & result, server_slot & slot) { // remember which tokens were sampled - used for repetition penalties during sampling const std::string token_str = llama_token_to_piece(ctx, result.tok); slot.sampled = result.tok; @@ -934,34 +1048,26 @@ struct llama_server_context slot.generated_text += token_str; slot.has_next_token = true; - if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) - { + if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) { // we can change penalty_prompt_tokens because it is always created from scratch each request slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok); } // check if there is incomplete UTF-8 character at the end bool incomplete = false; - for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) - { + for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) { unsigned char c = slot.generated_text[slot.generated_text.size() - i]; - if ((c & 0xC0) == 0x80) - { + if ((c & 0xC0) == 0x80) { // continuation byte: 10xxxxxx continue; } - if ((c & 0xE0) == 0xC0) - { + if ((c & 0xE0) == 0xC0) { // 2-byte character: 110xxxxx ... incomplete = i < 2; - } - else if ((c & 0xF0) == 0xE0) - { + } else if ((c & 0xF0) == 0xE0) { // 3-byte character: 1110xxxx ... incomplete = i < 3; - } - else if ((c & 0xF8) == 0xF0) - { + } else if ((c & 0xF8) == 0xF0) { // 4-byte character: 11110xxx ... incomplete = i < 4; } @@ -969,206 +1075,181 @@ struct llama_server_context break; } - if (!incomplete) - { + if (!incomplete) { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); + const std::string str_test = slot.generated_text.substr(pos); bool is_stop_full = false; - size_t stop_pos = find_stopping_strings(str_test, token_str.size(), STOP_FULL, slot); - if (stop_pos != std::string::npos) - { + + size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL); + if (stop_pos != std::string::npos) { is_stop_full = true; slot.generated_text.erase( slot.generated_text.begin() + pos + stop_pos, slot.generated_text.end()); pos = std::min(slot.n_sent_text, slot.generated_text.size()); - } - else - { + } else { is_stop_full = false; - stop_pos = find_stopping_strings(str_test, token_str.size(), STOP_PARTIAL, slot); + stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL); } // check if there is any token to predict - if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) - { + if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) { // no send the stop word in the response result.text_to_send = slot.generated_text.substr(pos, std::string::npos); slot.n_sent_text += result.text_to_send.size(); // add the token to slot queue and cache } + slot.add_token_string(result); - if (slot.params.stream) - { + if (slot.params.stream) { send_partial_response(slot, result); } } - if (incomplete) - { + if (incomplete) { slot.has_next_token = true; } // check the limits - if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) - { - slot.stopped_limit = true; + if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) { + slot.stopped_limit = true; slot.has_next_token = false; + + LOG_VERBOSE("stopped by limit", { + {"id_slot", slot.id}, + {"n_decoded", slot.n_decoded}, + {"n_predict", slot.params.n_predict}, + }); } - if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model)) - { - slot.stopped_eos = true; + if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model)) { + slot.stopped_eos = true; slot.has_next_token = false; + LOG_VERBOSE("eos token found", {}); } LOG_VERBOSE("next token", { - {"token", result.tok}, - {"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, - {"has_next_token", slot.has_next_token}, - {"n_remain", slot.n_remaining}, - {"num_tokens_predicted", slot.n_decoded}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - }); + {"token", result.tok}, + {"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, + {"has_next_token", slot.has_next_token}, + {"n_remain", slot.n_remaining}, + {"n_decoded", slot.n_decoded}, + {"stopped_eos", slot.stopped_eos}, + {"stopped_word", slot.stopped_word}, + {"stopped_limit", slot.stopped_limit}, + {"stopping_word", slot.stopping_word}, + }); return slot.has_next_token; // continue } - bool process_images(server_slot &slot) const - { - for (slot_image &img : slot.images) - { - if (!img.request_encode_image) - { - continue; - } - - if (!llava_image_embed_make_with_clip_img(clp_ctx, params.n_threads, img.img_data, &img.image_embedding, &img.image_tokens)) { - LOG_TEE("Error processing the given image"); - return false; - } - - - img.request_encode_image = false; - } - - return slot.images.size() > 0; - } - - void send_error(task_server& task, const std::string &error) - { - LOG_TEE("task %i - error: %s\n", task.id, error.c_str()); - task_result res; - res.id = task.id; - res.multitask_id = task.multitask_id; - res.stop = false; - res.error = true; - res.result_json = { { "content", error } }; - queue_results.send(res); - } - - json get_formated_generation(server_slot &slot) - { + json get_formated_generation(const server_slot & slot) const { const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); - const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && - eos_bias->second < 0.0f && std::isinf(eos_bias->second); + const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); + std::vector samplers_sequence; - for (const auto &sampler_type : slot.sparams.samplers_sequence) - { + samplers_sequence.reserve(slot.sparams.samplers_sequence.size()); + for (const auto & sampler_type : slot.sparams.samplers_sequence) { samplers_sequence.emplace_back(sampler_type_to_name_string(sampler_type)); } return json { - {"n_ctx", slot.n_ctx}, - {"n_predict", slot.n_predict}, - {"model", params.model_alias}, - {"seed", slot.params.seed}, - {"temperature", slot.sparams.temp}, - {"dynatemp_range", slot.sparams.dynatemp_range}, - {"dynatemp_exponent", slot.sparams.dynatemp_exponent}, - {"top_k", slot.sparams.top_k}, - {"top_p", slot.sparams.top_p}, - {"min_p", slot.sparams.min_p}, - {"tfs_z", slot.sparams.tfs_z}, - {"typical_p", slot.sparams.typical_p}, - {"repeat_last_n", slot.sparams.penalty_last_n}, - {"repeat_penalty", slot.sparams.penalty_repeat}, - {"presence_penalty", slot.sparams.penalty_present}, - {"frequency_penalty", slot.sparams.penalty_freq}, - {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, + {"n_ctx", slot.n_ctx}, + {"n_predict", slot.n_predict}, + {"model", params.model_alias}, + {"seed", slot.params.seed}, + {"temperature", slot.sparams.temp}, + {"dynatemp_range", slot.sparams.dynatemp_range}, + {"dynatemp_exponent", slot.sparams.dynatemp_exponent}, + {"top_k", slot.sparams.top_k}, + {"top_p", slot.sparams.top_p}, + {"min_p", slot.sparams.min_p}, + {"tfs_z", slot.sparams.tfs_z}, + {"typical_p", slot.sparams.typical_p}, + {"repeat_last_n", slot.sparams.penalty_last_n}, + {"repeat_penalty", slot.sparams.penalty_repeat}, + {"presence_penalty", slot.sparams.penalty_present}, + {"frequency_penalty", slot.sparams.penalty_freq}, + {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, - {"mirostat", slot.sparams.mirostat}, - {"mirostat_tau", slot.sparams.mirostat_tau}, - {"mirostat_eta", slot.sparams.mirostat_eta}, - {"penalize_nl", slot.sparams.penalize_nl}, - {"stop", slot.params.antiprompt}, - {"n_predict", slot.params.n_predict}, - {"n_keep", params.n_keep}, - {"ignore_eos", ignore_eos}, - {"stream", slot.params.stream}, - {"logit_bias", slot.sparams.logit_bias}, - {"n_probs", slot.sparams.n_probs}, - {"min_keep", slot.sparams.min_keep}, - {"grammar", slot.sparams.grammar}, - {"samplers", samplers_sequence} + {"mirostat", slot.sparams.mirostat}, + {"mirostat_tau", slot.sparams.mirostat_tau}, + {"mirostat_eta", slot.sparams.mirostat_eta}, + {"penalize_nl", slot.sparams.penalize_nl}, + {"stop", slot.params.antiprompt}, + {"n_predict", slot.params.n_predict}, + {"n_keep", params.n_keep}, + {"ignore_eos", ignore_eos}, + {"stream", slot.params.stream}, + {"logit_bias", slot.sparams.logit_bias}, + {"n_probs", slot.sparams.n_probs}, + {"min_keep", slot.sparams.min_keep}, + {"grammar", slot.sparams.grammar}, + {"samplers", samplers_sequence} }; } - void send_partial_response(server_slot &slot, completion_token_output tkn) - { - task_result res; - res.id = slot.task_id; - res.multitask_id = slot.multitask_id; - res.error = false; - res.stop = false; + void send_error(const server_task & task, const std::string & error) { + LOG_TEE("task %i - error: %s\n", task.id, error.c_str()); - res.result_json = json - { + server_task_result res; + res.id = task.id; + res.id_multi = task.id_multi; + res.stop = false; + res.error = true; + res.data = { { "content", error } }; + + queue_results.send(res); + } + + void send_partial_response(server_slot & slot, completion_token_output tkn) { + server_task_result res; + res.id = slot.id_task; + res.id_multi = slot.id_multi; + res.error = false; + res.stop = false; + res.data = json { {"content", tkn.text_to_send}, {"stop", false}, - {"slot_id", slot.id}, - {"multimodal", multimodal} + {"id_slot", slot.id}, + {"multimodal", false} }; - if (slot.sparams.n_probs > 0) - { - std::vector probs_output = {}; + if (slot.sparams.n_probs > 0) { const std::vector to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false); - size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size()); - size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size()); - if (probs_pos < probs_stop_pos) - { - probs_output = std::vector(slot.generated_token_probs.begin() + probs_pos, slot.generated_token_probs.begin() + probs_stop_pos); + const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size()); + const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size()); + + std::vector probs_output; + if (probs_pos < probs_stop_pos) { + probs_output = std::vector( + slot.generated_token_probs.begin() + probs_pos, + slot.generated_token_probs.begin() + probs_stop_pos); } slot.n_sent_token_probs = probs_stop_pos; - res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs_output); + + res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output); } - if (slot.oaicompat) - { - res.result_json["oaicompat_token_ctr"] = slot.n_decoded; - res.result_json["model"] = slot.oaicompat_model; + if (slot.oaicompat) { + res.data["oaicompat_token_ctr"] = slot.n_decoded; + res.data["model"] = slot.oaicompat_model; } queue_results.send(res); } - void send_final_response(server_slot &slot) - { - task_result res; - res.id = slot.task_id; - res.multitask_id = slot.multitask_id; - res.error = false; - res.stop = true; - - res.result_json = json - { + void send_final_response(const server_slot & slot) { + server_task_result res; + res.id = slot.id_task; + res.id_multi = slot.id_multi; + res.error = false; + res.stop = true; + res.data = json { {"content", !slot.params.stream ? slot.generated_text : ""}, - {"slot_id", slot.id}, + {"id_slot", slot.id}, {"stop", true}, {"model", params.model_alias}, {"tokens_predicted", slot.n_decoded}, @@ -1184,96 +1265,87 @@ struct llama_server_context {"timings", slot.get_formated_timings()} }; - if (slot.sparams.n_probs > 0) - { - std::vector probs = {}; - if (!slot.params.stream && slot.stopped_word) - { + if (slot.sparams.n_probs > 0) { + std::vector probs; + if (!slot.params.stream && slot.stopped_word) { const std::vector stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false); - probs = std::vector(slot.generated_token_probs.begin(), slot.generated_token_probs.end() - stop_word_toks.size()); - } - else - { + probs = std::vector( - slot.generated_token_probs.begin(), - slot.generated_token_probs.end()); + slot.generated_token_probs.begin(), + slot.generated_token_probs.end() - stop_word_toks.size()); + } else { + probs = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end()); } - res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs); + + res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs); } - if (slot.oaicompat) - { - res.result_json["oaicompat_token_ctr"] = slot.n_decoded; - res.result_json["model"] = slot.oaicompat_model; + if (slot.oaicompat) { + res.data["oaicompat_token_ctr"] = slot.n_decoded; + res.data["model"] = slot.oaicompat_model; } queue_results.send(res); } - void send_embedding(server_slot & slot, const llama_batch & batch) - { - task_result res; - res.id = slot.task_id; - res.multitask_id = slot.multitask_id; - res.error = false; - res.stop = true; + void send_embedding(const server_slot & slot, const llama_batch & batch) { + server_task_result res; + res.id = slot.id_task; + res.id_multi = slot.id_multi; + res.error = false; + res.stop = true; const int n_embd = llama_n_embd(model); - if (!params.embedding) - { - LOG_WARNING("embedding disabled", {{"params.embedding", params.embedding}}); - res.result_json = json - { - {"embedding", std::vector(n_embd, 0.0f)}, + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + LOG_ERROR("failed to get embeddings", { + {"token", batch.token [i]}, + {"seq_id", batch.seq_id[i][0]} + }); + + res.data = json { + {"embedding", std::vector(n_embd, 0.0f)}, + }; + + continue; + } + + res.data = json { + {"embedding", std::vector(embd, embd + n_embd)}, }; } - else - { - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; - } - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - if (embd == NULL) { - LOG_ERROR("failed to get embeddings for token", {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}}); - res.result_json = json - { - {"embedding", std::vector(n_embd, 0.0f)}, - }; - continue; - } - } - - res.result_json = json - { - {"embedding", std::vector(embd, embd + n_embd)}, - }; - } - } queue_results.send(res); } - void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id) - { - task_server task; - task.id = task_id; - task.target_id = 0; - task.data = std::move(data); - task.infill_mode = infill; - task.embedding_mode = embedding; - task.type = TASK_TYPE_COMPLETION; - task.multitask_id = multitask_id; + void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding) { + server_task task; + task.id = id_task; + task.id_multi = id_multi; + task.id_target = 0; + task.data = std::move(data); + task.infill = infill; + task.embedding = embedding; + task.type = SERVER_TASK_TYPE_COMPLETION; // when a completion task's prompt array is not a singleton, we split it into multiple requests // otherwise, it's a single-prompt task, we actually queue it // if there's numbers in the prompt array it will be treated as an array of tokens if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) { bool numbers = false; - for (const auto& e : task.data.at("prompt")) { + for (const auto & e : task.data.at("prompt")) { if (e.is_number()) { numbers = true; break; @@ -1288,106 +1360,23 @@ struct llama_server_context if (numbers) { queue_tasks.post(task); } else { - split_multiprompt_task(task_id, task); + split_multiprompt_task(id_task, task); } } else { - // an empty prompt can make slot become buggy - if (task.data.contains("prompt") && task.data["prompt"].is_string() && task.data["prompt"].get().empty()) { - task.data["prompt"] = " "; // add a space so that we have one token - } queue_tasks.post(task); } } - // for multiple images processing - bool ingest_images(server_slot &slot, int n_batch) - { - int image_idx = 0; + void request_cancel(int id_task) { + server_task task; + task.type = SERVER_TASK_TYPE_CANCEL; + task.id_target = id_task; - while (image_idx < (int) slot.images.size()) - { - slot_image &img = slot.images[image_idx]; - - // process prefix prompt - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) - { - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - 0, 0, 0, // unused - }; - if (llama_decode(ctx, batch_view)) - { - LOG_TEE("%s : failed to eval\n", __func__); - return false; - } - } - - // process image with llm - for (int i = 0; i < img.image_tokens; i += n_batch) - { - int n_eval = img.image_tokens - i; - if (n_eval > n_batch) - { - n_eval = n_batch; - } - - const int n_embd = llama_n_embd(model); - llama_batch batch_img = { - n_eval, - nullptr, - (img.image_embedding + i * n_embd), - nullptr, - nullptr, - nullptr, - nullptr, - slot.n_past, - 1, 0 - }; - if (llama_decode(ctx, batch_img)) - { - LOG_TEE("%s : failed to eval image\n", __func__); - return false; - } - slot.n_past += n_eval; - } - image_idx++; - - llama_batch_clear(batch); - - // append prefix of next image - const auto json_prompt = (image_idx >= (int) slot.images.size()) ? - slot.params.input_suffix : // no more images, then process suffix prompt - (json)(slot.images[image_idx].prefix_prompt); - - std::vector append_tokens = tokenize(json_prompt, false); // has next image - for (int i = 0; i < (int) append_tokens.size(); ++i) - { - llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true); - slot.n_past += 1; - } - } - - return true; - } - - void request_cancel(int task_id) - { - task_server task; - task.type = TASK_TYPE_CANCEL; - task.target_id = task_id; queue_tasks.post(task); } - void split_multiprompt_task(int multitask_id, task_server& multiprompt_task) - { - int prompt_count = multiprompt_task.data.at("prompt").size(); + void split_multiprompt_task(int id_multi, const server_task & multiprompt_task) { + const int prompt_count = multiprompt_task.data.at("prompt").size(); if (prompt_count <= 1) { send_error(multiprompt_task, "error while handling multiple prompts"); return; @@ -1395,127 +1384,121 @@ struct llama_server_context // generate all the ID for subtask std::vector subtask_ids(prompt_count); - for (int i = 0; i < prompt_count; i++) - { + for (int i = 0; i < prompt_count; i++) { subtask_ids[i] = queue_tasks.get_new_id(); } // queue up the multitask so we can track its subtask progression - queue_tasks.add_multitask(multitask_id, subtask_ids); + queue_tasks.add_multitask(id_multi, subtask_ids); // add subtasks - for (int i = 0; i < prompt_count; i++) - { + for (int i = 0; i < prompt_count; i++) { json subtask_data = multiprompt_task.data; subtask_data["prompt"] = subtask_data["prompt"][i]; // subtasks inherit everything else (infill mode, embedding mode, etc.) - request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id); + request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, multiprompt_task.embedding); } } - void process_single_task(task_server& task) - { - switch (task.type) - { - case TASK_TYPE_COMPLETION: { - server_slot *slot = get_slot(json_value(task.data, "slot_id", -1)); - if (slot == nullptr) + void process_single_task(const server_task & task) { + switch (task.type) { + case SERVER_TASK_TYPE_COMPLETION: { - // if no slot is available, we defer this task for processing later - LOG_VERBOSE("no slot is available", {{"task_id", task.id}}); - queue_tasks.defer(task); - break; - } - - if (task.data.contains("system_prompt")) - { - if (!all_slots_are_idle) { - send_error(task, "system prompt can only be updated when all slots are idle"); + server_slot * slot = get_slot(json_value(task.data, "id_slot", -1)); + if (slot == nullptr) { + // if no slot is available, we defer this task for processing later + LOG_VERBOSE("no slot is available", {{"id_task", task.id}}); + queue_tasks.defer(task); break; } - system_prompt_process(task.data["system_prompt"]); - // reset cache_tokens for all slots - for (server_slot &slot : slots) - { - slot.cache_tokens.clear(); - slot.n_past = 0; - slot.n_past_se = 0; + if (task.data.contains("system_prompt")) { + system_prompt_set(task.data["system_prompt"]); + + for (server_slot & slot : slots) { + slot.n_past = 0; + slot.n_past_se = 0; + } } - } - slot->reset(); + slot->reset(); - slot->infill = task.infill_mode; - slot->embedding = task.embedding_mode; - slot->task_id = task.id; - slot->multitask_id = task.multitask_id; + slot->id_task = task.id; + slot->id_multi = task.id_multi; + slot->infill = task.infill; + slot->embedding = task.embedding; - if (!launch_slot_with_data(slot, task.data)) - { - // send error result - send_error(task, "internal_error"); - break; - } - } break; - case TASK_TYPE_CANCEL: { // release slot linked with the task id - for (auto & slot : slots) - { - if (slot.task_id == task.target_id) - { - slot.release(); + if (!launch_slot_with_data(*slot, task.data)) { + // send error result + send_error(task, "internal_error"); break; } - } - } break; - case TASK_TYPE_NEXT_RESPONSE: { - // do nothing - } break; - case TASK_TYPE_METRICS: { - json slots_data = json::array(); - int n_idle_slots = 0; - int n_processing_slots = 0; - - for (server_slot &slot: slots) { - json slot_data = get_formated_generation(slot); - slot_data["id"] = slot.id; - slot_data["task_id"] = slot.task_id; - slot_data["state"] = slot.state; - slot_data["prompt"] = slot.prompt; - slot_data["next_token"] = { - {"has_next_token", slot.has_next_token}, - {"n_remain", slot.n_remaining}, - {"num_tokens_predicted", slot.n_decoded}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - }; - if (slot_data["state"] == IDLE) { - n_idle_slots++; - } else { - n_processing_slots++; + } break; + case SERVER_TASK_TYPE_CANCEL: + { + // release slot linked with the task id + for (auto & slot : slots) { + if (slot.id_task == task.id_target) { + slot.release(); + break; + } } - slots_data.push_back(slot_data); - } - LOG_INFO("slot data", { - {"task_id", task.id}, - {"n_idle_slots", n_idle_slots}, - {"n_processing_slots", n_processing_slots} - }); - LOG_VERBOSE("slot data", { - {"task_id", task.id}, - {"n_idle_slots", n_idle_slots}, - {"n_processing_slots", n_processing_slots}, - {"slots", slots_data} - }); - task_result res; - res.id = task.id; - res.multitask_id = task.multitask_id; - res.stop = true; - res.error = false; - res.result_json = { + } break; + case SERVER_TASK_TYPE_NEXT_RESPONSE: + { + // do nothing + } break; + case SERVER_TASK_TYPE_METRICS: + { + json slots_data = json::array(); + + int n_idle_slots = 0; + int n_processing_slots = 0; + + for (server_slot & slot : slots) { + json slot_data = get_formated_generation(slot); + slot_data["id"] = slot.id; + slot_data["id_task"] = slot.id_task; + slot_data["state"] = slot.state; + slot_data["prompt"] = slot.prompt; + slot_data["next_token"] = { + {"has_next_token", slot.has_next_token}, + {"n_remain", slot.n_remaining}, + {"n_decoded", slot.n_decoded}, + {"stopped_eos", slot.stopped_eos}, + {"stopped_word", slot.stopped_word}, + {"stopped_limit", slot.stopped_limit}, + {"stopping_word", slot.stopping_word}, + }; + + if (slot_data["state"] == SLOT_STATE_IDLE) { + n_idle_slots++; + } else { + n_processing_slots++; + } + + slots_data.push_back(slot_data); + } + LOG_INFO("slot data", { + {"id_task", task.id}, + {"n_idle_slots", n_idle_slots}, + {"n_processing_slots", n_processing_slots} + }); + + LOG_VERBOSE("slot data", { + {"id_task", task.id}, + {"n_idle_slots", n_idle_slots}, + {"n_processing_slots", n_processing_slots}, + {"slots", slots_data} + }); + + server_task_result res; + res.id = task.id; + res.id_multi = task.id_multi; + res.stop = true; + res.error = false; + res.data = { { "idle", n_idle_slots }, { "processing", n_processing_slots }, { "deferred", queue_tasks.queue_tasks_deferred.size() }, @@ -1532,71 +1515,104 @@ struct llama_server_context { "kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)}, { "slots", slots_data }, - }; - metrics.reset_bucket(); - queue_results.send(res); - } break; + }; + + metrics.reset_bucket(); + queue_results.send(res); + } break; } } - void on_finish_multitask(task_multi& multitask) - { + void on_finish_multitask(const server_task_multi & multitask) { // all subtasks done == multitask is done - task_result result; - result.id = multitask.id; - result.stop = true; + server_task_result result; + result.id = multitask.id; + result.stop = true; result.error = false; // collect json results into one json result std::vector result_jsons; - for (auto& subres : multitask.results) - { - result_jsons.push_back(subres.result_json); + for (const auto & subres : multitask.results) { + result_jsons.push_back(subres.data); result.error = result.error && subres.error; } - result.result_json = json{ { "results", result_jsons } }; + result.data = json { + { "results", result_jsons } + }; + queue_results.send(result); } bool update_slots() { - if (system_need_update) - { - LOG_INFO("updating system prompt", {}); + if (system_need_update) { system_prompt_update(); } - llama_batch_clear(batch); + // release slots + for (auto & slot : slots) { + if (slot.command == SLOT_COMMAND_RELEASE) { + slot.state = SLOT_STATE_IDLE; + slot.command = SLOT_COMMAND_NONE; + slot.t_last_used = ggml_time_us(); - if (all_slots_are_idle) - { - if (system_prompt.empty() && clean_kv_cache) - { - LOG_INFO("all slots are idle and system prompt is empty, clear the KV cache", {}); - kv_cache_clear(); + LOG_INFO("slot released", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"n_ctx", n_ctx}, + {"n_past", slot.n_past}, + {"n_system_tokens", system_tokens.size()}, + {"n_cache_tokens", slot.cache_tokens.size()}, + {"truncated", slot.truncated} + }); + + queue_tasks.notify_slot_changed(); } - return true; } - LOG_VERBOSE("posting NEXT_RESPONSE", {}); - task_server task; - task.type = TASK_TYPE_NEXT_RESPONSE; - task.target_id = -1; - queue_tasks.post(task); - - for (server_slot &slot : slots) + // check if all slots are idle { - if (slot.ga_n == 1) - { - if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx) - { + bool all_idle = true; + + for (auto & slot : slots) { + if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) { + all_idle = false; + break; + } + } + + if (all_idle) { + LOG_INFO("all slots are idle", {}); + if (system_prompt.empty() && clean_kv_cache) { + kv_cache_clear(); + } + + return true; + } + } + + { + LOG_VERBOSE("posting NEXT_RESPONSE", {}); + + server_task task; + task.type = SERVER_TASK_TYPE_NEXT_RESPONSE; + task.id_target = -1; + + queue_tasks.post(task); + } + + // apply context-shift if needed + // TODO: simplify and improve + for (server_slot & slot : slots) { + if (slot.ga_n == 1) { + if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) { // Shift context const int n_keep = slot.params.n_keep + add_bos_token; const int n_left = (int) system_tokens.size() + slot.n_past - n_keep; const int n_discard = n_left / 2; LOG_INFO("slot context shift", { - {"slot_id", slot.id}, - {"task_id", slot.task_id}, + {"id_slot", slot.id}, + {"id_task", slot.id_task}, {"n_keep", n_keep}, {"n_left", n_left}, {"n_discard", n_discard}, @@ -1605,15 +1621,17 @@ struct llama_server_context {"n_system_tokens", system_tokens.size()}, {"n_cache_tokens", slot.cache_tokens.size()} }); + llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); - for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) - { - slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; - } + if (slot.params.cache_prompt) { + for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { + slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; + } - slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); + slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); + } slot.n_past -= n_discard; @@ -1622,33 +1640,12 @@ struct llama_server_context } } - // decode any currently ongoing sequences - LOG_VERBOSE("decoding ongoing sequences", {}); - for (auto & slot : slots) - { - // release the slot - if (slot.command == RELEASE) - { - slot.state = IDLE; - slot.command = NONE; - slot.t_last_used = ggml_time_us(); + // start populating the batch for this iteration + llama_batch_clear(batch); - LOG_INFO("slot released", { - {"slot_id", slot.id}, - {"task_id", slot.task_id}, - {"n_ctx", n_ctx}, - {"n_past", slot.n_past}, - {"n_system_tokens", system_tokens.size()}, - {"n_cache_tokens", slot.cache_tokens.size()}, - {"truncated", slot.truncated} - }); - queue_tasks.notify_slot_changed(); - - continue; - } - - if (slot.state == IDLE) - { + // frist, add sampled tokens from any ongoing sequences + for (auto & slot : slots) { + if (slot.state == SLOT_STATE_IDLE) { continue; } @@ -1659,193 +1656,184 @@ struct llama_server_context // TODO: we always have to take into account the "system_tokens" // this is not great and needs to be improved somehow llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); + slot.n_past += 1; + + if (slot.params.cache_prompt) { + slot.cache_tokens.push_back(slot.sampled); + } + + LOG_VERBOSE("slot decode token", { + {"id_slot", slot.id}, + {"id_task", slot.id_task}, + {"n_ctx", n_ctx}, + {"n_past", slot.n_past}, + {"n_system_tokens", system_tokens.size()}, + {"n_cache_tokens", slot.cache_tokens.size()}, + {"truncated", slot.truncated} + }); } // process in chunks of params.n_batch int32_t n_batch = params.n_batch; - // assign workload to the slots - if (params.cont_batching || batch.n_tokens == 0) - { - for (auto & slot : slots) - { - const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get().empty()) || !slot.images.empty(); + // next, batch any pending prompts without exceeding n_batch + if (params.cont_batching || batch.n_tokens == 0) { + for (auto & slot : slots) { + const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get().empty()); // empty prompt passed -> release the slot and send empty response // note: infill mode allows empty prompt - if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt && !slot.infill) - { + if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT && !has_prompt && !slot.infill) { + slot.state = SLOT_STATE_PROCESSING; + slot.command = SLOT_COMMAND_NONE; slot.release(); slot.print_timings(); send_final_response(slot); continue; } - // need process the prompt - if (slot.state == IDLE && slot.command == LOAD_PROMPT) - { - slot.state = PROCESSING; - slot.command = NONE; - std::vector prompt_tokens; - slot.t_start_process_prompt = ggml_time_us(); - slot.t_start_genereration = 0; + // this slot still has a prompt to be processed + if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) { + auto & prompt_tokens = slot.prompt_tokens; - if (slot.infill) - { - bool suff_rm_leading_spc = true; - if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) - { - params.input_suffix.erase(0, 1); - suff_rm_leading_spc = false; - } - auto prefix_tokens = tokenize(slot.params.input_prefix, false); - auto suffix_tokens = tokenize(slot.params.input_suffix, false); - - const int space_token = 29871; // TODO: this should not be hardcoded - if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) { - suffix_tokens.erase(suffix_tokens.begin()); - } - - prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); - prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS - prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model)); - prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); - prefix_tokens.push_back(llama_token_middle(model)); - prompt_tokens = prefix_tokens; - } - else - { - prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt - } - - slot.n_prompt_tokens = prompt_tokens.size(); - - if (slot.params.n_keep < 0) - { - slot.params.n_keep = slot.n_prompt_tokens; - } - slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - - // if input prompt is too big, truncate it, if group attention self-extend is disabled - if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) - { - const int n_left = slot.n_ctx - slot.params.n_keep; - const int n_block_size = n_left / 2; - const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - - std::vector new_tokens( - prompt_tokens.begin(), - prompt_tokens.begin() + slot.params.n_keep); - new_tokens.insert( - new_tokens.end(), - prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, - prompt_tokens.end()); - - LOG_VERBOSE("input truncated", { - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_left", n_left}, - {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, + // we haven't tokenized the prompt yet - do it now: + if (prompt_tokens.empty()) { + LOG_VERBOSE("tokenizing prompt", { + {"id_slot", slot.id}, + {"id_task", slot.id_task} }); - slot.truncated = true; - prompt_tokens = new_tokens; - slot.n_prompt_tokens = prompt_tokens.size(); - GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); - } + slot.t_start_process_prompt = ggml_time_us(); + slot.t_start_generation = 0; - if (!slot.params.cache_prompt) - { - llama_sampling_reset(slot.ctx_sampling); - - slot.n_past = 0; - slot.n_past_se = 0; - slot.ga_i = 0; - slot.n_prompt_tokens_processed = slot.n_prompt_tokens; - } - else - { - // push the prompt into the sampling context (do not apply grammar) - for (auto &token : prompt_tokens) - { - llama_sampling_accept(slot.ctx_sampling, ctx, token, false); - } - - slot.n_past = common_part(slot.cache_tokens, prompt_tokens); - - // the last token of the cache is not in the KV cache until the next call to llama_decode - // (it was sampled, pushed into the "cache_tokens", but not yet put in the context) - if (slot.n_past > 0 && slot.n_past == (int32_t) slot.cache_tokens.size()) - { - slot.n_past -= 1; - } - - slot.n_prompt_tokens_processed = slot.n_prompt_tokens - slot.n_past; - - if (slot.ga_n != 1) - { - int ga_i = 0; - int32_t ga_n = slot.ga_n; - int32_t ga_w = slot.ga_w; - int32_t slot_npast = 0; - for (int k = 0; k < slot.n_past; ++k) - { - while (slot_npast >= ga_i + ga_w) { - const int bd = (ga_w/ga_n)*(ga_n - 1); - slot_npast -= bd; - ga_i += ga_w/ga_n; - } - slot_npast++; + if (slot.infill) { + bool suff_rm_leading_spc = true; + if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) { + params.input_suffix.erase(0, 1); + suff_rm_leading_spc = false; } - slot.n_past_se = slot_npast; - slot.ga_i = ga_i; + + auto prefix_tokens = tokenize(slot.params.input_prefix, false); + auto suffix_tokens = tokenize(slot.params.input_suffix, false); + + const int space_token = 29871; // TODO: this should not be hardcoded + if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) { + suffix_tokens.erase(suffix_tokens.begin()); + } + + prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); + prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS + prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model)); + prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + prefix_tokens.push_back(llama_token_middle(model)); + prompt_tokens = prefix_tokens; + } else { + prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt } - LOG_INFO("slot progression", { - { "slot_id", slot.id }, - { "task_id", slot.task_id }, - { "n_past", slot.n_past }, - { "n_past_se", slot.n_past_se }, - { "ga_i", slot.ga_i }, - { "n_prompt_tokens_processed", slot.n_prompt_tokens_processed } - }); + slot.n_past = 0; + slot.n_prompt_tokens = prompt_tokens.size(); + + if (slot.embedding) { + // this prompt is too large to process - discard it + if (slot.n_prompt_tokens > n_batch) { + slot.state = SLOT_STATE_PROCESSING; + slot.command = SLOT_COMMAND_NONE; + slot.release(); + slot.print_timings(); + send_final_response(slot); + continue; + } + } else { + if (slot.params.n_keep < 0) { + slot.params.n_keep = slot.n_prompt_tokens; + } + slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); + + // if input prompt is too big, truncate it (if group attention self-extend is disabled) + if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) { + const int n_left = slot.n_ctx - slot.params.n_keep; + + const int n_block_size = n_left / 2; + const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; + + std::vector new_tokens( + prompt_tokens.begin(), + prompt_tokens.begin() + slot.params.n_keep); + + new_tokens.insert( + new_tokens.end(), + prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, + prompt_tokens.end()); + + prompt_tokens = std::move(new_tokens); + + slot.truncated = true; + slot.n_prompt_tokens = prompt_tokens.size(); + + LOG_VERBOSE("input truncated", { + {"n_ctx", slot.n_ctx}, + {"n_keep", slot.params.n_keep}, + {"n_left", n_left}, + {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, + }); + + GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); + } + + llama_sampling_reset(slot.ctx_sampling); + + if (!slot.params.cache_prompt) { + slot.n_past_se = 0; + slot.ga_i = 0; + } else { + GGML_ASSERT(slot.ga_n == 1); + + // reuse any previously computed tokens that are common with the new prompt + slot.n_past = common_part(slot.cache_tokens, prompt_tokens); + + // remove the non-common part from the cache + slot.cache_tokens.resize(slot.n_past); + + // push the prompt into the sampling context (do not apply grammar) + for (int i = 0; i < slot.n_past; ++i) { + llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false); + } + } + } + + if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { + // we have to evaluate at least 1 token to generate logits. + LOG_INFO("we have to evaluate at least 1 token to generate logits", { + { "id_slot", slot.id }, + { "id_task", slot.id_task } + }); + + slot.n_past--; + if (slot.ga_i > 0) { + slot.n_past_se--; + } + } + + slot.n_prompt_tokens_processed = 0; } - slot.cache_tokens = prompt_tokens; - - if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) - { - // we have to evaluate at least 1 token to generate logits. - LOG_INFO("we have to evaluate at least 1 token to generate logits", { - { "slot_id", slot.id }, - { "task_id", slot.task_id } - }); - slot.n_past--; - if (slot.ga_i > 0) - { - slot.n_past_se--; + if (slot.embedding) { + // cannot fit the prompt in the current batch - will try next iter + if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { + continue; } } - int p0 = (int) system_tokens.size() + slot.n_past; - LOG_INFO("kv cache rm [p0, end)", { - { "slot_id", slot.id }, - { "task_id", slot.task_id }, - { "p0", p0 } - }); + const int p0 = (int) system_tokens.size() + slot.n_past; llama_kv_cache_seq_rm(ctx, slot.id, p0, -1); - LOG_VERBOSE("prompt ingested", { - {"n_past", slot.n_past}, - {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, - {"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())}, - }); - - const bool has_images = process_images(slot); - - // process the prefix of first image - std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; + LOG_INFO("kv cache rm [p0, end)", { + { "id_slot", slot.id }, + { "id_task", slot.id_task }, + { "p0", p0 } + }); int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; @@ -1853,61 +1841,82 @@ struct llama_server_context int32_t ga_n = slot.ga_n; int32_t ga_w = slot.ga_w; - for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) - { - if (slot.ga_n != 1) - { + // add prompt tokens for processing in the current batch + // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow + for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past) { + if (slot.ga_n != 1) { while (slot_npast >= ga_i + ga_w) { const int bd = (ga_w/ga_n)*(ga_n - 1); slot_npast -= bd; ga_i += ga_w/ga_n; } } - llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false); + + llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false); + + if (slot.params.cache_prompt) { + slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); + } + + slot.n_prompt_tokens_processed++; slot_npast++; } - if (has_images && !ingest_images(slot, n_batch)) - { - LOG_ERROR("failed processing images", { - {"slot_id", slot.id}, - {"task_id", slot.task_id}, - }); - // FIXME @phymbert: to be properly tested - // early returning without changing the slot state will block the slot for ever - // no one at the moment is checking the return value - return false; - } + LOG_VERBOSE("prompt processing progress", { + {"id_slot", slot.id}, + {"n_past", slot.n_past}, + {"n_ctx", n_ctx}, + {"n_tokens", batch.n_tokens}, + {"progress", (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens}, + }); - // extract the logits only for the last token - if (batch.n_tokens > 0) - { + // entire prompt has been processed - start decoding new tokens + if (slot.n_past == slot.n_prompt_tokens) { + slot.state = SLOT_STATE_PROCESSING; + slot.command = SLOT_COMMAND_NONE; + + GGML_ASSERT(batch.n_tokens > 0); + + // extract the logits only for the last token batch.logits[batch.n_tokens - 1] = true; - } - slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; + slot.n_decoded = 0; + slot.i_batch = batch.n_tokens - 1; + + LOG_VERBOSE("prompt done", { + {"id_slot", slot.id}, + {"n_past", slot.n_past}, + {"n_ctx", n_ctx}, + {"n_tokens", batch.n_tokens}, + }); + } + } + + if (batch.n_tokens >= n_batch) { + break; } } } - if (batch.n_tokens == 0) - { - all_slots_are_idle = true; + if (batch.n_tokens == 0) { + LOG_VERBOSE("no tokens to decode", {}); + return true; } - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) - { + LOG_VERBOSE("decoding batch", { + {"n_tokens", batch.n_tokens}, + }); + + // process the created batch of tokens + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); - for (auto & slot : slots) - { - if (slot.ga_n != 1) - { + for (auto & slot : slots) { + if (slot.ga_n != 1) { // context extension via Self-Extend - while (slot.n_past_se >= slot.ga_i + slot.ga_w) - { + // TODO: simplify and/or abstract this + while (slot.n_past_se >= slot.ga_i + slot.ga_w) { const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; @@ -1918,8 +1927,8 @@ struct llama_server_context LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd); - llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n); - llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd); + llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); + llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd); slot.n_past_se -= bd; @@ -1927,12 +1936,12 @@ struct llama_server_context LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i); } + slot.n_past_se += n_tokens; } } - llama_batch batch_view = - { + llama_batch batch_view = { n_tokens, batch.token + i, nullptr, @@ -1945,10 +1954,8 @@ struct llama_server_context const int ret = llama_decode(ctx, batch_view); - if (ret != 0) - { - if (n_batch == 1 || ret < 0) - { + if (ret != 0) { + if (n_batch == 1 || ret < 0) { // if you get here, it means the KV cache is full - try increasing it via the context size LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); return false; @@ -1959,19 +1966,17 @@ struct llama_server_context // retry with half the batch size to try to find a free slot in the KV cache n_batch /= 2; i -= n_batch; + continue; } - for (auto & slot : slots) - { - if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) - { + for (auto & slot : slots) { + if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { continue; } // prompt evaluated for embedding - if (slot.embedding) - { + if (slot.embedding) { send_embedding(slot, batch_view); slot.release(); slot.i_batch = -1; @@ -1984,10 +1989,9 @@ struct llama_server_context llama_sampling_accept(slot.ctx_sampling, ctx, id, true); slot.n_decoded += 1; - if (slot.n_decoded == 1) - { - slot.t_start_genereration = ggml_time_us(); - slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3; + if (slot.n_decoded == 1) { + slot.t_start_generation = ggml_time_us(); + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; metrics.on_prompt_eval(slot); } @@ -1995,19 +1999,19 @@ struct llama_server_context result.tok = id; const int32_t n_probs = slot.sparams.n_probs; - if (slot.sparams.temp <= 0 && n_probs > 0) - { + if (slot.sparams.temp <= 0 && n_probs > 0) { // for llama_sample_token_greedy we need to sort candidates llama_sample_softmax(ctx, &cur_p); } - for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i) - { - result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p}); + for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) { + result.probs.push_back({ + cur_p.data[i].id, + cur_p.data[i].p + }); } - if (!process_token(result, slot)) - { + if (!process_token(result, slot)) { slot.release(); slot.print_timings(); send_final_response(slot); @@ -2019,24 +2023,23 @@ struct llama_server_context } LOG_VERBOSE("slots updated", {}); + return true; } - json model_meta() { - return json{ - {"vocab_type", llama_vocab_type(model)}, - {"n_vocab", llama_n_vocab(model)}, - {"n_ctx_train", llama_n_ctx_train(model)}, - {"n_embd", llama_n_embd(model)}, - {"n_params", llama_model_n_params(model)}, - {"size", llama_model_size(model)}, + json model_meta() const { + return json { + {"vocab_type", llama_vocab_type (model)}, + {"n_vocab", llama_n_vocab (model)}, + {"n_ctx_train", llama_n_ctx_train (model)}, + {"n_embd", llama_n_embd (model)}, + {"n_params", llama_model_n_params(model)}, + {"size", llama_model_size (model)}, }; } }; -static void server_print_usage(const char *argv0, const gpt_params ¶ms, - const server_params &sparams) -{ +static void server_print_usage(const char * argv0, const gpt_params & params, const server_params & sparams) { printf("usage: %s [options]\n", argv0); printf("\n"); printf("options:\n"); @@ -2054,17 +2057,14 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n"); printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow); printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); - printf(" --pooling {none,mean,cls}\n"); - printf(" pooling type for embeddings, use model default if unspecified\n"); + printf(" --pooling {none,mean,cls} pooling type for embeddings, use model default if unspecified\n"); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); - if (llama_supports_mlock()) - { + if (llama_supports_mlock()) { printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); } - if (llama_supports_mmap()) - { + if (llama_supports_mmap()) { printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); } printf(" --numa TYPE attempt optimizations that help on some NUMA systems\n"); @@ -2096,7 +2096,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); - printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); + printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); printf(" -spf FNAME, --system-prompt-file FNAME\n"); @@ -2105,7 +2105,6 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" KV cache data type for K (default: f16)\n"); printf(" -ctv TYPE, --cache-type-v TYPE\n"); printf(" KV cache data type for V (default: f16)\n"); - printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n"); printf(" --log-format log output format: json or text (default: json)\n"); printf(" --log-disable disables logging to a file.\n"); printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n"); @@ -2123,57 +2122,41 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf("\n"); } -static void server_params_parse(int argc, char **argv, server_params &sparams, - gpt_params ¶ms, llama_server_context& llama) -{ - gpt_params default_params; +static void server_params_parse(int argc, char ** argv, server_params & sparams, gpt_params & params) { + gpt_params default_params; server_params default_sparams; + std::string arg; bool invalid_param = false; - for (int i = 1; i < argc; i++) - { + for (int i = 1; i < argc; i++) { arg = argv[i]; - if (arg == "--port") - { - if (++i >= argc) - { + if (arg == "--port") { + if (++i >= argc) { invalid_param = true; break; } sparams.port = std::stoi(argv[i]); - } - else if (arg == "--host") - { - if (++i >= argc) - { + } else if (arg == "--host") { + if (++i >= argc) { invalid_param = true; break; } sparams.hostname = argv[i]; - } - else if (arg == "--path") - { - if (++i >= argc) - { + } else if (arg == "--path") { + if (++i >= argc) { invalid_param = true; break; } sparams.public_path = argv[i]; - } - else if (arg == "--api-key") - { - if (++i >= argc) - { + } else if (arg == "--api-key") { + if (++i >= argc) { invalid_param = true; break; } sparams.api_keys.emplace_back(argv[i]); - } - else if (arg == "--api-key-file") - { - if (++i >= argc) - { + } else if (arg == "--api-key-file") { + if (++i >= argc) { invalid_param = true; break; } @@ -2190,53 +2173,36 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } } key_file.close(); - } - else if (arg == "--timeout" || arg == "-to") - { - if (++i >= argc) - { + } else if (arg == "--timeout" || arg == "-to") { + if (++i >= argc) { invalid_param = true; break; } sparams.read_timeout = std::stoi(argv[i]); sparams.write_timeout = std::stoi(argv[i]); - } - else if (arg == "-m" || arg == "--model") - { - if (++i >= argc) - { + } else if (arg == "-m" || arg == "--model") { + if (++i >= argc) { invalid_param = true; break; } params.model = argv[i]; - } - else if (arg == "-a" || arg == "--alias") - { - if (++i >= argc) - { + } else if (arg == "-a" || arg == "--alias") { + if (++i >= argc) { invalid_param = true; break; } params.model_alias = argv[i]; - } - else if (arg == "-h" || arg == "--help") - { + } else if (arg == "-h" || arg == "--help") { server_print_usage(argv[0], default_params, default_sparams); exit(0); - } - else if (arg == "-c" || arg == "--ctx-size" || arg == "--ctx_size") - { - if (++i >= argc) - { + } else if (arg == "-c" || arg == "--ctx-size" || arg == "--ctx_size") { + if (++i >= argc) { invalid_param = true; break; } params.n_ctx = std::stoi(argv[i]); - } - else if (arg == "--rope-scaling") - { - if (++i >= argc) - { + } else if (arg == "--rope-scaling") { + if (++i >= argc) { invalid_param = true; break; } @@ -2245,59 +2211,44 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; } else { invalid_param = true; break; } - } - else if (arg == "--rope-freq-base") - { - if (++i >= argc) - { + } else if (arg == "--rope-freq-base") { + if (++i >= argc) { invalid_param = true; break; } params.rope_freq_base = std::stof(argv[i]); - } - else if (arg == "--rope-freq-scale") - { - if (++i >= argc) - { + } else if (arg == "--rope-freq-scale") { + if (++i >= argc) { invalid_param = true; break; } params.rope_freq_scale = std::stof(argv[i]); - } - else if (arg == "--yarn-ext-factor") - { + } else if (arg == "--yarn-ext-factor") { if (++i >= argc) { invalid_param = true; break; } params.yarn_ext_factor = std::stof(argv[i]); } - else if (arg == "--yarn-attn-factor") - { + else if (arg == "--yarn-attn-factor") { if (++i >= argc) { invalid_param = true; break; } params.yarn_attn_factor = std::stof(argv[i]); - } - else if (arg == "--yarn-beta-fast") - { + } else if (arg == "--yarn-beta-fast") { if (++i >= argc) { invalid_param = true; break; } params.yarn_beta_fast = std::stof(argv[i]); - } - else if (arg == "--yarn-beta-slow") - { + } else if (arg == "--yarn-beta-slow") { if (++i >= argc) { invalid_param = true; break; } params.yarn_beta_slow = std::stof(argv[i]); - } - else if (arg == "--pooling") - { + } else if (arg == "--pooling") { if (++i >= argc) { invalid_param = true; break; @@ -2307,108 +2258,79 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } else { invalid_param = true; break; } - } - else if (arg == "--threads" || arg == "-t") - { + } else if (arg == "--threads" || arg == "-t") { if (++i >= argc) { invalid_param = true; break; } params.n_threads = std::stoi(argv[i]); - } - else if (arg == "--grp-attn-n" || arg == "-gan") - { + } else if (arg == "--grp-attn-n" || arg == "-gan") { if (++i >= argc) { invalid_param = true; break; } params.grp_attn_n = std::stoi(argv[i]); - } - else if (arg == "--grp-attn-w" || arg == "-gaw") - { - if (++i >= argc) - { + } else if (arg == "--grp-attn-w" || arg == "-gaw") { + if (++i >= argc) { invalid_param = true; break; } params.grp_attn_w = std::stoi(argv[i]); - } - else if (arg == "--threads-batch" || arg == "-tb") - { - if (++i >= argc) - { + } else if (arg == "--threads-batch" || arg == "-tb") { + if (++i >= argc) { invalid_param = true; break; } params.n_threads_batch = std::stoi(argv[i]); - } - else if (arg == "--threads-http") - { - if (++i >= argc) - { + } else if (arg == "--threads-http") { + if (++i >= argc) { invalid_param = true; break; } sparams.n_threads_http = std::stoi(argv[i]); - } - else if (arg == "-b" || arg == "--batch-size") - { - if (++i >= argc) - { + } else if (arg == "-b" || arg == "--batch-size") { + if (++i >= argc) { invalid_param = true; break; } params.n_batch = std::stoi(argv[i]); - } - else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") - { - if (++i >= argc) - { + } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { + if (++i >= argc) { invalid_param = true; break; } if (llama_supports_gpu_offload()) { params.n_gpu_layers = std::stoi(argv[i]); } else { - LOG_WARNING("Not compiled with GPU offload support, --n-gpu-layers option will be ignored. " - "See main README.md for information on enabling GPU BLAS support", - {{"n_gpu_layers", params.n_gpu_layers}}); + LOG_WARNING( + "Not compiled with GPU offload support, --n-gpu-layers option will be ignored. " + "See main README.md for information on enabling GPU BLAS support", + {{"n_gpu_layers", params.n_gpu_layers}}); } - } - else if (arg == "--split-mode" || arg == "-sm") - { + } else if (arg == "--split-mode" || arg == "-sm") { if (++i >= argc) { invalid_param = true; break; } std::string arg_next = argv[i]; - if (arg_next == "none") - { + if (arg_next == "none") { params.split_mode = LLAMA_SPLIT_MODE_NONE; - } - else if (arg_next == "layer") - { + } else if (arg_next == "layer") { params.split_mode = LLAMA_SPLIT_MODE_LAYER; - } - else if (arg_next == "row") - { + } else if (arg_next == "row") { params.split_mode = LLAMA_SPLIT_MODE_ROW; - } - else { + } else { invalid_param = true; break; } #ifndef GGML_USE_CUBLAS fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting the split mode has no effect.\n"); #endif // GGML_USE_CUBLAS - } - else if (arg == "--tensor-split" || arg == "-ts") - { - if (++i >= argc) - { + } else if (arg == "--tensor-split" || arg == "-ts") { + if (++i >= argc) { invalid_param = true; break; } @@ -2421,25 +2343,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, std::vector split_arg{it, {}}; GGML_ASSERT(split_arg.size() <= llama_max_devices()); - for (size_t i_device = 0; i_device < llama_max_devices(); ++i_device) - { - if (i_device < split_arg.size()) - { + for (size_t i_device = 0; i_device < llama_max_devices(); ++i_device) { + if (i_device < split_arg.size()) { params.tensor_split[i_device] = std::stof(split_arg[i_device]); - } - else - { + } else { params.tensor_split[i_device] = 0.0f; } } #else LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n", {}); #endif // GGML_USE_CUBLAS - } - else if (arg == "--main-gpu" || arg == "-mg") - { - if (++i >= argc) - { + } else if (arg == "--main-gpu" || arg == "-mg") { + if (++i >= argc) { invalid_param = true; break; } @@ -2448,98 +2363,70 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, #else LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.", {}); #endif - } - else if (arg == "--lora") - { - if (++i >= argc) - { + } else if (arg == "--lora") { + if (++i >= argc) { invalid_param = true; break; } params.lora_adapter.emplace_back(argv[i], 1.0f); params.use_mmap = false; - } - else if (arg == "--lora-scaled") - { - if (++i >= argc) - { + } else if (arg == "--lora-scaled") { + if (++i >= argc) { invalid_param = true; break; } const char * lora_adapter = argv[i]; - if (++i >= argc) - { + if (++i >= argc) { invalid_param = true; break; } params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i])); params.use_mmap = false; - } - else if (arg == "--lora-base") - { - if (++i >= argc) - { + } else if (arg == "--lora-base") { + if (++i >= argc) { invalid_param = true; break; } params.lora_base = argv[i]; - } - else if (arg == "-v" || arg == "--verbose") - { + } else if (arg == "-v" || arg == "--verbose") { #if SERVER_VERBOSE != 1 LOG_WARNING("server.cpp is not built with verbose logging.", {}); #else server_verbose = true; #endif - } - else if (arg == "--mlock") - { + } else if (arg == "--mlock") { params.use_mlock = true; - } - else if (arg == "--no-mmap") - { + } else if (arg == "--no-mmap") { params.use_mmap = false; - } - else if (arg == "--numa") { + } else if (arg == "--numa") { if (++i >= argc) { invalid_param = true; break; } else { std::string value(argv[i]); /**/ if (value == "distribute" || value == "" ) { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; } - else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; } - else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; } + else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; } + else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; } else { invalid_param = true; break; } } - } - else if (arg == "--embedding") - { + } else if (arg == "--embedding" || arg == "--embeddings") { params.embedding = true; - } - else if (arg == "-cb" || arg == "--cont-batching") - { + } else if (arg == "-cb" || arg == "--cont-batching") { params.cont_batching = true; - } - else if (arg == "-np" || arg == "--parallel") - { - if (++i >= argc) - { + } else if (arg == "-np" || arg == "--parallel") { + if (++i >= argc) { invalid_param = true; break; } params.n_parallel = std::stoi(argv[i]); - } else if (arg == "-n" || arg == "--n-predict") - { - if (++i >= argc) - { + } else if (arg == "-n" || arg == "--n-predict") { + if (++i >= argc) { invalid_param = true; break; } params.n_predict = std::stoi(argv[i]); - } else if (arg == "-spf" || arg == "--system-prompt-file") - { - if (++i >= argc) - { + } else if (arg == "-spf" || arg == "--system-prompt-file") { + if (++i >= argc) { invalid_param = true; break; } @@ -2549,67 +2436,39 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, invalid_param = true; break; } - std::string systm_content; + std::string system_prompt; std::copy( std::istreambuf_iterator(file), std::istreambuf_iterator(), - std::back_inserter(systm_content) + std::back_inserter(system_prompt) ); - llama.system_prompt_process(json::parse(systm_content)); - } - else if (arg == "-ctk" || arg == "--cache-type-k") { + sparams.system_prompt = system_prompt; + } else if (arg == "-ctk" || arg == "--cache-type-k") { params.cache_type_k = argv[++i]; - } - else if (arg == "-ctv" || arg == "--cache-type-v") { + } else if (arg == "-ctv" || arg == "--cache-type-v") { params.cache_type_v = argv[++i]; - } - else if(arg == "--mmproj") - { - if (++i >= argc) - { + } else if (arg == "--log-format") { + if (++i >= argc) { invalid_param = true; break; } - params.mmproj = argv[i]; - } - else if (arg == "--log-format") - { - if (++i >= argc) - { - invalid_param = true; - break; - } - if (std::strcmp(argv[i], "json") == 0) - { + if (std::strcmp(argv[i], "json") == 0) { server_log_json = true; - } - else if (std::strcmp(argv[i], "text") == 0) - { + } else if (std::strcmp(argv[i], "text") == 0) { server_log_json = false; - } - else - { + } else { invalid_param = true; break; } - } - else if (arg == "--log-disable") - { + } else if (arg == "--log-disable") { log_set_target(stdout); LOG_INFO("logging to file is disabled.", {}); - } - else if (arg == "--slots-endpoint-disable") - { + } else if (arg == "--slots-endpoint-disable") { sparams.slots_endpoint = false; - } - else if (arg == "--metrics") - { + } else if (arg == "--metrics") { sparams.metrics_endpoint = true; - } - else if (arg == "--chat-template") - { - if (++i >= argc) - { + } else if (arg == "--chat-template") { + if (++i >= argc) { invalid_param = true; break; } @@ -2620,9 +2479,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } sparams.chat_template = argv[i]; - } - else if (arg == "--override-kv") - { + } else if (arg == "--override-kv") { if (++i >= argc) { invalid_param = true; break; @@ -2633,6 +2490,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, invalid_param = true; break; } + struct llama_model_kv_override kvo; std::strncpy(kvo.key, argv[i], sep - argv[i]); kvo.key[sep - argv[i]] = 0; @@ -2663,67 +2521,28 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } params.kv_overrides.push_back(kvo); - } - else - { + } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); server_print_usage(argv[0], default_params, default_sparams); exit(1); } } + if (!params.kv_overrides.empty()) { params.kv_overrides.emplace_back(); params.kv_overrides.back().key[0] = 0; } - if (invalid_param) - { + if (invalid_param) { fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); server_print_usage(argv[0], default_params, default_sparams); exit(1); } } -/* llama.cpp completion api semantics */ -static json format_partial_response( - llama_server_context &llama, server_slot *slot, const std::string &content, const std::vector &probs -) { - json res = json - { - {"content", content }, - {"stop", false}, - {"slot_id", slot->id }, - {"multimodal", llama.multimodal } - }; - - if (slot->sparams.n_probs > 0) - { - res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); - } - - return res; -} - -static json format_tokenizer_response(const std::vector &tokens) -{ - return json { - {"tokens", tokens} - }; -} - -static json format_detokenized_response(std::string content) -{ - return json { - {"content", content} - }; -} - - -static void log_server_request(const httplib::Request &req, const httplib::Response &res) -{ +static void log_server_request(const httplib::Request & req, const httplib::Response & res) { // skip GH copilot requests when using default port - if (req.path == "/v1/health" || req.path == "/v1/completions") - { + if (req.path == "/v1/health" || req.path == "/v1/completions") { return; } @@ -2742,24 +2561,9 @@ static void log_server_request(const httplib::Request &req, const httplib::Respo }); } -static void append_to_generated_text_from_generated_token_probs(llama_server_context &llama, server_slot *slot) -{ - auto & gtps = slot->generated_token_probs; - auto translator = token_translator{llama.ctx}; - auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); }; - const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen); - if (slot->generated_text.capacity() < slot->generated_text.size() + len) - { - slot->generated_text.reserve(slot->generated_text.size() + len); - } - for (const completion_token_output & cto : gtps) - { - slot->generated_text += translator(cto); - } -} - std::function shutdown_handler; std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; + inline void signal_handler(int signal) { if (is_terminating.test_and_set()) { // in case it hangs, we can force terminate the server by hitting Ctrl+C twice @@ -2767,40 +2571,45 @@ inline void signal_handler(int signal) { fprintf(stderr, "Received second interrupt, terminating immediately.\n"); exit(1); } + shutdown_handler(signal); } -int main(int argc, char **argv) -{ +int main(int argc, char ** argv) { #if SERVER_VERBOSE != 1 log_disable(); #endif // own arguments required by this example - gpt_params params; + gpt_params params; server_params sparams; // struct that contains llama context and inference - llama_server_context llama; + server_context ctx_server; - server_params_parse(argc, argv, sparams, params, llama); + server_params_parse(argc, argv, sparams, params); - if (params.model_alias == "unknown") - { + if (!sparams.system_prompt.empty()) { + ctx_server.system_prompt_set(json::parse(sparams.system_prompt)); + } + + if (params.model_alias == "unknown") { params.model_alias = params.model; } llama_backend_init(); llama_numa_init(params.numa); - LOG_INFO("build info", {{"build", LLAMA_BUILD_NUMBER}, - {"commit", LLAMA_COMMIT}}); + LOG_INFO("build info", { + {"build", LLAMA_BUILD_NUMBER}, + {"commit", LLAMA_COMMIT} + }); LOG_INFO("system info", { - {"n_threads", params.n_threads}, - {"n_threads_batch", params.n_threads_batch}, - {"total_threads", std::thread::hardware_concurrency()}, - {"system_info", llama_print_system_info()}, - }); + {"n_threads", params.n_threads}, + {"n_threads_batch", params.n_threads_batch}, + {"total_threads", std::thread::hardware_concurrency()}, + {"system_info", llama_print_system_info()}, + }); httplib::Server svr; @@ -2809,152 +2618,163 @@ int main(int argc, char **argv) svr.set_default_headers({{"Server", "llama.cpp"}}); // CORS preflight - svr.Options(R"(.*)", [](const httplib::Request &req, httplib::Response &res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + svr.Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Credentials", "true"); - res.set_header("Access-Control-Allow-Methods", "POST"); - res.set_header("Access-Control-Allow-Headers", "*"); + res.set_header("Access-Control-Allow-Methods", "POST"); + res.set_header("Access-Control-Allow-Headers", "*"); }); - svr.Get("/health", [&](const httplib::Request& req, httplib::Response& res) { + svr.Get("/health", [&](const httplib::Request & req, httplib::Response & res) { server_state current_state = state.load(); - switch(current_state) { - case SERVER_STATE_READY: { - // request slots data using task queue - task_server task; - task.id = llama.queue_tasks.get_new_id(); - task.type = TASK_TYPE_METRICS; - task.target_id = -1; + switch (current_state) { + case SERVER_STATE_READY: + { + // request slots data using task queue + server_task task; + task.id = ctx_server.queue_tasks.get_new_id(); + task.type = SERVER_TASK_TYPE_METRICS; + task.id_target = -1; - llama.queue_results.add_waiting_task_id(task.id); - llama.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); - // get the result - task_result result = llama.queue_results.recv(task.id); - llama.queue_results.remove_waiting_task_id(task.id); + // get the result + server_task_result result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); - int n_idle_slots = result.result_json["idle"]; - int n_processing_slots = result.result_json["processing"]; + const int n_idle_slots = result.data["idle"]; + const int n_processing_slots = result.data["processing"]; - json health = { + json health = { {"status", "ok"}, {"slots_idle", n_idle_slots}, - {"slots_processing", n_processing_slots}}; - res.status = 200; // HTTP OK - if (sparams.slots_endpoint && req.has_param("include_slots")) { - health["slots"] = result.result_json["slots"]; - } + {"slots_processing", n_processing_slots} + }; - if (n_idle_slots == 0) { - health["status"] = "no slot available"; - if (req.has_param("fail_on_no_slot")) { - res.status = 503; // HTTP Service Unavailable + res.status = 200; // HTTP OK + if (sparams.slots_endpoint && req.has_param("include_slots")) { + health["slots"] = result.data["slots"]; } + + if (n_idle_slots == 0) { + health["status"] = "no slot available"; + if (req.has_param("fail_on_no_slot")) { + res.status = 503; // HTTP Service Unavailable + } + } + + res.set_content(health.dump(), "application/json"); + break; } - res.set_content(health.dump(), "application/json"); - break; - } case SERVER_STATE_LOADING_MODEL: - res.set_content(R"({"status": "loading model"})", "application/json"); - res.status = 503; // HTTP Service Unavailable - break; + { + res.set_content(R"({"status": "loading model"})", "application/json"); + res.status = 503; // HTTP Service Unavailable + } break; case SERVER_STATE_ERROR: - res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json"); - res.status = 500; // HTTP Internal Server Error - break; + { + res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json"); + res.status = 500; // HTTP Internal Server Error + } break; } }); if (sparams.slots_endpoint) { - svr.Get("/slots", [&](const httplib::Request&, httplib::Response& res) { + svr.Get("/slots", [&](const httplib::Request &, httplib::Response & res) { // request slots data using task queue - task_server task; - task.id = llama.queue_tasks.get_new_id(); - task.type = TASK_TYPE_METRICS; - task.target_id = -1; + server_task task; + task.id = ctx_server.queue_tasks.get_new_id(); + task.id_multi = -1; + task.id_target = -1; + task.type = SERVER_TASK_TYPE_METRICS; - llama.queue_results.add_waiting_task_id(task.id); - llama.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); // get the result - task_result result = llama.queue_results.recv(task.id); - llama.queue_results.remove_waiting_task_id(task.id); + server_task_result result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); - res.set_content(result.result_json["slots"].dump(), "application/json"); + res.set_content(result.data["slots"].dump(), "application/json"); res.status = 200; // HTTP OK }); } if (sparams.metrics_endpoint) { - svr.Get("/metrics", [&](const httplib::Request&, httplib::Response& res) { + svr.Get("/metrics", [&](const httplib::Request &, httplib::Response & res) { // request slots data using task queue - task_server task; - task.id = llama.queue_tasks.get_new_id(); - task.type = TASK_TYPE_METRICS; - task.target_id = -1; + server_task task; + task.id = ctx_server.queue_tasks.get_new_id(); + task.id_multi = -1; + task.id_target = -1; + task.type = SERVER_TASK_TYPE_METRICS; - llama.queue_results.add_waiting_task_id(task.id); - llama.queue_tasks.post(task); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); // get the result - task_result result = llama.queue_results.recv(task.id); - llama.queue_results.remove_waiting_task_id(task.id); + server_task_result result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); - json data = result.result_json; + json data = result.data; - uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"]; - uint64_t t_prompt_processing = data["t_prompt_processing"]; + const uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"]; + const uint64_t t_prompt_processing = data["t_prompt_processing"]; - uint64_t n_tokens_predicted = data["n_tokens_predicted"]; - uint64_t t_tokens_generation = data["t_tokens_generation"]; + const uint64_t n_tokens_predicted = data["n_tokens_predicted"]; + const uint64_t t_tokens_generation = data["t_tokens_generation"]; - int32_t kv_cache_used_cells = data["kv_cache_used_cells"]; + const int32_t kv_cache_used_cells = data["kv_cache_used_cells"]; // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names json all_metrics_def = json { - {"counter", {{ - {"name", "prompt_tokens_total"}, - {"help", "Number of prompt tokens processed."}, - {"value", data["n_prompt_tokens_processed_total"]} - }, { - {"name", "tokens_predicted_total"}, - {"help", "Number of generation tokens processed."}, - {"value", data["n_tokens_predicted_total"]} - }}}, - {"gauge", {{ - {"name", "prompt_tokens_seconds"}, - {"help", "Average prompt throughput in tokens/s."}, - {"value", n_prompt_tokens_processed ? 1e3 / t_prompt_processing * n_prompt_tokens_processed : 0} - },{ - {"name", "predicted_tokens_seconds"}, - {"help", "Average generation throughput in tokens/s."}, - {"value", n_tokens_predicted ? 1e3 / t_tokens_generation * n_tokens_predicted : 0} - },{ - {"name", "kv_cache_usage_ratio"}, - {"help", "KV-cache usage. 1 means 100 percent usage."}, - {"value", 1. * kv_cache_used_cells / params.n_ctx} - },{ - {"name", "kv_cache_tokens"}, - {"help", "KV-cache tokens."}, - {"value", data["kv_cache_tokens_count"]} - },{ - {"name", "requests_processing"}, - {"help", "Number of request processing."}, - {"value", data["processing"]} - },{ - {"name", "requests_deferred"}, - {"help", "Number of request deferred."}, - {"value", data["deferred"]} - }}} + {"counter", {{ + {"name", "prompt_tokens_total"}, + {"help", "Number of prompt tokens processed."}, + {"value", data["n_prompt_tokens_processed_total"]} + }, { + {"name", "tokens_predicted_total"}, + {"help", "Number of generation tokens processed."}, + {"value", data["n_tokens_predicted_total"]} + }}}, + {"gauge", {{ + {"name", "prompt_tokens_seconds"}, + {"help", "Average prompt throughput in tokens/s."}, + {"value", n_prompt_tokens_processed ? 1e3 / t_prompt_processing * n_prompt_tokens_processed : 0} + },{ + {"name", "predicted_tokens_seconds"}, + {"help", "Average generation throughput in tokens/s."}, + {"value", n_tokens_predicted ? 1e3 / t_tokens_generation * n_tokens_predicted : 0} + },{ + {"name", "kv_cache_usage_ratio"}, + {"help", "KV-cache usage. 1 means 100 percent usage."}, + {"value", 1. * kv_cache_used_cells / params.n_ctx} + },{ + {"name", "kv_cache_tokens"}, + {"help", "KV-cache tokens."}, + {"value", data["kv_cache_tokens_count"]} + },{ + {"name", "requests_processing"}, + {"help", "Number of request processing."}, + {"value", data["processing"]} + },{ + {"name", "requests_deferred"}, + {"help", "Number of request deferred."}, + {"value", data["deferred"]} + }}} }; std::stringstream prometheus; - for (const auto& el : all_metrics_def.items()) { - const auto& type = el.key(); - const auto& metrics_def = el.value(); - for (const auto& metric_def : metrics_def) { - std::string name = metric_def["name"]; - std::string help = metric_def["help"]; + + for (const auto & el : all_metrics_def.items()) { + const auto & type = el.key(); + const auto & metrics_def = el.value(); + + for (const auto & metric_def : metrics_def) { + const std::string name = metric_def["name"]; + const std::string help = metric_def["help"]; + auto value = json_value(metric_def, "value", 0); prometheus << "# HELP llamacpp:" << name << " " << help << "\n" << "# TYPE llamacpp:" << name << " " << type << "\n" @@ -2969,49 +2789,39 @@ int main(int argc, char **argv) svr.set_logger(log_server_request); - svr.set_exception_handler([](const httplib::Request &, httplib::Response &res, std::exception_ptr ep) - { - const char fmt[] = "500 Internal Server Error\n%s"; - char buf[BUFSIZ]; - try - { - std::rethrow_exception(std::move(ep)); - } - catch (std::exception &e) - { - snprintf(buf, sizeof(buf), fmt, e.what()); - } - catch (...) - { - snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); - } - res.set_content(buf, "text/plain; charset=utf-8"); - res.status = 500; - }); + svr.set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { + const char fmt[] = "500 Internal Server Error\n%s"; - svr.set_error_handler([](const httplib::Request &, httplib::Response &res) - { - if (res.status == 401) - { - res.set_content("Unauthorized", "text/plain; charset=utf-8"); - } - if (res.status == 400) - { - res.set_content("Invalid request", "text/plain; charset=utf-8"); - } - else if (res.status == 404) - { - res.set_content("File Not Found", "text/plain; charset=utf-8"); - res.status = 404; - } - }); + char buf[BUFSIZ]; + try { + std::rethrow_exception(std::move(ep)); + } catch (std::exception &e) { + snprintf(buf, sizeof(buf), fmt, e.what()); + } catch (...) { + snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); + } + + res.set_content(buf, "text/plain; charset=utf-8"); + res.status = 500; + }); + + svr.set_error_handler([](const httplib::Request &, httplib::Response & res) { + if (res.status == 401) { + res.set_content("Unauthorized", "text/plain; charset=utf-8"); + } + if (res.status == 400) { + res.set_content("Invalid request", "text/plain; charset=utf-8"); + } + if (res.status == 404) { + res.set_content("File Not Found", "text/plain; charset=utf-8"); + } + }); // set timeouts and change hostname and port svr.set_read_timeout (sparams.read_timeout); svr.set_write_timeout(sparams.write_timeout); - if (!svr.bind_to_port(sparams.hostname, sparams.port)) - { + if (!svr.bind_to_port(sparams.hostname, sparams.port)) { fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port); return 1; } @@ -3020,8 +2830,9 @@ int main(int argc, char **argv) svr.set_base_dir(sparams.public_path); std::unordered_map log_data; + log_data["hostname"] = sparams.hostname; - log_data["port"] = std::to_string(sparams.port); + log_data["port"] = std::to_string(sparams.port); if (sparams.api_keys.size() == 1) { log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4); @@ -3030,20 +2841,23 @@ int main(int argc, char **argv) } // load the model - if (!llama.load_model(params)) - { + if (!ctx_server.load_model(params)) { state.store(SERVER_STATE_ERROR); return 1; } else { - llama.initialize(); + ctx_server.initialize(); state.store(SERVER_STATE_READY); - LOG_INFO("model loaded", {}); } - const auto model_meta = llama.model_meta(); + + LOG_INFO("model loaded", {}); + + const auto model_meta = ctx_server.model_meta(); if (sparams.chat_template.empty()) { // custom chat template is not supplied - // check if the template comes with the model is supported by us - llama.validate_model_chat_template(sparams); + if (!ctx_server.validate_model_chat_template()) { + LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {}); + sparams.chat_template = "chatml"; + } } // Middleware for API key validation @@ -3055,6 +2869,7 @@ int main(int argc, char **argv) // Check for API key in the header auto auth_header = req.get_header_value("Authorization"); + std::string prefix = "Bearer "; if (auth_header.substr(0, prefix.size()) == prefix) { std::string received_api_key = auth_header.substr(prefix.size()); @@ -3073,179 +2888,173 @@ int main(int argc, char **argv) }; // this is only called if no index.html is found in the public --path - svr.Get("/", [](const httplib::Request &, httplib::Response &res) - { - res.set_content(reinterpret_cast(&index_html), index_html_len, "text/html; charset=utf-8"); - return false; - }); + svr.Get("/", [](const httplib::Request &, httplib::Response & res) { + res.set_content(reinterpret_cast(&index_html), index_html_len, "text/html; charset=utf-8"); + return false; + }); // this is only called if no index.js is found in the public --path - svr.Get("/index.js", [](const httplib::Request &, httplib::Response &res) - { - res.set_content(reinterpret_cast(&index_js), index_js_len, "text/javascript; charset=utf-8"); - return false; - }); + svr.Get("/index.js", [](const httplib::Request &, httplib::Response & res) { + res.set_content(reinterpret_cast(&index_js), index_js_len, "text/javascript; charset=utf-8"); + return false; + }); // this is only called if no index.html is found in the public --path - svr.Get("/completion.js", [](const httplib::Request &, httplib::Response &res) - { - res.set_content(reinterpret_cast(&completion_js), completion_js_len, "application/javascript; charset=utf-8"); - return false; - }); + svr.Get("/completion.js", [](const httplib::Request &, httplib::Response & res) { + res.set_content(reinterpret_cast(&completion_js), completion_js_len, "application/javascript; charset=utf-8"); + return false; + }); // this is only called if no index.html is found in the public --path - svr.Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response &res) - { - res.set_content(reinterpret_cast(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8"); - return false; - }); + svr.Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) { + res.set_content(reinterpret_cast(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8"); + return false; + }); - svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response &res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - json data = { - { "user_name", llama.name_user.c_str() }, - { "assistant_name", llama.name_assistant.c_str() }, - { "default_generation_settings", llama.default_generation_settings_for_props }, - { "total_slots", llama.params.n_parallel } - }; - res.set_content(data.dump(), "application/json; charset=utf-8"); - }); + svr.Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + json data = { + { "user_name", ctx_server.name_user.c_str() }, + { "assistant_name", ctx_server.name_assistant.c_str() }, + { "default_generation_settings", ctx_server.default_generation_settings_for_props }, + { "total_slots", ctx_server.params.n_parallel } + }; - svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - if (!validate_api_key(req, res)) { - return; - } - json data = json::parse(req.body); - const int task_id = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, data, false, false, -1); - if (!json_value(data, "stream", false)) { - std::string completion_text; - task_result result = llama.queue_results.recv(task_id); - if (!result.error && result.stop) { - res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); - } - else - { - res.status = 404; - res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); - } - llama.queue_results.remove_waiting_task_id(task_id); - } else { - const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) - { - while (true) - { - task_result result = llama.queue_results.recv(task_id); - if (!result.error) { - const std::string str = - "data: " + - result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - LOG_VERBOSE("data stream", { - { "to_send", str } - }); - if (!sink.write(str.c_str(), str.size())) - { - llama.queue_results.remove_waiting_task_id(task_id); - return false; - } - if (result.stop) { - break; - } - } else { - const std::string str = - "error: " + - result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - LOG_VERBOSE("data stream", { - { "to_send", str } - }); - if (!sink.write(str.c_str(), str.size())) - { - llama.queue_results.remove_waiting_task_id(task_id); - return false; - } - break; - } - } + res.set_content(data.dump(), "application/json; charset=utf-8"); + }); - llama.queue_results.remove_waiting_task_id(task_id); - sink.done(); - return true; - }; - - auto on_complete = [task_id, &llama] (bool) - { - // cancel - llama.request_cancel(task_id); - llama.queue_results.remove_waiting_task_id(task_id); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - }); - - svr.Get("/v1/models", [¶ms, &model_meta](const httplib::Request& req, httplib::Response& res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - std::time_t t = std::time(0); - - json models = { - {"object", "list"}, - {"data", { - { - {"id", params.model_alias}, - {"object", "model"}, - {"created", t}, - {"owned_by", "llamacpp"}, - {"meta", model_meta} - }, - }} - }; - - res.set_content(models.dump(), "application/json; charset=utf-8"); - }); - - const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res) - { + svr.Post("/completion", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; } - json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template); - const int task_id = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, data, false, false, -1); + json data = json::parse(req.body); + + const int id_task = ctx_server.queue_tasks.get_new_id(); + + ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.request_completion(id_task, -1, data, false, false); if (!json_value(data, "stream", false)) { - std::string completion_text; - task_result result = llama.queue_results.recv(task_id); - + server_task_result result = ctx_server.queue_results.recv(id_task); if (!result.error && result.stop) { - json oaicompat_result = format_final_response_oaicompat(data, result); - - res.set_content(oaicompat_result.dump(-1, ' ', false, - json::error_handler_t::replace), - "application/json; charset=utf-8"); + res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); } else { res.status = 500; - res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); + res.set_content(result.data["content"], "text/plain; charset=utf-8"); } - llama.queue_results.remove_waiting_task_id(task_id); - } else { - const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) { - while (true) { - task_result llama_result = llama.queue_results.recv(task_id); - if (!llama_result.error) { - std::vector result_array = format_partial_response_oaicompat( llama_result); - for (auto it = result_array.begin(); it != result_array.end(); ++it) - { + ctx_server.queue_results.remove_waiting_task_id(id_task); + } else { + const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) { + while (true) { + server_task_result result = ctx_server.queue_results.recv(id_task); + if (!result.error) { + const std::string str = + "data: " + + result.data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; + + LOG_VERBOSE("data stream", { + { "to_send", str } + }); + + if (!sink.write(str.c_str(), str.size())) { + ctx_server.queue_results.remove_waiting_task_id(id_task); + return false; + } + + if (result.stop) { + break; + } + } else { + const std::string str = + "error: " + + result.data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; + + LOG_VERBOSE("data stream", { + { "to_send", str } + }); + + if (!sink.write(str.c_str(), str.size())) { + ctx_server.queue_results.remove_waiting_task_id(id_task); + return false; + } + + break; + } + } + + ctx_server.queue_results.remove_waiting_task_id(id_task); + sink.done(); + + return true; + }; + + auto on_complete = [id_task, &ctx_server] (bool) { + // cancel + ctx_server.request_cancel(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + }; + + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + } + }); + + svr.Get("/v1/models", [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + + json models = { + {"object", "list"}, + {"data", { + { + {"id", params.model_alias}, + {"object", "model"}, + {"created", std::time(0)}, + {"owned_by", "llamacpp"}, + {"meta", model_meta} + }, + }} + }; + + res.set_content(models.dump(), "application/json; charset=utf-8"); + }); + + const auto chat_completions = [&ctx_server, &validate_api_key, &sparams](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + if (!validate_api_key(req, res)) { + return; + } + + json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template); + + const int id_task = ctx_server.queue_tasks.get_new_id(); + + ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.request_completion(id_task, -1, data, false, false); + + if (!json_value(data, "stream", false)) { + server_task_result result = ctx_server.queue_results.recv(id_task); + + if (!result.error && result.stop) { + json result_oai = format_final_response_oaicompat(data, result.data); + + res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); + } else { + res.status = 500; + res.set_content(result.data["content"], "text/plain; charset=utf-8"); + } + ctx_server.queue_results.remove_waiting_task_id(id_task); + } else { + const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) { + while (true) { + server_task_result result = ctx_server.queue_results.recv(id_task); + if (!result.error) { + std::vector result_array = format_partial_response_oaicompat(result.data); + + for (auto it = result_array.begin(); it != result_array.end(); ++it) { if (!it->empty()) { const std::string str = "data: " + @@ -3253,251 +3062,235 @@ int main(int argc, char **argv) "\n\n"; LOG_VERBOSE("data stream", {{"to_send", str}}); if (!sink.write(str.c_str(), str.size())) { - llama.queue_results.remove_waiting_task_id(task_id); + ctx_server.queue_results.remove_waiting_task_id(id_task); return false; } } } - if (llama_result.stop) { + if (result.stop) { break; } } else { const std::string str = "error: " + - llama_result.result_json.dump(-1, ' ', false, - json::error_handler_t::replace) + + result.data.dump(-1, ' ', false, json::error_handler_t::replace) + "\n\n"; LOG_VERBOSE("data stream", {{"to_send", str}}); if (!sink.write(str.c_str(), str.size())) { - llama.queue_results.remove_waiting_task_id(task_id); + ctx_server.queue_results.remove_waiting_task_id(id_task); return false; } break; } } sink.done(); - llama.queue_results.remove_waiting_task_id(task_id); + ctx_server.queue_results.remove_waiting_task_id(id_task); return true; }; - auto on_complete = [task_id, &llama](bool) { + auto on_complete = [id_task, &ctx_server](bool) { // cancel request - llama.request_cancel(task_id); - llama.queue_results.remove_waiting_task_id(task_id); + ctx_server.request_cancel(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); } }; - svr.Post("/chat/completions", chat_completions); + svr.Post("/chat/completions", chat_completions); svr.Post("/v1/chat/completions", chat_completions); - svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - if (!validate_api_key(req, res)) { - return; - } - json data = json::parse(req.body); - const int task_id = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, data, true, false, -1); - if (!json_value(data, "stream", false)) { - std::string completion_text; - task_result result = llama.queue_results.recv(task_id); - if (!result.error && result.stop) - { - res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); - } - else - { - res.status = 404; - res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); - } - llama.queue_results.remove_waiting_task_id(task_id); - } else { - const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) { - while (true) - { - task_result result = llama.queue_results.recv(task_id); - if (!result.error) { - const std::string str = - "data: " + - result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; - LOG_VERBOSE("data stream", { - { "to_send", str } - }); - if (!sink.write(str.c_str(), str.size())) - { - llama.queue_results.remove_waiting_task_id(task_id); - return false; - } - if (result.stop) - { - break; - } - } - else - { - break; - } - } - - llama.queue_results.remove_waiting_task_id(task_id); - sink.done(); - return true; - }; - - auto on_complete = [task_id, &llama] (bool) - { - // cancel - llama.request_cancel(task_id); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - }); - - svr.Options(R"(/.*)", [](const httplib::Request &, httplib::Response &res) - { return res.set_content("", "application/json; charset=utf-8"); }); - - svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - const json body = json::parse(req.body); - std::vector tokens; - if (body.count("content") != 0) - { - tokens = llama.tokenize(body["content"], false); - } - const json data = format_tokenizer_response(tokens); - return res.set_content(data.dump(), "application/json; charset=utf-8"); - }); - - svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - const json body = json::parse(req.body); - std::string content; - if (body.count("tokens") != 0) - { - const std::vector tokens = body["tokens"]; - content = tokens_to_str(llama.ctx, tokens.cbegin(), tokens.cend()); - } - - const json data = format_detokenized_response(content); - return res.set_content(data.dump(), "application/json; charset=utf-8"); - }); - - svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - const json body = json::parse(req.body); - json prompt; - if (body.count("content") != 0) - { - prompt = body["content"]; - } - else - { - prompt = ""; - } - - json image_data; - if (body.count("image_data") != 0) { - image_data = body["image_data"]; - } - else - { - image_data = ""; - } - - // create and queue the task - const int task_id = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1); - - // get the result - task_result result = llama.queue_results.recv(task_id); - llama.queue_results.remove_waiting_task_id(task_id); - - // send the result - return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); - }); - - svr.Post("/v1/embeddings", [&llama](const httplib::Request &req, httplib::Response &res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - const json body = json::parse(req.body); - - json prompt; - if (body.count("input") != 0) - { - prompt = body["input"]; - // batch - if(prompt.is_array()) { - json data = json::array(); - int i = 0; - for (const json &elem : prompt) { - const int task_id = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, { {"prompt", elem}, { "n_predict", 0} }, false, true, -1); - - // get the result - task_result result = llama.queue_results.recv(task_id); - llama.queue_results.remove_waiting_task_id(task_id); - - json embedding = json{ - {"embedding", json_value(result.result_json, "embedding", json::array())}, - {"index", i++}, - {"object", "embedding"} - }; - data.push_back(embedding); - } - json result = format_embeddings_response_oaicompat(body, data); - return res.set_content(result.dump(), "application/json; charset=utf-8"); - } - } - else - { - prompt = ""; - } - - // create and queue the task - const int task_id = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}}, false, true, -1); - - // get the result - task_result result = llama.queue_results.recv(task_id); - llama.queue_results.remove_waiting_task_id(task_id); - - json data = json::array({json{ - {"embedding", json_value(result.result_json, "embedding", json::array())}, - {"index", 0}, - {"object", "embedding"} - }} - ); - - json root = format_embeddings_response_oaicompat(body, data); - - // send the result - return res.set_content(root.dump(), "application/json; charset=utf-8"); - }); - - // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? - // "Bus error: 10" - this is on macOS, it does not crash on Linux - //std::thread t2([&]() - /*{ - bool running = true; - while (running) - { - running = llama.update_slots(); + svr.Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + if (!validate_api_key(req, res)) { + return; } - }*/ - //); + + json data = json::parse(req.body); + + const int id_task = ctx_server.queue_tasks.get_new_id(); + + ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.request_completion(id_task, -1, data, true, false); + + if (!json_value(data, "stream", false)) { + server_task_result result = ctx_server.queue_results.recv(id_task); + if (!result.error && result.stop) { + res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); + } else { + res.status = 404; + res.set_content(result.data["content"], "text/plain; charset=utf-8"); + } + + ctx_server.queue_results.remove_waiting_task_id(id_task); + } else { + const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) { + while (true) { + server_task_result result = ctx_server.queue_results.recv(id_task); + if (!result.error) { + const std::string str = + "data: " + + result.data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; + + LOG_VERBOSE("data stream", { + { "to_send", str } + }); + + if (!sink.write(str.c_str(), str.size())) { + ctx_server.queue_results.remove_waiting_task_id(id_task); + return false; + } + + if (result.stop) { + break; + } + } else { + break; + } + } + + ctx_server.queue_results.remove_waiting_task_id(id_task); + sink.done(); + + return true; + }; + + auto on_complete = [id_task, &ctx_server] (bool) { + ctx_server.request_cancel(id_task); + }; + + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + } + }); + + svr.Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) { + return res.set_content("", "application/json; charset=utf-8"); + }); + + svr.Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const json body = json::parse(req.body); + + std::vector tokens; + if (body.count("content") != 0) { + tokens = ctx_server.tokenize(body["content"], false); + } + const json data = format_tokenizer_response(tokens); + return res.set_content(data.dump(), "application/json; charset=utf-8"); + }); + + svr.Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const json body = json::parse(req.body); + + std::string content; + if (body.count("tokens") != 0) { + const std::vector tokens = body["tokens"]; + content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend()); + } + + const json data = format_detokenized_response(content); + return res.set_content(data.dump(), "application/json; charset=utf-8"); + }); + + svr.Post("/embedding", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + if (!params.embedding) { + res.status = 501; + res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8"); + return; + } + + const json body = json::parse(req.body); + + json prompt; + if (body.count("content") != 0) { + prompt = body["content"]; + } else { + prompt = ""; + } + + // create and queue the task + const int id_task = ctx_server.queue_tasks.get_new_id(); + + ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0} }, false, true); + + // get the result + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + // send the result + return res.set_content(result.data.dump(), "application/json; charset=utf-8"); + }); + + svr.Post("/v1/embeddings", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + if (!params.embedding) { + res.status = 501; + res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8"); + return; + } + + const json body = json::parse(req.body); + + json prompt; + if (body.count("input") != 0) { + prompt = body["input"]; + if (prompt.is_array()) { + json data = json::array(); + + int i = 0; + for (const json & elem : prompt) { + const int id_task = ctx_server.queue_tasks.get_new_id(); + + ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.request_completion(id_task, -1, { {"prompt", elem}, { "n_predict", 0} }, false, true); + + // get the result + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + json embedding = json{ + {"embedding", json_value(result.data, "embedding", json::array())}, + {"index", i++}, + {"object", "embedding"} + }; + + data.push_back(embedding); + } + + json result = format_embeddings_response_oaicompat(body, data); + + return res.set_content(result.dump(), "application/json; charset=utf-8"); + } + } else { + prompt = ""; + } + + // create and queue the task + const int id_task = ctx_server.queue_tasks.get_new_id(); + + ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true); + + // get the result + server_task_result result = ctx_server.queue_results.recv(id_task); + ctx_server.queue_results.remove_waiting_task_id(id_task); + + json data = json::array({json{ + {"embedding", json_value(result.data, "embedding", json::array())}, + {"index", 0}, + {"object", "embedding"} + }} + ); + + json root = format_embeddings_response_oaicompat(body, data); + + return res.set_content(root.dump(), "application/json; charset=utf-8"); + }); if (sparams.n_threads_http < 1) { // +2 threads for monitoring endpoints @@ -3507,34 +3300,33 @@ int main(int argc, char **argv) svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); }; LOG_INFO("HTTP server listening", log_data); + // run the HTTP server in a thread - see comment below - std::thread t([&]() - { - if (!svr.listen_after_bind()) - { - state.store(SERVER_STATE_ERROR); - return 1; - } + std::thread t([&]() { + if (!svr.listen_after_bind()) { + state.store(SERVER_STATE_ERROR); + return 1; + } - return 0; - }); + return 0; + }); - llama.queue_tasks.on_new_task(std::bind( - &llama_server_context::process_single_task, &llama, std::placeholders::_1)); - llama.queue_tasks.on_finish_multitask(std::bind( - &llama_server_context::on_finish_multitask, &llama, std::placeholders::_1)); - llama.queue_tasks.on_run_slots(std::bind( - &llama_server_context::update_slots, &llama)); - llama.queue_results.on_multitask_update(std::bind( - &llama_server_queue::update_multitask, - &llama.queue_tasks, + ctx_server.queue_tasks.on_new_task(std::bind( + &server_context::process_single_task, &ctx_server, std::placeholders::_1)); + ctx_server.queue_tasks.on_finish_multitask(std::bind( + &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1)); + ctx_server.queue_tasks.on_run_slots(std::bind( + &server_context::update_slots, &ctx_server)); + ctx_server.queue_results.on_multitask_update(std::bind( + &server_queue::update_multitask, + &ctx_server.queue_tasks, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3 )); shutdown_handler = [&](int) { - llama.queue_tasks.terminate(); + ctx_server.queue_tasks.terminate(); }; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) @@ -3549,10 +3341,13 @@ int main(int argc, char **argv) }; SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif - llama.queue_tasks.start_loop(); + + ctx_server.queue_tasks.start_loop(); + svr.stop(); t.join(); llama_backend_free(); + return 0; } diff --git a/examples/server/tests/features/embeddings.feature b/examples/server/tests/features/embeddings.feature new file mode 100644 index 000000000..b47661e94 --- /dev/null +++ b/examples/server/tests/features/embeddings.feature @@ -0,0 +1,94 @@ +@llama.cpp +@embeddings +Feature: llama.cpp server + + Background: Server startup + Given a server listening on localhost:8080 + And a model file bert-bge-small/ggml-model-f16.gguf from HF repo ggml-org/models + And a model alias bert-bge-small + And 42 as server seed + And 2 slots + And 1024 as batch size + And 2048 KV cache size + And embeddings extraction + Then the server is starting + Then the server is healthy + + Scenario: Embedding + When embeddings are computed for: + """ + What is the capital of Bulgaria ? + """ + Then embeddings are generated + + Scenario: OAI Embeddings compatibility + Given a model bert-bge-small + When an OAI compatible embeddings computation request for: + """ + What is the capital of Spain ? + """ + Then embeddings are generated + + Scenario: OAI Embeddings compatibility with multiple inputs + Given a model bert-bge-small + Given a prompt: + """ + In which country Paris is located ? + """ + And a prompt: + """ + Is Madrid the capital of Spain ? + """ + When an OAI compatible embeddings computation request for multiple inputs + Then embeddings are generated + + Scenario: Multi users embeddings + Given a prompt: + """ + Write a very long story about AI. + """ + And a prompt: + """ + Write another very long music lyrics. + """ + And a prompt: + """ + Write a very long poem. + """ + And a prompt: + """ + Write a very long joke. + """ + Given concurrent embedding requests + Then the server is busy + Then the server is idle + Then all embeddings are generated + + Scenario: Multi users OAI compatibility embeddings + Given a prompt: + """ + In which country Paris is located ? + """ + And a prompt: + """ + Is Madrid the capital of Spain ? + """ + And a prompt: + """ + What is the biggest US city ? + """ + And a prompt: + """ + What is the capital of Bulgaria ? + """ + And a model bert-bge-small + Given concurrent OAI embedding requests + Then the server is busy + Then the server is idle + Then all embeddings are generated + + Scenario: All embeddings should be the same + Given 10 fixed prompts + And a model bert-bge-small + Given concurrent OAI embedding requests + Then all embeddings are the same diff --git a/examples/server/tests/features/parallel.feature b/examples/server/tests/features/parallel.feature index 86cdf7282..066698c8e 100644 --- a/examples/server/tests/features/parallel.feature +++ b/examples/server/tests/features/parallel.feature @@ -9,7 +9,6 @@ Feature: Parallel And 512 as batch size And 64 KV cache size And 2 slots - And embeddings extraction And continuous batching Then the server is starting Then the server is healthy @@ -99,48 +98,3 @@ Feature: Parallel Then the server is busy Then the server is idle Then all prompts are predicted - - Scenario: Multi users embeddings - Given a prompt: - """ - Write a very long story about AI. - """ - And a prompt: - """ - Write another very long music lyrics. - """ - And a prompt: - """ - Write a very long poem. - """ - And a prompt: - """ - Write a very long joke. - """ - Given concurrent embedding requests - Then the server is busy - Then the server is idle - Then all embeddings are generated - - Scenario: Multi users OAI compatibility embeddings - Given a prompt: - """ - In which country Paris is located ? - """ - And a prompt: - """ - Is Madrid the capital of Spain ? - """ - And a prompt: - """ - What is the biggest US city ? - """ - And a prompt: - """ - What is the capital of Bulgaria ? - """ - And a model tinyllama-2 - Given concurrent OAI embedding requests - Then the server is busy - Then the server is idle - Then all embeddings are generated diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature index 7c977bcce..f3b758c79 100644 --- a/examples/server/tests/features/server.feature +++ b/examples/server/tests/features/server.feature @@ -49,34 +49,6 @@ Feature: llama.cpp server | llama-2 | Book | What is the best book | 8 | (Mom\|what)+ | 8 | disabled | | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thanks\|happy\|bird)+ | 32 | enabled | - Scenario: Embedding - When embeddings are computed for: - """ - What is the capital of Bulgaria ? - """ - Then embeddings are generated - - Scenario: OAI Embeddings compatibility - Given a model tinyllama-2 - When an OAI compatible embeddings computation request for: - """ - What is the capital of Spain ? - """ - Then embeddings are generated - - Scenario: OAI Embeddings compatibility with multiple inputs - Given a model tinyllama-2 - Given a prompt: - """ - In which country Paris is located ? - """ - And a prompt: - """ - Is Madrid the capital of Spain ? - """ - When an OAI compatible embeddings computation request for multiple inputs - Then embeddings are generated - Scenario: Tokenize / Detokenize When tokenizing: """ diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 319527802..a0b2ffdfe 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -10,6 +10,7 @@ from contextlib import closing from re import RegexFlag import aiohttp +import numpy as np import openai from behave import step from behave.api.async_step import async_run_until_complete @@ -24,6 +25,9 @@ def step_server_config(context, server_fqdn, server_port): if 'PORT' in os.environ: context.server_port = int(os.environ['PORT']) print(f"$PORT set, overriding server port with to {context.server_port}") + if 'FQDN' in os.environ: + context.server_fqdn = os.environ['FQDN'] + print(f"$FQDN set, overriding server fqdn with to {context.server_fqdn}") context.base_url = f'http://{context.server_fqdn}:{context.server_port}' @@ -34,6 +38,7 @@ def step_server_config(context, server_fqdn, server_port): context.n_ga_w = None context.n_gpu_layer = None context.n_predict = None + context.n_prompts = 0 context.n_server_predict = None context.n_slots = None context.prompt_prefix = None @@ -202,6 +207,7 @@ def step_n_tokens_predicted(context, predicted_n): @step(u'a user prompt {user_prompt}') def step_user_prompt(context, user_prompt): context.prompts.append(user_prompt) + context.n_prompts = len(context.prompts) @step(u'a system prompt {system_prompt}') @@ -290,6 +296,12 @@ def step_prompt_passkey(context): context.prompt_passkey = context.text +@step(u'{n_prompts:d} fixed prompts') +def step_fixed_prompts(context, n_prompts): + context.prompts.extend([str(0)*(context.n_batch if context.n_batch is not None else 512) for i in range(n_prompts)]) + context.n_prompts = n_prompts + + @step(u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk') def step_prompt_passkey(context, passkey, i_pos): prompt = "" @@ -301,6 +313,7 @@ def step_prompt_passkey(context, passkey, i_pos): passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m" print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n") context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix) + context.n_prompts = len(context.prompts) @step(u'an OAI compatible chat completions request with {api_error} api error') @@ -341,11 +354,13 @@ async def step_oai_chat_completions(context, api_error): @step(u'a prompt') def step_a_prompt(context): context.prompts.append(context.text) + context.n_prompts = len(context.prompts) @step(u'a prompt {prompt}') def step_a_prompt_prompt(context, prompt): context.prompts.append(prompt) + context.n_prompts = len(context.prompts) @step(u'concurrent completion requests') @@ -430,25 +445,47 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None): @step(u'embeddings are computed for') @async_run_until_complete async def step_compute_embedding(context): + context.n_prompts = 1 context.embeddings = await request_embedding(context.text, base_url=context.base_url) +@step(u'all embeddings are the same') +@async_run_until_complete +async def step_all_embeddings_are_the_same(context): + n_embedding_requests = await gather_tasks_results(context) + assert n_embedding_requests > 0 + embeddings = [] + for i in range(n_embedding_requests): + embedding = context.tasks_result.pop().pop() + embeddings.append(embedding) + assert_embeddings(embedding) + n = len(embeddings) + for i in range(n-1): + for j in range(i+1, n): + embedding1 = np.array(embeddings[i]) + embedding2 = np.array(embeddings[j]) + if context.debug: + print(f"embedding1: {embedding1[-8:]}\n") + print(f"embedding2: {embedding2[-8:]}\n") + similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) + msg = f"Similarity between {i} and {j}: {similarity:.10f}" + if context.debug: + print(f"{msg}\n") + assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg + @step(u'embeddings are generated') def step_assert_embeddings(context): - if len(context.prompts) == 0: - assert_embeddings(context.embeddings) - else: - assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n" - f"context.prompts={context.prompts}\n" - f"context.embeddings={context.embeddings}") - for embedding in context.embeddings: - context.prompts.pop() - assert_embeddings(embedding) + assert context.n_prompts == len(context.embeddings), (f"unexpected response:\n" + f"context.n_prompts={context.n_prompts}\n" + f"context.embeddings={context.embeddings}") + for embedding in context.embeddings: + assert_embeddings(embedding) @step(u'an OAI compatible embeddings computation request for') @async_run_until_complete async def step_oai_compute_embeddings(context): + context.n_prompts = 1 context.embeddings = await request_oai_embeddings(context.text, base_url=context.base_url, user_api_key=context.user_api_key, @@ -462,6 +499,7 @@ async def step_oai_compute_embeddings_multiple_inputs(context): base_url=context.base_url, user_api_key=context.user_api_key, model=context.model) + context.prompts.clear() @step(u'concurrent embedding requests') @@ -488,9 +526,9 @@ async def step_concurrent_oai_embedding_requests(context): @async_run_until_complete() async def all_embeddings_are_generated(context): n_embedding_requests = await gather_tasks_results(context) - assert n_embedding_requests > 0 + assert n_embedding_requests == context.n_prompts for i in range(n_embedding_requests): - assert_embeddings(context.tasks_result.pop()) + assert_embeddings(context.tasks_result.pop().pop()) @step(u'tokenizing') @@ -588,11 +626,11 @@ def step_supported_models(context, i_model, param, preposition, param_value): async def concurrent_requests(context, f_completion, *args, **kwargs): - n_prompts = len(context.prompts) + context.n_prompts = len(context.prompts) if context.debug: - print(f"starting {n_prompts} concurrent completion requests...") - assert n_prompts > 0 - for prompt_no in range(n_prompts): + print(f"starting {context.n_prompts} concurrent completion requests...") + assert context.n_prompts > 0 + for prompt_no in range(context.n_prompts): shifted_args = [context.prompts.pop(), *args] context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs))) await asyncio.sleep(0.1) @@ -765,7 +803,7 @@ async def request_embedding(content, base_url=None): }) as response: assert response.status == 200 response_json = await response.json() - return response_json['embedding'] + return [response_json['embedding']] async def request_oai_embeddings(input, @@ -775,6 +813,7 @@ async def request_oai_embeddings(input, user_api_key = user_api_key if user_api_key is not None else 'nope' if async_client: origin = 'llama.cpp' + headers=[] if user_api_key is not None: headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin} async with aiohttp.ClientSession() as session: @@ -783,14 +822,21 @@ async def request_oai_embeddings(input, "input": input, "model": model, }, - headers=headers) as response: + headers=headers, + timeout=3600) as response: assert response.status == 200, f"received status code not expected: {response.status}" assert response.headers['Access-Control-Allow-Origin'] == origin assert response.headers['Content-Type'] == "application/json; charset=utf-8" response_json = await response.json() assert response_json['model'] == model, f"invalid model received: {response_json['model']}" assert response_json['object'] == 'list' - return response_json['data'] + if isinstance(input, collections.abc.Sequence): + embeddings = [] + for an_oai_embeddings in response_json['data']: + embeddings.append(an_oai_embeddings['embedding']) + else: + embeddings = [response_json['data']['embedding']] + return embeddings else: openai.api_key = user_api_key openai.api_base = f'{base_url}/v1' @@ -804,7 +850,7 @@ async def request_oai_embeddings(input, for an_oai_embeddings in oai_embeddings.data: embeddings.append(an_oai_embeddings.embedding) else: - embeddings = oai_embeddings.data.embedding + embeddings = [oai_embeddings.data.embedding] return embeddings @@ -899,6 +945,8 @@ def assert_embeddings(embeddings): assert len(embeddings) > 0 embeddings_computed = False for emb in embeddings: + if not isinstance(emb, float): + assert False, f"Bad embeddings: {embeddings}" if emb != 0: embeddings_computed = True assert embeddings_computed, f"Embeddings: {embeddings}" diff --git a/examples/server/tests/requirements.txt b/examples/server/tests/requirements.txt index 5d4210164..2e4f42ad2 100644 --- a/examples/server/tests/requirements.txt +++ b/examples/server/tests/requirements.txt @@ -1,5 +1,6 @@ aiohttp~=3.9.3 behave~=1.2.6 huggingface_hub~=0.20.3 +numpy~=1.24.4 openai~=0.25.0 prometheus-client~=0.20.0 diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index b6e49d8b9..df0a27782 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -1,15 +1,16 @@ #pragma once -#include -#include -#include -#include -#include -#include +#include "llama.h" +#include "common.h" #include "json.hpp" -#include "../llava/clip.h" +#include +#include +#include +#include + +#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" using json = nlohmann::json; @@ -37,83 +38,35 @@ extern bool server_log_json; #define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded - SERVER_STATE_ERROR // An error occurred, load_model failed -}; - -enum task_type { - TASK_TYPE_COMPLETION, - TASK_TYPE_CANCEL, - TASK_TYPE_NEXT_RESPONSE, - TASK_TYPE_METRICS -}; - -struct task_server { - int id = -1; // to be filled by llama_server_queue - int target_id; - task_type type; - json data; - bool infill_mode = false; - bool embedding_mode = false; - int multitask_id = -1; -}; - -struct task_result { - int id; - int multitask_id = -1; - bool stop; - bool error; - json result_json; -}; - -struct task_multi { - int id; - std::set subtasks_remaining{}; - std::vector results{}; -}; - -// completion token output with probabilities -struct completion_token_output { - struct token_prob - { - llama_token tok; - float prob; - }; - - std::vector probs; - llama_token tok; - std::string text_to_send; -}; - -struct token_translator { - llama_context * ctx; - std::string operator()(llama_token tok) const { return llama_token_to_piece(ctx, tok); } - std::string operator()(const completion_token_output &cto) const { return (*this)(cto.tok); } -}; +template +static T json_value(const json &body, const std::string &key, const T &default_value) { + // Fallback null to default value + return body.contains(key) && !body.at(key).is_null() + ? body.value(key, default_value) + : default_value; +} static inline void server_log(const char *level, const char *function, int line, const char *message, const nlohmann::ordered_json &extra) { std::stringstream ss_tid; ss_tid << std::this_thread::get_id(); json log = nlohmann::ordered_json{ - {"tid", ss_tid.str()}, + {"tid", ss_tid.str()}, {"timestamp", time(nullptr)}, }; if (server_log_json) { - log.merge_patch( - { - {"level", level}, - {"function", function}, - {"line", line}, - {"msg", message}, - }); + log.merge_patch( { + {"level", level}, + {"function", function}, + {"line", line}, + {"msg", message}, + }); + if (!extra.empty()) { log.merge_patch(extra); } - std::cout << log.dump(-1, ' ', false, json::error_handler_t::replace) << "\n" << std::flush; + printf("%s\n", log.dump(-1, ' ', false, json::error_handler_t::replace).c_str()); } else { char buf[1024]; snprintf(buf, 1024, "%4s [%24s] %s", level, function, message); @@ -136,22 +89,13 @@ static inline void server_log(const char *level, const char *function, int line, } // -// server utils +// chat template utils // -template -static T json_value(const json &body, const std::string &key, const T &default_value) { - // Fallback null to default value - return body.contains(key) && !body.at(key).is_null() - ? body.value(key, default_value) - : default_value; -} - // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid inline bool verify_custom_template(const std::string & tmpl) { llama_chat_message chat[] = {{"user", "test"}}; - std::vector buf(1); - int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, buf.data(), buf.size()); + int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); return res >= 0; } @@ -163,7 +107,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri std::vector chat(messages.size()); for (size_t i = 0; i < messages.size(); ++i) { - auto &curr_msg = messages[i]; + const auto & curr_msg = messages[i]; str[i*2 + 0] = json_value(curr_msg, "role", std::string("")); str[i*2 + 1] = json_value(curr_msg, "content", std::string("")); alloc_size += str[i*2 + 1].length(); @@ -183,261 +127,13 @@ inline std::string format_chat(const struct llama_model * model, const std::stri res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size()); } - std::string formatted_chat(buf.data(), res); + const std::string formatted_chat(buf.data(), res); + LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}}); return formatted_chat; } -// -// work queue utils -// - -struct llama_server_queue { - int id = 0; - std::mutex mutex_tasks; - bool running; - // queues - std::vector queue_tasks; - std::vector queue_tasks_deferred; - std::vector queue_multitasks; - std::condition_variable condition_tasks; - // callback functions - std::function callback_new_task; - std::function callback_finish_multitask; - std::function callback_run_slots; - - // Add a new task to the end of the queue - int post(task_server task) { - std::unique_lock lock(mutex_tasks); - if (task.id == -1) { - task.id = id++; - LOG_VERBOSE("new task id", {{"new_id", task.id}}); - } - queue_tasks.push_back(std::move(task)); - condition_tasks.notify_one(); - return task.id; - } - - // Add a new task, but defer until one slot is available - void defer(task_server task) { - std::unique_lock lock(mutex_tasks); - queue_tasks_deferred.push_back(std::move(task)); - } - - // Get the next id for creating anew task - int get_new_id() { - std::unique_lock lock(mutex_tasks); - int new_id = id++; - LOG_VERBOSE("new task id", {{"new_id", new_id}}); - return new_id; - } - - // Register function to process a new task - void on_new_task(std::function callback) { - callback_new_task = callback; - } - - // Register function to process a multitask when it is finished - void on_finish_multitask(std::function callback) { - callback_finish_multitask = callback; - } - - // Register the function to be called when all slots data is ready to be processed - void on_run_slots(std::function callback) { - callback_run_slots = callback; - } - - // Call when the state of one slot is changed - void notify_slot_changed() { - // move deferred tasks back to main loop - std::unique_lock lock(mutex_tasks); - for (auto & task : queue_tasks_deferred) { - queue_tasks.push_back(std::move(task)); - } - queue_tasks_deferred.clear(); - } - - // end the start_loop routine - void terminate() { - { - std::unique_lock lock(mutex_tasks); - running = false; - } - condition_tasks.notify_all(); - } - - /** - * Main loop consists of these steps: - * - Wait until a new task arrives - * - Process the task (i.e. maybe copy data into slot) - * - Check if multitask is finished - * - Run all slots - */ - void start_loop() { - running = true; - while (true) { - LOG_VERBOSE("new task may arrive", {}); - { - while (true) - { - std::unique_lock lock(mutex_tasks); - if (queue_tasks.empty()) { - lock.unlock(); - break; - } - task_server task = queue_tasks.front(); - queue_tasks.erase(queue_tasks.begin()); - lock.unlock(); - LOG_VERBOSE("callback_new_task", {{"task_id", task.id}}); - callback_new_task(task); - } - LOG_VERBOSE("update_multitasks", {}); - // check if we have any finished multitasks - auto queue_iterator = queue_multitasks.begin(); - while (queue_iterator != queue_multitasks.end()) - { - if (queue_iterator->subtasks_remaining.empty()) - { - // all subtasks done == multitask is done - task_multi current_multitask = *queue_iterator; - callback_finish_multitask(current_multitask); - // remove this multitask - queue_iterator = queue_multitasks.erase(queue_iterator); - } - else - { - ++queue_iterator; - } - } - // all tasks in the current loop is processed, slots data is now ready - LOG_VERBOSE("callback_run_slots", {}); - callback_run_slots(); - } - LOG_VERBOSE("wait for new task", {}); - // wait for new task - { - std::unique_lock lock(mutex_tasks); - if (queue_tasks.empty()) { - if (!running) { - LOG_VERBOSE("ending start_loop", {}); - return; - } - condition_tasks.wait(lock, [&]{ - return (!queue_tasks.empty() || !running); - }); - } - } - } - } - - // - // functions to manage multitasks - // - - // add a multitask by specifying the id of all subtask (subtask is a task_server) - void add_multitask(int multitask_id, std::vector& sub_ids) - { - std::lock_guard lock(mutex_tasks); - task_multi multi; - multi.id = multitask_id; - std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); - queue_multitasks.push_back(multi); - } - - // updatethe remaining subtasks, while appending results to multitask - void update_multitask(int multitask_id, int subtask_id, task_result& result) - { - std::lock_guard lock(mutex_tasks); - for (auto& multitask : queue_multitasks) - { - if (multitask.id == multitask_id) - { - multitask.subtasks_remaining.erase(subtask_id); - multitask.results.push_back(result); - } - } - } -}; - -struct llama_server_response { - typedef std::function callback_multitask_t; - callback_multitask_t callback_update_multitask; - // for keeping track of all tasks waiting for the result - std::set waiting_task_ids; - // the main result queue - std::vector queue_results; - std::mutex mutex_results; - std::condition_variable condition_results; - - // add the task_id to the list of tasks waiting for response - void add_waiting_task_id(int task_id) { - LOG_VERBOSE("waiting for task id", {{"task_id", task_id}}); - std::unique_lock lock(mutex_results); - waiting_task_ids.insert(task_id); - } - - // when the request is finished, we can remove task associated with it - void remove_waiting_task_id(int task_id) { - LOG_VERBOSE("remove waiting for task id", {{"task_id", task_id}}); - std::unique_lock lock(mutex_results); - waiting_task_ids.erase(task_id); - } - - // This function blocks the thread until there is a response for this task_id - task_result recv(int task_id) { - while (true) - { - std::unique_lock lock(mutex_results); - condition_results.wait(lock, [&]{ - return !queue_results.empty(); - }); - - for (int i = 0; i < (int) queue_results.size(); i++) - { - if (queue_results[i].id == task_id) - { - assert(queue_results[i].multitask_id == -1); - task_result res = queue_results[i]; - queue_results.erase(queue_results.begin() + i); - return res; - } - } - } - - // should never reach here - } - - // Register the function to update multitask - void on_multitask_update(callback_multitask_t callback) { - callback_update_multitask = callback; - } - - // Send a new result to a waiting task_id - void send(task_result result) { - std::unique_lock lock(mutex_results); - LOG_VERBOSE("send new result", {{"task_id", result.id}}); - for (auto& task_id : waiting_task_ids) { - // LOG_TEE("waiting task id %i \n", task_id); - // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result - if (result.multitask_id == task_id) - { - LOG_VERBOSE("callback_update_multitask", {{"task_id", task_id}}); - callback_update_multitask(task_id, result.id, result); - continue; - } - - if (result.id == task_id) - { - LOG_VERBOSE("queue_results.push_back", {{"task_id", task_id}}); - queue_results.push_back(result); - condition_results.notify_all(); - return; - } - } - } -}; - // // base64 utils (TODO: move to common in the future) // @@ -447,13 +143,11 @@ static const std::string base64_chars = "abcdefghijklmnopqrstuvwxyz" "0123456789+/"; -static inline bool is_base64(uint8_t c) -{ +static inline bool is_base64(uint8_t c) { return (isalnum(c) || (c == '+') || (c == '/')); } -static inline std::vector base64_decode(const std::string & encoded_string) -{ +static inline std::vector base64_decode(const std::string & encoded_string) { int i = 0; int j = 0; int in_ = 0; @@ -465,13 +159,10 @@ static inline std::vector base64_decode(const std::string & encoded_str std::vector ret; - while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) - { + while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { char_array_4[i++] = encoded_string[in_]; in_++; - if (i == 4) - { - for (i = 0; i <4; i++) - { + if (i == 4) { + for (i = 0; i < 4; i++) { char_array_4[i] = base64_chars.find(char_array_4[i]); } @@ -479,23 +170,20 @@ static inline std::vector base64_decode(const std::string & encoded_str char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - for (i = 0; (i < 3); i++) - { + for (i = 0; (i < 3); i++) { ret.push_back(char_array_3[i]); } + i = 0; } } - if (i) - { - for (j = i; j <4; j++) - { + if (i) { + for (j = i; j < 4; j++) { char_array_4[j] = 0; } - for (j = 0; j <4; j++) - { + for (j = 0; j < 4; j++) { char_array_4[j] = base64_chars.find(char_array_4[j]); } @@ -503,8 +191,7 @@ static inline std::vector base64_decode(const std::string & encoded_str char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - for (j = 0; (j < i - 1); j++) - { + for (j = 0; j < i - 1; j++) { ret.push_back(char_array_3[j]); } } @@ -516,8 +203,7 @@ static inline std::vector base64_decode(const std::string & encoded_str // random string / id // -static std::string random_string() -{ +static std::string random_string() { static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); std::random_device rd; @@ -532,10 +218,10 @@ static std::string random_string() return result; } -static std::string gen_chatcmplid() -{ +static std::string gen_chatcmplid() { std::stringstream chatcmplid; chatcmplid << "chatcmpl-" << random_string(); + return chatcmplid.str(); } @@ -543,91 +229,316 @@ static std::string gen_chatcmplid() // other common utils // -static size_t common_part(const std::vector &a, const std::vector &b) -{ +static size_t common_part(const std::vector & a, const std::vector & b) { size_t i; - for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) - { - } + for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} + return i; } -static bool ends_with(const std::string &str, const std::string &suffix) -{ - return str.size() >= suffix.size() && - 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); +static bool ends_with(const std::string & str, const std::string & suffix) { + return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); } -static size_t find_partial_stop_string(const std::string &stop, - const std::string &text) -{ - if (!text.empty() && !stop.empty()) - { +static size_t find_partial_stop_string(const std::string &stop, const std::string &text) { + if (!text.empty() && !stop.empty()) { const char text_last_char = text.back(); - for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) - { - if (stop[char_index] == text_last_char) - { + for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { + if (stop[char_index] == text_last_char) { const std::string current_partial = stop.substr(0, char_index + 1); - if (ends_with(text, current_partial)) - { + if (ends_with(text, current_partial)) { return text.size() - char_index - 1; } } } } + return std::string::npos; } // TODO: reuse llama_detokenize template -static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) -{ +static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { std::string ret; - for (; begin != end; ++begin) - { + for (; begin != end; ++begin) { ret += llama_token_to_piece(ctx, *begin); } + return ret; } // format incomplete utf-8 multibyte character for output -static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) -{ +static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token); + // if the size is 1 and first bit is 1, meaning it's a partial character // (size > 1 meaning it's already a known token) - if (out.size() == 1 && (out[0] & 0x80) == 0x80) - { + if (out.size() == 1 && (out[0] & 0x80) == 0x80) { std::stringstream ss; ss << std::hex << (out[0] & 0xff); std::string res(ss.str()); out = "byte: \\x" + res; } + return out; } +struct completion_token_output { + llama_token tok; + std::string text_to_send; + + struct token_prob { + llama_token tok; + float prob; + }; + + std::vector probs; +}; + // convert a vector of completion_token_output to json -static json probs_vector_to_json(const llama_context *ctx, const std::vector &probs) -{ +static json probs_vector_to_json(const llama_context * ctx, const std::vector & probs) { json out = json::array(); - for (const auto &prob : probs) - { + + for (const auto & prob : probs) { json probs_for_token = json::array(); - for (const auto &p : prob.probs) - { - std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok); - probs_for_token.push_back(json - { + + for (const auto & p : prob.probs) { + const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok); + probs_for_token.push_back(json { {"tok_str", tok_str}, {"prob", p.prob}, }); } - std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok); - out.push_back(json{ + + const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok); + out.push_back(json { {"content", tok_str}, {"probs", probs_for_token}, }); } + return out; } + +// +// OAI utils +// + +static json oaicompat_completion_params_parse( + const struct llama_model * model, + const json & body, /* openai api json semantics */ + const std::string & chat_template) { + json llama_params; + + llama_params["__oaicompat"] = true; + + // Map OpenAI parameters to llama.cpp parameters + // + // For parameters that are defined by the OpenAI documentation (e.g. + // temperature), we explicitly specify OpenAI's intended default; we + // need to do that because sometimes OpenAI disagrees with llama.cpp + // + // https://platform.openai.com/docs/api-reference/chat/create + llama_sampling_params default_sparams; + llama_params["model"] = json_value(body, "model", std::string("unknown")); + llama_params["prompt"] = format_chat(model, chat_template, body["messages"]); + llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); + llama_params["temperature"] = json_value(body, "temperature", 0.0); + llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k); + llama_params["top_p"] = json_value(body, "top_p", 1.0); + llama_params["n_predict"] = json_value(body, "max_tokens", -1); + llama_params["logit_bias"] = json_value(body, "logit_bias", json::object()); + llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0); + llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0); + llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED); + llama_params["stream"] = json_value(body, "stream", false); + llama_params["mirostat"] = json_value(body, "mirostat", default_sparams.mirostat); + llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); + llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); + llama_params["penalize_nl"] = json_value(body, "penalize_nl", default_sparams.penalize_nl); + llama_params["typical_p"] = json_value(body, "typical_p", default_sparams.typical_p); + llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n); + llama_params["ignore_eos"] = json_value(body, "ignore_eos", false); + llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z); + + if (body.count("grammar") != 0) { + llama_params["grammar"] = json_value(body, "grammar", json::object()); + } + + // Handle 'stop' field + if (body.contains("stop") && body["stop"].is_string()) { + llama_params["stop"] = json::array({body["stop"].get()}); + } else { + llama_params["stop"] = json_value(body, "stop", json::array()); + } + + // Ensure there is ChatML-specific end sequence among stop words + llama_params["stop"].push_back("<|im_end|>"); + + return llama_params; +} + +static json format_final_response_oaicompat(const json & request, json result, bool streaming = false) { + bool stopped_word = result.count("stopped_word") != 0; + bool stopped_eos = json_value(result, "stopped_eos", false); + int num_tokens_predicted = json_value(result, "tokens_predicted", 0); + int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); + std::string content = json_value(result, "content", std::string("")); + + std::string finish_reason = "length"; + if (stopped_word || stopped_eos) { + finish_reason = "stop"; + } + + json choices = + streaming ? json::array({json{{"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}}}) + : json::array({json{{"finish_reason", finish_reason}, + {"index", 0}, + {"message", json{{"content", content}, + {"role", "assistant"}}}}}); + + std::time_t t = std::time(0); + + json res = json { + {"choices", choices}, + {"created", t}, + {"model", + json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, + {"usage", json { + {"completion_tokens", num_tokens_predicted}, + {"prompt_tokens", num_prompt_tokens}, + {"total_tokens", num_tokens_predicted + num_prompt_tokens} + }}, + {"id", gen_chatcmplid()} + }; + + if (server_verbose) { + res["__verbose"] = result; + } + + if (result.contains("completion_probabilities")) { + res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); + } + + return res; +} + +// return value is vector as there is one case where we might need to generate two responses +static std::vector format_partial_response_oaicompat(json result) { + if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { + return std::vector({result}); + } + + bool first = json_value(result, "oaicompat_token_ctr", 0) == 0; + std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); + + bool stopped_word = json_value(result, "stopped_word", false); + bool stopped_eos = json_value(result, "stopped_eos", false); + bool stopped_limit = json_value(result, "stopped_limit", false); + std::string content = json_value(result, "content", std::string("")); + + std::string finish_reason; + if (stopped_word || stopped_eos) { + finish_reason = "stop"; + } + if (stopped_limit) { + finish_reason = "length"; + } + + std::time_t t = std::time(0); + + json choices; + + if (!finish_reason.empty()) { + choices = json::array({json{{"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}}}); + } else { + if (first) { + if (content.empty()) { + choices = json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{{"role", "assistant"}}}}}); + } else { + // We have to send this as two updates to conform to openai behavior + json initial_ret = json{{"choices", json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"role", "assistant"} + }}}})}, + {"created", t}, + {"id", gen_chatcmplid()}, + {"model", modelname}, + {"object", "chat.completion.chunk"}}; + + json second_ret = json{ + {"choices", json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"content", content}}} + }})}, + {"created", t}, + {"id", gen_chatcmplid()}, + {"model", modelname}, + {"object", "chat.completion.chunk"}}; + + return std::vector({initial_ret, second_ret}); + } + } else { + // Some idiosyncrasy in task processing logic makes several trailing calls + // with empty content, we ignore these at the calee site. + if (content.empty()) { + return std::vector({json::object()}); + } + + choices = json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", + json{ + {"content", content}, + }}, + }}); + } + } + + json ret = json { + {"choices", choices}, + {"created", t}, + {"id", gen_chatcmplid()}, + {"model", modelname}, + {"object", "chat.completion.chunk"} + }; + + return std::vector({ret}); +} + +static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) { + json res = json { + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json { + {"prompt_tokens", 0}, + {"total_tokens", 0} + }}, + {"data", embeddings} + }; + + return res; +} + +static json format_tokenizer_response(const std::vector & tokens) { + return json { + {"tokens", tokens} + }; +} + +static json format_detokenized_response(const std::string & content) { + return json { + {"content", content} + }; +} diff --git a/llama.cpp b/llama.cpp index b27aa2728..478099648 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13541,18 +13541,22 @@ LLAMA_API int32_t llama_chat_apply_template( curr_tmpl = std::string(model_template.data(), model_template.size()); } } + // format the chat to string std::vector chat_vec; chat_vec.resize(n_msg); for (size_t i = 0; i < n_msg; i++) { chat_vec[i] = &chat[i]; } + std::string formatted_chat; int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass); if (res < 0) { return res; } - strncpy(buf, formatted_chat.c_str(), length); + if (buf && length > 0) { + strncpy(buf, formatted_chat.c_str(), length); + } return res; }