add input embeddings handling

This commit is contained in:
Meng Zhang 2023-09-15 14:47:04 +08:00
parent ab13d071e1
commit 8bc76a225d

329
llama.cpp
View File

@ -3424,6 +3424,331 @@ static struct ggml_cgraph * llm_build_falcon(
return gf;
}
static struct ggml_cgraph * llm_build_starcoder(
llama_context & lctx,
const llama_token * tokens,
const float * embd,
int n_tokens,
int n_past) {
GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT
const int N = n_tokens;
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & kv_self = lctx.kv_self;
GGML_ASSERT(!!kv_self.ctx);
const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
const int64_t n_ctx = hparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head();
const int64_t n_embd_gqa = hparams.n_embd_gqa();
GGML_ASSERT(n_embd_head == hparams.n_rot);
const float freq_base = hparams.rope_freq_base;
const float freq_scale = hparams.rope_freq_scale;
const float norm_eps = hparams.f_norm_eps;
const int n_gpu_layers = model.n_gpu_layers;
auto & buf_compute = lctx.buf_compute;
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size,
/*.mem_buffer =*/ buf_compute.data,
/*.no_alloc =*/ false,
};
params.no_alloc = true;
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur;
struct ggml_tensor * token;
struct ggml_tensor * position;
struct ggml_tensor * inpL;
if (tokens) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
ggml_allocr_alloc(lctx.alloc, inp_tokens);
if (!ggml_allocr_is_measure(lctx.alloc)) {
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
}
ggml_set_name(inp_tokens, "inp_tokens");
token = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
} else {
#ifdef GGML_USE_MPI
GGML_ASSERT(false && "not implemented");
#endif
token = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
ggml_allocr_alloc(lctx.alloc, token);
if (!ggml_allocr_is_measure(lctx.alloc)) {
memcpy(token->data, embd, N * n_embd * ggml_element_size(inpL));
}
}
{
// Compute position embeddings.
struct ggml_tensor * inp_positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
ggml_allocr_alloc(lctx.alloc, inp_positions);
if (!ggml_allocr_is_measure(lctx.alloc)) {
for (int i = 0; i < N; ++i) {
((int32_t *) inp_positions->data)[i] = n_past + i;
}
}
ggml_set_name(inp_positions, "inp_positions");
position = ggml_get_rows(ctx0, model.pos_embeddings, inp_positions);
}
inpL = ggml_add(ctx0, token, position);
const int i_gpu_start = n_layer - n_gpu_layers;
(void) i_gpu_start;
// offload functions set the tensor output backend to GPU
// tensors are GPU-accelerated if any input or the output has been offloaded
//
// with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
// in that case ggml_cuda_assign_buffers has no effect
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
offload_func_t offload_func_kq = llama_nop;
offload_func_t offload_func_v = llama_nop;
#ifdef GGML_USE_CUBLAS
if (n_gpu_layers > n_layer) {
offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
}
if (n_gpu_layers > n_layer + 1) {
offload_func_v = ggml_cuda_assign_buffers_no_alloc;
}
if (n_gpu_layers > n_layer + 2) {
offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
}
#endif // GGML_USE_CUBLAS
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(lctx.alloc, KQ_scale);
if (!ggml_allocr_is_measure(lctx.alloc)) {
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
}
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * attn_norm;
offload_func_t offload_func = llama_nop;
#ifdef GGML_USE_CUBLAS
if (il >= i_gpu_start) {
offload_func = ggml_cuda_assign_buffers_no_alloc;
}
#endif // GGML_USE_CUBLAS
// self-attention
// TODO: refactor into common function (shared with LLaMA)
{
attn_norm = ggml_norm(ctx0, inpL, norm_eps);
offload_func(attn_norm);
attn_norm = ggml_add(ctx0,
ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm),
model.layers[il].attn_norm_b);
offload_func(attn_norm->src[0]);
offload_func(attn_norm);
if (model.layers[il].attn_norm_2) { // Falcon-40B
cur = ggml_norm(ctx0, inpL, norm_eps);
offload_func(cur);
cur = ggml_add(ctx0,
ggml_mul(ctx0, cur, model.layers[il].attn_norm_2),
model.layers[il].attn_norm_2_b);
offload_func(cur->src[0]);
offload_func(cur);
} else { // Falcon 7B
cur = attn_norm;
}
// compute QKV
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
offload_func_kq(cur);
// Note that the strides for Kcur, Vcur are set up so that the
// resulting views are misaligned with the tensor's storage
// (by applying the K/V offset we shift the tensor's original
// view to stick out behind the viewed QKV tensor's allocated
// memory, so to say). This is ok because no actual accesses
// happen to that out-of-range memory, but it can require some
// trickery when trying to accurately dump these views for
// debugging.
const size_t wsize = ggml_type_size(cur->type);
// TODO: these 2 ggml_conts are technically not needed, but we add them until CUDA support for
// non-contiguous views is added for the rope operator
struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_3d(
ctx0, cur, n_embd_head, n_head, N,
wsize * n_embd_head,
wsize * n_embd_head * (n_head + 2 * n_head_kv),
0));
offload_func_kq(tmpq);
struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_3d(
ctx0, cur, n_embd_head, n_head_kv, N,
wsize * n_embd_head,
wsize * n_embd_head * (n_head + 2 * n_head_kv),
wsize * n_embd_head * n_head));
offload_func_kq(tmpk);
struct ggml_tensor * tmpv = ggml_view_3d(
ctx0, cur, n_embd_head, n_head_kv, N,
wsize * n_embd_head,
wsize * n_embd_head * (n_head + 2 * n_head_kv),
wsize * n_embd_head * (n_head + n_head_kv));
offload_func_v(tmpv);
// using mode = 2 for neox mode
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
offload_func_kq(Qcur);
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
offload_func_kq(Kcur);
{
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N));
offload_func_v(Vcur);
offload_func_v(Vcur->src[0]->src[0]);
ggml_set_name(Vcur, "Vcur");
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
offload_func_kq(k);
ggml_set_name(k, "k");
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
offload_func_v(v);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
}
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
offload_func_kq(Q);
ggml_set_name(Q, "Q");
struct ggml_tensor * K =
ggml_view_3d(ctx0, kv_self.k,
n_embd_head, n_past + N, n_head_kv,
ggml_element_size(kv_self.k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
offload_func_kq(K);
ggml_set_name(K, "K");
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
offload_func_kq(KQ);
ggml_set_name(KQ, "KQ");
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
offload_func_kq(KQ_scaled);
ggml_set_name(KQ_scaled, "KQ_scaled");
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
offload_func_kq(KQ_masked);
ggml_set_name(KQ_masked, "KQ_masked");
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
offload_func_v(KQ_soft_max);
ggml_set_name(KQ_soft_max, "KQ_soft_max");
struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
n_past + N, n_embd_head, n_head_kv,
ggml_element_size(kv_self.v)*n_ctx,
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
offload_func_v(V);
ggml_set_name(V, "V");
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
offload_func_v(KQV);
ggml_set_name(KQV, "KQV");
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
offload_func_v(KQV_merged);
ggml_set_name(KQV_merged, "KQV_merged");
cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
offload_func_v(cur);
ggml_set_name(cur, "KQV_merged_contiguous");
cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
offload_func(cur);
ggml_set_name(cur, "result_wo");
}
struct ggml_tensor * attn_out = cur;
// feed forward
{
struct ggml_tensor * inpFF = attn_norm;
cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF);
offload_func(cur);
cur = ggml_gelu(ctx0, cur);
offload_func(cur);
cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
offload_func(cur);
}
cur = ggml_add(ctx0, cur, attn_out);
offload_func(cur);
cur = ggml_add(ctx0, cur, inpL);
offload_func(cur);
// input for next layer
inpL = cur;
}
cur = inpL;
// norm
{
cur = ggml_norm(ctx0, cur, norm_eps);
offload_func_nr(cur);
cur = ggml_add(ctx0,
ggml_mul(ctx0, cur, model.output_norm),
model.output_norm_b);
ggml_set_name(cur, "result_norm");
}
cur = ggml_mul_mat(ctx0, model.output, cur);
ggml_set_name(cur, "result_output");
ggml_build_forward_expand(gf, cur);
ggml_free(ctx0);
return gf;
}
static struct ggml_cgraph * llama_build_graph(
llama_context & lctx,
const llama_token * tokens,
@ -3447,6 +3772,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm_build_falcon(lctx, tokens, embd, n_tokens, n_past);
} break;
case LLM_ARCH_STARCODER:
{
result = llm_build_starcoder(lctx, tokens, embd, n_tokens, n_past);
} break;
default:
GGML_ASSERT(false);
};