llama : avoid ggml_cast, use F32 query

This commit is contained in:
Georgi Gerganov 2024-01-25 17:46:07 +02:00
parent 40ea8cd1ac
commit f9ca5dcbe8
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
6 changed files with 44 additions and 17 deletions

View File

@ -2177,7 +2177,7 @@ static bool ggml_metal_graph_compute(
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
{ {
GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == GGML_TYPE_F32);
struct ggml_tensor * src2 = gf->nodes[i]->src[2]; struct ggml_tensor * src2 = gf->nodes[i]->src[2];
struct ggml_tensor * src3 = gf->nodes[i]->src[3]; struct ggml_tensor * src3 = gf->nodes[i]->src[3];
@ -2254,7 +2254,7 @@ static bool ggml_metal_graph_compute(
[encoder setBytes:&scale length:sizeof( float) atIndex:27]; [encoder setBytes:&scale length:sizeof( float) atIndex:27];
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
const int64_t nsg = ne01 < 4 ? 4 : 2; // simdgroups per threadgroup (a.k.a. warps) const int64_t nsg = ne01 < 4 ? 12 : 2; // simdgroups per threadgroup (a.k.a. warps)
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values)

View File

@ -2054,8 +2054,9 @@ kernel void kernel_flash_attn_ext_f16(
for (int64_t i = 0; i < L4; ++i) { for (int64_t i = 0; i < L4; ++i) {
// load heads from Q to shared memory // load heads from Q to shared memory
for (int64_t j = sgitg; j < Q; j += nsg) { for (int64_t j = sgitg; j < Q; j += nsg) {
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
if (iq1 + j < ne01) { if (iq1 + j < ne01) {
pq4[j*T4 + N4*i + tiisg] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + tiisg]; pq4[j*T4 + N4*i + tiisg] = (half4) q4[N4*i + tiisg];
} else { } else {
pq4[j*T4 + N4*i + tiisg] = 0.0h; pq4[j*T4 + N4*i + tiisg] = 0.0h;
} }

31
ggml.c
View File

@ -4178,6 +4178,8 @@ struct ggml_tensor * ggml_mul_mat(
void ggml_mul_mat_set_prec( void ggml_mul_mat_set_prec(
struct ggml_tensor * a, struct ggml_tensor * a,
enum ggml_prec prec) { enum ggml_prec prec) {
GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
const int32_t prec_i32 = (int32_t) prec; const int32_t prec_i32 = (int32_t) prec;
ggml_set_op_params_i32(a, 0, prec_i32); ggml_set_op_params_i32(a, 0, prec_i32);
@ -5781,6 +5783,16 @@ struct ggml_tensor * ggml_flash_attn_ext(
return result; return result;
} }
void ggml_flash_attn_ext_set_prec(
struct ggml_tensor * a,
enum ggml_prec prec) {
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
const int32_t prec_i32 = (int32_t) prec;
ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
}
// ggml_flash_ff // ggml_flash_ff
struct ggml_tensor * ggml_flash_ff( struct ggml_tensor * ggml_flash_ff(
@ -13347,7 +13359,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_ASSERT(ne2 == N); GGML_ASSERT(ne2 == N);
GGML_ASSERT(P >= 0); GGML_ASSERT(P >= 0);
GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nbq0 == sizeof(float));
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
@ -13408,6 +13420,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
float M = -INFINITY; float M = -INFINITY;
float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
memset(V16, 0, D*sizeof(ggml_fp16_t)); memset(V16, 0, D*sizeof(ggml_fp16_t));
@ -13433,10 +13446,19 @@ static void ggml_compute_forward_flash_attn_ext_f16(
float s; float s;
// convert Q to F16 in V32
{
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
for (int64_t d = 0; d < D; ++d) {
Q16[d] = GGML_FP32_TO_FP16(pq[d]);
}
}
ggml_vec_dot_f16(D, ggml_vec_dot_f16(D,
&s, &s,
(ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); Q16);
s = s*scale + mv; s = s*scale + mv;
@ -13488,13 +13510,14 @@ static void ggml_compute_forward_flash_attn_ext(
const struct ggml_tensor * v, const struct ggml_tensor * v,
const struct ggml_tensor * mask, const struct ggml_tensor * mask,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
switch (q->type) { switch (dst->op_params[1]) {
case GGML_TYPE_F16: case GGML_PREC_DEFAULT:
{ {
ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
} break; } break;
default: default:
{ {
// TODO: implement F32 precision
GGML_ASSERT(false); GGML_ASSERT(false);
} break; } break;
} }

4
ggml.h
View File

@ -1633,6 +1633,10 @@ extern "C" {
struct ggml_tensor * mask, struct ggml_tensor * mask,
float scale); float scale);
GGML_API void ggml_flash_attn_ext_set_prec(
struct ggml_tensor * a,
enum ggml_prec prec);
GGML_API struct ggml_tensor * ggml_flash_attn_back( GGML_API struct ggml_tensor * ggml_flash_attn_back(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * q, struct ggml_tensor * q,

View File

@ -4368,7 +4368,8 @@ static struct ggml_tensor * llm_build_kqv(
0); 0);
cb(v, "v", il); cb(v, "v", il);
cur = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale);
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT);
//printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
//printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
//printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]);

View File

@ -1386,26 +1386,24 @@ struct test_leaky_relu : public test_case {
// GGML_OP_FLASH_ATTN_EXT // GGML_OP_FLASH_ATTN_EXT
struct test_flash_attn_ext : public test_case { struct test_flash_attn_ext : public test_case {
const ggml_type typeq;
const int64_t hs; // head size const int64_t hs; // head size
const int64_t nh; // num heads const int64_t nh; // num heads
const int64_t kv; // kv size const int64_t kv; // kv size
const int64_t nb; // batch size const int64_t nb; // batch size
std::string vars() override { std::string vars() override {
return VARS_TO_STR5(typeq, hs, nh, kv, nb); return VARS_TO_STR4(hs, nh, kv, nb);
} }
double max_nmse_err() override { double max_nmse_err() override {
return 5e-5; return 5e-5;
} }
test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) : hs(hs), nh(nh), kv(kv), nb(nb) {}
: typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1); ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1);
@ -1680,9 +1678,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_pad());
test_cases.emplace_back(new test_leaky_relu()); test_cases.emplace_back(new test_leaky_relu());
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 8)); test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 8));
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 7)); test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 7));
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 1)); test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 1));
#if !defined(__SANITIZE_THREAD__) #if !defined(__SANITIZE_THREAD__)
// FIXME: these tests use too much memory with thread sanitizer // FIXME: these tests use too much memory with thread sanitizer