mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-05 16:10:42 +01:00
llama : avoid ggml_cast, use F32 query
This commit is contained in:
parent
40ea8cd1ac
commit
f9ca5dcbe8
@ -2177,7 +2177,7 @@ static bool ggml_metal_graph_compute(
|
|||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne00 % 4 == 0);
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
||||||
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
||||||
@ -2254,7 +2254,7 @@ static bool ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||||
|
|
||||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||||
const int64_t nsg = ne01 < 4 ? 4 : 2; // simdgroups per threadgroup (a.k.a. warps)
|
const int64_t nsg = ne01 < 4 ? 12 : 2; // simdgroups per threadgroup (a.k.a. warps)
|
||||||
|
|
||||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||||
const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values)
|
const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values)
|
||||||
|
@ -2054,8 +2054,9 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
for (int64_t i = 0; i < L4; ++i) {
|
for (int64_t i = 0; i < L4; ++i) {
|
||||||
// load heads from Q to shared memory
|
// load heads from Q to shared memory
|
||||||
for (int64_t j = sgitg; j < Q; j += nsg) {
|
for (int64_t j = sgitg; j < Q; j += nsg) {
|
||||||
|
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
||||||
if (iq1 + j < ne01) {
|
if (iq1 + j < ne01) {
|
||||||
pq4[j*T4 + N4*i + tiisg] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + tiisg];
|
pq4[j*T4 + N4*i + tiisg] = (half4) q4[N4*i + tiisg];
|
||||||
} else {
|
} else {
|
||||||
pq4[j*T4 + N4*i + tiisg] = 0.0h;
|
pq4[j*T4 + N4*i + tiisg] = 0.0h;
|
||||||
}
|
}
|
||||||
|
31
ggml.c
31
ggml.c
@ -4178,6 +4178,8 @@ struct ggml_tensor * ggml_mul_mat(
|
|||||||
void ggml_mul_mat_set_prec(
|
void ggml_mul_mat_set_prec(
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
enum ggml_prec prec) {
|
enum ggml_prec prec) {
|
||||||
|
GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
|
||||||
|
|
||||||
const int32_t prec_i32 = (int32_t) prec;
|
const int32_t prec_i32 = (int32_t) prec;
|
||||||
|
|
||||||
ggml_set_op_params_i32(a, 0, prec_i32);
|
ggml_set_op_params_i32(a, 0, prec_i32);
|
||||||
@ -5781,6 +5783,16 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_flash_attn_ext_set_prec(
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
enum ggml_prec prec) {
|
||||||
|
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
|
||||||
|
|
||||||
|
const int32_t prec_i32 = (int32_t) prec;
|
||||||
|
|
||||||
|
ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_flash_ff
|
// ggml_flash_ff
|
||||||
|
|
||||||
struct ggml_tensor * ggml_flash_ff(
|
struct ggml_tensor * ggml_flash_ff(
|
||||||
@ -13347,7 +13359,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
GGML_ASSERT(ne2 == N);
|
GGML_ASSERT(ne2 == N);
|
||||||
GGML_ASSERT(P >= 0);
|
GGML_ASSERT(P >= 0);
|
||||||
|
|
||||||
GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
|
GGML_ASSERT(nbq0 == sizeof(float));
|
||||||
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
||||||
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
|
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
|
||||||
|
|
||||||
@ -13408,6 +13420,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
float M = -INFINITY;
|
float M = -INFINITY;
|
||||||
|
|
||||||
float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
|
float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
|
||||||
|
ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
|
||||||
ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
|
ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
|
||||||
|
|
||||||
memset(V16, 0, D*sizeof(ggml_fp16_t));
|
memset(V16, 0, D*sizeof(ggml_fp16_t));
|
||||||
@ -13433,10 +13446,19 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
|
|
||||||
float s;
|
float s;
|
||||||
|
|
||||||
|
// convert Q to F16 in V32
|
||||||
|
{
|
||||||
|
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||||
|
|
||||||
|
for (int64_t d = 0; d < D; ++d) {
|
||||||
|
Q16[d] = GGML_FP32_TO_FP16(pq[d]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ggml_vec_dot_f16(D,
|
ggml_vec_dot_f16(D,
|
||||||
&s,
|
&s,
|
||||||
(ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
(ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
||||||
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
Q16);
|
||||||
|
|
||||||
s = s*scale + mv;
|
s = s*scale + mv;
|
||||||
|
|
||||||
@ -13488,13 +13510,14 @@ static void ggml_compute_forward_flash_attn_ext(
|
|||||||
const struct ggml_tensor * v,
|
const struct ggml_tensor * v,
|
||||||
const struct ggml_tensor * mask,
|
const struct ggml_tensor * mask,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
switch (q->type) {
|
switch (dst->op_params[1]) {
|
||||||
case GGML_TYPE_F16:
|
case GGML_PREC_DEFAULT:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
|
ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
|
// TODO: implement F32 precision
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
} break;
|
} break;
|
||||||
}
|
}
|
||||||
|
4
ggml.h
4
ggml.h
@ -1633,6 +1633,10 @@ extern "C" {
|
|||||||
struct ggml_tensor * mask,
|
struct ggml_tensor * mask,
|
||||||
float scale);
|
float scale);
|
||||||
|
|
||||||
|
GGML_API void ggml_flash_attn_ext_set_prec(
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
enum ggml_prec prec);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * q,
|
struct ggml_tensor * q,
|
||||||
|
@ -4368,7 +4368,8 @@ static struct ggml_tensor * llm_build_kqv(
|
|||||||
0);
|
0);
|
||||||
cb(v, "v", il);
|
cb(v, "v", il);
|
||||||
|
|
||||||
cur = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale);
|
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale);
|
||||||
|
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT);
|
||||||
//printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
|
//printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
|
||||||
//printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
|
//printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
|
||||||
//printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
|
//printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
|
||||||
|
@ -1386,26 +1386,24 @@ struct test_leaky_relu : public test_case {
|
|||||||
|
|
||||||
// GGML_OP_FLASH_ATTN_EXT
|
// GGML_OP_FLASH_ATTN_EXT
|
||||||
struct test_flash_attn_ext : public test_case {
|
struct test_flash_attn_ext : public test_case {
|
||||||
const ggml_type typeq;
|
|
||||||
const int64_t hs; // head size
|
const int64_t hs; // head size
|
||||||
const int64_t nh; // num heads
|
const int64_t nh; // num heads
|
||||||
const int64_t kv; // kv size
|
const int64_t kv; // kv size
|
||||||
const int64_t nb; // batch size
|
const int64_t nb; // batch size
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR5(typeq, hs, nh, kv, nb);
|
return VARS_TO_STR4(hs, nh, kv, nb);
|
||||||
}
|
}
|
||||||
|
|
||||||
double max_nmse_err() override {
|
double max_nmse_err() override {
|
||||||
return 5e-5;
|
return 5e-5;
|
||||||
}
|
}
|
||||||
|
|
||||||
test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16,
|
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
|
||||||
int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
|
: hs(hs), nh(nh), kv(kv), nb(nb) {}
|
||||||
: typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {}
|
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1);
|
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
|
||||||
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
||||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
||||||
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1);
|
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1);
|
||||||
@ -1680,9 +1678,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
test_cases.emplace_back(new test_pad());
|
test_cases.emplace_back(new test_pad());
|
||||||
test_cases.emplace_back(new test_leaky_relu());
|
test_cases.emplace_back(new test_leaky_relu());
|
||||||
|
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 8));
|
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 8));
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 7));
|
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 7));
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 1));
|
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 1));
|
||||||
|
|
||||||
#if !defined(__SANITIZE_THREAD__)
|
#if !defined(__SANITIZE_THREAD__)
|
||||||
// FIXME: these tests use too much memory with thread sanitizer
|
// FIXME: these tests use too much memory with thread sanitizer
|
||||||
|
Loading…
Reference in New Issue
Block a user