ggml : ggml_rope now takes a vector with positions instead of n_past

This commit is contained in:
Georgi Gerganov 2023-09-17 21:12:51 +03:00
parent 3b4bab6a38
commit 1fb033fd85
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
9 changed files with 270 additions and 131 deletions

View File

@ -556,6 +556,14 @@ struct ggml_tensor * forward(
struct ggml_tensor * kc = kv_self.k; struct ggml_tensor * kc = kv_self.k;
struct ggml_tensor * vc = kv_self.v; struct ggml_tensor * vc = kv_self.v;
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
{
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}
// inpL shape [n_embd,N,1,1] // inpL shape [n_embd,N,1,1]
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -583,8 +591,8 @@ struct ggml_tensor * forward(
// wk shape [n_embd, n_embd, 1, 1] // wk shape [n_embd, n_embd, 1, 1]
// Qcur shape [n_embd/n_head, n_head, N, 1] // Qcur shape [n_embd/n_head, n_head, N, 1]
// Kcur shape [n_embd/n_head, n_head, N, 1] // Kcur shape [n_embd/n_head, n_head, N, 1]
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0);
// store key and value to memory // store key and value to memory
{ {
@ -810,9 +818,18 @@ struct ggml_tensor * forward_batch(
struct ggml_tensor * kc = kv_self.k; struct ggml_tensor * kc = kv_self.k;
struct ggml_tensor * vc = kv_self.v; struct ggml_tensor * vc = kv_self.v;
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
{
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}
// inpL shape [n_embd,N*n_batch,1] // inpL shape [n_embd,N*n_batch,1]
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
assert_shape_2d(inpL, n_embd, N*n_batch); assert_shape_2d(inpL, n_embd, N*n_batch);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL; struct ggml_tensor * inpSA = inpL;
@ -840,8 +857,8 @@ struct ggml_tensor * forward_batch(
// wk shape [n_embd, n_embd, 1, 1] // wk shape [n_embd, n_embd, 1, 1]
// Qcur shape [n_embd/n_head, n_head, N, n_batch] // Qcur shape [n_embd/n_head, n_head, N, n_batch]
// Kcur shape [n_embd/n_head, n_head, N, n_batch] // Kcur shape [n_embd/n_head, n_head, N, n_batch]
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0); struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0); struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0);
assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch); assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch); assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
@ -1100,6 +1117,14 @@ struct ggml_tensor * forward_lora(
struct ggml_tensor * kc = kv_self.k; struct ggml_tensor * kc = kv_self.k;
struct ggml_tensor * vc = kv_self.v; struct ggml_tensor * vc = kv_self.v;
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
{
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}
// inpL shape [n_embd,N,1,1] // inpL shape [n_embd,N,1,1]
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -1133,7 +1158,7 @@ struct ggml_tensor * forward_lora(
model->layers[il].wqb, model->layers[il].wqb,
cur)), cur)),
n_embd/n_head, n_head, N), n_embd/n_head, n_head, N),
n_past, n_rot, 0, 0); KQ_pos, n_rot, 0, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0, struct ggml_tensor * Kcur = ggml_rope(ctx0,
ggml_reshape_3d(ctx0, ggml_reshape_3d(ctx0,
ggml_mul_mat(ctx0, ggml_mul_mat(ctx0,
@ -1142,7 +1167,7 @@ struct ggml_tensor * forward_lora(
model->layers[il].wkb, model->layers[il].wkb,
cur)), cur)),
n_embd/n_head, n_head, N), n_embd/n_head, n_head, N),
n_past, n_rot, 0, 0); KQ_pos, n_rot, 0, 0);
// store key and value to memory // store key and value to memory
{ {

View File

@ -679,15 +679,23 @@ struct ggml_tensor * llama_build_train_graphs(
} }
}; };
// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
{
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}
// rope has so much parameters that we make a custom function for it // rope has so much parameters that we make a custom function for it
auto rope = [ctx, n_rot, n_ctx, rope_freq_base, rope_freq_scale] auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
(struct ggml_tensor * t) -> struct ggml_tensor * { (struct ggml_tensor * t) -> struct ggml_tensor * {
// not capturing these, to silcence warnings // not capturing these, to silcence warnings
const int n_past = 0;
const int rope_mode = 0; const int rope_mode = 0;
return ggml_rope_custom(ctx, return ggml_rope_custom(ctx,
t, n_past, n_rot, rope_mode, n_ctx, t, KQ_pos, n_rot, rope_mode, n_ctx,
rope_freq_base, rope_freq_scale); rope_freq_base, rope_freq_scale);
}; };

View File

@ -1210,7 +1210,9 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_OP_ROPE: case GGML_OP_ROPE:
{ {
const int n_past = ((int32_t *) dst->op_params)[0]; GGML_ASSERT(ne10 == ne02);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
@ -1221,28 +1223,29 @@ void ggml_metal_graph_compute(
[encoder setComputePipelineState:ctx->pipeline_rope]; [encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&n_past length:sizeof( int) atIndex:18]; [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19]; //[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
[encoder setBytes:&mode length:sizeof( int) atIndex:20]; [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21]; [encoder setBytes:&mode length:sizeof( int) atIndex:21];
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22]; [encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break; } break;

View File

@ -854,29 +854,30 @@ kernel void kernel_alibi_f32(
} }
kernel void kernel_rope( kernel void kernel_rope(
device const void * src0, device const void * src0,
device float * dst, device const int32_t * src1,
constant int64_t & ne00, device float * dst,
constant int64_t & ne01, constant int64_t & ne00,
constant int64_t & ne02, constant int64_t & ne01,
constant int64_t & ne03, constant int64_t & ne02,
constant uint64_t & nb00, constant int64_t & ne03,
constant uint64_t & nb01, constant uint64_t & nb00,
constant uint64_t & nb02, constant uint64_t & nb01,
constant uint64_t & nb03, constant uint64_t & nb02,
constant int64_t & ne0, constant uint64_t & nb03,
constant int64_t & ne1, constant int64_t & ne0,
constant int64_t & ne2, constant int64_t & ne1,
constant int64_t & ne3, constant int64_t & ne2,
constant uint64_t & nb0, constant int64_t & ne3,
constant uint64_t & nb1, constant uint64_t & nb0,
constant uint64_t & nb2, constant uint64_t & nb1,
constant uint64_t & nb3, constant uint64_t & nb2,
constant int & n_past, constant uint64_t & nb3,
constant int & n_dims, constant int & n_past,
constant int & mode, constant int & n_dims,
constant float & freq_base, constant int & mode,
constant float & freq_scale, constant float & freq_base,
constant float & freq_scale,
uint tiitg[[thread_index_in_threadgroup]], uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]], uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) { uint3 tgpig[[threadgroup_position_in_grid]]) {
@ -886,7 +887,9 @@ kernel void kernel_rope(
const bool is_neox = mode & 2; const bool is_neox = mode & 2;
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); device const int32_t * pos = src1;
const int64_t p = pos[i2];
const float theta_0 = freq_scale * (float)p; const float theta_0 = freq_scale * (float)p;
const float inv_ndims = -1.f/n_dims; const float inv_ndims = -1.f/n_dims;
@ -1320,8 +1323,8 @@ kernel void kernel_mul_mat_q3_K_f32(
float yl[32]; float yl[32];
const uint16_t kmask1 = 0x3030; //const uint16_t kmask1 = 0x3030;
const uint16_t kmask2 = 0x0f0f; //const uint16_t kmask2 = 0x0f0f;
const int tid = tiisg/4; const int tid = tiisg/4;
const int ix = tiisg%4; const int ix = tiisg%4;

113
ggml.c
View File

@ -6968,7 +6968,7 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
static struct ggml_tensor * ggml_rope_impl( static struct ggml_tensor * ggml_rope_impl(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int n_past, struct ggml_tensor * b,
int n_dims, int n_dims,
int mode, int mode,
int n_ctx, int n_ctx,
@ -6977,6 +6977,10 @@ static struct ggml_tensor * ggml_rope_impl(
float xpos_base, float xpos_base,
bool xpos_down, bool xpos_down,
bool inplace) { bool inplace) {
GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]);
bool is_node = false; bool is_node = false;
if (a->grad) { if (a->grad) {
@ -6985,7 +6989,7 @@ static struct ggml_tensor * ggml_rope_impl(
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
int32_t params[8] = { n_past, n_dims, mode, n_ctx }; int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
memcpy(params + 4, &freq_base, sizeof(float)); memcpy(params + 4, &freq_base, sizeof(float));
memcpy(params + 5, &freq_scale, sizeof(float)); memcpy(params + 5, &freq_scale, sizeof(float));
memcpy(params + 6, &xpos_base, sizeof(float)); memcpy(params + 6, &xpos_base, sizeof(float));
@ -6995,6 +6999,7 @@ static struct ggml_tensor * ggml_rope_impl(
result->op = GGML_OP_ROPE; result->op = GGML_OP_ROPE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a; result->src[0] = a;
result->src[1] = b;
return result; return result;
} }
@ -7002,55 +7007,55 @@ static struct ggml_tensor * ggml_rope_impl(
struct ggml_tensor * ggml_rope( struct ggml_tensor * ggml_rope(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int n_past, struct ggml_tensor * b,
int n_dims, int n_dims,
int mode, int mode,
int n_ctx) { int n_ctx) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false); return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
} }
struct ggml_tensor * ggml_rope_inplace( struct ggml_tensor * ggml_rope_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int n_past, struct ggml_tensor * b,
int n_dims, int n_dims,
int mode, int mode,
int n_ctx) { int n_ctx) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true); return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
} }
struct ggml_tensor * ggml_rope_custom( struct ggml_tensor * ggml_rope_custom(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int n_past, struct ggml_tensor * b,
int n_dims, int n_dims,
int mode, int mode,
int n_ctx, int n_ctx,
float freq_base, float freq_base,
float freq_scale) { float freq_scale) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false); return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
} }
struct ggml_tensor * ggml_rope_custom_inplace( struct ggml_tensor * ggml_rope_custom_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int n_past, struct ggml_tensor * b,
int n_dims, int n_dims,
int mode, int mode,
int n_ctx, int n_ctx,
float freq_base, float freq_base,
float freq_scale) { float freq_scale) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true); return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
} }
struct ggml_tensor * ggml_rope_xpos_inplace( struct ggml_tensor * ggml_rope_xpos_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int n_past, struct ggml_tensor * b,
int n_dims, int n_dims,
float base, float base,
bool down) { bool down) {
return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true); return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
} }
// ggml_rope_back // ggml_rope_back
@ -7058,7 +7063,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
struct ggml_tensor * ggml_rope_back( struct ggml_tensor * ggml_rope_back(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int n_past, struct ggml_tensor * b,
int n_dims, int n_dims,
int mode, int mode,
int n_ctx, int n_ctx,
@ -7066,7 +7071,10 @@ struct ggml_tensor * ggml_rope_back(
float freq_scale, float freq_scale,
float xpos_base, float xpos_base,
bool xpos_down) { bool xpos_down) {
GGML_ASSERT(n_past >= 0); GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]);
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
bool is_node = false; bool is_node = false;
@ -7077,7 +7085,7 @@ struct ggml_tensor * ggml_rope_back(
struct ggml_tensor * result = ggml_dup_tensor(ctx, a); struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
int32_t params[8] = { n_past, n_dims, mode, n_ctx }; int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
memcpy(params + 4, &freq_base, sizeof(float)); memcpy(params + 4, &freq_base, sizeof(float));
memcpy(params + 5, &freq_scale, sizeof(float)); memcpy(params + 5, &freq_scale, sizeof(float));
memcpy(params + 6, &xpos_base, sizeof(float)); memcpy(params + 6, &xpos_base, sizeof(float));
@ -7087,6 +7095,7 @@ struct ggml_tensor * ggml_rope_back(
result->op = GGML_OP_ROPE_BACK; result->op = GGML_OP_ROPE_BACK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a; result->src[0] = a;
result->src[1] = b;
return result; return result;
} }
@ -12620,8 +12629,8 @@ static void ggml_compute_forward_clamp(
static void ggml_compute_forward_rope_f32( static void ggml_compute_forward_rope_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return; return;
} }
@ -12631,9 +12640,9 @@ static void ggml_compute_forward_rope_f32(
// these two only relevant for xPos RoPE: // these two only relevant for xPos RoPE:
float xpos_base; float xpos_base;
bool xpos_down; bool xpos_down;
const int n_past = ((int32_t *) dst->op_params)[0]; //const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3]; const int n_ctx = ((int32_t *) dst->op_params)[3];
@ -12669,14 +12678,14 @@ static void ggml_compute_forward_rope_f32(
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
const bool is_skip = mode & 1;
const bool is_neox = mode & 2; const bool is_neox = mode & 2;
const bool is_glm = mode & 4; const bool is_glm = mode & 4;
const bool is_diff = mode & 8; // TODO: temporary
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = (is_skip ? n_past : 0); i2 < ne2; i2++) { for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = is_diff ? n_past : is_skip ? i2 : n_past + i2; const int64_t p = pos[i2];
for (int64_t i1 = 0; i1 < ne1; i1++) { for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue; if (ir++ < ir0) continue;
if (ir > ir1) break; if (ir > ir1) break;
@ -12713,7 +12722,7 @@ static void ggml_compute_forward_rope_f32(
const float cos_theta = cosf(theta); const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta); const float sin_theta = sinf(theta);
// zeta scaling for xPos only: // zeta scaling for xPos only:
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f; float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
if (xpos_down) zeta = 1.0f / zeta; if (xpos_down) zeta = 1.0f / zeta;
theta *= theta_scale; theta *= theta_scale;
@ -12758,8 +12767,8 @@ static void ggml_compute_forward_rope_f32(
static void ggml_compute_forward_rope_f16( static void ggml_compute_forward_rope_f16(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return; return;
} }
@ -12767,15 +12776,13 @@ static void ggml_compute_forward_rope_f16(
float freq_base; float freq_base;
float freq_scale; float freq_scale;
const int n_past = ((int32_t *) dst->op_params)[0]; //const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3]; const int n_ctx = ((int32_t *) dst->op_params)[3];
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
assert(n_past >= 0);
GGML_TENSOR_UNARY_OP_LOCALS; GGML_TENSOR_UNARY_OP_LOCALS;
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
@ -12806,9 +12813,11 @@ static void ggml_compute_forward_rope_f16(
const bool is_neox = mode & 2; const bool is_neox = mode & 2;
const bool is_glm = mode & 4; const bool is_glm = mode & 4;
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); const int64_t p = pos[i2];
for (int64_t i1 = 0; i1 < ne1; i1++) { for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue; if (ir++ < ir0) continue;
if (ir > ir1) break; if (ir > ir1) break;
@ -12887,15 +12896,16 @@ static void ggml_compute_forward_rope_f16(
static void ggml_compute_forward_rope( static void ggml_compute_forward_rope(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
ggml_compute_forward_rope_f16(params, src0, dst); ggml_compute_forward_rope_f16(params, src0, src1, dst);
} break; } break;
case GGML_TYPE_F32: case GGML_TYPE_F32:
{ {
ggml_compute_forward_rope_f32(params, src0, dst); ggml_compute_forward_rope_f32(params, src0, src1, dst);
} break; } break;
default: default:
{ {
@ -12909,6 +12919,7 @@ static void ggml_compute_forward_rope(
static void ggml_compute_forward_rope_back_f32( static void ggml_compute_forward_rope_back_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@ -12926,7 +12937,7 @@ static void ggml_compute_forward_rope_back_f32(
float xpos_base; float xpos_base;
bool xpos_down; bool xpos_down;
const int n_past = ((int32_t *) dst->op_params)[0]; //const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx); const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
@ -12935,8 +12946,6 @@ static void ggml_compute_forward_rope_back_f32(
memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool)); memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
assert(n_past >= 0);
GGML_TENSOR_UNARY_OP_LOCALS; GGML_TENSOR_UNARY_OP_LOCALS;
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
@ -12963,9 +12972,11 @@ static void ggml_compute_forward_rope_back_f32(
const bool is_neox = mode & 2; const bool is_neox = mode & 2;
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); const int64_t p = pos[i2];
for (int64_t i1 = 0; i1 < ne1; i1++) { for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue; if (ir++ < ir0) continue;
if (ir > ir1) break; if (ir > ir1) break;
@ -12977,7 +12988,7 @@ static void ggml_compute_forward_rope_back_f32(
const float cos_theta = cosf(theta); const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta); const float sin_theta = sinf(theta);
// zeta scaling for xPos only: // zeta scaling for xPos only:
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f; float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
if (xpos_down) zeta = 1.0f / zeta; if (xpos_down) zeta = 1.0f / zeta;
theta *= theta_scale; theta *= theta_scale;
@ -13020,6 +13031,7 @@ static void ggml_compute_forward_rope_back_f32(
static void ggml_compute_forward_rope_back_f16( static void ggml_compute_forward_rope_back_f16(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@ -13030,12 +13042,10 @@ static void ggml_compute_forward_rope_back_f16(
// dx = rope_back(dy, src1) // dx = rope_back(dy, src1)
// src0 is dy, src1 contains options // src0 is dy, src1 contains options
const int n_past = ((int32_t *) dst->op_params)[0]; //const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
assert(n_past >= 0);
GGML_TENSOR_UNARY_OP_LOCALS; GGML_TENSOR_UNARY_OP_LOCALS;
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
@ -13062,9 +13072,11 @@ static void ggml_compute_forward_rope_back_f16(
const bool is_neox = mode & 2; const bool is_neox = mode & 2;
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); const int64_t p = pos[i2];
for (int64_t i1 = 0; i1 < ne1; i1++) { for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue; if (ir++ < ir0) continue;
if (ir > ir1) break; if (ir > ir1) break;
@ -13116,15 +13128,16 @@ static void ggml_compute_forward_rope_back_f16(
static void ggml_compute_forward_rope_back( static void ggml_compute_forward_rope_back(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
ggml_compute_forward_rope_back_f16(params, src0, dst); ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
} break; } break;
case GGML_TYPE_F32: case GGML_TYPE_F32:
{ {
ggml_compute_forward_rope_back_f32(params, src0, dst); ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
} break; } break;
default: default:
{ {
@ -15861,11 +15874,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break; } break;
case GGML_OP_ROPE: case GGML_OP_ROPE:
{ {
ggml_compute_forward_rope(params, tensor->src[0], tensor); ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
} break; } break;
case GGML_OP_ROPE_BACK: case GGML_OP_ROPE_BACK:
{ {
ggml_compute_forward_rope_back(params, tensor->src[0], tensor); ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
} break; } break;
case GGML_OP_ALIBI: case GGML_OP_ALIBI:
{ {
@ -16503,7 +16516,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{ {
// necessary for llama // necessary for llama
if (src0->grad) { if (src0->grad) {
const int n_past = ((int32_t *) tensor->op_params)[0]; //const int n_past = ((int32_t *) tensor->op_params)[0];
const int n_dims = ((int32_t *) tensor->op_params)[1]; const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2]; const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ctx = ((int32_t *) tensor->op_params)[3]; const int n_ctx = ((int32_t *) tensor->op_params)[3];
@ -16520,7 +16533,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad, src0->grad,
ggml_rope_back(ctx, ggml_rope_back(ctx,
tensor->grad, tensor->grad,
n_past, src1,
n_dims, n_dims,
mode, mode,
n_ctx, n_ctx,
@ -16534,7 +16547,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_OP_ROPE_BACK: case GGML_OP_ROPE_BACK:
{ {
if (src0->grad) { if (src0->grad) {
const int n_past = ((int32_t *) tensor->op_params)[0]; //const int n_past = ((int32_t *) tensor->op_params)[0];
const int n_dims = ((int32_t *) tensor->op_params)[1]; const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2]; const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ctx = ((int32_t *) tensor->op_params)[3]; const int n_ctx = ((int32_t *) tensor->op_params)[3];
@ -16551,7 +16564,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad, src0->grad,
ggml_rope_impl(ctx, ggml_rope_impl(ctx,
tensor->grad, tensor->grad,
n_past, src1,
n_dims, n_dims,
mode, mode,
n_ctx, n_ctx,

17
ggml.h
View File

@ -1219,14 +1219,15 @@ extern "C" {
struct ggml_tensor * b); struct ggml_tensor * b);
// rotary position embedding // rotary position embedding
// if mode & 1 == 1, skip n_past elements // if mode & 1 == 1, skip n_past elements (DEPRECATED)
// if mode & 2 == 1, GPT-NeoX style // if mode & 2 == 1, GPT-NeoX style
// if mode & 4 == 1, ChatGLM style // if mode & 4 == 1, ChatGLM style
// TODO: avoid creating a new tensor every time //
// b is an int32 vector with size a->ne[2], it contains the positions
GGML_API struct ggml_tensor * ggml_rope( GGML_API struct ggml_tensor * ggml_rope(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int n_past, struct ggml_tensor * b,
int n_dims, int n_dims,
int mode, int mode,
int n_ctx); int n_ctx);
@ -1235,7 +1236,7 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_rope_inplace( GGML_API struct ggml_tensor * ggml_rope_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int n_past, struct ggml_tensor * b,
int n_dims, int n_dims,
int mode, int mode,
int n_ctx); int n_ctx);
@ -1244,7 +1245,7 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_rope_custom( GGML_API struct ggml_tensor * ggml_rope_custom(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int n_past, struct ggml_tensor * b,
int n_dims, int n_dims,
int mode, int mode,
int n_ctx, int n_ctx,
@ -1255,7 +1256,7 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_rope_custom_inplace( GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int n_past, struct ggml_tensor * b,
int n_dims, int n_dims,
int mode, int mode,
int n_ctx, int n_ctx,
@ -1266,7 +1267,7 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_rope_xpos_inplace( GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int n_past, struct ggml_tensor * b,
int n_dims, int n_dims,
float base, float base,
bool down); bool down);
@ -1276,7 +1277,7 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_rope_back( GGML_API struct ggml_tensor * ggml_rope_back(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int n_past, struct ggml_tensor * b,
int n_dims, int n_dims,
int mode, int mode,
int n_ctx, int n_ctx,

View File

@ -2428,6 +2428,16 @@ static struct ggml_cgraph * llm_build_llama(
} }
} }
// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
ggml_allocr_alloc(lctx.alloc, KQ_pos);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il); ggml_format_name(inpL, "layer_inp_%d", il);
@ -2464,11 +2474,11 @@ static struct ggml_cgraph * llm_build_llama(
offload_func_kq(tmpq); offload_func_kq(tmpq);
ggml_set_name(tmpq, "tmpq"); ggml_set_name(tmpq, "tmpq");
struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Kcur); offload_func_kq(Kcur);
ggml_set_name(Kcur, "Kcur"); ggml_set_name(Kcur, "Kcur");
struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Qcur); offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur"); ggml_set_name(Qcur, "Qcur");
@ -2754,6 +2764,7 @@ static struct ggml_cgraph * llm_build_baichaun(
} }
#endif // GGML_USE_CUBLAS #endif // GGML_USE_CUBLAS
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(lctx.alloc, KQ_scale); ggml_allocr_alloc(lctx.alloc, KQ_scale);
if (!ggml_allocr_is_measure(lctx.alloc)) { if (!ggml_allocr_is_measure(lctx.alloc)) {
@ -2761,6 +2772,32 @@ static struct ggml_cgraph * llm_build_baichaun(
} }
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
// KQ_mask
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, N, 1);
ggml_allocr_alloc(lctx.alloc, KQ_mask);
if (!ggml_allocr_is_measure(lctx.alloc)) {
float * data = (float *) KQ_mask->data;
memset(data, 0, ggml_nbytes(KQ_mask));
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < N; ++j) {
for (int i = n_past + j + 1; i < n_past + N; ++i) {
data[h*(n_past + N)*N + j*(n_past + N) + i] = -INFINITY;
}
}
}
}
// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
ggml_allocr_alloc(lctx.alloc, KQ_pos);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il); ggml_format_name(inpL, "layer_inp_%d", il);
@ -2801,11 +2838,11 @@ static struct ggml_cgraph * llm_build_baichaun(
struct ggml_tensor * Qcur; struct ggml_tensor * Qcur;
switch (model.type) { switch (model.type) {
case MODEL_7B: case MODEL_7B:
Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale);
Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale);
break; break;
case MODEL_13B: case MODEL_13B:
Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N); Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N);
Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N); Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N);
break; break;
default: default:
@ -2874,12 +2911,14 @@ static struct ggml_cgraph * llm_build_baichaun(
switch (model.type) { switch (model.type) {
case MODEL_7B: case MODEL_7B:
KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask);
//KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
break; break;
case MODEL_13B: case MODEL_13B:
KQ_scaled_alibi =ggml_alibi(ctx0, KQ_scaled, n_past, n_head, 8); KQ_scaled_alibi =ggml_alibi(ctx0, KQ_scaled, n_past, n_head, 8);
ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi");
KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past); KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask);
//KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past);
break; break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
@ -3114,6 +3153,7 @@ static struct ggml_cgraph * llm_build_falcon(
} }
#endif // GGML_USE_CUBLAS #endif // GGML_USE_CUBLAS
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(lctx.alloc, KQ_scale); ggml_allocr_alloc(lctx.alloc, KQ_scale);
if (!ggml_allocr_is_measure(lctx.alloc)) { if (!ggml_allocr_is_measure(lctx.alloc)) {
@ -3121,6 +3161,32 @@ static struct ggml_cgraph * llm_build_falcon(
} }
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
// KQ_mask
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, N, 1);
ggml_allocr_alloc(lctx.alloc, KQ_mask);
if (!ggml_allocr_is_measure(lctx.alloc)) {
float * data = (float *) KQ_mask->data;
memset(data, 0, ggml_nbytes(KQ_mask));
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < N; ++j) {
for (int i = n_past + j + 1; i < n_past + N; ++i) {
data[h*(n_past + N)*N + j*(n_past + N) + i] = -INFINITY;
}
}
}
}
// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
ggml_allocr_alloc(lctx.alloc, KQ_pos);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * attn_norm; struct ggml_tensor * attn_norm;
@ -3197,9 +3263,9 @@ static struct ggml_cgraph * llm_build_falcon(
offload_func_v(tmpv); offload_func_v(tmpv);
// using mode = 2 for neox mode // using mode = 2 for neox mode
struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale); struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, tmpq, KQ_pos, n_embd_head, 2, 0, freq_base, freq_scale);
offload_func_kq(Qcur); offload_func_kq(Qcur);
struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale); struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, tmpk, KQ_pos, n_embd_head, 2, 0, freq_base, freq_scale);
offload_func_kq(Kcur); offload_func_kq(Kcur);
{ {

View File

@ -1404,6 +1404,11 @@ int main(int argc, const char ** argv) {
for (int n_past = 1; n_past < ne2[2]; ++n_past) { for (int n_past = 1; n_past < ne2[2]; ++n_past) {
x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]);
for (int i = 0; i < ne2[2]; ++i) {
((int32_t *) p->data)[i] = n_past + i;
}
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
const bool skip_past = (mode & 1); const bool skip_past = (mode & 1);
@ -1415,7 +1420,7 @@ int main(int argc, const char ** argv) {
continue; continue;
} }
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0)); struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0));
GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY); check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
@ -1438,6 +1443,11 @@ int main(int argc, const char ** argv) {
for (int n_past = 1; n_past < ne2[2]; ++n_past) { for (int n_past = 1; n_past < ne2[2]; ++n_past) {
x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f); x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f);
struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]);
for (int i = 0; i < ne2[2]; ++i) {
((int32_t *) p->data)[i] = n_past + i;
}
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
const bool skip_past = (mode & 1); const bool skip_past = (mode & 1);
@ -1449,7 +1459,7 @@ int main(int argc, const char ** argv) {
continue; continue;
} }
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0)); struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0));
GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY); check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);

View File

@ -144,7 +144,17 @@ int main(int /*argc*/, const char ** /*argv*/) {
const int64_t ne[4] = { 2*n_rot, 32, 73, 1 }; const int64_t ne[4] = { 2*n_rot, 32, 73, 1 };
const int n_past_0 = 100; const int n_past_0 = 100;
const int n_past_1 = 33; const int n_past_2 = 33;
struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
for (int i = 0; i < ne[2]; ++i) {
((int32_t *) p0->data)[i] = n_past_0 + i;
((int32_t *) p1->data)[i] = n_past_2 - n_past_0;
((int32_t *) p2->data)[i] = n_past_2 + i;
}
// test mode 0, 2, 4 (standard, GPT-NeoX, GLM) // test mode 0, 2, 4 (standard, GPT-NeoX, GLM)
const int mode = m == 0 ? 0 : m == 1 ? 2 : 4; const int mode = m == 0 ? 0 : m == 1 ? 2 : 4;
@ -152,12 +162,12 @@ int main(int /*argc*/, const char ** /*argv*/) {
x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
// 100, 101, 102, ..., 172 // 100, 101, 102, ..., 172
struct ggml_tensor * r0 = ggml_rope(ctx0, x, n_past_0, n_rot, mode, 1024); struct ggml_tensor * r0 = ggml_rope(ctx0, x, p0, n_rot, mode, 1024);
// -67, -67, -67, ..., -67 // -67, -67, -67, ..., -67
struct ggml_tensor * r1 = ggml_rope(ctx0, r0, n_past_1 - n_past_0, n_rot, mode + 8, 1024); // diff mode struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode, 1024); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
// 33, 34, 35, ..., 105 // 33, 34, 35, ..., 105
struct ggml_tensor * r2 = ggml_rope(ctx0, x, n_past_1, n_rot, mode, 1024); struct ggml_tensor * r2 = ggml_rope(ctx0, x, p2, n_rot, mode, 1024);
ggml_cgraph * gf = ggml_new_graph(ctx0); ggml_cgraph * gf = ggml_new_graph(ctx0);