context : minor

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-15 10:54:21 +02:00
parent e7f2dc8bc4
commit 40b521b656
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
5 changed files with 37 additions and 47 deletions

View File

@ -8,30 +8,6 @@
#include <cstring> #include <cstring>
#include <stdexcept> #include <stdexcept>
void llama_set_k_shift(struct llama_context & lctx) {
const int64_t kv_size = lctx.kv_self.size;
assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
int32_t * data = (int32_t *) lctx.inp_K_shift->data;
for (int i = 0; i < kv_size; ++i) {
data[i] = lctx.kv_self.cells[i].delta;
}
}
void llama_set_s_copy(struct llama_context & lctx) {
const int64_t kv_size = lctx.kv_self.size;
assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
int32_t * data = (int32_t *) lctx.inp_s_copy->data;
for (int i = 0; i < kv_size; ++i) {
data[i] = lctx.kv_self.cells[i].src;
}
}
// llama input // llama input
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
@ -58,6 +34,16 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
return relative_bucket; return relative_bucket;
} }
void llama_context::set_k_shift(llama_kv_cache & kv) {
assert(ggml_backend_buffer_is_host(inp_K_shift->buffer));
int32_t * data = (int32_t *) inp_K_shift->data;
for (uint32_t i = 0; i < kv.size; ++i) {
data[i] = kv.cells[i].delta;
}
}
void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) { void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
// //
// set input data // set input data
@ -134,7 +120,6 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
const int64_t n_seq_tokens = ubatch.n_seq_tokens; const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seqs = ubatch.n_seqs;
float * data = nullptr; float * data = nullptr;
float * data_swa = nullptr; float * data_swa = nullptr;
@ -599,6 +584,7 @@ uint32_t llama_n_ubatch(const struct llama_context * ctx) {
} }
uint32_t llama_n_seq_max(const struct llama_context * ctx) { uint32_t llama_n_seq_max(const struct llama_context * ctx) {
// TODO: add notion of n_seq_max to llama_kv_cache and use it here
return ctx->kv_self.size; return ctx->kv_self.size;
} }

View File

@ -107,13 +107,11 @@ struct llama_context {
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
void set_k_shift(llama_kv_cache & kv);
}; };
// TODO: make these methods of llama_context // TODO: make these methods of llama_context
void llama_set_k_shift(struct llama_context & lctx);
void llama_set_s_copy(struct llama_context & lctx);
void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch); void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch);
// Make sure enough space is available for outputs. // Make sure enough space is available for outputs.

View File

@ -6,6 +6,7 @@
#include "llama-model.h" #include "llama-model.h"
#include <algorithm> #include <algorithm>
#include <cassert>
#include <limits> #include <limits>
#include <map> #include <map>
#include <stdexcept> #include <stdexcept>

View File

@ -1176,7 +1176,9 @@ struct llm_build_context {
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
const int64_t n_head_kv = hparams.n_head_kv(il); const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
struct ggml_tensor * rope_factors = build_rope_factors(il); struct ggml_tensor * rope_factors = build_rope_factors(il);
struct ggml_tensor * k = struct ggml_tensor * k =
ggml_view_3d(ctx0, kv_self.k_l[il], ggml_view_3d(ctx0, kv_self.k_l[il],
n_embd_head_k, n_head_kv, n_ctx, n_embd_head_k, n_head_kv, n_ctx,
@ -1189,6 +1191,7 @@ struct llm_build_context {
// dequantize to f32 -> RoPE -> quantize back // dequantize to f32 -> RoPE -> quantize back
tmp = ggml_cast(ctx0, k, GGML_TYPE_F32); tmp = ggml_cast(ctx0, k, GGML_TYPE_F32);
cb(tmp, "K_f32", il); cb(tmp, "K_f32", il);
for (auto & backend : lctx.backends) { for (auto & backend : lctx.backends) {
// Figure out which backend KV cache belongs to // Figure out which backend KV cache belongs to
if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(kv_self.k_l[il]->buffer))) { if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(kv_self.k_l[il]->buffer))) {
@ -1200,6 +1203,7 @@ struct llm_build_context {
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
cb(tmp, "K_shifted_f32", il); cb(tmp, "K_shifted_f32", il);
tmp = ggml_cpy(ctx0, tmp, k); tmp = ggml_cpy(ctx0, tmp, k);
} else { } else {
// we rotate only the first n_rot dimensions // we rotate only the first n_rot dimensions
@ -1208,6 +1212,7 @@ struct llm_build_context {
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
} }
cb(tmp, "K_shifted", il); cb(tmp, "K_shifted", il);
ggml_build_forward_expand(gf, tmp); ggml_build_forward_expand(gf, tmp);
} }
@ -9201,7 +9206,7 @@ static void llama_kv_self_update_impl(llama_context & lctx) {
ggml_backend_sched_alloc_graph(lctx.sched.get(), gf); ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
llama_set_k_shift(lctx); lctx.set_k_shift(kv);
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool); llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);