mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-25 19:08:44 +01:00
llama : save and restore kv cache for single seq id (#6341)
* llama : save and restore kv cache for single seq id * remove trailing whitespace * respond error in case there's no space in the kv cache * add kv seq save restore to test case * add --slot-save-path arg to enable save restore and restrict save location * Returning 0 for some cases, instead of asserting. * cleanup error cases * rename sequence state functions * rename state get set functions * add previous function names back in with DEPRECATED notice * update doc * adjust endpoints to preferred style * fix restoring zero cell count * handle seq rm return value * unused param * keep in the size check * fix return types * add server test case for slot save restore * cleanup * add cake * cleanup style * add special * removing a whole sequence never fails * move sequence state file functionality from server to llama to match session api and add version tags * catch exceptions on save as well * error log messages * check types for stricter restore * update server doc * readme : update API changes date * strict filename validation * move include, reject bom as well * also reject empty filename * reject whitespace and trailing dot --------- Co-authored-by: Martin Evans <martindevans@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
87fb5b4234
commit
beea6e1b16
@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
|
|||||||
|
|
||||||
### Recent API changes
|
### Recent API changes
|
||||||
|
|
||||||
|
- [2024 Apr 4] State and session file functions reorganized under `llama_state_*` https://github.com/ggerganov/llama.cpp/pull/6341
|
||||||
- [2024 Mar 26] Logits and embeddings API updated for compactness https://github.com/ggerganov/llama.cpp/pull/6122
|
- [2024 Mar 26] Logits and embeddings API updated for compactness https://github.com/ggerganov/llama.cpp/pull/6122
|
||||||
- [2024 Mar 13] Add `llama_synchronize()` + `llama_context_params.n_ubatch` https://github.com/ggerganov/llama.cpp/pull/6017
|
- [2024 Mar 13] Add `llama_synchronize()` + `llama_context_params.n_ubatch` https://github.com/ggerganov/llama.cpp/pull/6017
|
||||||
- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_seq_max()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328
|
- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_seq_max()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
|
#include <codecvt>
|
||||||
|
|
||||||
#if defined(__APPLE__) && defined(__MACH__)
|
#if defined(__APPLE__) && defined(__MACH__)
|
||||||
#include <sys/types.h>
|
#include <sys/types.h>
|
||||||
@ -27,7 +28,6 @@
|
|||||||
#ifndef NOMINMAX
|
#ifndef NOMINMAX
|
||||||
# define NOMINMAX
|
# define NOMINMAX
|
||||||
#endif
|
#endif
|
||||||
#include <codecvt>
|
|
||||||
#include <locale>
|
#include <locale>
|
||||||
#include <windows.h>
|
#include <windows.h>
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
@ -1500,6 +1500,77 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
|
|||||||
GGML_UNREACHABLE();
|
GGML_UNREACHABLE();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate if a filename is safe to use
|
||||||
|
// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
|
||||||
|
bool validate_file_name(const std::string & filename) {
|
||||||
|
if (!filename.length()) {
|
||||||
|
// Empty filename invalid
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (filename.length() > 255) {
|
||||||
|
// Limit at common largest possible filename on Linux filesystems
|
||||||
|
// to avoid unnecessary further validation
|
||||||
|
// (On systems with smaller limits it will be caught by the OS)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::u32string filename_utf32;
|
||||||
|
try {
|
||||||
|
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
|
||||||
|
filename_utf32 = converter.from_bytes(filename);
|
||||||
|
|
||||||
|
// If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
|
||||||
|
// or invalid encodings were encountered. Reject such attempts
|
||||||
|
std::string filename_reencoded = converter.to_bytes(filename_utf32);
|
||||||
|
if (filename_reencoded != filename) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} catch (const std::exception &) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for forbidden codepoints:
|
||||||
|
// - Control characters
|
||||||
|
// - Unicode equivalents of illegal characters
|
||||||
|
// - UTF-16 surrogate pairs
|
||||||
|
// - UTF-8 replacement character
|
||||||
|
// - Byte order mark (BOM)
|
||||||
|
// - Illegal characters: / \ : * ? " < > |
|
||||||
|
for (char32_t c : filename_utf32) {
|
||||||
|
if (c <= 0x1F // Control characters (C0)
|
||||||
|
|| c == 0x7F // Control characters (DEL)
|
||||||
|
|| (c >= 0x80 && c <= 0x9F) // Control characters (C1)
|
||||||
|
|| c == 0xFF0E // Fullwidth Full Stop (period equivalent)
|
||||||
|
|| c == 0x2215 // Division Slash (forward slash equivalent)
|
||||||
|
|| c == 0x2216 // Set Minus (backslash equivalent)
|
||||||
|
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|
||||||
|
|| c == 0xFFFD // Replacement Character (UTF-8)
|
||||||
|
|| c == 0xFEFF // Byte Order Mark (BOM)
|
||||||
|
|| c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
|
||||||
|
|| c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
|
||||||
|
// Unicode and other whitespace is not affected, only 0x20 space
|
||||||
|
if (filename.front() == ' ' || filename.back() == ' ' || filename.back() == '.') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reject any ".." (currently stricter than necessary, it should be fine to just check for == ".." instead)
|
||||||
|
if (filename.find("..") != std::string::npos) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reject "."
|
||||||
|
if (filename == ".") {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// String utils
|
// String utils
|
||||||
//
|
//
|
||||||
|
@ -179,6 +179,8 @@ std::string gpt_random_prompt(std::mt19937 & rng);
|
|||||||
|
|
||||||
void process_escapes(std::string& input);
|
void process_escapes(std::string& input);
|
||||||
|
|
||||||
|
bool validate_file_name(const std::string & filename);
|
||||||
|
|
||||||
//
|
//
|
||||||
// String utils
|
// String utils
|
||||||
//
|
//
|
||||||
|
@ -235,7 +235,7 @@ int main(int argc, char ** argv) {
|
|||||||
// The file exists and is not empty
|
// The file exists and is not empty
|
||||||
session_tokens.resize(n_ctx);
|
session_tokens.resize(n_ctx);
|
||||||
size_t n_token_count_out = 0;
|
size_t n_token_count_out = 0;
|
||||||
if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
|
if (!llama_state_load_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
|
||||||
LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
|
LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -693,7 +693,7 @@ int main(int argc, char ** argv) {
|
|||||||
// optionally save the session on first sample (for faster prompt loading next time)
|
// optionally save the session on first sample (for faster prompt loading next time)
|
||||||
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
|
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
|
||||||
need_to_save_session = false;
|
need_to_save_session = false;
|
||||||
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
||||||
|
|
||||||
LOG("saved session to %s\n", path_session.c_str());
|
LOG("saved session to %s\n", path_session.c_str());
|
||||||
}
|
}
|
||||||
@ -935,7 +935,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) {
|
if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) {
|
||||||
LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
|
LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
|
||||||
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx);
|
||||||
|
@ -24,6 +24,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
std::string result0;
|
std::string result0;
|
||||||
std::string result1;
|
std::string result1;
|
||||||
|
std::string result2;
|
||||||
|
|
||||||
// init
|
// init
|
||||||
llama_model * model;
|
llama_model * model;
|
||||||
@ -44,8 +45,8 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// save state (rng, logits, embedding and kv_cache) to file
|
// save state (rng, logits, embedding and kv_cache) to file
|
||||||
{
|
{
|
||||||
std::vector<uint8_t> state_mem(llama_get_state_size(ctx));
|
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
|
||||||
const size_t written = llama_copy_state_data(ctx, state_mem.data());
|
const size_t written = llama_state_get_data(ctx, state_mem.data());
|
||||||
|
|
||||||
FILE *fp_write = fopen("dump_state.bin", "wb");
|
FILE *fp_write = fopen("dump_state.bin", "wb");
|
||||||
fwrite(state_mem.data(), 1, written, fp_write);
|
fwrite(state_mem.data(), 1, written, fp_write);
|
||||||
@ -97,13 +98,13 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// load state (rng, logits, embedding and kv_cache) from file
|
// load state (rng, logits, embedding and kv_cache) from file
|
||||||
{
|
{
|
||||||
std::vector<uint8_t> state_mem(llama_get_state_size(ctx2));
|
std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));
|
||||||
|
|
||||||
FILE * fp_read = fopen("dump_state.bin", "rb");
|
FILE * fp_read = fopen("dump_state.bin", "rb");
|
||||||
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
||||||
fclose(fp_read);
|
fclose(fp_read);
|
||||||
|
|
||||||
if (read != llama_set_state_data(ctx2, state_mem.data())) {
|
if (read != llama_state_set_data(ctx2, state_mem.data())) {
|
||||||
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
||||||
llama_free(ctx2);
|
llama_free(ctx2);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
@ -141,16 +142,104 @@ int main(int argc, char ** argv) {
|
|||||||
n_past += 1;
|
n_past += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
printf("\n");
|
printf("\n\n");
|
||||||
|
|
||||||
llama_free(ctx2);
|
llama_free(ctx2);
|
||||||
llama_free_model(model);
|
|
||||||
|
|
||||||
if (result0 != result1) {
|
if (result0 != result1) {
|
||||||
fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__);
|
fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// make new context
|
||||||
|
auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
||||||
|
|
||||||
|
printf("\nsingle seq run: %s", params.prompt.c_str());
|
||||||
|
|
||||||
|
// load state (rng, logits, embedding and kv_cache) from file
|
||||||
|
{
|
||||||
|
std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));
|
||||||
|
|
||||||
|
FILE * fp_read = fopen("dump_state.bin", "rb");
|
||||||
|
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
||||||
|
fclose(fp_read);
|
||||||
|
|
||||||
|
if (read != llama_state_set_data(ctx3, state_mem.data())) {
|
||||||
|
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
||||||
|
llama_free(ctx3);
|
||||||
|
llama_free_model(model);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
// restore state (last tokens)
|
||||||
|
n_past = n_past_saved;
|
||||||
|
|
||||||
|
// save seq 0 and load into seq 1
|
||||||
|
{
|
||||||
|
// save kv of seq 0
|
||||||
|
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
|
||||||
|
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
|
||||||
|
if (ncopy != seq_store.size()) {
|
||||||
|
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
|
||||||
|
llama_free(ctx3);
|
||||||
|
llama_free_model(model);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
|
||||||
|
|
||||||
|
// erase whole kv
|
||||||
|
llama_kv_cache_clear(ctx3);
|
||||||
|
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
||||||
|
|
||||||
|
// restore kv into seq 1
|
||||||
|
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
|
||||||
|
if (nset != seq_store.size()) {
|
||||||
|
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
|
||||||
|
llama_free(ctx3);
|
||||||
|
llama_free_model(model);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset);
|
||||||
|
}
|
||||||
|
|
||||||
|
// third run with seq 1 instead of 0
|
||||||
|
for (auto i = 0; i < params.n_predict; i++) {
|
||||||
|
auto * logits = llama_get_logits(ctx3);
|
||||||
|
auto n_vocab = llama_n_vocab(model);
|
||||||
|
std::vector<llama_token_data> candidates;
|
||||||
|
candidates.reserve(n_vocab);
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
|
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||||
|
}
|
||||||
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
|
auto next_token = llama_sample_token(ctx3, &candidates_p);
|
||||||
|
auto next_token_str = llama_token_to_piece(ctx3, next_token);
|
||||||
|
|
||||||
|
printf("%s", next_token_str.c_str());
|
||||||
|
result2 += next_token_str;
|
||||||
|
|
||||||
|
if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) {
|
||||||
|
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||||
|
llama_free(ctx3);
|
||||||
|
llama_free_model(model);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
n_past += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("\n");
|
||||||
|
|
||||||
|
llama_free(ctx3);
|
||||||
|
llama_free_model(model);
|
||||||
|
|
||||||
|
if (result0 != result2) {
|
||||||
|
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
fprintf(stderr, "\n%s : success\n", __func__);
|
fprintf(stderr, "\n%s : success\n", __func__);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -57,6 +57,7 @@ page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/
|
|||||||
- `-n N, --n-predict N`: Set the maximum tokens to predict. Default: `-1`
|
- `-n N, --n-predict N`: Set the maximum tokens to predict. Default: `-1`
|
||||||
- `--slots-endpoint-disable`: To disable slots state monitoring endpoint. Slots state may contain user data, prompts included.
|
- `--slots-endpoint-disable`: To disable slots state monitoring endpoint. Slots state may contain user data, prompts included.
|
||||||
- `--metrics`: enable prometheus `/metrics` compatible endpoint. Default: disabled
|
- `--metrics`: enable prometheus `/metrics` compatible endpoint. Default: disabled
|
||||||
|
- `--slot-save-path PATH`: Specifies the path where the state of slots (the prompt cache) can be stored. If not provided, the slot management endpoints will be disabled.
|
||||||
- `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name. Default: template taken from model's metadata. We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template)
|
- `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name. Default: template taken from model's metadata. We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template)
|
||||||
- `--log-disable`: Output logs to stdout only, not to `llama.log`. Default: enabled
|
- `--log-disable`: Output logs to stdout only, not to `llama.log`. Default: enabled
|
||||||
- `--log-format FORMAT`: Define the log output to FORMAT: json or text Default: `json`
|
- `--log-format FORMAT`: Define the log output to FORMAT: json or text Default: `json`
|
||||||
@ -517,6 +518,57 @@ Available metrics:
|
|||||||
- `llamacpp:requests_processing`: Number of requests processing.
|
- `llamacpp:requests_processing`: Number of requests processing.
|
||||||
- `llamacpp:requests_deferred`: Number of requests deferred.
|
- `llamacpp:requests_deferred`: Number of requests deferred.
|
||||||
|
|
||||||
|
- **POST** `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file.
|
||||||
|
|
||||||
|
*Options:*
|
||||||
|
|
||||||
|
`filename`: Name of the file to save the slot's prompt cache. The file will be saved in the directory specified by the `--slot-save-path` server parameter.
|
||||||
|
|
||||||
|
### Result JSON
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id_slot": 0,
|
||||||
|
"filename": "slot_save_file.bin",
|
||||||
|
"n_saved": 1745,
|
||||||
|
"n_written": 14309796,
|
||||||
|
"timings": {
|
||||||
|
"save_ms": 49.865
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- **POST** `/slots/{id_slot}?action=restore`: Restore the prompt cache of the specified slot from a file.
|
||||||
|
|
||||||
|
*Options:*
|
||||||
|
|
||||||
|
`filename`: Name of the file to restore the slot's prompt cache from. The file should be located in the directory specified by the `--slot-save-path` server parameter.
|
||||||
|
|
||||||
|
### Result JSON
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id_slot": 0,
|
||||||
|
"filename": "slot_save_file.bin",
|
||||||
|
"n_restored": 1745,
|
||||||
|
"n_read": 14309796,
|
||||||
|
"timings": {
|
||||||
|
"restore_ms": 42.937
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- **POST** `/slots/{id_slot}?action=erase`: Erase the prompt cache of the specified slot.
|
||||||
|
|
||||||
|
### Result JSON
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id_slot": 0,
|
||||||
|
"n_erased": 1745
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## More examples
|
## More examples
|
||||||
|
|
||||||
### Change system prompt on runtime
|
### Change system prompt on runtime
|
||||||
|
@ -61,7 +61,10 @@ enum server_task_type {
|
|||||||
SERVER_TASK_TYPE_COMPLETION,
|
SERVER_TASK_TYPE_COMPLETION,
|
||||||
SERVER_TASK_TYPE_CANCEL,
|
SERVER_TASK_TYPE_CANCEL,
|
||||||
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
||||||
SERVER_TASK_TYPE_METRICS
|
SERVER_TASK_TYPE_METRICS,
|
||||||
|
SERVER_TASK_TYPE_SLOT_SAVE,
|
||||||
|
SERVER_TASK_TYPE_SLOT_RESTORE,
|
||||||
|
SERVER_TASK_TYPE_SLOT_ERASE,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task {
|
struct server_task {
|
||||||
@ -128,6 +131,7 @@ struct server_params {
|
|||||||
|
|
||||||
bool slots_endpoint = true;
|
bool slots_endpoint = true;
|
||||||
bool metrics_endpoint = false;
|
bool metrics_endpoint = false;
|
||||||
|
std::string slot_save_path;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_slot {
|
struct server_slot {
|
||||||
@ -1612,6 +1616,107 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
queue_results.send(res);
|
queue_results.send(res);
|
||||||
} break;
|
} break;
|
||||||
|
case SERVER_TASK_TYPE_SLOT_SAVE:
|
||||||
|
{
|
||||||
|
int id_slot = task.data["id_slot"];
|
||||||
|
server_slot * slot = get_slot(id_slot);
|
||||||
|
if (slot == nullptr) {
|
||||||
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t token_count = slot->cache_tokens.size();
|
||||||
|
const int64_t t_start = ggml_time_us();
|
||||||
|
|
||||||
|
std::string filename = task.data["filename"];
|
||||||
|
std::string filepath = task.data["filepath"];
|
||||||
|
|
||||||
|
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
|
||||||
|
|
||||||
|
const int64_t t_end = ggml_time_us();
|
||||||
|
const double t_save_ms = (t_end - t_start) / 1000.0;
|
||||||
|
|
||||||
|
server_task_result result;
|
||||||
|
result.id = task.id;
|
||||||
|
result.stop = true;
|
||||||
|
result.error = false;
|
||||||
|
result.data = json {
|
||||||
|
{ "id_slot", id_slot },
|
||||||
|
{ "filename", filename },
|
||||||
|
{ "n_saved", token_count }, // tokens saved
|
||||||
|
{ "n_written", nwrite }, // bytes written
|
||||||
|
{ "timings", {
|
||||||
|
{ "save_ms", t_save_ms }
|
||||||
|
} }
|
||||||
|
};
|
||||||
|
queue_results.send(result);
|
||||||
|
} break;
|
||||||
|
case SERVER_TASK_TYPE_SLOT_RESTORE:
|
||||||
|
{
|
||||||
|
int id_slot = task.data["id_slot"];
|
||||||
|
server_slot * slot = get_slot(id_slot);
|
||||||
|
if (slot == nullptr) {
|
||||||
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t t_start = ggml_time_us();
|
||||||
|
|
||||||
|
std::string filename = task.data["filename"];
|
||||||
|
std::string filepath = task.data["filepath"];
|
||||||
|
|
||||||
|
slot->cache_tokens.resize(slot->n_ctx);
|
||||||
|
size_t token_count = 0;
|
||||||
|
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
|
||||||
|
if (nread == 0) {
|
||||||
|
slot->cache_tokens.resize(0);
|
||||||
|
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
slot->cache_tokens.resize(token_count);
|
||||||
|
|
||||||
|
const int64_t t_end = ggml_time_us();
|
||||||
|
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
||||||
|
|
||||||
|
server_task_result result;
|
||||||
|
result.id = task.id;
|
||||||
|
result.stop = true;
|
||||||
|
result.error = false;
|
||||||
|
result.data = json {
|
||||||
|
{ "id_slot", id_slot },
|
||||||
|
{ "filename", filename },
|
||||||
|
{ "n_restored", token_count }, // tokens restored
|
||||||
|
{ "n_read", nread }, // bytes read
|
||||||
|
{ "timings", {
|
||||||
|
{ "restore_ms", t_restore_ms }
|
||||||
|
} }
|
||||||
|
};
|
||||||
|
queue_results.send(result);
|
||||||
|
} break;
|
||||||
|
case SERVER_TASK_TYPE_SLOT_ERASE:
|
||||||
|
{
|
||||||
|
int id_slot = task.data["id_slot"];
|
||||||
|
server_slot * slot = get_slot(id_slot);
|
||||||
|
if (slot == nullptr) {
|
||||||
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Erase token cache
|
||||||
|
const size_t n_erased = slot->cache_tokens.size();
|
||||||
|
llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
|
||||||
|
slot->cache_tokens.clear();
|
||||||
|
|
||||||
|
server_task_result result;
|
||||||
|
result.id = task.id;
|
||||||
|
result.stop = true;
|
||||||
|
result.error = false;
|
||||||
|
result.data = json {
|
||||||
|
{ "id_slot", id_slot },
|
||||||
|
{ "n_erased", n_erased }
|
||||||
|
};
|
||||||
|
queue_results.send(result);
|
||||||
|
} break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2249,6 +2354,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
|
|||||||
printf(" --log-disable disables logging to a file.\n");
|
printf(" --log-disable disables logging to a file.\n");
|
||||||
printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n");
|
printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n");
|
||||||
printf(" --metrics enable prometheus compatible metrics endpoint (default: %s).\n", sparams.metrics_endpoint ? "enabled" : "disabled");
|
printf(" --metrics enable prometheus compatible metrics endpoint (default: %s).\n", sparams.metrics_endpoint ? "enabled" : "disabled");
|
||||||
|
printf(" --slot-save-path PATH path to save slot kv cache (default: disabled)\n");
|
||||||
printf("\n");
|
printf("\n");
|
||||||
printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict);
|
printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict);
|
||||||
printf(" --override-kv KEY=TYPE:VALUE\n");
|
printf(" --override-kv KEY=TYPE:VALUE\n");
|
||||||
@ -2657,6 +2763,16 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
|
|||||||
sparams.slots_endpoint = false;
|
sparams.slots_endpoint = false;
|
||||||
} else if (arg == "--metrics") {
|
} else if (arg == "--metrics") {
|
||||||
sparams.metrics_endpoint = true;
|
sparams.metrics_endpoint = true;
|
||||||
|
} else if (arg == "--slot-save-path") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
sparams.slot_save_path = argv[i];
|
||||||
|
// if doesn't end with DIRECTORY_SEPARATOR, add it
|
||||||
|
if (!sparams.slot_save_path.empty() && sparams.slot_save_path[sparams.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) {
|
||||||
|
sparams.slot_save_path += DIRECTORY_SEPARATOR;
|
||||||
|
}
|
||||||
} else if (arg == "--chat-template") {
|
} else if (arg == "--chat-template") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
@ -3159,6 +3275,112 @@ int main(int argc, char ** argv) {
|
|||||||
res.status = 200; // HTTP OK
|
res.status = 200; // HTTP OK
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const auto handle_slots_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
||||||
|
json request_data = json::parse(req.body);
|
||||||
|
std::string filename = request_data["filename"];
|
||||||
|
if (!validate_file_name(filename)) {
|
||||||
|
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::string filepath = sparams.slot_save_path + filename;
|
||||||
|
|
||||||
|
server_task task;
|
||||||
|
task.type = SERVER_TASK_TYPE_SLOT_SAVE;
|
||||||
|
task.data = {
|
||||||
|
{ "id_slot", id_slot },
|
||||||
|
{ "filename", filename },
|
||||||
|
{ "filepath", filepath }
|
||||||
|
};
|
||||||
|
|
||||||
|
const int id_task = ctx_server.queue_tasks.post(task);
|
||||||
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||||
|
|
||||||
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||||
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||||
|
|
||||||
|
if (result.error) {
|
||||||
|
res_error(res, result.data);
|
||||||
|
} else {
|
||||||
|
res.set_content(result.data.dump(), "application/json");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto handle_slots_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
||||||
|
json request_data = json::parse(req.body);
|
||||||
|
std::string filename = request_data["filename"];
|
||||||
|
if (!validate_file_name(filename)) {
|
||||||
|
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::string filepath = sparams.slot_save_path + filename;
|
||||||
|
|
||||||
|
server_task task;
|
||||||
|
task.type = SERVER_TASK_TYPE_SLOT_RESTORE;
|
||||||
|
task.data = {
|
||||||
|
{ "id_slot", id_slot },
|
||||||
|
{ "filename", filename },
|
||||||
|
{ "filepath", filepath }
|
||||||
|
};
|
||||||
|
|
||||||
|
const int id_task = ctx_server.queue_tasks.post(task);
|
||||||
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||||
|
|
||||||
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||||
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||||
|
|
||||||
|
if (result.error) {
|
||||||
|
res_error(res, result.data);
|
||||||
|
} else {
|
||||||
|
res.set_content(result.data.dump(), "application/json");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
|
||||||
|
server_task task;
|
||||||
|
task.type = SERVER_TASK_TYPE_SLOT_ERASE;
|
||||||
|
task.data = {
|
||||||
|
{ "id_slot", id_slot },
|
||||||
|
};
|
||||||
|
|
||||||
|
const int id_task = ctx_server.queue_tasks.post(task);
|
||||||
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||||
|
|
||||||
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||||
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||||
|
|
||||||
|
if (result.error) {
|
||||||
|
res_error(res, result.data);
|
||||||
|
} else {
|
||||||
|
res.set_content(result.data.dump(), "application/json");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
|
||||||
|
std::string id_slot_str = req.path_params.at("id_slot");
|
||||||
|
int id_slot;
|
||||||
|
|
||||||
|
try {
|
||||||
|
id_slot = std::stoi(id_slot_str);
|
||||||
|
} catch (const std::exception &) {
|
||||||
|
res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string action = req.get_param_value("action");
|
||||||
|
|
||||||
|
if (action == "save") {
|
||||||
|
handle_slots_save(req, res, id_slot);
|
||||||
|
} else if (action == "restore") {
|
||||||
|
handle_slots_restore(req, res, id_slot);
|
||||||
|
} else if (action == "erase") {
|
||||||
|
handle_slots_erase(req, res, id_slot);
|
||||||
|
} else {
|
||||||
|
res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
json data = {
|
json data = {
|
||||||
@ -3521,6 +3743,10 @@ int main(int argc, char ** argv) {
|
|||||||
svr->Post("/v1/embeddings", handle_embeddings);
|
svr->Post("/v1/embeddings", handle_embeddings);
|
||||||
svr->Post("/tokenize", handle_tokenize);
|
svr->Post("/tokenize", handle_tokenize);
|
||||||
svr->Post("/detokenize", handle_detokenize);
|
svr->Post("/detokenize", handle_detokenize);
|
||||||
|
if (!sparams.slot_save_path.empty()) {
|
||||||
|
// only enable slot endpoints if slot_save_path is set
|
||||||
|
svr->Post("/slots/:id_slot", handle_slots_action);
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Start the server
|
// Start the server
|
||||||
|
58
examples/server/tests/features/slotsave.feature
Normal file
58
examples/server/tests/features/slotsave.feature
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
@llama.cpp
|
||||||
|
@slotsave
|
||||||
|
Feature: llama.cpp server slot management
|
||||||
|
|
||||||
|
Background: Server startup
|
||||||
|
Given a server listening on localhost:8080
|
||||||
|
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
||||||
|
And prompt caching is enabled
|
||||||
|
And 2 slots
|
||||||
|
And . as slot save path
|
||||||
|
And 2048 KV cache size
|
||||||
|
And 42 as server seed
|
||||||
|
And 24 max tokens to predict
|
||||||
|
Then the server is starting
|
||||||
|
Then the server is healthy
|
||||||
|
|
||||||
|
Scenario: Save and Restore Slot
|
||||||
|
# First prompt in slot 1 should be fully processed
|
||||||
|
Given a user prompt "What is the capital of France?"
|
||||||
|
And using slot id 1
|
||||||
|
And a completion request with no api error
|
||||||
|
Then 24 tokens are predicted matching (Lily|cake)
|
||||||
|
And 22 prompt tokens are processed
|
||||||
|
When the slot 1 is saved with filename "slot1.bin"
|
||||||
|
Then the server responds with status code 200
|
||||||
|
# Since we have cache, this should only process the last tokens
|
||||||
|
Given a user prompt "What is the capital of Germany?"
|
||||||
|
And a completion request with no api error
|
||||||
|
Then 24 tokens are predicted matching (Thank|special)
|
||||||
|
And 7 prompt tokens are processed
|
||||||
|
# Loading the original cache into slot 0,
|
||||||
|
# we should only be processing 1 prompt token and get the same output
|
||||||
|
When the slot 0 is restored with filename "slot1.bin"
|
||||||
|
Then the server responds with status code 200
|
||||||
|
Given a user prompt "What is the capital of France?"
|
||||||
|
And using slot id 0
|
||||||
|
And a completion request with no api error
|
||||||
|
Then 24 tokens are predicted matching (Lily|cake)
|
||||||
|
And 1 prompt tokens are processed
|
||||||
|
# For verification that slot 1 was not corrupted during slot 0 load, same thing
|
||||||
|
Given a user prompt "What is the capital of Germany?"
|
||||||
|
And using slot id 1
|
||||||
|
And a completion request with no api error
|
||||||
|
Then 24 tokens are predicted matching (Thank|special)
|
||||||
|
And 1 prompt tokens are processed
|
||||||
|
|
||||||
|
Scenario: Erase Slot
|
||||||
|
Given a user prompt "What is the capital of France?"
|
||||||
|
And using slot id 1
|
||||||
|
And a completion request with no api error
|
||||||
|
Then 24 tokens are predicted matching (Lily|cake)
|
||||||
|
And 22 prompt tokens are processed
|
||||||
|
When the slot 1 is erased
|
||||||
|
Then the server responds with status code 200
|
||||||
|
Given a user prompt "What is the capital of France?"
|
||||||
|
And a completion request with no api error
|
||||||
|
Then 24 tokens are predicted matching (Lily|cake)
|
||||||
|
And 22 prompt tokens are processed
|
@ -49,6 +49,9 @@ def step_server_config(context, server_fqdn, server_port):
|
|||||||
context.n_predict = None
|
context.n_predict = None
|
||||||
context.n_prompts = 0
|
context.n_prompts = 0
|
||||||
context.n_server_predict = None
|
context.n_server_predict = None
|
||||||
|
context.slot_save_path = None
|
||||||
|
context.id_slot = None
|
||||||
|
context.cache_prompt = None
|
||||||
context.n_slots = None
|
context.n_slots = None
|
||||||
context.prompt_prefix = None
|
context.prompt_prefix = None
|
||||||
context.prompt_suffix = None
|
context.prompt_suffix = None
|
||||||
@ -119,6 +122,21 @@ def step_server_n_predict(context, n_predict):
|
|||||||
context.n_server_predict = n_predict
|
context.n_server_predict = n_predict
|
||||||
|
|
||||||
|
|
||||||
|
@step('{slot_save_path} as slot save path')
|
||||||
|
def step_slot_save_path(context, slot_save_path):
|
||||||
|
context.slot_save_path = slot_save_path
|
||||||
|
|
||||||
|
|
||||||
|
@step('using slot id {id_slot:d}')
|
||||||
|
def step_id_slot(context, id_slot):
|
||||||
|
context.id_slot = id_slot
|
||||||
|
|
||||||
|
|
||||||
|
@step('prompt caching is enabled')
|
||||||
|
def step_enable_prompt_cache(context):
|
||||||
|
context.cache_prompt = True
|
||||||
|
|
||||||
|
|
||||||
@step('continuous batching')
|
@step('continuous batching')
|
||||||
def step_server_continuous_batching(context):
|
def step_server_continuous_batching(context):
|
||||||
context.server_continuous_batching = True
|
context.server_continuous_batching = True
|
||||||
@ -212,6 +230,8 @@ async def step_request_completion(context, api_error):
|
|||||||
context.base_url,
|
context.base_url,
|
||||||
debug=context.debug,
|
debug=context.debug,
|
||||||
n_predict=context.n_predict,
|
n_predict=context.n_predict,
|
||||||
|
cache_prompt=context.cache_prompt,
|
||||||
|
id_slot=context.id_slot,
|
||||||
seed=await completions_seed(context),
|
seed=await completions_seed(context),
|
||||||
expect_api_error=expect_api_error,
|
expect_api_error=expect_api_error,
|
||||||
user_api_key=context.user_api_key)
|
user_api_key=context.user_api_key)
|
||||||
@ -711,12 +731,48 @@ async def concurrent_requests(context, f_completion, *args, **kwargs):
|
|||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
@step('the slot {slot_id:d} is saved with filename "{filename}"')
|
||||||
|
@async_run_until_complete
|
||||||
|
async def step_save_slot(context, slot_id, filename):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(f'{context.base_url}/slots/{slot_id}?action=save',
|
||||||
|
json={"filename": filename},
|
||||||
|
headers={"Content-Type": "application/json"}) as response:
|
||||||
|
context.response = response
|
||||||
|
|
||||||
|
|
||||||
|
@step('the slot {slot_id:d} is restored with filename "{filename}"')
|
||||||
|
@async_run_until_complete
|
||||||
|
async def step_restore_slot(context, slot_id, filename):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(f'{context.base_url}/slots/{slot_id}?action=restore',
|
||||||
|
json={"filename": filename},
|
||||||
|
headers={"Content-Type": "application/json"}) as response:
|
||||||
|
context.response = response
|
||||||
|
|
||||||
|
|
||||||
|
@step('the slot {slot_id:d} is erased')
|
||||||
|
@async_run_until_complete
|
||||||
|
async def step_erase_slot(context, slot_id):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(f'{context.base_url}/slots/{slot_id}?action=erase',
|
||||||
|
headers={"Content-Type": "application/json"}) as response:
|
||||||
|
context.response = response
|
||||||
|
|
||||||
|
|
||||||
|
@step('the server responds with status code {status_code:d}')
|
||||||
|
def step_server_responds_with_status_code(context, status_code):
|
||||||
|
assert context.response.status == status_code
|
||||||
|
|
||||||
|
|
||||||
async def request_completion(prompt,
|
async def request_completion(prompt,
|
||||||
base_url,
|
base_url,
|
||||||
debug=False,
|
debug=False,
|
||||||
prompt_prefix=None,
|
prompt_prefix=None,
|
||||||
prompt_suffix=None,
|
prompt_suffix=None,
|
||||||
n_predict=None,
|
n_predict=None,
|
||||||
|
cache_prompt=False,
|
||||||
|
id_slot=None,
|
||||||
seed=None,
|
seed=None,
|
||||||
expect_api_error=None,
|
expect_api_error=None,
|
||||||
user_api_key=None):
|
user_api_key=None):
|
||||||
@ -738,6 +794,8 @@ async def request_completion(prompt,
|
|||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"input_suffix": prompt_suffix,
|
"input_suffix": prompt_suffix,
|
||||||
"n_predict": n_predict if n_predict is not None else -1,
|
"n_predict": n_predict if n_predict is not None else -1,
|
||||||
|
"cache_prompt": cache_prompt,
|
||||||
|
"id_slot": id_slot,
|
||||||
"seed": seed if seed is not None else 42
|
"seed": seed if seed is not None else 42
|
||||||
},
|
},
|
||||||
headers=headers,
|
headers=headers,
|
||||||
@ -1104,6 +1162,8 @@ def start_server_background(context):
|
|||||||
server_args.extend(['--parallel', context.n_slots])
|
server_args.extend(['--parallel', context.n_slots])
|
||||||
if context.n_server_predict:
|
if context.n_server_predict:
|
||||||
server_args.extend(['--n-predict', context.n_server_predict])
|
server_args.extend(['--n-predict', context.n_server_predict])
|
||||||
|
if context.slot_save_path:
|
||||||
|
server_args.extend(['--slot-save-path', context.slot_save_path])
|
||||||
if context.server_api_key:
|
if context.server_api_key:
|
||||||
server_args.extend(['--api-key', context.server_api_key])
|
server_args.extend(['--api-key', context.server_api_key])
|
||||||
if context.n_ga:
|
if context.n_ga:
|
||||||
|
463
llama.cpp
463
llama.cpp
@ -14907,9 +14907,33 @@ void llama_kv_cache_update(struct llama_context * ctx) {
|
|||||||
llama_kv_cache_update_internal(*ctx);
|
llama_kv_cache_update_internal(*ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// deprecated
|
||||||
|
size_t llama_get_state_size(const struct llama_context * ctx) {
|
||||||
|
return llama_state_get_size(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
// deprecated
|
||||||
|
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
|
||||||
|
return llama_state_get_data(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
// deprecated
|
||||||
|
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
||||||
|
return llama_state_set_data(ctx, src);
|
||||||
|
}
|
||||||
|
|
||||||
|
// deprecated
|
||||||
|
bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||||
|
return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
|
||||||
|
}
|
||||||
|
|
||||||
|
// deprecated
|
||||||
|
bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
||||||
|
return llama_state_save_file(ctx, path_session, tokens, n_token_count);
|
||||||
|
}
|
||||||
|
|
||||||
// Returns the *maximum* size of the state
|
// Returns the *maximum* size of the state
|
||||||
size_t llama_get_state_size(const struct llama_context * ctx) {
|
size_t llama_state_get_size(const struct llama_context * ctx) {
|
||||||
const auto & cparams = ctx->cparams;
|
const auto & cparams = ctx->cparams;
|
||||||
const auto & hparams = ctx->model.hparams;
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
|
||||||
@ -14997,15 +15021,15 @@ struct llama_data_file_context : llama_data_context {
|
|||||||
* file context:
|
* file context:
|
||||||
* llama_file file("/path", "wb");
|
* llama_file file("/path", "wb");
|
||||||
* llama_data_file_context data_ctx(&file);
|
* llama_data_file_context data_ctx(&file);
|
||||||
* llama_copy_state_data(ctx, &data_ctx);
|
* llama_state_get_data(ctx, &data_ctx);
|
||||||
*
|
*
|
||||||
* buffer context:
|
* buffer context:
|
||||||
* std::vector<uint8_t> buf(max_size, 0);
|
* std::vector<uint8_t> buf(max_size, 0);
|
||||||
* llama_data_buffer_context data_ctx(&buf.data());
|
* llama_data_buffer_context data_ctx(&buf.data());
|
||||||
* llama_copy_state_data(ctx, &data_ctx);
|
* llama_state_get_data(ctx, &data_ctx);
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
|
static void llama_state_get_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
|
||||||
// copy rng
|
// copy rng
|
||||||
{
|
{
|
||||||
std::ostringstream rng_ss;
|
std::ostringstream rng_ss;
|
||||||
@ -15149,15 +15173,15 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
|
size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst) {
|
||||||
llama_data_buffer_context data_ctx(dst);
|
llama_data_buffer_context data_ctx(dst);
|
||||||
llama_copy_state_data_internal(ctx, &data_ctx);
|
llama_state_get_data_internal(ctx, &data_ctx);
|
||||||
|
|
||||||
return data_ctx.get_size_written();
|
return data_ctx.get_size_written();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sets the state reading from the specified source address
|
// Sets the state reading from the specified source address
|
||||||
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
|
||||||
const uint8_t * inp = src;
|
const uint8_t * inp = src;
|
||||||
|
|
||||||
// set rng
|
// set rng
|
||||||
@ -15309,14 +15333,14 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const size_t nread = inp - src;
|
const size_t nread = inp - src;
|
||||||
const size_t max_size = llama_get_state_size(ctx);
|
const size_t max_size = llama_state_get_size(ctx);
|
||||||
|
|
||||||
GGML_ASSERT(nread <= max_size);
|
GGML_ASSERT(nread <= max_size);
|
||||||
|
|
||||||
return nread;
|
return nread;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||||
llama_file file(path_session, "rb");
|
llama_file file(path_session, "rb");
|
||||||
|
|
||||||
// sanity checks
|
// sanity checks
|
||||||
@ -15354,7 +15378,7 @@ static bool llama_load_session_file_internal(struct llama_context * ctx, const c
|
|||||||
// restore the context state
|
// restore the context state
|
||||||
{
|
{
|
||||||
const size_t n_state_size_cur = file.size - file.tell();
|
const size_t n_state_size_cur = file.size - file.tell();
|
||||||
const size_t n_state_size_max = llama_get_state_size(ctx);
|
const size_t n_state_size_max = llama_state_get_size(ctx);
|
||||||
|
|
||||||
if (n_state_size_cur > n_state_size_max) {
|
if (n_state_size_cur > n_state_size_max) {
|
||||||
LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
|
LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
|
||||||
@ -15364,22 +15388,22 @@ static bool llama_load_session_file_internal(struct llama_context * ctx, const c
|
|||||||
std::vector<uint8_t> state_data(n_state_size_max);
|
std::vector<uint8_t> state_data(n_state_size_max);
|
||||||
file.read_raw(state_data.data(), n_state_size_cur);
|
file.read_raw(state_data.data(), n_state_size_cur);
|
||||||
|
|
||||||
llama_set_state_data(ctx, state_data.data());
|
llama_state_set_data(ctx, state_data.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||||
try {
|
try {
|
||||||
return llama_load_session_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
|
return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
LLAMA_LOG_ERROR("error loading session file: %s\n", err.what());
|
LLAMA_LOG_ERROR("error loading session file: %s\n", err.what());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
||||||
llama_file file(path_session, "wb");
|
llama_file file(path_session, "wb");
|
||||||
|
|
||||||
file.write_u32(LLAMA_SESSION_MAGIC);
|
file.write_u32(LLAMA_SESSION_MAGIC);
|
||||||
@ -15393,11 +15417,420 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
|
|||||||
|
|
||||||
// save the context state using stream saving
|
// save the context state using stream saving
|
||||||
llama_data_file_context data_ctx(&file);
|
llama_data_file_context data_ctx(&file);
|
||||||
llama_copy_state_data_internal(ctx, &data_ctx);
|
llama_state_get_data_internal(ctx, &data_ctx);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
||||||
|
try {
|
||||||
|
return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count);
|
||||||
|
} catch (const std::exception & err) {
|
||||||
|
LLAMA_LOG_ERROR("error saving session file: %s\n", err.what());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) {
|
||||||
|
// save the size of size_t as a uint32_t for safety check
|
||||||
|
const size_t size_t_size_size = sizeof(uint32_t);
|
||||||
|
|
||||||
|
// other values
|
||||||
|
const size_t s_cell_count_size = sizeof(uint32_t);
|
||||||
|
const size_t s_layer_count_size = sizeof(uint32_t);
|
||||||
|
const size_t n_embd_v_gqa_size = sizeof(uint32_t);
|
||||||
|
|
||||||
|
size_t s_cell_count = 0;
|
||||||
|
size_t s_cell_data_size = 0;
|
||||||
|
const auto & kv_self = ctx->kv_self;
|
||||||
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
|
||||||
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
|
||||||
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
||||||
|
const auto & cell = kv_self.cells[i];
|
||||||
|
if (cell.seq_id.count(seq_id) > 0) {
|
||||||
|
++s_cell_count;
|
||||||
|
s_cell_data_size += sizeof(llama_pos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int il = 0; il < (int)n_layer; ++il) {
|
||||||
|
// types of keys and values
|
||||||
|
s_cell_data_size += sizeof(int32_t) * 2;
|
||||||
|
// k_size_row and v_size_el values of layer
|
||||||
|
s_cell_data_size += sizeof(size_t) * 2;
|
||||||
|
|
||||||
|
// keys
|
||||||
|
const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
|
||||||
|
s_cell_data_size += k_size_row * s_cell_count;
|
||||||
|
|
||||||
|
// values (transposed)
|
||||||
|
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
|
||||||
|
s_cell_data_size += v_size_el * s_cell_count * n_embd_v_gqa;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t s_total = (
|
||||||
|
size_t_size_size +
|
||||||
|
s_cell_count_size +
|
||||||
|
s_layer_count_size +
|
||||||
|
n_embd_v_gqa_size +
|
||||||
|
s_cell_data_size
|
||||||
|
);
|
||||||
|
|
||||||
|
return s_total;
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx, llama_seq_id seq_id) {
|
||||||
|
const auto & kv_self = ctx->kv_self;
|
||||||
|
GGML_ASSERT(!kv_self.recurrent); // not implemented
|
||||||
|
|
||||||
|
// Save the size of size_t as a uint32_t for safety check
|
||||||
|
const uint32_t size_t_size = sizeof(size_t);
|
||||||
|
data_ctx.write(&size_t_size, sizeof(size_t_size));
|
||||||
|
|
||||||
|
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
||||||
|
uint32_t cell_count = 0;
|
||||||
|
|
||||||
|
// Count the number of cells with the specified seq_id
|
||||||
|
// Find all the ranges of cells with this seq id
|
||||||
|
{
|
||||||
|
uint32_t cell_range_begin = kv_self.size;
|
||||||
|
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
||||||
|
const auto & cell = kv_self.cells[i];
|
||||||
|
if (cell.has_seq_id(seq_id)) {
|
||||||
|
++cell_count;
|
||||||
|
if (cell_range_begin == kv_self.size) {
|
||||||
|
cell_range_begin = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (cell_range_begin != kv_self.size) {
|
||||||
|
cell_ranges.push_back({ cell_range_begin, i });
|
||||||
|
cell_range_begin = kv_self.size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (cell_range_begin != kv_self.size) {
|
||||||
|
cell_ranges.push_back({ cell_range_begin, kv_self.size });
|
||||||
|
}
|
||||||
|
|
||||||
|
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
||||||
|
uint32_t cell_count_check = 0;
|
||||||
|
for (const auto & range : cell_ranges) {
|
||||||
|
cell_count_check += range.second - range.first;
|
||||||
|
}
|
||||||
|
GGML_ASSERT(cell_count == cell_count_check);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the cell count
|
||||||
|
data_ctx.write(&cell_count, sizeof(cell_count));
|
||||||
|
|
||||||
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
|
||||||
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
|
||||||
|
|
||||||
|
// Write the layer count
|
||||||
|
data_ctx.write(&n_layer, sizeof(n_layer));
|
||||||
|
|
||||||
|
// Write n_embd_v_gqa
|
||||||
|
data_ctx.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
|
||||||
|
|
||||||
|
// Iterate the ranges and write all the pos (this is the token position in the prompt)
|
||||||
|
for (const auto & range : cell_ranges) {
|
||||||
|
for (uint32_t i = range.first; i < range.second; ++i) {
|
||||||
|
const auto & cell = kv_self.cells[i];
|
||||||
|
data_ctx.write(&cell.pos, sizeof(cell.pos));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate and write all the keys first, each row is a cell
|
||||||
|
// Get whole range at a time
|
||||||
|
std::vector<uint8_t> tmp_buf;
|
||||||
|
for (int il = 0; il < (int)n_layer; ++il) {
|
||||||
|
// Write key type
|
||||||
|
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
|
||||||
|
data_ctx.write(&k_type_i, sizeof(k_type_i));
|
||||||
|
|
||||||
|
// Write row size of key
|
||||||
|
const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
|
||||||
|
data_ctx.write(&k_size_row, sizeof(k_size_row));
|
||||||
|
|
||||||
|
// Read each range of cells of k_size length each into tmp_buf and write out
|
||||||
|
for (const auto & range : cell_ranges) {
|
||||||
|
const size_t range_size = range.second - range.first;
|
||||||
|
tmp_buf.resize(range_size * k_size_row);
|
||||||
|
ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), range.first * k_size_row, range_size * k_size_row);
|
||||||
|
data_ctx.write(tmp_buf.data(), tmp_buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For the values, they are transposed, so we also need the element size and get the element ranges from each row
|
||||||
|
const uint32_t kv_size = kv_self.size;
|
||||||
|
for (int il = 0; il < (int)n_layer; ++il) {
|
||||||
|
// Write value type
|
||||||
|
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
||||||
|
data_ctx.write(&v_type_i, sizeof(v_type_i));
|
||||||
|
|
||||||
|
// Write element size
|
||||||
|
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
|
||||||
|
data_ctx.write(&v_size_el, sizeof(v_size_el));
|
||||||
|
|
||||||
|
// For each row, we get the element values of each cell
|
||||||
|
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||||
|
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
||||||
|
for (const auto & range : cell_ranges) {
|
||||||
|
const size_t range_size = range.second - range.first;
|
||||||
|
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
||||||
|
tmp_buf.resize(range_size * v_size_el);
|
||||||
|
ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
|
||||||
|
data_ctx.write(tmp_buf.data(), tmp_buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return data_ctx.get_size_written();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t llama_state_seq_get_data(struct llama_context* ctx, uint8_t* dst, llama_seq_id seq_id) {
|
||||||
|
llama_data_buffer_context data_ctx(dst);
|
||||||
|
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) {
|
||||||
|
auto & kv_self = ctx->kv_self;
|
||||||
|
GGML_ASSERT(!kv_self.recurrent); // not implemented
|
||||||
|
|
||||||
|
// Wipe the slot
|
||||||
|
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
||||||
|
|
||||||
|
const uint8_t * inp = src;
|
||||||
|
|
||||||
|
// Read size of size_t
|
||||||
|
uint32_t size_t_size;
|
||||||
|
memcpy(&size_t_size, inp, sizeof(size_t_size));
|
||||||
|
inp += sizeof(size_t_size);
|
||||||
|
if (size_t_size != sizeof(size_t)) {
|
||||||
|
LLAMA_LOG_ERROR("%s: size_t size mismatch\n", __func__);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the cell count
|
||||||
|
uint32_t cell_count;
|
||||||
|
memcpy(&cell_count, inp, sizeof(cell_count));
|
||||||
|
inp += sizeof(cell_count);
|
||||||
|
|
||||||
|
// Read the layer count
|
||||||
|
uint32_t n_layer_ref;
|
||||||
|
memcpy(&n_layer_ref, inp, sizeof(n_layer_ref));
|
||||||
|
inp += sizeof(n_layer_ref);
|
||||||
|
|
||||||
|
// Read n_embd_v_gqa
|
||||||
|
uint32_t n_embd_v_gqa_ref;
|
||||||
|
memcpy(&n_embd_v_gqa_ref, inp, sizeof(n_embd_v_gqa_ref));
|
||||||
|
inp += sizeof(n_embd_v_gqa_ref);
|
||||||
|
|
||||||
|
// Sanity check model compatibility
|
||||||
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
|
||||||
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
|
||||||
|
if (n_layer != n_layer_ref) {
|
||||||
|
LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
|
||||||
|
LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate the new cells for the slot
|
||||||
|
if (cell_count) {
|
||||||
|
llama_batch batch = llama_batch_init(cell_count, 0, 1);
|
||||||
|
batch.n_tokens = cell_count;
|
||||||
|
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||||
|
llama_pos pos;
|
||||||
|
memcpy(&pos, inp, sizeof(pos));
|
||||||
|
inp += sizeof(pos);
|
||||||
|
|
||||||
|
batch.pos[i] = pos;
|
||||||
|
batch.n_seq_id[i] = 1;
|
||||||
|
batch.seq_id[i][0] = dest_seq_id;
|
||||||
|
}
|
||||||
|
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||||
|
llama_batch_free(batch);
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
||||||
|
// Assume that this is one contiguous block of cells
|
||||||
|
GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
|
||||||
|
GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
|
||||||
|
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
||||||
|
GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
|
||||||
|
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
llama_batch_free(batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t kv_size = kv_self.size;
|
||||||
|
const uint32_t kv_head = kv_self.head;
|
||||||
|
|
||||||
|
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo
|
||||||
|
for (int il = 0; il < (int)n_layer; ++il) {
|
||||||
|
// Read type of key
|
||||||
|
int32_t k_type_i_ref;
|
||||||
|
memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref));
|
||||||
|
inp += sizeof(k_type_i_ref);
|
||||||
|
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
|
||||||
|
if (k_type_i != k_type_i_ref) {
|
||||||
|
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
||||||
|
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read row size of key
|
||||||
|
size_t k_size_row_ref;
|
||||||
|
memcpy(&k_size_row_ref, inp, sizeof(k_size_row_ref));
|
||||||
|
inp += sizeof(k_size_row_ref);
|
||||||
|
const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
|
||||||
|
if (k_size_row != k_size_row_ref) {
|
||||||
|
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
||||||
|
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, k_size_row_ref, il);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cell_count) {
|
||||||
|
// Read and set the keys for the whole cell range
|
||||||
|
ggml_backend_tensor_set(kv_self.k_l[il], inp, kv_head * k_size_row, cell_count * k_size_row);
|
||||||
|
inp += cell_count * k_size_row;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For each layer, read the values for each cell (transposed)
|
||||||
|
for (int il = 0; il < (int)n_layer; ++il) {
|
||||||
|
// Read type of value
|
||||||
|
int32_t v_type_i_ref;
|
||||||
|
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
|
||||||
|
inp += sizeof(v_type_i_ref);
|
||||||
|
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
||||||
|
if (v_type_i != v_type_i_ref) {
|
||||||
|
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
||||||
|
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read element size of value
|
||||||
|
size_t v_size_el_ref;
|
||||||
|
memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
|
||||||
|
inp += sizeof(v_size_el_ref);
|
||||||
|
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
|
||||||
|
if (v_size_el != v_size_el_ref) {
|
||||||
|
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
||||||
|
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cell_count) {
|
||||||
|
// For each row in the transposed matrix, read the values for the whole cell range
|
||||||
|
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||||
|
const size_t dst_offset = (kv_head + j * kv_size) * v_size_el;
|
||||||
|
ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el);
|
||||||
|
inp += cell_count * v_size_el;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t nread = inp - src;
|
||||||
|
return nread;
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
|
||||||
|
llama_file file(filepath, "wb");
|
||||||
|
|
||||||
|
file.write_u32(LLAMA_STATE_SEQ_MAGIC);
|
||||||
|
file.write_u32(LLAMA_STATE_SEQ_VERSION);
|
||||||
|
|
||||||
|
// save the prompt
|
||||||
|
file.write_u32((uint32_t)n_token_count);
|
||||||
|
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
||||||
|
|
||||||
|
// save the context state using stream saving
|
||||||
|
llama_data_file_context data_ctx(&file);
|
||||||
|
llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
||||||
|
|
||||||
|
const size_t res = file.tell();
|
||||||
|
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written());
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||||
|
llama_file file(filepath, "rb");
|
||||||
|
|
||||||
|
// version checks
|
||||||
|
{
|
||||||
|
const uint32_t magic = file.read_u32();
|
||||||
|
const uint32_t version = file.read_u32();
|
||||||
|
|
||||||
|
if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
|
||||||
|
LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// load the prompt
|
||||||
|
{
|
||||||
|
const uint32_t n_token_count = file.read_u32();
|
||||||
|
|
||||||
|
if (n_token_count > n_token_capacity) {
|
||||||
|
LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
|
||||||
|
*n_token_count_out = n_token_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
// restore the context state
|
||||||
|
{
|
||||||
|
const size_t state_size = file.size - file.tell();
|
||||||
|
std::vector<uint8_t> state_data(state_size);
|
||||||
|
file.read_raw(state_data.data(), state_size);
|
||||||
|
const size_t nread = llama_state_seq_set_data(ctx, state_data.data(), dest_seq_id);
|
||||||
|
if (!nread) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
GGML_ASSERT(nread <= state_size);
|
||||||
|
GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
|
||||||
|
}
|
||||||
|
|
||||||
|
return file.tell();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
|
||||||
|
try {
|
||||||
|
return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count);
|
||||||
|
} catch (const std::exception & err) {
|
||||||
|
LLAMA_LOG_ERROR("error saving sequence state file: %s\n", err.what());
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||||
|
try {
|
||||||
|
return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out);
|
||||||
|
} catch (const std::exception & err) {
|
||||||
|
LLAMA_LOG_ERROR("error loading sequence state file: %s\n", err.what());
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) {
|
void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) {
|
||||||
ctx->cparams.n_threads = n_threads;
|
ctx->cparams.n_threads = n_threads;
|
||||||
ctx->cparams.n_threads_batch = n_threads_batch;
|
ctx->cparams.n_threads_batch = n_threads_batch;
|
||||||
|
73
llama.h
73
llama.h
@ -37,10 +37,14 @@
|
|||||||
|
|
||||||
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
||||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||||
|
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
||||||
|
|
||||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||||
#define LLAMA_SESSION_VERSION 5
|
#define LLAMA_SESSION_VERSION 5
|
||||||
|
|
||||||
|
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||||
|
#define LLAMA_STATE_SEQ_VERSION 1
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
@ -523,6 +527,7 @@ extern "C" {
|
|||||||
struct llama_context * ctx);
|
struct llama_context * ctx);
|
||||||
|
|
||||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
|
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
||||||
// seq_id < 0 : match any sequence
|
// seq_id < 0 : match any sequence
|
||||||
// p0 < 0 : [0, p1]
|
// p0 < 0 : [0, p1]
|
||||||
// p1 < 0 : [p0, inf)
|
// p1 < 0 : [p0, inf)
|
||||||
@ -594,34 +599,92 @@ extern "C" {
|
|||||||
|
|
||||||
// Returns the maximum size in bytes of the state (rng, logits, embedding
|
// Returns the maximum size in bytes of the state (rng, logits, embedding
|
||||||
// and kv_cache) - will often be smaller after compacting tokens
|
// and kv_cache) - will often be smaller after compacting tokens
|
||||||
LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
|
LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
|
||||||
|
LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx),
|
||||||
|
"use llama_state_get_size instead");
|
||||||
|
|
||||||
// Copies the state to the specified destination address.
|
// Copies the state to the specified destination address.
|
||||||
// Destination needs to have allocated enough memory.
|
// Destination needs to have allocated enough memory.
|
||||||
// Returns the number of bytes copied
|
// Returns the number of bytes copied
|
||||||
LLAMA_API size_t llama_copy_state_data(
|
LLAMA_API size_t llama_state_get_data(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
uint8_t * dst);
|
uint8_t * dst);
|
||||||
|
LLAMA_API DEPRECATED(size_t llama_copy_state_data(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
uint8_t * dst),
|
||||||
|
"use llama_state_get_data instead");
|
||||||
|
|
||||||
// Set the state reading from the specified address
|
// Set the state reading from the specified address
|
||||||
// Returns the number of bytes read
|
// Returns the number of bytes read
|
||||||
LLAMA_API size_t llama_set_state_data(
|
LLAMA_API size_t llama_state_set_data(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const uint8_t * src);
|
const uint8_t * src);
|
||||||
|
LLAMA_API DEPRECATED(size_t llama_set_state_data(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
const uint8_t * src),
|
||||||
|
"use llama_state_set_data instead");
|
||||||
|
|
||||||
// Save/load session file
|
// Save/load session file
|
||||||
LLAMA_API bool llama_load_session_file(
|
LLAMA_API bool llama_state_load_file(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const char * path_session,
|
const char * path_session,
|
||||||
llama_token * tokens_out,
|
llama_token * tokens_out,
|
||||||
size_t n_token_capacity,
|
size_t n_token_capacity,
|
||||||
size_t * n_token_count_out);
|
size_t * n_token_count_out);
|
||||||
|
LLAMA_API DEPRECATED(bool llama_load_session_file(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
const char * path_session,
|
||||||
|
llama_token * tokens_out,
|
||||||
|
size_t n_token_capacity,
|
||||||
|
size_t * n_token_count_out),
|
||||||
|
"use llama_state_load_file instead");
|
||||||
|
|
||||||
LLAMA_API bool llama_save_session_file(
|
LLAMA_API bool llama_state_save_file(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const char * path_session,
|
const char * path_session,
|
||||||
const llama_token * tokens,
|
const llama_token * tokens,
|
||||||
size_t n_token_count);
|
size_t n_token_count);
|
||||||
|
LLAMA_API DEPRECATED(bool llama_save_session_file(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
const char * path_session,
|
||||||
|
const llama_token * tokens,
|
||||||
|
size_t n_token_count),
|
||||||
|
"use llama_state_save_file instead");
|
||||||
|
|
||||||
|
// Get the exact size needed to copy the KV cache of a single sequence
|
||||||
|
LLAMA_API size_t llama_state_seq_get_size(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
|
// Copy the KV cache of a single sequence into the specified buffer
|
||||||
|
LLAMA_API size_t llama_state_seq_get_data(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
uint8_t * dst,
|
||||||
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
|
// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
|
||||||
|
// Returns:
|
||||||
|
// - Positive: Ok
|
||||||
|
// - Zero: Failed to load
|
||||||
|
LLAMA_API size_t llama_state_seq_set_data(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
const uint8_t * src,
|
||||||
|
llama_seq_id dest_seq_id);
|
||||||
|
|
||||||
|
LLAMA_API size_t llama_state_seq_save_file(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
const char * filepath,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
const llama_token * tokens,
|
||||||
|
size_t n_token_count);
|
||||||
|
|
||||||
|
LLAMA_API size_t llama_state_seq_load_file(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
const char * filepath,
|
||||||
|
llama_seq_id dest_seq_id,
|
||||||
|
llama_token * tokens_out,
|
||||||
|
size_t n_token_capacity,
|
||||||
|
size_t * n_token_count_out);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Decoding
|
// Decoding
|
||||||
|
Loading…
Reference in New Issue
Block a user