mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 05:48:47 +01:00
Fix more int overflow during quant (PPL/CUDA). (#6563)
* Fix more int overflow during quant. * Fix some more int overflow in softmax. * Revert back to int64_t.
This commit is contained in:
parent
7bb36ccf91
commit
e00b4a8f81
@ -5,16 +5,16 @@
|
|||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||||
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
|
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
|
||||||
const int64_t i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
|
const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
|
||||||
|
|
||||||
if (i >= k) {
|
if (i >= k) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t ib = i/qk; // block index
|
const int64_t ib = i/qk; // block index
|
||||||
const int iqs = (i%qk)/qr; // quant index
|
const int64_t iqs = (i%qk)/qr; // quant index
|
||||||
const int iybs = i - i%qk; // y block start index
|
const int64_t iybs = i - i%qk; // y block start index
|
||||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
const int64_t y_offset = qr == 1 ? 1 : qk/2;
|
||||||
|
|
||||||
// dequantize
|
// dequantize
|
||||||
dfloat2 v;
|
dfloat2 v;
|
||||||
@ -29,7 +29,7 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
|
|||||||
#if __CUDA_ARCH__ >= CC_PASCAL
|
#if __CUDA_ARCH__ >= CC_PASCAL
|
||||||
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
|
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
|
||||||
|
|
||||||
const int i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
|
const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
|
||||||
const int * x0 = ((int *) vx) + blockIdx.x * nint;
|
const int * x0 = ((int *) vx) + blockIdx.x * nint;
|
||||||
half2 * y2 = (half2 *) (y + i0);
|
half2 * y2 = (half2 *) (y + i0);
|
||||||
|
|
||||||
@ -73,9 +73,9 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t
|
|||||||
const int64_t i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
|
|
||||||
// assume 32 threads
|
// assume 32 threads
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const int il = tid/8;
|
const int64_t il = tid/8;
|
||||||
const int ir = tid%8;
|
const int64_t ir = tid%8;
|
||||||
const int64_t ib = 8*i + ir;
|
const int64_t ib = 8*i + ir;
|
||||||
if (ib >= nb32) {
|
if (ib >= nb32) {
|
||||||
return;
|
return;
|
||||||
@ -101,9 +101,9 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t
|
|||||||
const int64_t i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
|
|
||||||
// assume 32 threads
|
// assume 32 threads
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const int il = tid/8;
|
const int64_t il = tid/8;
|
||||||
const int ir = tid%8;
|
const int64_t ir = tid%8;
|
||||||
const int64_t ib = 8*i + ir;
|
const int64_t ib = 8*i + ir;
|
||||||
if (ib >= nb32) {
|
if (ib >= nb32) {
|
||||||
return;
|
return;
|
||||||
@ -127,14 +127,14 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t
|
|||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_q2_K * x = (const block_q2_K *) vx;
|
const block_q2_K * x = (const block_q2_K *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int n = tid/32;
|
const int64_t n = tid/32;
|
||||||
const int l = tid - 32*n;
|
const int64_t l = tid - 32*n;
|
||||||
const int is = 8*n + l/16;
|
const int64_t is = 8*n + l/16;
|
||||||
|
|
||||||
const uint8_t q = x[i].qs[32*n + l];
|
const uint8_t q = x[i].qs[32*n + l];
|
||||||
dst_t * y = yy + i*QK_K + 128*n;
|
dst_t * y = yy + i*QK_K + 128*n;
|
||||||
@ -146,8 +146,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
|
|||||||
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
|
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
|
||||||
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
|
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
|
||||||
#else
|
#else
|
||||||
const int is = tid/16; // 0 or 1
|
const int64_t is = tid/16; // 0 or 1
|
||||||
const int il = tid%16; // 0...15
|
const int64_t il = tid%16; // 0...15
|
||||||
const uint8_t q = x[i].qs[il] >> (2*is);
|
const uint8_t q = x[i].qs[il] >> (2*is);
|
||||||
dst_t * y = yy + i*QK_K + 16*is + il;
|
dst_t * y = yy + i*QK_K + 16*is + il;
|
||||||
float dall = __low2half(x[i].dm);
|
float dall = __low2half(x[i].dm);
|
||||||
@ -161,19 +161,19 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
|
|||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_q3_K * x = (const block_q3_K *) vx;
|
const block_q3_K * x = (const block_q3_K *) vx;
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int r = threadIdx.x/4;
|
const int64_t r = threadIdx.x/4;
|
||||||
const int tid = r/2;
|
const int64_t tid = r/2;
|
||||||
const int is0 = r%2;
|
const int64_t is0 = r%2;
|
||||||
const int l0 = 16*is0 + 4*(threadIdx.x%4);
|
const int64_t l0 = 16*is0 + 4*(threadIdx.x%4);
|
||||||
const int n = tid / 4;
|
const int64_t n = tid / 4;
|
||||||
const int j = tid - 4*n;
|
const int64_t j = tid - 4*n;
|
||||||
|
|
||||||
uint8_t m = 1 << (4*n + j);
|
uint8_t m = 1 << (4*n + j);
|
||||||
int is = 8*n + 2*j + is0;
|
int64_t is = 8*n + 2*j + is0;
|
||||||
int shift = 2*j;
|
int shift = 2*j;
|
||||||
|
|
||||||
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
|
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
|
||||||
@ -189,11 +189,11 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t
|
|||||||
|
|
||||||
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
||||||
#else
|
#else
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const int is = tid/16; // 0 or 1
|
const int64_t is = tid/16; // 0 or 1
|
||||||
const int il = tid%16; // 0...15
|
const int64_t il = tid%16; // 0...15
|
||||||
const int im = il/8; // 0...1
|
const int64_t im = il/8; // 0...1
|
||||||
const int in = il%8; // 0...7
|
const int64_t in = il%8; // 0...7
|
||||||
|
|
||||||
dst_t * y = yy + i*QK_K + 16*is + il;
|
dst_t * y = yy + i*QK_K + 16*is + il;
|
||||||
|
|
||||||
@ -227,15 +227,15 @@ template<typename dst_t>
|
|||||||
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
const block_q4_K * x = (const block_q4_K *) vx;
|
const block_q4_K * x = (const block_q4_K *) vx;
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
// assume 32 threads
|
// assume 32 threads
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const int il = tid/8;
|
const int64_t il = tid/8;
|
||||||
const int ir = tid%8;
|
const int64_t ir = tid%8;
|
||||||
const int is = 2*il;
|
const int64_t is = 2*il;
|
||||||
const int n = 4;
|
const int64_t n = 4;
|
||||||
|
|
||||||
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
||||||
|
|
||||||
@ -254,7 +254,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t
|
|||||||
y[l +32] = d2 * (q[l] >> 4) - m2;
|
y[l +32] = d2 * (q[l] >> 4) - m2;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const uint8_t * q = x[i].qs;
|
const uint8_t * q = x[i].qs;
|
||||||
dst_t * y = yy + i*QK_K;
|
dst_t * y = yy + i*QK_K;
|
||||||
const float d = (float)x[i].dm[0];
|
const float d = (float)x[i].dm[0];
|
||||||
@ -268,14 +268,14 @@ template<typename dst_t>
|
|||||||
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
const block_q5_K * x = (const block_q5_K *) vx;
|
const block_q5_K * x = (const block_q5_K *) vx;
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
// assume 64 threads - this is very slightly better than the one below
|
// assume 64 threads - this is very slightly better than the one below
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const int il = tid/16; // il is in 0...3
|
const int64_t il = tid/16; // il is in 0...3
|
||||||
const int ir = tid%16; // ir is in 0...15
|
const int64_t ir = tid%16; // ir is in 0...15
|
||||||
const int is = 2*il; // is is in 0...6
|
const int64_t is = 2*il; // is is in 0...6
|
||||||
|
|
||||||
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
|
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
|
||||||
|
|
||||||
@ -298,11 +298,11 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t
|
|||||||
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
|
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
|
||||||
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
|
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
|
||||||
#else
|
#else
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const uint8_t q = x[i].qs[tid];
|
const uint8_t q = x[i].qs[tid];
|
||||||
const int im = tid/8; // 0...3
|
const int64_t im = tid/8; // 0...3
|
||||||
const int in = tid%8; // 0...7
|
const int64_t in = tid%8; // 0...7
|
||||||
const int is = tid/16; // 0 or 1
|
const int64_t is = tid/16; // 0 or 1
|
||||||
const uint8_t h = x[i].qh[in] >> im;
|
const uint8_t h = x[i].qh[in] >> im;
|
||||||
const float d = x[i].d;
|
const float d = x[i].d;
|
||||||
dst_t * y = yy + i*QK_K + tid;
|
dst_t * y = yy + i*QK_K + tid;
|
||||||
@ -359,13 +359,13 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
|
|||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint16_t * q2 = x[i].qs + 4*ib;
|
const uint16_t * q2 = x[i].qs + 4*ib;
|
||||||
const uint8_t * aux8 = (const uint8_t *)q2;
|
const uint8_t * aux8 = (const uint8_t *)q2;
|
||||||
@ -383,13 +383,13 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
|
|||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint16_t * q2 = x[i].qs + 4*ib;
|
const uint16_t * q2 = x[i].qs + 4*ib;
|
||||||
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
|
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
|
||||||
@ -405,13 +405,13 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
|
|||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_iq2_s * x = (const block_iq2_s *) vx;
|
const block_iq2_s * x = (const block_iq2_s *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
|
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
|
||||||
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
||||||
@ -426,13 +426,13 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_
|
|||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
|
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint8_t * q3 = x[i].qs + 8*ib;
|
const uint8_t * q3 = x[i].qs + 8*ib;
|
||||||
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
|
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
|
||||||
@ -454,13 +454,13 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
|
|||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_iq3_s * x = (const block_iq3_s *) vx;
|
const block_iq3_s * x = (const block_iq3_s *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint8_t * qs = x[i].qs + 8*ib;
|
const uint8_t * qs = x[i].qs + 8*ib;
|
||||||
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
|
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
|
||||||
@ -480,13 +480,13 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
|
|||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_iq1_s * x = (const block_iq1_s *) vx;
|
const block_iq1_s * x = (const block_iq1_s *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
|
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
|
||||||
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
|
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
|
||||||
@ -506,18 +506,18 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
|
|||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_iq1_m * x = (const block_iq1_m *) vx;
|
const block_iq1_m * x = (const block_iq1_m *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
||||||
iq1m_scale_t scale;
|
iq1m_scale_t scale;
|
||||||
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||||
const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
|
const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
|
||||||
const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
|
const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
|
||||||
const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
|
const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
|
||||||
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
|
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
|
||||||
@ -537,12 +537,12 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
|
|||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
|
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||||
const uint8_t * q4 = x[ib].qs + 4*il;
|
const uint8_t * q4 = x[ib].qs + 4*il;
|
||||||
const float d = (float)x[ib].d;
|
const float d = (float)x[ib].d;
|
||||||
@ -556,12 +556,12 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
|
|||||||
#if QK_K != 64
|
#if QK_K != 64
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||||
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
|
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
|
||||||
const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
|
const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
|
||||||
|
@ -28,7 +28,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
|
|||||||
extern __shared__ float data_soft_max_f32[];
|
extern __shared__ float data_soft_max_f32[];
|
||||||
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
||||||
// shared memory buffer to cache values between iterations:
|
// shared memory buffer to cache values between iterations:
|
||||||
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
|
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
|
||||||
|
|
||||||
float max_val = -INFINITY;
|
float max_val = -INFINITY;
|
||||||
|
|
||||||
@ -40,8 +40,8 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int ix = rowx*ncols + col;
|
const int64_t ix = (int64_t)rowx*ncols + col;
|
||||||
const int iy = rowy*ncols + col;
|
const int64_t iy = (int64_t)rowy*ncols + col;
|
||||||
|
|
||||||
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
|
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
|
||||||
|
|
||||||
@ -109,7 +109,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int idst = rowx*ncols + col;
|
const int64_t idst = (int64_t)rowx*ncols + col;
|
||||||
dst[idst] = vals[col] * inv_sum;
|
dst[idst] = vals[col] * inv_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user