llama : factor graph input into a function

This commit is contained in:
Georgi Gerganov 2023-10-29 07:52:43 +02:00
parent 4e98897ede
commit 0dc05b8433
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

260
llama.cpp
View File

@ -5267,6 +5267,130 @@ static struct ggml_cgraph * llm_build_mpt(
return gf;
}
static void llama_build_graph_input(
llama_context & lctx,
const llama_batch & batch,
struct ggml_cgraph * graph) {
struct ggml_tensor * cur = nullptr;
// inp_tokens
if (batch.token) {
cur = ggml_graph_get_tensor(graph, "inp_tokens");
GGML_ASSERT(cur != nullptr); // required
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_tokens = cur->ne[0];
memcpy(cur->data, batch.token, n_tokens*ggml_element_size(cur));
}
}
// inp_embd
if (batch.embd) {
cur = ggml_graph_get_tensor(graph, "inp_embd");
GGML_ASSERT(cur != nullptr); // required
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_embd = cur->ne[0];
const int64_t n_tokens = cur->ne[1];
memcpy(cur->data, batch.embd, n_tokens*n_embd*ggml_element_size(cur));
}
}
// TODO: make the following required based on the ARCH
// inp_pos
cur = ggml_graph_get_tensor(graph, "inp_pos");
if (cur) {
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_tokens = cur->ne[0];
int32_t * data = (int32_t *) cur->data;
for (int i = 0; i < n_tokens; ++i) {
data[i] = batch.pos[i];
}
}
}
// KQ_scale
cur = ggml_graph_get_tensor(graph, "KQ_scale");
if (cur) {
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_embd_head = lctx.model.hparams.n_embd_head();
ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head)));
}
}
// KQ_mask
cur = ggml_graph_get_tensor(graph, "KQ_mask");
if (cur) {
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_kv = cur->ne[0];
const int64_t n_tokens = cur->ne[1];
float * data = (float *) cur->data;
memset(data, 0, ggml_nbytes(cur));
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
}
}
}
}
}
}
// KQ_pos
cur = ggml_graph_get_tensor(graph, "KQ_pos");
if (cur) {
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_tokens = cur->ne[0];
int32_t * data = (int32_t *) cur->data;
for (int i = 0; i < n_tokens; ++i) {
data[i] = batch.pos[i];
}
}
}
// K_shift
cur = ggml_graph_get_tensor(graph, "K_shift");
if (cur) {
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_ctx = cur->ne[0];
int32_t * data = (int32_t *) cur->data;
for (int i = 0; i < n_ctx; ++i) {
data[i] = lctx.kv_self.cells[i].delta;
}
}
} while (0);
}
static struct ggml_cgraph * llama_build_graph(
llama_context & lctx,
const llama_batch & batch) {
@ -5315,141 +5439,7 @@ static struct ggml_cgraph * llama_build_graph(
}
// allocate memory and set the values for the input tensors of the graph
// inp_tokens
if (batch.token) {
struct ggml_tensor * cur = ggml_graph_get_tensor(result, "inp_tokens");
GGML_ASSERT(cur != nullptr);
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_tokens = cur->ne[0];
memcpy(cur->data, batch.token, n_tokens*ggml_element_size(cur));
}
}
// inp_embd
if (batch.embd) {
struct ggml_tensor * cur = ggml_graph_get_tensor(result, "inp_embd");
GGML_ASSERT(cur != nullptr);
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_embd = cur->ne[0];
const int64_t n_tokens = cur->ne[1];
memcpy(cur->data, batch.embd, n_tokens*n_embd*ggml_element_size(cur));
}
}
// inp_pos
do {
struct ggml_tensor * cur = ggml_graph_get_tensor(result, "inp_pos");
if (cur == nullptr) {
break;
}
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_tokens = cur->ne[0];
int32_t * data = (int32_t *) cur->data;
for (int i = 0; i < n_tokens; ++i) {
data[i] = batch.pos[i];
}
}
} while (0);
// KQ_scale
do {
struct ggml_tensor * cur = ggml_graph_get_tensor(result, "KQ_scale");
if (cur == nullptr) {
break;
}
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_embd_head = lctx.model.hparams.n_embd_head();
ggml_set_f32(cur, 1.0f/sqrtf(float(n_embd_head)));
}
} while (0);
// KQ_mask
do {
struct ggml_tensor * cur = ggml_graph_get_tensor(result, "KQ_mask");
if (cur == nullptr) {
break;
}
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_kv = cur->ne[0];
const int64_t n_tokens = cur->ne[1];
float * data = (float *) cur->data;
memset(data, 0, ggml_nbytes(cur));
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
}
}
}
}
}
} while (0);
// KQ_pos
do {
struct ggml_tensor * cur = ggml_graph_get_tensor(result, "KQ_pos");
if (cur == nullptr) {
break;
}
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_tokens = cur->ne[0];
int32_t * data = (int32_t *) cur->data;
for (int i = 0; i < n_tokens; ++i) {
data[i] = batch.pos[i];
}
}
} while (0);
// K_shift
do {
struct ggml_tensor * cur = ggml_graph_get_tensor(result, "K_shift");
if (cur == nullptr) {
break;
}
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
const int64_t n_ctx = cur->ne[0];
int32_t * data = (int32_t *) cur->data;
for (int i = 0; i < n_ctx; ++i) {
data[i] = lctx.kv_self.cells[i].delta;
}
}
} while (0);
llama_build_graph_input(lctx, batch, result);
// offload layers
// TODO: this code will be obsoleted with backend v2