llama : offload K shift tensors

This commit is contained in:
Georgi Gerganov 2023-12-03 17:43:04 +02:00
parent 986b3da76a
commit f3dbfb9f60
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -3543,8 +3543,8 @@ static void llm_build_k_shift(
GGML_ASSERT(n_embd_head % n_rot == 0); GGML_ASSERT(n_embd_head % n_rot == 0);
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx); struct ggml_tensor * K_shift_host = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx);
cb(K_shift, "K_shift", -1); cb(K_shift_host, "K_shift_host", -1);
int rope_type = 0; int rope_type = 0;
@ -3555,6 +3555,10 @@ static void llm_build_k_shift(
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
// offloaded mirrors
struct ggml_tensor * K_shift = ggml_view_tensor(ctx, K_shift_host);
cb(K_shift, "K_shift", il);
struct ggml_tensor * tmp = struct ggml_tensor * tmp =
// we rotate only the first n_rot dimensions // we rotate only the first n_rot dimensions
ggml_rope_custom_inplace(ctx, ggml_rope_custom_inplace(ctx,
@ -5196,6 +5200,8 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
{ "inp_pos_host", OFFLOAD_FUNC_NOP }, // this is often used for KQ ops (e.g. rope) { "inp_pos_host", OFFLOAD_FUNC_NOP }, // this is often used for KQ ops (e.g. rope)
{ "KQ_scale_host", OFFLOAD_FUNC_NOP }, { "KQ_scale_host", OFFLOAD_FUNC_NOP },
{ "KQ_mask_host", OFFLOAD_FUNC_NOP }, { "KQ_mask_host", OFFLOAD_FUNC_NOP },
{ "K_shift_host", OFFLOAD_FUNC_NOP },
{ "inp_pos", OFFLOAD_FUNC }, // these are offloaded versions of the tensors { "inp_pos", OFFLOAD_FUNC }, // these are offloaded versions of the tensors
{ "KQ_scale", OFFLOAD_FUNC }, { "KQ_scale", OFFLOAD_FUNC },
{ "KQ_mask", OFFLOAD_FUNC }, { "KQ_mask", OFFLOAD_FUNC },
@ -5389,7 +5395,7 @@ static struct ggml_cgraph * llama_build_graph(
alloc_inp_KQ_mask = true; alloc_inp_KQ_mask = true;
} }
if (!alloc_inp_K_shift && strcmp(name, "K_shift") == 0) { if (!alloc_inp_K_shift && strcmp(name, "K_shift_host") == 0) {
ggml_allocr_alloc(lctx.alloc, cur); ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) { if (!ggml_allocr_is_measure(lctx.alloc)) {