mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-27 12:33:06 +01:00
ggml, llama : avoid heavy V transpose + improvements (#775)
ggml : - added ggml_view_3d() - ggml_view_tensor() now inherits the stride too - reimplement ggml_cpy() to account for dst stride - no longer require tensor->data to be memory aligned llama : - compute RoPE on 32-bit tensors (should be more accurate) - store RoPE-ed K in the KV cache - store transposed V in the KV cache (significant speed-up) - avoid unnecessary Q copy
This commit is contained in:
parent
3416298929
commit
986b6ce9f9
281
ggml.c
281
ggml.c
@ -3219,7 +3219,8 @@ struct ggml_tensor * ggml_new_tensor_impl(
|
|||||||
/*.pad =*/ { 0 },
|
/*.pad =*/ { 0 },
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_assert_aligned(result->data);
|
// TODO: this should not be needed as long as we don't rely on aligned SIMD loads
|
||||||
|
//ggml_assert_aligned(result->data);
|
||||||
|
|
||||||
for (int i = 0; i < n_dims; i++) {
|
for (int i = 0; i < n_dims; i++) {
|
||||||
result->ne[i] = ne[i];
|
result->ne[i] = ne[i];
|
||||||
@ -3620,7 +3621,14 @@ float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
|
|||||||
struct ggml_tensor * ggml_view_tensor(
|
struct ggml_tensor * ggml_view_tensor(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
const struct ggml_tensor * src) {
|
const struct ggml_tensor * src) {
|
||||||
return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
|
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
|
||||||
|
|
||||||
|
result->nb[0] = src->nb[0];
|
||||||
|
result->nb[1] = src->nb[1];
|
||||||
|
result->nb[2] = src->nb[2];
|
||||||
|
result->nb[3] = src->nb[3];
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -4510,6 +4518,37 @@ struct ggml_tensor * ggml_view_2d(
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_view_3d
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_view_3d(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int64_t ne0,
|
||||||
|
int64_t ne1,
|
||||||
|
int64_t ne2,
|
||||||
|
size_t nb1,
|
||||||
|
size_t nb2,
|
||||||
|
size_t offset) {
|
||||||
|
if (a->grad) {
|
||||||
|
GGML_ASSERT(false); // gradient propagation is not supported
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, 1 };
|
||||||
|
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset);
|
||||||
|
|
||||||
|
result->nb[1] = nb1;
|
||||||
|
result->nb[2] = nb2;
|
||||||
|
result->nb[3] = result->nb[2]*ne2;
|
||||||
|
|
||||||
|
result->op = GGML_OP_VIEW;
|
||||||
|
result->grad = NULL;
|
||||||
|
result->src0 = a;
|
||||||
|
result->src1 = NULL; // TODO: maybe store the offset here?
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_permute
|
// ggml_permute
|
||||||
|
|
||||||
struct ggml_tensor * ggml_permute(
|
struct ggml_tensor * ggml_permute(
|
||||||
@ -4845,7 +4884,6 @@ static void ggml_compute_forward_dup_f16(
|
|||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
GGML_ASSERT(params->ith == 0);
|
GGML_ASSERT(params->ith == 0);
|
||||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
||||||
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
||||||
|
|
||||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||||
@ -4862,40 +4900,90 @@ static void ggml_compute_forward_dup_f16(
|
|||||||
const size_t nb02 = src0->nb[2];
|
const size_t nb02 = src0->nb[2];
|
||||||
const size_t nb03 = src0->nb[3];
|
const size_t nb03 = src0->nb[3];
|
||||||
|
|
||||||
if (ggml_is_contiguous(src0) && src0->type == dst->type) {
|
const size_t nb0 = dst->nb[0];
|
||||||
|
const size_t nb1 = dst->nb[1];
|
||||||
|
const size_t nb2 = dst->nb[2];
|
||||||
|
const size_t nb3 = dst->nb[3];
|
||||||
|
|
||||||
|
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
||||||
memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
|
memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (src0->nb[0] == sizeof(ggml_fp16_t)) {
|
if (src0->type == dst->type &&
|
||||||
if (dst->type == GGML_TYPE_F16) {
|
src0->ne[0] == dst->ne[0] &&
|
||||||
size_t id = 0;
|
src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
|
||||||
|
// copy by rows
|
||||||
const size_t rs = ne00*nb00;
|
const size_t rs = ne00*nb00;
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||||
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
memcpy(
|
||||||
char * dst_ptr = (char *) dst->data + id*rs;
|
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
||||||
|
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
||||||
|
rs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
memcpy(dst_ptr, src0_ptr, rs);
|
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
|
||||||
|
|
||||||
id++;
|
// dst counters
|
||||||
|
int64_t i10 = 0;
|
||||||
|
int64_t i11 = 0;
|
||||||
|
int64_t i12 = 0;
|
||||||
|
int64_t i13 = 0;
|
||||||
|
|
||||||
|
if (dst->type == GGML_TYPE_F16) {
|
||||||
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
|
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||||
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||||
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||||
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||||
|
|
||||||
|
memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
|
||||||
|
|
||||||
|
if (++i10 == ne00) {
|
||||||
|
i10 = 0;
|
||||||
|
if (++i11 == ne01) {
|
||||||
|
i11 = 0;
|
||||||
|
if (++i12 == ne02) {
|
||||||
|
i12 = 0;
|
||||||
|
if (++i13 == ne03) {
|
||||||
|
i13 = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (dst->type == GGML_TYPE_F32) {
|
} else if (dst->type == GGML_TYPE_F32) {
|
||||||
size_t id = 0;
|
|
||||||
float * dst_ptr = (float *) dst->data;
|
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||||
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||||
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||||
|
|
||||||
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
|
*(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
|
||||||
id++;
|
|
||||||
|
if (++i10 == ne00) {
|
||||||
|
i10 = 0;
|
||||||
|
if (++i11 == ne01) {
|
||||||
|
i11 = 0;
|
||||||
|
if (++i12 == ne02) {
|
||||||
|
i12 = 0;
|
||||||
|
if (++i13 == ne03) {
|
||||||
|
i13 = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -4903,45 +4991,6 @@ static void ggml_compute_forward_dup_f16(
|
|||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false); // TODO: implement
|
GGML_ASSERT(false); // TODO: implement
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
//printf("%s: this is not optimal - fix me\n", __func__);
|
|
||||||
|
|
||||||
if (dst->type == GGML_TYPE_F32) {
|
|
||||||
size_t id = 0;
|
|
||||||
float * dst_ptr = (float *) dst->data;
|
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
||||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
|
||||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
||||||
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
||||||
|
|
||||||
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
|
|
||||||
id++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (dst->type == GGML_TYPE_F16) {
|
|
||||||
size_t id = 0;
|
|
||||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
||||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
|
||||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
||||||
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
||||||
|
|
||||||
dst_ptr[id] = *src0_ptr;
|
|
||||||
id++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
GGML_ASSERT(false); // TODO: implement
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_compute_forward_dup_f32(
|
static void ggml_compute_forward_dup_f32(
|
||||||
@ -4949,7 +4998,6 @@ static void ggml_compute_forward_dup_f32(
|
|||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
GGML_ASSERT(params->ith == 0);
|
GGML_ASSERT(params->ith == 0);
|
||||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
||||||
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
||||||
|
|
||||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||||
@ -4966,40 +5014,70 @@ static void ggml_compute_forward_dup_f32(
|
|||||||
const size_t nb02 = src0->nb[2];
|
const size_t nb02 = src0->nb[2];
|
||||||
const size_t nb03 = src0->nb[3];
|
const size_t nb03 = src0->nb[3];
|
||||||
|
|
||||||
if (ggml_is_contiguous(src0) && src0->type == dst->type) {
|
const size_t nb0 = dst->nb[0];
|
||||||
|
const size_t nb1 = dst->nb[1];
|
||||||
|
const size_t nb2 = dst->nb[2];
|
||||||
|
const size_t nb3 = dst->nb[3];
|
||||||
|
|
||||||
|
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
||||||
memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
|
memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (src0->nb[0] == sizeof(float)) {
|
// dst counters
|
||||||
|
int64_t i10 = 0;
|
||||||
|
int64_t i11 = 0;
|
||||||
|
int64_t i12 = 0;
|
||||||
|
int64_t i13 = 0;
|
||||||
|
|
||||||
if (dst->type == GGML_TYPE_F32) {
|
if (dst->type == GGML_TYPE_F32) {
|
||||||
size_t id = 0;
|
|
||||||
const size_t rs = ne00*nb00;
|
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
||||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
|
||||||
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
|
||||||
char * dst_ptr = (char *) dst->data + id*rs;
|
|
||||||
|
|
||||||
memcpy(dst_ptr, src0_ptr, rs);
|
|
||||||
|
|
||||||
id++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (dst->type == GGML_TYPE_F16) {
|
|
||||||
size_t id = 0;
|
|
||||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||||
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||||
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||||
|
|
||||||
dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
|
memcpy(dst_ptr, src0_ptr, sizeof(float));
|
||||||
id++;
|
|
||||||
|
if (++i10 == dst->ne[0]) {
|
||||||
|
i10 = 0;
|
||||||
|
if (++i11 == dst->ne[1]) {
|
||||||
|
i11 = 0;
|
||||||
|
if (++i12 == dst->ne[2]) {
|
||||||
|
i12 = 0;
|
||||||
|
if (++i13 == dst->ne[3]) {
|
||||||
|
i13 = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (dst->type == GGML_TYPE_F16) {
|
||||||
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
|
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||||
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||||
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||||
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||||
|
|
||||||
|
*(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
|
||||||
|
|
||||||
|
if (++i10 == dst->ne[0]) {
|
||||||
|
i10 = 0;
|
||||||
|
if (++i11 == dst->ne[1]) {
|
||||||
|
i11 = 0;
|
||||||
|
if (++i12 == dst->ne[2]) {
|
||||||
|
i12 = 0;
|
||||||
|
if (++i13 == dst->ne[3]) {
|
||||||
|
i13 = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -5007,45 +5085,6 @@ static void ggml_compute_forward_dup_f32(
|
|||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false); // TODO: implement
|
GGML_ASSERT(false); // TODO: implement
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
//printf("%s: this is not optimal - fix me\n", __func__);
|
|
||||||
|
|
||||||
if (dst->type == GGML_TYPE_F32) {
|
|
||||||
size_t id = 0;
|
|
||||||
float * dst_ptr = (float *) dst->data;
|
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
||||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
|
||||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
||||||
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
||||||
|
|
||||||
dst_ptr[id] = *src0_ptr;
|
|
||||||
id++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (dst->type == GGML_TYPE_F16) {
|
|
||||||
size_t id = 0;
|
|
||||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
||||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
|
||||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
||||||
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
||||||
|
|
||||||
dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
|
|
||||||
id++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
GGML_ASSERT(false); // TODO: implement
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_compute_forward_dup(
|
static void ggml_compute_forward_dup(
|
||||||
|
10
ggml.h
10
ggml.h
@ -558,6 +558,16 @@ struct ggml_tensor * ggml_view_2d(
|
|||||||
size_t nb1, // row stride in bytes
|
size_t nb1, // row stride in bytes
|
||||||
size_t offset);
|
size_t offset);
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_view_3d(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int64_t ne0,
|
||||||
|
int64_t ne1,
|
||||||
|
int64_t ne2,
|
||||||
|
size_t nb1, // row stride in bytes
|
||||||
|
size_t nb2, // slice stride in bytes
|
||||||
|
size_t offset);
|
||||||
|
|
||||||
struct ggml_tensor * ggml_permute(
|
struct ggml_tensor * ggml_permute(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
61
llama.cpp
61
llama.cpp
@ -810,37 +810,35 @@ static bool llama_eval_internal(
|
|||||||
|
|
||||||
// self-attention
|
// self-attention
|
||||||
{
|
{
|
||||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
// compute Q and K and RoPE them
|
||||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||||
|
|
||||||
// store key and value to memory
|
// store key and value to memory
|
||||||
if (N >= 1) {
|
{
|
||||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
|
// compute the transposed [N, n_embd] V matrix
|
||||||
struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_embd, (ggml_element_size(kv_self.v)*n_embd)*(il*n_ctx + n_past));
|
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N));
|
||||||
|
|
||||||
|
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
|
||||||
|
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
|
||||||
|
( n_ctx)*ggml_element_size(kv_self.v),
|
||||||
|
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
|
||||||
|
|
||||||
|
// important: storing RoPE-ed version of K in the KV cache!
|
||||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
||||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
|
|
||||||
struct ggml_tensor * Q =
|
struct ggml_tensor * Q =
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_rope(ctx0,
|
|
||||||
ggml_cpy(ctx0,
|
|
||||||
Qcur,
|
Qcur,
|
||||||
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
|
|
||||||
n_past, n_rot, 0),
|
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
|
||||||
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
|
|
||||||
struct ggml_tensor * K =
|
struct ggml_tensor * K =
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_rope(ctx0,
|
|
||||||
ggml_reshape_3d(ctx0,
|
ggml_reshape_3d(ctx0,
|
||||||
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
|
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
|
||||||
n_embd/n_head, n_head, n_past + N),
|
n_embd/n_head, n_head, n_past + N),
|
||||||
n_past, n_rot, 1),
|
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
|
||||||
// K * Q
|
// K * Q
|
||||||
@ -858,18 +856,23 @@ static bool llama_eval_internal(
|
|||||||
// KQ = soft_max(KQ_masked)
|
// KQ = soft_max(KQ_masked)
|
||||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
||||||
|
|
||||||
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
|
// split cached V into n_head heads
|
||||||
struct ggml_tensor * V_trans =
|
struct ggml_tensor * V =
|
||||||
ggml_cpy(ctx0,
|
ggml_view_3d(ctx0, kv_self.v,
|
||||||
ggml_permute(ctx0,
|
n_past + N, n_embd/n_head, n_head,
|
||||||
ggml_reshape_3d(ctx0,
|
n_ctx*ggml_element_size(kv_self.v),
|
||||||
ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.v)*n_embd),
|
n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head,
|
||||||
n_embd/n_head, n_head, n_past + N),
|
il*n_ctx*ggml_element_size(kv_self.v)*n_embd);
|
||||||
1, 2, 0, 3),
|
|
||||||
ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
|
|
||||||
|
|
||||||
// KQV = transpose(V) * KQ_soft_max
|
#if 1
|
||||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||||
|
#else
|
||||||
|
// make V contiguous in memory to speed up the matmul, however we waste time on the copy
|
||||||
|
// on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
|
||||||
|
// is there a better way?
|
||||||
|
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
|
||||||
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
|
||||||
|
#endif
|
||||||
|
|
||||||
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
||||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||||
@ -955,9 +958,13 @@ static bool llama_eval_internal(
|
|||||||
ggml_build_forward_expand(&gf, inpL);
|
ggml_build_forward_expand(&gf, inpL);
|
||||||
ggml_graph_compute (ctx0, &gf);
|
ggml_graph_compute (ctx0, &gf);
|
||||||
|
|
||||||
//if (n_past%100 == 0) {
|
// print timing information per ggml operation (for debugging purposes)
|
||||||
|
// requires GGML_PERF to be defined
|
||||||
//ggml_graph_print(&gf);
|
//ggml_graph_print(&gf);
|
||||||
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
|
|
||||||
|
// plot the computation graph in dot format (for debugging purposes)
|
||||||
|
//if (n_past%100 == 0) {
|
||||||
|
// ggml_graph_dump_dot(&gf, NULL, "llama.dot");
|
||||||
//}
|
//}
|
||||||
|
|
||||||
//embd_w.resize(n_vocab*N);
|
//embd_w.resize(n_vocab*N);
|
||||||
|
Loading…
Reference in New Issue
Block a user