mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 00:39:00 +01:00
WIP
This commit is contained in:
parent
e43e81a5d7
commit
0fe9cd488f
26
ggml-cuda.cu
26
ggml-cuda.cu
@ -2371,8 +2371,8 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
|
|||||||
}
|
}
|
||||||
|
|
||||||
//#define IQ3S_MULTIPLIER 746226
|
//#define IQ3S_MULTIPLIER 746226
|
||||||
//#define IQ3S_MULTIPLIER 717154
|
//#define IQ3S_MULTIPLIER 677595
|
||||||
#define IQ3S_MULTIPLIER 677595
|
#define IQ3S_MULTIPLIER 190842953LL
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
@ -2393,8 +2393,8 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
|
|||||||
aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
|
aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
|
||||||
aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
|
aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
|
||||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||||
aux32[0] = (((__vadd4(aux32[0], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101;
|
aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101;
|
||||||
aux32[1] = (((__vadd4(aux32[1], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101;
|
aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101;
|
||||||
uint32_t signs0 = __vcmpeq4(((signs & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
|
uint32_t signs0 = __vcmpeq4(((signs & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
|
||||||
uint32_t signs1 = __vcmpeq4(((signs >> 4) * 0x01010101) & 0x08040201, 0x08040201);
|
uint32_t signs1 = __vcmpeq4(((signs >> 4) * 0x01010101) & 0x08040201, 0x08040201);
|
||||||
aux32[0] = __vsub4(aux32[0] ^ signs0, signs0);
|
aux32[0] = __vsub4(aux32[0] ^ signs0, signs0);
|
||||||
@ -2404,9 +2404,7 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
|
|||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
for (int j = 0; j < 8; ++j) {
|
for (int j = 0; j < 8; ++j) {
|
||||||
//y[j] = d * (2*((grid[j]-1)/2) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
y[j] = d * (2*((grid[j]-1)/2) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||||
//y[j] = d * iq3s_values[grid[j]] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
|
||||||
y[j] = d * (2*(((grid[j]+1)/2) & 7) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#else
|
#else
|
||||||
@ -5216,7 +5214,6 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
|
|||||||
const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
|
const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
|
||||||
|
|
||||||
uint32_t aux32[2];
|
uint32_t aux32[2];
|
||||||
const uint8_t * grid = (const uint8_t *)aux32;
|
|
||||||
|
|
||||||
const int ib32 = iqs;
|
const int ib32 = iqs;
|
||||||
const uint8_t * qs = bq2->qs + 8*ib32;
|
const uint8_t * qs = bq2->qs + 8*ib32;
|
||||||
@ -5225,25 +5222,16 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
|
|||||||
for (int l = 0; l < 4; ++l) {
|
for (int l = 0; l < 4; ++l) {
|
||||||
aux32[0] = ((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
|
aux32[0] = ((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
|
||||||
aux32[1] = ((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
|
aux32[1] = ((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
|
||||||
aux32[0] = (((__vadd4(aux32[0], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101;
|
aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101;
|
||||||
aux32[1] = (((__vadd4(aux32[1], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101;
|
aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101;
|
||||||
uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
|
uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
|
||||||
uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
|
uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
|
||||||
const int grid_l = __vsub4(aux32[0] ^ signs0, signs0);
|
const int grid_l = __vsub4(aux32[0] ^ signs0, signs0);
|
||||||
const int grid_h = __vsub4(aux32[1] ^ signs1, signs1);
|
const int grid_h = __vsub4(aux32[1] ^ signs1, signs1);
|
||||||
sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
|
sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
|
||||||
sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
|
sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
|
||||||
//const uint32_t * grid1 = iq3xs_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
|
|
||||||
//const uint32_t * grid2 = iq3xs_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
|
|
||||||
//uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
|
|
||||||
//uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
|
|
||||||
//const int grid_l = __vsub4(grid1[0] ^ signs0, signs0);
|
|
||||||
//const int grid_h = __vsub4(grid2[0] ^ signs1, signs1);
|
|
||||||
//sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
|
|
||||||
//sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
|
|
||||||
q8 += 8;
|
q8 += 8;
|
||||||
}
|
}
|
||||||
//const float d = (float)bq2->d * (0.5f + ((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds) * 0.5f;
|
|
||||||
const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds);
|
const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds);
|
||||||
return d * sumi;
|
return d * sumi;
|
||||||
#else
|
#else
|
||||||
|
@ -4124,7 +4124,8 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y
|
|||||||
//#define IQ3S_MULTIPLIER 2469109
|
//#define IQ3S_MULTIPLIER 2469109
|
||||||
//#define IQ3S_MULTIPLIER 746226
|
//#define IQ3S_MULTIPLIER 746226
|
||||||
//#define IQ3S_MULTIPLIER 717154
|
//#define IQ3S_MULTIPLIER 717154
|
||||||
#define IQ3S_MULTIPLIER 677595
|
//#define IQ3S_MULTIPLIER 677595
|
||||||
|
#define IQ3S_MULTIPLIER 190842953
|
||||||
#define IQ3S_BITS 3
|
#define IQ3S_BITS 3
|
||||||
|
|
||||||
void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) {
|
void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) {
|
||||||
@ -4148,8 +4149,7 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in
|
|||||||
aux32[0] = ((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
|
aux32[0] = ((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
|
||||||
aux32[1] = ((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
|
aux32[1] = ((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
|
||||||
for (int j = 0; j < 8; ++j) {
|
for (int j = 0; j < 8; ++j) {
|
||||||
//y[j] = db1 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
|
y[j] = db1 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||||
y[j] = db1 * (2*(((grid[j]+1)/2) & 7) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
|
|
||||||
}
|
}
|
||||||
y += 8;
|
y += 8;
|
||||||
}
|
}
|
||||||
@ -4159,8 +4159,7 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in
|
|||||||
aux32[0] = ((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
|
aux32[0] = ((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
|
||||||
aux32[1] = ((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
|
aux32[1] = ((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
|
||||||
for (int j = 0; j < 8; ++j) {
|
for (int j = 0; j < 8; ++j) {
|
||||||
//y[j] = db2 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
|
y[j] = db2 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||||
y[j] = db2 * (2*(((grid[j]+1)/2) & 7) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
|
|
||||||
}
|
}
|
||||||
y += 8;
|
y += 8;
|
||||||
}
|
}
|
||||||
@ -10121,9 +10120,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||||||
const __m256i idx_mask = _mm256_set1_epi32(256);
|
const __m256i idx_mask = _mm256_set1_epi32(256);
|
||||||
const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
|
const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
|
||||||
const __m256i idx_mult = _mm256_set1_epi32(IQ3S_MULTIPLIER);
|
const __m256i idx_mult = _mm256_set1_epi32(IQ3S_MULTIPLIER);
|
||||||
const __m256i m1 = _mm256_set1_epi32(0x01010101);
|
const __m256i m1 = _mm256_set1_epi8(1);
|
||||||
const __m256i m7 = _mm256_set1_epi32(0x07070707);
|
const __m256i m7 = _mm256_set1_epi32(0x07070707);
|
||||||
const __m256i m15 = _mm256_set1_epi32(0x0f0f0f0f);
|
const __m256i m15 = _mm256_set1_epi32(0x0f0f0f0f);
|
||||||
|
const __m256i m0 = _mm256_setzero_si256();
|
||||||
|
|
||||||
__m256 accumf = _mm256_setzero_ps();
|
__m256 accumf = _mm256_setzero_ps();
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
@ -10143,9 +10143,13 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||||||
const __m256i idx_h_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[ib32+1]), idx_shift), idx_mask);
|
const __m256i idx_h_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[ib32+1]), idx_shift), idx_mask);
|
||||||
const __m256i idx_32_l = _mm256_or_si256(idx_h_l, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l_16)));
|
const __m256i idx_32_l = _mm256_or_si256(idx_h_l, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l_16)));
|
||||||
const __m256i idx_32_h = _mm256_or_si256(idx_h_h, _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l_16, 1)));
|
const __m256i idx_32_h = _mm256_or_si256(idx_h_h, _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l_16, 1)));
|
||||||
const __m256i idx_l = _mm256_add_epi32(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1);
|
|
||||||
|
// v = MAX(((IQ3S_MULTIPLIER * idx) & 0x0f0f0f0f) - 1, 0)
|
||||||
|
const __m256i idx_l = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1), m0);
|
||||||
|
// v = (((v >> 1) & 0x07070707) << 1) | 0x01010101
|
||||||
const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1);
|
const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1);
|
||||||
const __m256i idx_h = _mm256_add_epi32(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1);
|
|
||||||
|
const __m256i idx_h = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1), m0);
|
||||||
const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1);
|
const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1);
|
||||||
|
|
||||||
__m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
|
__m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
|
||||||
@ -11375,10 +11379,9 @@ static void iq3xs_init_grid512(void) {
|
|||||||
const uint8_t * q4 = (const uint8_t *)&aux32;
|
const uint8_t * q4 = (const uint8_t *)&aux32;
|
||||||
for (int k = 0; k < grid_size; ++k) {
|
for (int k = 0; k < grid_size; ++k) {
|
||||||
int8_t * pos = (int8_t *)(the_grid + k);
|
int8_t * pos = (int8_t *)(the_grid + k);
|
||||||
aux32 = (IQ3S_MULTIPLIER * k) & 0x0f0f0f0f;
|
aux32 = ((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f;
|
||||||
for (int i = 0; i < 4; ++i) {
|
for (int i = 0; i < 4; ++i) {
|
||||||
//pos[i] = 2*((q4[i]-1)/2) + 1;
|
pos[i] = 2*((q4[i]-1)/2) + 1;
|
||||||
pos[i] = 2*(((q4[i]+1)/2) & kmask) + 1;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user