iq3_s_mult: also CUDA

This commit is contained in:
Iwan Kawrakow 2024-03-03 19:12:05 +02:00
parent e5e72562c5
commit f2c2bd6b26
2 changed files with 18 additions and 1 deletions

View File

@ -544,7 +544,7 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
#define QR3_XS 8
#define QI3_XS (QK_K / (4*QR3_XS))
#define IQ3S_BLOCK_SIZE 16
#define IQ3S_BLOCK_SIZE 32
typedef struct {
half d;
uint8_t qs[QK_K/4];
@ -5237,7 +5237,11 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
const int ib32 = iqs;
const uint8_t * qs = bq2->qs + 8*ib32;
const int8_t * q8 = bq8_1[ib32].qs;
#if IQ3S_BLOCK_SIZE == 32
int sumi = 0;
#else
int sumi[2] = {0, 0};
#endif
for (int l = 0; l < 4; ++l) {
#ifdef IQ3S_SLOW_MULT
aux32[0] = ((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
@ -5252,12 +5256,23 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
const int grid_l = __vsub4(aux32[0] ^ signs0, signs0);
const int grid_h = __vsub4(aux32[1] ^ signs1, signs1);
#if IQ3S_BLOCK_SIZE == 32
sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
#else
sumi[l/2] = __dp4a(grid_l, *((int *)q8+0), sumi[l/2]);
sumi[l/2] = __dp4a(grid_h, *((int *)q8+1), sumi[l/2]);
#endif
q8 += 8;
}
#if IQ3S_BLOCK_SIZE == 32
const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds);
return d * sumi;
#else
int ls1 = 1 + 2*(bq2->scales[ib32] & 0xf);
int ls2 = 1 + 2*(bq2->scales[ib32] >> 4);
return (float)bq2->d * __low2float(bq8_1[ib32].ds) * (ls1 * sumi[0] + ls2 * sumi[1]);
#endif
#else
assert(false);
return 0.f;

View File

@ -10037,6 +10037,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
UNUSED(by);
UNUSED(bs);
GGML_ASSERT(IQ3S_BLOCK_SIZE == 32 && "IQ3S_BLOCK_SIZE != 32 is not implemented");
const block_iq3_s * restrict x = vx;
const block_q8_K * restrict y = vy;