diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a14d0d1db..c76d00a39 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -14746,7 +14746,7 @@ static void ggml_compute_forward_pool_1d_sk_p0( const struct ggml_tensor * src = dst->src[0]; - assert(src->type == GGML_TYPE_F32); + assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16); if (params->ith != 0) { return; @@ -14759,10 +14759,8 @@ static void ggml_compute_forward_pool_1d_sk_p0( const int64_t rs = dst->ne[0]; while (cdata < data_end) { - const float * const srow = (const float *)cdata; - + const void * srow = (const void *)cdata; int j = 0; - for (int64_t i = 0; i < rs; ++i) { switch (op) { case GGML_OP_POOL_AVG: drow[i] = 0; break; @@ -14770,10 +14768,11 @@ static void ggml_compute_forward_pool_1d_sk_p0( case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } for (int ki = 0; ki < k; ++ki) { + const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); switch (op) { - case GGML_OP_POOL_AVG: drow[i] += srow[j]; break; - case GGML_OP_POOL_MAX: if (srow[j] > drow[i]) drow[i] = srow[j]; break; - case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); + case GGML_OP_POOL_AVG: drow[i] += srow_j; break; + case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break; + case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } ++j; } @@ -14814,7 +14813,7 @@ static void ggml_compute_forward_pool_2d( const struct ggml_tensor * src = dst->src[0]; - GGML_ASSERT(src->type == GGML_TYPE_F32); + assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16); if (params->ith != 0) { return; @@ -14857,14 +14856,15 @@ static void ggml_compute_forward_pool_2d( for (int ky = 0; ky < k1; ++ky) { if (iy + ky < 0 || iy + ky >= src->ne[1]) continue; - const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky)); + const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky)); for (int kx = 0; kx < k0; ++kx) { int j = ix + kx; if (j < 0 || j >= src->ne[0]) continue; + const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); switch (op) { - case GGML_OP_POOL_AVG: *out += srow[j]; break; - case GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break; - case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); + case GGML_OP_POOL_AVG: *out += srow_j; break; + case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break; + case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } } }