diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 37fdd10cb..ff721ea43 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -544,7 +544,7 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong #define QR3_XS 8 #define QI3_XS (QK_K / (4*QR3_XS)) -#define IQ3S_BLOCK_SIZE 16 +#define IQ3S_BLOCK_SIZE 32 typedef struct { half d; uint8_t qs[QK_K/4]; @@ -5237,7 +5237,11 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( const int ib32 = iqs; const uint8_t * qs = bq2->qs + 8*ib32; const int8_t * q8 = bq8_1[ib32].qs; +#if IQ3S_BLOCK_SIZE == 32 int sumi = 0; +#else + int sumi[2] = {0, 0}; +#endif for (int l = 0; l < 4; ++l) { #ifdef IQ3S_SLOW_MULT aux32[0] = ((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; @@ -5252,12 +5256,23 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); const int grid_l = __vsub4(aux32[0] ^ signs0, signs0); const int grid_h = __vsub4(aux32[1] ^ signs1, signs1); +#if IQ3S_BLOCK_SIZE == 32 sumi = __dp4a(grid_l, *((int *)q8+0), sumi); sumi = __dp4a(grid_h, *((int *)q8+1), sumi); +#else + sumi[l/2] = __dp4a(grid_l, *((int *)q8+0), sumi[l/2]); + sumi[l/2] = __dp4a(grid_h, *((int *)q8+1), sumi[l/2]); +#endif q8 += 8; } +#if IQ3S_BLOCK_SIZE == 32 const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds); return d * sumi; +#else + int ls1 = 1 + 2*(bq2->scales[ib32] & 0xf); + int ls2 = 1 + 2*(bq2->scales[ib32] >> 4); + return (float)bq2->d * __low2float(bq8_1[ib32].ds) * (ls1 * sumi[0] + ls2 * sumi[1]); +#endif #else assert(false); return 0.f; diff --git a/ggml-quants.c b/ggml-quants.c index cfa36b310..8d2cae527 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -10037,6 +10037,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(by); UNUSED(bs); + GGML_ASSERT(IQ3S_BLOCK_SIZE == 32 && "IQ3S_BLOCK_SIZE != 32 is not implemented"); + const block_iq3_s * restrict x = vx; const block_q8_K * restrict y = vy;