mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-26 12:21:40 +01:00
llama : fix session saving/loading (#3400)
* llama : fix session saving/loading * llama : temp fix for clearing "future" tokens from the KV cache * llama : fix handling of "future" tokens when loading sessions * llama : fix comments for llama_kv_cache API
This commit is contained in:
parent
48be797ffb
commit
ac2219fef3
@ -9,7 +9,7 @@ if [[ -z "${PROMPT_CACHE_FILE+x}" || -z "${CHAT_SAVE_DIR+x}" ]]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
MODEL="${MODEL:-./models/13B/ggml-model-q4_0.bin}"
|
||||
MODEL="${MODEL:-./models/llama-13b/ggml-model-q4_0.gguf}"
|
||||
PROMPT_TEMPLATE="${PROMPT_TEMPLATE:-./prompts/chat.txt}"
|
||||
USER_NAME="${USER_NAME:-User}"
|
||||
AI_NAME="${AI_NAME:-ChatLLaMa}"
|
||||
@ -61,9 +61,9 @@ fi
|
||||
|
||||
if [[ ! -e "$PROMPT_CACHE_FILE" ]]; then
|
||||
echo 'Prompt cache does not exist, building...'
|
||||
# Default batch_size to 8 here for better user feedback during initial prompt processing
|
||||
# Default batch_size to 64 here for better user feedback during initial prompt processing
|
||||
./main 2>>"$LOG" \
|
||||
--batch_size 8 \
|
||||
--batch_size 64 \
|
||||
"${OPTS[@]}" \
|
||||
--prompt-cache "$PROMPT_CACHE_FILE" \
|
||||
--file "$CUR_PROMPT_FILE" \
|
||||
@ -132,7 +132,7 @@ while read -e line; do
|
||||
# HACK get num tokens from debug message
|
||||
# TODO get both messages in one go
|
||||
if ! session_size_msg="$(tail -n30 "$LOG" | grep -oE "$SESSION_SIZE_MSG_PATTERN")" ||
|
||||
! sample_time_msg="$( tail -n10 "$LOG" | grep -oE "$SAMPLE_TIME_MSG_PATTERN")"; then
|
||||
! sample_time_msg="$(tail -n10 "$LOG" | grep -oE "$SAMPLE_TIME_MSG_PATTERN")"; then
|
||||
echo >&2 "Couldn't get number of tokens from ./main output!"
|
||||
exit 1
|
||||
fi
|
||||
|
@ -543,6 +543,9 @@ int main(int argc, char ** argv) {
|
||||
if (i > 0) {
|
||||
embd.erase(embd.begin(), embd.begin() + i);
|
||||
}
|
||||
|
||||
// remove any "future" tokens that we might have inherited from the session from the KV cache
|
||||
llama_kv_cache_tokens_rm(ctx, n_past, -1);
|
||||
}
|
||||
|
||||
// evaluate tokens in batches
|
||||
|
@ -332,7 +332,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
|
||||
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, n_ctx);
|
||||
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1);
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
|
@ -448,7 +448,7 @@ struct llama_server_context
|
||||
n_past = common_part(embd, prompt_tokens);
|
||||
|
||||
// since #3228 we now have to manually manage the KV cache
|
||||
llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx);
|
||||
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
|
||||
|
||||
embd = prompt_tokens;
|
||||
if (n_past == num_prompt_tokens)
|
||||
|
@ -172,7 +172,7 @@ int main(int argc, char ** argv) {
|
||||
LOG("out of drafted tokens\n");
|
||||
}
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx);
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0));
|
||||
++n_past_dft;
|
||||
|
||||
@ -257,7 +257,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// evaluate the drafted token on the draft model
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx);
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, -1);
|
||||
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0));
|
||||
++n_past_cur;
|
||||
|
||||
@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// evaluate the target model on the drafted tokens
|
||||
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx);
|
||||
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, -1);
|
||||
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0));
|
||||
++n_past_tgt;
|
||||
|
||||
|
134
llama.cpp
134
llama.cpp
@ -1283,8 +1283,8 @@ static bool llama_kv_cache_init(
|
||||
// find an empty slot of size "n_tokens" in the cache
|
||||
// updates the cache head
|
||||
static bool llama_kv_cache_find_slot(
|
||||
struct llama_kv_cache & cache,
|
||||
const struct llama_batch & batch) {
|
||||
struct llama_kv_cache & cache,
|
||||
const struct llama_batch & batch) {
|
||||
const uint32_t n_ctx = cache.size;
|
||||
const uint32_t n_tokens = batch.n_tokens;
|
||||
|
||||
@ -1352,10 +1352,13 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0,
|
||||
}
|
||||
|
||||
static void llama_kv_cache_seq_rm(
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
||||
cache.cells[i].seq_id.erase(seq_id);
|
||||
@ -1367,11 +1370,14 @@ static void llama_kv_cache_seq_rm(
|
||||
}
|
||||
|
||||
static void llama_kv_cache_seq_cp(
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
||||
cache.cells[i].seq_id.insert(seq_id_dst);
|
||||
@ -1389,11 +1395,14 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
|
||||
}
|
||||
|
||||
static void llama_kv_cache_seq_shift(
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta) {
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta) {
|
||||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
||||
cache.cells[i].pos += delta;
|
||||
@ -7209,16 +7218,6 @@ struct llama_data_file_context : llama_data_context {
|
||||
*
|
||||
*/
|
||||
static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
|
||||
// TODO: does not support multi-sequence states
|
||||
{
|
||||
const auto & kv_self = ctx->kv_self;
|
||||
for (uint32_t i = 0; i < kv_self.head; ++i) {
|
||||
GGML_ASSERT(kv_self.cells[i].pos == (int32_t) i);
|
||||
GGML_ASSERT(kv_self.cells[i].seq_id.size() == 1);
|
||||
GGML_ASSERT(kv_self.cells[i].has_seq_id(0));
|
||||
}
|
||||
}
|
||||
|
||||
// copy rng
|
||||
{
|
||||
std::stringstream rng_ss;
|
||||
@ -7271,36 +7270,38 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
||||
const auto & hparams = ctx->model.hparams;
|
||||
const auto & cparams = ctx->cparams;
|
||||
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_embd = hparams.n_embd_gqa();
|
||||
const int n_ctx = cparams.n_ctx;
|
||||
const auto n_layer = hparams.n_layer;
|
||||
const auto n_embd = hparams.n_embd_gqa();
|
||||
const auto n_ctx = cparams.n_ctx;
|
||||
|
||||
const size_t kv_size = kv_self.buf.size;
|
||||
const int kv_ntok = kv_self.head;
|
||||
const size_t kv_buf_size = kv_self.buf.size;
|
||||
const uint32_t kv_head = kv_self.head;
|
||||
const uint32_t kv_size = kv_self.size;
|
||||
|
||||
data_ctx->write(&kv_size, sizeof(kv_size));
|
||||
data_ctx->write(&kv_ntok, sizeof(kv_ntok));
|
||||
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
|
||||
data_ctx->write(&kv_head, sizeof(kv_head));
|
||||
data_ctx->write(&kv_size, sizeof(kv_size));
|
||||
|
||||
if (kv_size) {
|
||||
if (kv_buf_size) {
|
||||
const size_t elt_size = ggml_element_size(kv_self.k);
|
||||
|
||||
ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
|
||||
ggml_cgraph gf{};
|
||||
|
||||
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
|
||||
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
|
||||
std::vector<uint8_t> kout3d_data(ggml_nbytes(kout3d), 0);
|
||||
kout3d->data = kout3d_data.data();
|
||||
|
||||
ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
|
||||
ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer);
|
||||
std::vector<uint8_t> vout3d_data(ggml_nbytes(vout3d), 0);
|
||||
vout3d->data = vout3d_data.data();
|
||||
|
||||
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
|
||||
n_embd, kv_ntok, n_layer,
|
||||
n_embd, kv_head, n_layer,
|
||||
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
|
||||
|
||||
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
|
||||
kv_ntok, n_embd, n_layer,
|
||||
kv_head, n_embd, n_layer,
|
||||
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
|
||||
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
|
||||
@ -7314,6 +7315,20 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
||||
data_ctx->write(kout3d_data.data(), kout3d_data.size());
|
||||
data_ctx->write(vout3d_data.data(), vout3d_data.size());
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||
const auto & cell = kv_self.cells[i];
|
||||
|
||||
const llama_pos pos = cell.pos;
|
||||
const size_t seq_id_size = cell.seq_id.size();
|
||||
|
||||
data_ctx->write(&pos, sizeof(pos));
|
||||
data_ctx->write(&seq_id_size, sizeof(seq_id_size));
|
||||
|
||||
for (auto seq_id : cell.seq_id) {
|
||||
data_ctx->write(&seq_id, sizeof(seq_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -7385,34 +7400,36 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
|
||||
const int n_embd = hparams.n_embd_gqa();
|
||||
const int n_ctx = cparams.n_ctx;
|
||||
|
||||
size_t kv_size;
|
||||
int kv_ntok;
|
||||
size_t kv_buf_size;
|
||||
uint32_t kv_head;
|
||||
uint32_t kv_size;
|
||||
|
||||
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
|
||||
memcpy(&kv_ntok, inp, sizeof(kv_ntok)); inp += sizeof(kv_ntok);
|
||||
memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
|
||||
memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head);
|
||||
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
|
||||
|
||||
if (kv_size) {
|
||||
GGML_ASSERT(kv_self.buf.size == kv_size);
|
||||
if (kv_buf_size) {
|
||||
GGML_ASSERT(kv_self.buf.size == kv_buf_size);
|
||||
|
||||
const size_t elt_size = ggml_element_size(kv_self.k);
|
||||
|
||||
ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
|
||||
ggml_cgraph gf{};
|
||||
|
||||
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
|
||||
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
|
||||
kin3d->data = (void *) inp;
|
||||
inp += ggml_nbytes(kin3d);
|
||||
|
||||
ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
|
||||
ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer);
|
||||
vin3d->data = (void *) inp;
|
||||
inp += ggml_nbytes(vin3d);
|
||||
|
||||
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
|
||||
n_embd, kv_ntok, n_layer,
|
||||
n_embd, kv_head, n_layer,
|
||||
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
|
||||
|
||||
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
|
||||
kv_ntok, n_embd, n_layer,
|
||||
kv_head, n_embd, n_layer,
|
||||
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
|
||||
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
|
||||
@ -7422,8 +7439,27 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
|
||||
ggml_free(cpy_ctx);
|
||||
}
|
||||
|
||||
ctx->kv_self.head = kv_ntok;
|
||||
ctx->kv_self.head = kv_head;
|
||||
ctx->kv_self.size = kv_size;
|
||||
|
||||
ctx->kv_self.cells.resize(kv_size);
|
||||
|
||||
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||
llama_pos pos;
|
||||
size_t seq_id_size;
|
||||
|
||||
memcpy(&pos, inp, sizeof(pos)); inp += sizeof(pos);
|
||||
memcpy(&seq_id_size, inp, sizeof(seq_id_size)); inp += sizeof(seq_id_size);
|
||||
|
||||
ctx->kv_self.cells[i].pos = pos;
|
||||
|
||||
llama_seq_id seq_id;
|
||||
|
||||
for (size_t j = 0; j < seq_id_size; ++j) {
|
||||
memcpy(&seq_id, inp, sizeof(seq_id)); inp += sizeof(seq_id);
|
||||
ctx->kv_self.cells[i].seq_id.insert(seq_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const size_t nread = inp - src;
|
||||
|
10
llama.h
10
llama.h
@ -42,7 +42,7 @@
|
||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||
|
||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||
#define LLAMA_SESSION_VERSION 1
|
||||
#define LLAMA_SESSION_VERSION 2
|
||||
|
||||
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
|
||||
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
|
||||
@ -333,12 +333,16 @@ extern "C" {
|
||||
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
|
||||
|
||||
// Remove all tokens data of cells in [c0, c1)
|
||||
// c0 < 0 : [0, c1]
|
||||
// c1 < 0 : [c0, inf)
|
||||
LLAMA_API void llama_kv_cache_tokens_rm(
|
||||
struct llama_context * ctx,
|
||||
int32_t c0,
|
||||
int32_t c1);
|
||||
|
||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_cache_seq_rm(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
@ -347,6 +351,8 @@ extern "C" {
|
||||
|
||||
// Copy all tokens that belong to the specified sequence to another sequence
|
||||
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_cache_seq_cp(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
@ -361,6 +367,8 @@ extern "C" {
|
||||
|
||||
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_cache_seq_shift(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
|
Loading…
Reference in New Issue
Block a user