mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 13:28:50 +01:00
sync : ggml (SD ops, tests, kernels) (#4444)
* sync : ggml (SD ops, tests, kernels) ggml-ci * cuda : restore im2col ggml-ci * metal : fix accuracy of dequantization kernels ggml-ci * cuda : restore correct im2col ggml-ci * metal : try to fix moe test by reducing expert size ggml-ci * cuda : fix bin bcast when src1 and dst have different types ggml-ci --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
70f806b821
commit
4d98d9a656
477
ggml-cuda.cu
477
ggml-cuda.cu
@ -439,6 +439,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
||||
|
||||
#define CUDA_GELU_BLOCK_SIZE 256
|
||||
#define CUDA_SILU_BLOCK_SIZE 256
|
||||
#define CUDA_TANH_BLOCK_SIZE 256
|
||||
#define CUDA_RELU_BLOCK_SIZE 256
|
||||
#define CUDA_SQR_BLOCK_SIZE 256
|
||||
#define CUDA_CPY_BLOCK_SIZE 32
|
||||
@ -451,6 +452,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
||||
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||
#define CUDA_GET_ROWS_BLOCK_SIZE 256
|
||||
#define CUDA_UPSCALE_BLOCK_SIZE 256
|
||||
#define CUDA_CONCAT_BLOCK_SIZE 256
|
||||
#define CUDA_PAD_BLOCK_SIZE 256
|
||||
#define CUDA_ACC_BLOCK_SIZE 256
|
||||
#define CUDA_IM2COL_BLOCK_SIZE 256
|
||||
|
||||
// dmmv = dequantize_mul_mat_vec
|
||||
#ifndef GGML_CUDA_DMMV_X
|
||||
@ -612,6 +618,24 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
|
||||
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
||||
}
|
||||
|
||||
static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
||||
const int ne10, const int ne11, const int ne12,
|
||||
const int nb1, const int nb2, int offset) {
|
||||
const int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
int src1_idx = i - offset;
|
||||
int oz = src1_idx / nb2;
|
||||
int oy = (src1_idx - (oz * nb2)) / nb1;
|
||||
int ox = src1_idx % nb1;
|
||||
if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
|
||||
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
|
||||
} else {
|
||||
dst[i] = x[i];
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
|
||||
const float GELU_COEF_A = 0.044715f;
|
||||
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
@ -634,6 +658,23 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
|
||||
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
||||
}
|
||||
|
||||
static __global__ void gelu_quick_f32(const float *x, float *dst, int k) {
|
||||
const float GELU_QUICK_COEF = -1.702f;
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
|
||||
}
|
||||
|
||||
static __global__ void tanh_f32(const float *x, float *dst, int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = tanhf(x[i]);
|
||||
}
|
||||
|
||||
static __global__ void relu_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
@ -643,6 +684,14 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
|
||||
dst[i] = fmaxf(x[i], 0);
|
||||
}
|
||||
|
||||
static __global__ void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
|
||||
}
|
||||
|
||||
static __global__ void sqr_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
@ -688,6 +737,132 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void concat_f32(const float *x,const float *y, float *dst, const int ne0, const int ne02) {
|
||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
}
|
||||
// operation
|
||||
int offset_dst =
|
||||
nidx +
|
||||
blockIdx.y * ne0 +
|
||||
blockIdx.z * ne0 * gridDim.y;
|
||||
if (blockIdx.z < ne02) { // src0
|
||||
int offset_src =
|
||||
nidx +
|
||||
blockIdx.y * ne0 +
|
||||
blockIdx.z * ne0 * gridDim.y;
|
||||
dst[offset_dst] = x[offset_src];
|
||||
} else {
|
||||
int offset_src =
|
||||
nidx +
|
||||
blockIdx.y * ne0 +
|
||||
(blockIdx.z - ne02) * ne0 * gridDim.y;
|
||||
dst[offset_dst] = y[offset_src];
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void upscale_f32(const float *x, float *dst, const int ne00, const int nb02, const int scale_factor) {
|
||||
int ne0 = ne00 * scale_factor;
|
||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
}
|
||||
// operation
|
||||
int i00 = nidx / scale_factor;
|
||||
int i01 = blockIdx.y / scale_factor;
|
||||
int offset_src =
|
||||
i00 +
|
||||
i01 * ne00 +
|
||||
blockIdx.z * nb02;
|
||||
int offset_dst =
|
||||
nidx +
|
||||
blockIdx.y * ne0 +
|
||||
blockIdx.z * ne0 * gridDim.y;
|
||||
dst[offset_dst] = x[offset_src];
|
||||
}
|
||||
|
||||
static __global__ void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02) {
|
||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// operation
|
||||
int offset_dst =
|
||||
nidx +
|
||||
blockIdx.y * ne0 +
|
||||
blockIdx.z * ne0 * gridDim.y;
|
||||
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02) {
|
||||
int offset_src =
|
||||
nidx +
|
||||
blockIdx.y * ne00 +
|
||||
blockIdx.z * ne00 * ne01;
|
||||
dst[offset_dst] = x[offset_src];
|
||||
} else {
|
||||
dst[offset_dst] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
|
||||
int start = blockIdx.x * group_size;
|
||||
int end = start + group_size;
|
||||
|
||||
start += threadIdx.x;
|
||||
|
||||
if (end >= ne_elements) {
|
||||
end = ne_elements;
|
||||
}
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int j = start; j < end; j += block_size) {
|
||||
tmp += x[j];
|
||||
}
|
||||
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if (block_size > WARP_SIZE) {
|
||||
__shared__ float s_sum[32];
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
__syncthreads();
|
||||
tmp = s_sum[lane_id];
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
}
|
||||
|
||||
float mean = tmp / group_size;
|
||||
tmp = 0.0f;
|
||||
|
||||
for (int j = start; j < end; j += block_size) {
|
||||
float xi = x[j] - mean;
|
||||
dst[j] = xi;
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if (block_size > WARP_SIZE) {
|
||||
__shared__ float s_sum[32];
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
__syncthreads();
|
||||
tmp = s_sum[lane_id];
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
}
|
||||
|
||||
float variance = tmp / group_size;
|
||||
float scale = rsqrtf(variance + eps);
|
||||
for (int j = start; j < end; j += block_size) {
|
||||
dst[j] *= scale;
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
@ -5071,19 +5246,30 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
|
||||
|
||||
static __global__ void im2col_f32_f16(
|
||||
const float * x, half * dst,
|
||||
int ofs0, int ofs1, int IW, int IH, int CHW,
|
||||
int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW,
|
||||
int s0, int s1, int p0, int p1, int d0, int d1) {
|
||||
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
|
||||
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
|
||||
const int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (i >= pelements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ksize = OW * (KH > 1 ? KW : 1);
|
||||
const int kx = i / ksize;
|
||||
const int kd = kx * ksize;
|
||||
const int ky = (i - kd) / OW;
|
||||
const int ix = i % OW;
|
||||
|
||||
const int iiw = ix * s0 + kx * d0 - p0;
|
||||
const int iih = blockIdx.y * s1 + ky * d1 - p1;
|
||||
|
||||
const int offset_dst =
|
||||
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
|
||||
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
|
||||
(blockIdx.y * OW + ix) * CHW +
|
||||
(blockIdx.z * (KW * KH) + ky * KW + kx);
|
||||
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
dst[offset_dst] = __float2half(0.0f);
|
||||
} else {
|
||||
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
|
||||
const int offset_src = blockIdx.z * offset_delta;
|
||||
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
|
||||
}
|
||||
}
|
||||
@ -5220,10 +5406,10 @@ struct bin_bcast_cuda {
|
||||
size_t nb12 = cnb1[2];
|
||||
size_t nb13 = cnb1[3];
|
||||
|
||||
size_t s0 = nb0 / sizeof(src1_t);
|
||||
size_t s1 = nb1 / sizeof(src1_t);
|
||||
size_t s2 = nb2 / sizeof(src1_t);
|
||||
size_t s3 = nb3 / sizeof(src1_t);
|
||||
size_t s0 = nb0 / sizeof(dst_t);
|
||||
size_t s1 = nb1 / sizeof(dst_t);
|
||||
size_t s2 = nb2 / sizeof(dst_t);
|
||||
size_t s3 = nb3 / sizeof(dst_t);
|
||||
|
||||
size_t s10 = nb10 / sizeof(src1_t);
|
||||
size_t s11 = nb11 / sizeof(src1_t);
|
||||
@ -5269,6 +5455,13 @@ struct bin_bcast_cuda {
|
||||
}
|
||||
};
|
||||
|
||||
static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
|
||||
const int ne10, const int ne11, const int ne12,
|
||||
const int nb1, const int nb2, const int offset, cudaStream_t stream) {
|
||||
int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
|
||||
acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
|
||||
}
|
||||
|
||||
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
||||
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
@ -5279,11 +5472,26 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
|
||||
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
||||
gelu_quick_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
|
||||
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
||||
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
||||
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
|
||||
}
|
||||
|
||||
static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
|
||||
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
@ -5300,6 +5508,38 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
|
||||
}
|
||||
}
|
||||
|
||||
static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
|
||||
static const float eps = 1e-6f;
|
||||
if (group_size < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
|
||||
}
|
||||
}
|
||||
|
||||
static void concat_f32_cuda(const float * x, const float * y, float * dst, const int ne0, int ne1, int ne2, int ne02, cudaStream_t stream) {
|
||||
int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
|
||||
dim3 gridDim(num_blocks, ne1, ne2);
|
||||
concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
|
||||
}
|
||||
|
||||
static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int scale_factor, cudaStream_t stream) {
|
||||
int ne0 = (ne00 * scale_factor);
|
||||
int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
||||
dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02);
|
||||
upscale_f32<<<gridDim, CUDA_UPSCALE_BLOCK_SIZE, 0, stream>>>(x, dst, ne00, ne00 * ne01, scale_factor);
|
||||
}
|
||||
|
||||
static void pad_f32_cuda(const float * x, float * dst,
|
||||
const int ne00, const int ne01, const int ne02,
|
||||
const int ne0, const int ne1, const int ne2, cudaStream_t stream) {
|
||||
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
||||
dim3 gridDim(num_blocks, ne1, ne2);
|
||||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02);
|
||||
}
|
||||
|
||||
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
if (ncols < 1024) {
|
||||
@ -6263,12 +6503,13 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
|
||||
}
|
||||
|
||||
static void im2col_f32_f16_cuda(const float* x, half* dst,
|
||||
int OH, int IW, int IH, int OW, int IC,
|
||||
int KH, int KW, int N, int ofs0, int ofs1,
|
||||
int IW, int IH, int OW, int OH, int KW, int KH, int IC,
|
||||
int offset_delta,
|
||||
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
|
||||
dim3 block_nums(IC, OH, OW);
|
||||
dim3 block_dims(N, KH, KW);
|
||||
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
||||
const int parallel_elements = OW * KW * KH;
|
||||
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
|
||||
dim3 block_nums(num_blocks, OH, IC);
|
||||
im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
||||
}
|
||||
|
||||
// buffer pool for cuda
|
||||
@ -6615,6 +6856,25 @@ inline void ggml_cuda_op_add(
|
||||
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_acc(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
|
||||
|
||||
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
|
||||
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
|
||||
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
|
||||
int offset = dst->op_params[3] / 4; // offset in bytes
|
||||
|
||||
acc_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream);
|
||||
|
||||
(void) dst;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_mul(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
@ -6657,6 +6917,34 @@ inline void ggml_cuda_op_silu(
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_gelu_quick(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
gelu_quick_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_tanh(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
tanh_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_relu(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
@ -6671,6 +6959,23 @@ inline void ggml_cuda_op_relu(
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_leaky_relu(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
float negative_slope;
|
||||
memcpy(&negative_slope, dst->op_params, sizeof(float));
|
||||
|
||||
leaky_relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_sqr(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
@ -6705,6 +7010,71 @@ inline void ggml_cuda_op_norm(
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
|
||||
inline void ggml_cuda_op_group_norm(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
int num_groups = dst->op_params[0];
|
||||
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
|
||||
group_norm_f32_cuda(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_concat(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
|
||||
concat_f32_cuda(src0_dd + i3 * (src0->nb[3] / 4), src1_dd + i3 * (src1->nb[3] / 4), dst_dd + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], main_stream);
|
||||
}
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_upscale(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
||||
|
||||
const int scale_factor = dst->op_params[0];
|
||||
|
||||
upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_pad(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
||||
|
||||
pad_f32_cuda(src0_dd, dst_dd,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_rms_norm(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
@ -7219,7 +7589,6 @@ inline void ggml_cuda_op_im2col(
|
||||
|
||||
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
|
||||
|
||||
const int64_t N = src1->ne[is_2D ? 3 : 2];
|
||||
const int64_t IC = src1->ne[is_2D ? 2 : 1];
|
||||
const int64_t IH = is_2D ? src1->ne[1] : 1;
|
||||
const int64_t IW = src1->ne[0];
|
||||
@ -7230,17 +7599,15 @@ inline void ggml_cuda_op_im2col(
|
||||
const int64_t OH = is_2D ? dst->ne[2] : 1;
|
||||
const int64_t OW = dst->ne[1];
|
||||
|
||||
const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
|
||||
const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
||||
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
||||
|
||||
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
|
||||
OH, IW, IH, OW, IC, KH, KW, N,
|
||||
ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
|
||||
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
||||
|
||||
(void) src0;
|
||||
(void) src0_dd;
|
||||
}
|
||||
|
||||
|
||||
inline void ggml_cuda_op_sum_rows(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
@ -7789,6 +8156,10 @@ static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, gg
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
|
||||
}
|
||||
|
||||
static void ggml_cuda_acc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_acc);
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
|
||||
}
|
||||
@ -7805,10 +8176,22 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
|
||||
}
|
||||
|
||||
static void ggml_cuda_gelu_quick(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu_quick);
|
||||
}
|
||||
|
||||
static void ggml_cuda_tanh(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_tanh);
|
||||
}
|
||||
|
||||
static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
|
||||
}
|
||||
|
||||
static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu);
|
||||
}
|
||||
|
||||
static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
|
||||
}
|
||||
@ -7817,6 +8200,22 @@ static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, g
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
|
||||
}
|
||||
|
||||
static void ggml_cuda_group_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_group_norm);
|
||||
}
|
||||
|
||||
static void ggml_cuda_concat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_concat);
|
||||
}
|
||||
|
||||
static void ggml_cuda_upscale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_upscale);
|
||||
}
|
||||
|
||||
static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pad);
|
||||
}
|
||||
|
||||
static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
|
||||
}
|
||||
@ -8809,6 +9208,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||
case GGML_OP_ADD:
|
||||
func = ggml_cuda_add;
|
||||
break;
|
||||
case GGML_OP_ACC:
|
||||
func = ggml_cuda_acc;
|
||||
break;
|
||||
case GGML_OP_MUL:
|
||||
func = ggml_cuda_mul;
|
||||
break;
|
||||
@ -8823,6 +9225,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||
case GGML_UNARY_OP_SILU:
|
||||
func = ggml_cuda_silu;
|
||||
break;
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
func = ggml_cuda_gelu_quick;
|
||||
break;
|
||||
case GGML_UNARY_OP_TANH:
|
||||
func = ggml_cuda_tanh;
|
||||
break;
|
||||
case GGML_UNARY_OP_RELU:
|
||||
func = ggml_cuda_relu;
|
||||
break;
|
||||
@ -8833,6 +9241,21 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||
case GGML_OP_NORM:
|
||||
func = ggml_cuda_norm;
|
||||
break;
|
||||
case GGML_OP_GROUP_NORM:
|
||||
func = ggml_cuda_group_norm;
|
||||
break;
|
||||
case GGML_OP_CONCAT:
|
||||
func = ggml_cuda_concat;
|
||||
break;
|
||||
case GGML_OP_UPSCALE:
|
||||
func = ggml_cuda_upscale;
|
||||
break;
|
||||
case GGML_OP_PAD:
|
||||
func = ggml_cuda_pad;
|
||||
break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
func = ggml_cuda_leaky_relu;
|
||||
break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
func = ggml_cuda_rms_norm;
|
||||
break;
|
||||
@ -8855,9 +9278,6 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||
func = ggml_cuda_sqr;
|
||||
break;
|
||||
case GGML_OP_CLAMP:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cuda_clamp;
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
@ -8866,6 +9286,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||
case GGML_OP_CONT:
|
||||
func = ggml_cuda_dup;
|
||||
break;
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
@ -9285,6 +9706,8 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@ -9369,6 +9792,12 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
265
ggml-metal.m
265
ggml-metal.m
@ -66,9 +66,11 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(div_row);
|
||||
GGML_METAL_DECL_KERNEL(scale);
|
||||
GGML_METAL_DECL_KERNEL(scale_4);
|
||||
GGML_METAL_DECL_KERNEL(silu);
|
||||
GGML_METAL_DECL_KERNEL(tanh);
|
||||
GGML_METAL_DECL_KERNEL(relu);
|
||||
GGML_METAL_DECL_KERNEL(gelu);
|
||||
GGML_METAL_DECL_KERNEL(gelu_quick);
|
||||
GGML_METAL_DECL_KERNEL(silu);
|
||||
GGML_METAL_DECL_KERNEL(soft_max);
|
||||
GGML_METAL_DECL_KERNEL(soft_max_4);
|
||||
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
||||
@ -86,6 +88,7 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||
GGML_METAL_DECL_KERNEL(group_norm);
|
||||
GGML_METAL_DECL_KERNEL(norm);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
||||
@ -145,8 +148,11 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(rope_f16);
|
||||
GGML_METAL_DECL_KERNEL(alibi_f32);
|
||||
GGML_METAL_DECL_KERNEL(im2col_f16);
|
||||
GGML_METAL_DECL_KERNEL(upscale_f32);
|
||||
GGML_METAL_DECL_KERNEL(pad_f32);
|
||||
GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
|
||||
GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
|
||||
GGML_METAL_DECL_KERNEL(leaky_relu_f32);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
||||
@ -334,9 +340,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(div_row);
|
||||
GGML_METAL_ADD_KERNEL(scale);
|
||||
GGML_METAL_ADD_KERNEL(scale_4);
|
||||
GGML_METAL_ADD_KERNEL(silu);
|
||||
GGML_METAL_ADD_KERNEL(tanh);
|
||||
GGML_METAL_ADD_KERNEL(relu);
|
||||
GGML_METAL_ADD_KERNEL(gelu);
|
||||
GGML_METAL_ADD_KERNEL(gelu_quick);
|
||||
GGML_METAL_ADD_KERNEL(silu);
|
||||
GGML_METAL_ADD_KERNEL(soft_max);
|
||||
GGML_METAL_ADD_KERNEL(soft_max_4);
|
||||
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
||||
@ -354,6 +362,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||
GGML_METAL_ADD_KERNEL(group_norm);
|
||||
GGML_METAL_ADD_KERNEL(norm);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
|
||||
@ -415,8 +424,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(rope_f16);
|
||||
GGML_METAL_ADD_KERNEL(alibi_f32);
|
||||
GGML_METAL_ADD_KERNEL(im2col_f16);
|
||||
GGML_METAL_ADD_KERNEL(upscale_f32);
|
||||
GGML_METAL_ADD_KERNEL(pad_f32);
|
||||
GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
|
||||
GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
|
||||
GGML_METAL_ADD_KERNEL(leaky_relu_f32);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
|
||||
@ -450,9 +462,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(div_row);
|
||||
GGML_METAL_DEL_KERNEL(scale);
|
||||
GGML_METAL_DEL_KERNEL(scale_4);
|
||||
GGML_METAL_DEL_KERNEL(silu);
|
||||
GGML_METAL_DEL_KERNEL(tanh);
|
||||
GGML_METAL_DEL_KERNEL(relu);
|
||||
GGML_METAL_DEL_KERNEL(gelu);
|
||||
GGML_METAL_DEL_KERNEL(gelu_quick);
|
||||
GGML_METAL_DEL_KERNEL(silu);
|
||||
GGML_METAL_DEL_KERNEL(soft_max);
|
||||
GGML_METAL_DEL_KERNEL(soft_max_4);
|
||||
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
||||
@ -470,6 +484,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
||||
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
||||
GGML_METAL_DEL_KERNEL(rms_norm);
|
||||
GGML_METAL_DEL_KERNEL(group_norm);
|
||||
GGML_METAL_DEL_KERNEL(norm);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
||||
@ -531,8 +546,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(rope_f16);
|
||||
GGML_METAL_DEL_KERNEL(alibi_f32);
|
||||
GGML_METAL_DEL_KERNEL(im2col_f16);
|
||||
GGML_METAL_DEL_KERNEL(upscale_f32);
|
||||
GGML_METAL_DEL_KERNEL(pad_f32);
|
||||
GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
|
||||
GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
|
||||
GGML_METAL_DEL_KERNEL(leaky_relu_f32);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
||||
@ -843,9 +861,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
||||
switch (op->op) {
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@ -853,11 +873,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_GET_ROWS:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_SCALE:
|
||||
@ -865,11 +885,15 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_ALIBI:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return true;
|
||||
@ -902,8 +926,9 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
||||
};
|
||||
}
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
return op->ne[0] % 4 == 0;
|
||||
return op->ne[3] == 1;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
@ -979,7 +1004,10 @@ void ggml_metal_graph_compute(
|
||||
} break;
|
||||
}
|
||||
|
||||
GGML_ASSERT(ggml_metal_supports_op(dst));
|
||||
if (!ggml_metal_supports_op(dst)) {
|
||||
GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
||||
GGML_ASSERT(!"unsupported op");
|
||||
}
|
||||
|
||||
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
||||
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
||||
@ -1076,6 +1104,8 @@ void ggml_metal_graph_compute(
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
{
|
||||
const size_t offs = 0;
|
||||
|
||||
bool bcast_row = false;
|
||||
|
||||
int64_t nb = ne00;
|
||||
@ -1134,7 +1164,8 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
||||
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
||||
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
||||
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
||||
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
||||
|
||||
if (bcast_row) {
|
||||
const int64_t n = ggml_nelements(dst)/4;
|
||||
@ -1146,6 +1177,86 @@ void ggml_metal_graph_compute(
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_ACC:
|
||||
{
|
||||
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dstt == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
|
||||
const size_t pnb1 = ((int32_t *) dst->op_params)[0];
|
||||
const size_t pnb2 = ((int32_t *) dst->op_params)[1];
|
||||
const size_t pnb3 = ((int32_t *) dst->op_params)[2];
|
||||
const size_t offs = ((int32_t *) dst->op_params)[3];
|
||||
|
||||
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
||||
|
||||
if (!inplace) {
|
||||
// run a separete kernel to cpy src->dst
|
||||
// not sure how to avoid this
|
||||
// TODO: make a simpler cpy_bytes kernel
|
||||
|
||||
const int nth = MIN(1024, ne00);
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_add];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
||||
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
|
||||
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
|
||||
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
||||
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
||||
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
||||
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
||||
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
|
||||
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
|
||||
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
||||
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
||||
|
||||
const int nth = MIN(1024, ne0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SCALE:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
@ -1169,16 +1280,15 @@ void ggml_metal_graph_compute(
|
||||
} break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(gf->nodes[i])) {
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
{
|
||||
[encoder setComputePipelineState:ctx->pipeline_silu];
|
||||
[encoder setComputePipelineState:ctx->pipeline_tanh];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
GGML_ASSERT(n % 4 == 0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_RELU:
|
||||
{
|
||||
@ -1199,6 +1309,28 @@ void ggml_metal_graph_compute(
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
GGML_ASSERT(n % 4 == 0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
{
|
||||
[encoder setComputePipelineState:ctx->pipeline_gelu_quick];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
GGML_ASSERT(n % 4 == 0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_SILU:
|
||||
{
|
||||
[encoder setComputePipelineState:ctx->pipeline_silu];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
GGML_ASSERT(n % 4 == 0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
@ -1837,6 +1969,38 @@ void ggml_metal_graph_compute(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_GROUP_NORM:
|
||||
{
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
|
||||
//float eps;
|
||||
//memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
const float eps = 1e-6f; // TODO: temporarily hardcoded
|
||||
|
||||
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
//while (nth < ne00/4 && nth < 1024) {
|
||||
// nth *= 2;
|
||||
//}
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_group_norm];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
||||
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
||||
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_NORM:
|
||||
{
|
||||
float eps;
|
||||
@ -2006,6 +2170,65 @@ void ggml_metal_graph_compute(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
||||
} break;
|
||||
case GGML_OP_UPSCALE:
|
||||
{
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
|
||||
const int sf = dst->op_params[0];
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_upscale_f32];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
||||
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
||||
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
||||
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
||||
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
||||
|
||||
const int nth = MIN(1024, ne0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_PAD:
|
||||
{
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_pad_f32];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
||||
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
||||
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
||||
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
||||
|
||||
const int nth = MIN(1024, ne0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
@ -2027,6 +2250,22 @@ void ggml_metal_graph_compute(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
{
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
|
||||
float slope;
|
||||
memcpy(&slope, dst->op_params, sizeof(float));
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
|
286
ggml-metal.metal
286
ggml-metal.metal
@ -79,6 +79,7 @@ kernel void kernel_add(
|
||||
constant int64_t & nb1,
|
||||
constant int64_t & nb2,
|
||||
constant int64_t & nb3,
|
||||
constant int64_t & offs,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
@ -90,9 +91,9 @@ kernel void kernel_add(
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i11 = i01 % ne11;
|
||||
|
||||
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
||||
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
|
||||
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
||||
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
||||
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||||
const int i10 = i0 % ne10;
|
||||
@ -204,7 +205,7 @@ kernel void kernel_add_row(
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
constant int64_t & nb [[buffer(27)]],
|
||||
constant int64_t & nb [[buffer(28)]],
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
||||
}
|
||||
@ -213,7 +214,7 @@ kernel void kernel_mul_row(
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
constant int64_t & nb [[buffer(27)]],
|
||||
constant int64_t & nb [[buffer(28)]],
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
||||
}
|
||||
@ -222,7 +223,7 @@ kernel void kernel_div_row(
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
constant int64_t & nb [[buffer(27)]],
|
||||
constant int64_t & nb [[buffer(28)]],
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
||||
}
|
||||
@ -243,6 +244,47 @@ kernel void kernel_scale_4(
|
||||
dst[tpig] = src0[tpig] * scale;
|
||||
}
|
||||
|
||||
kernel void kernel_relu(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = max(0.0f, src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_tanh(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
dst[tpig] = precise::tanh(x);
|
||||
}
|
||||
|
||||
constant float GELU_COEF_A = 0.044715f;
|
||||
constant float GELU_QUICK_COEF = -1.702f;
|
||||
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
kernel void kernel_gelu(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float4 & x = src0[tpig];
|
||||
|
||||
// BEWARE !!!
|
||||
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
||||
// This was observed with Falcon 7B and 40B models
|
||||
//
|
||||
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_quick(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float4 & x = src0[tpig];
|
||||
|
||||
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_silu(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
@ -251,13 +293,6 @@ kernel void kernel_silu(
|
||||
dst[tpig] = x / (1.0f + exp(-x));
|
||||
}
|
||||
|
||||
kernel void kernel_relu(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = max(0.0f, src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sqr(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
@ -313,22 +348,6 @@ kernel void kernel_sum_rows(
|
||||
dst_row[0] = row_sum;
|
||||
}
|
||||
|
||||
constant float GELU_COEF_A = 0.044715f;
|
||||
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
kernel void kernel_gelu(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float4 & x = src0[tpig];
|
||||
|
||||
// BEWARE !!!
|
||||
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
||||
// This was observed with Falcon 7B and 40B models
|
||||
//
|
||||
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_soft_max(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
@ -650,6 +669,94 @@ kernel void kernel_rms_norm(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_group_norm(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int32_t & n_groups,
|
||||
constant float & eps,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t ne = ne00*ne01*ne02;
|
||||
const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
|
||||
|
||||
int start = tgpig * gs;
|
||||
int end = start + gs;
|
||||
|
||||
start += tpitg;
|
||||
|
||||
if (end >= ne) {
|
||||
end = ne;
|
||||
}
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int j = start; j < end; j += ntg) {
|
||||
tmp += src0[j];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
tmp = simd_sum(tmp);
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = tmp;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
tmp = buf[tiisg];
|
||||
tmp = simd_sum(tmp);
|
||||
}
|
||||
|
||||
const float mean = tmp / gs;
|
||||
tmp = 0.0f;
|
||||
|
||||
for (int j = start; j < end; j += ntg) {
|
||||
float xi = src0[j] - mean;
|
||||
dst[j] = xi;
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
tmp = simd_sum(tmp);
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = tmp;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
tmp = buf[tiisg];
|
||||
tmp = simd_sum(tmp);
|
||||
}
|
||||
|
||||
const float variance = tmp / gs;
|
||||
const float scale = 1.0f/sqrt(variance + eps);
|
||||
for (int j = start; j < end; j += ntg) {
|
||||
dst[j] *= scale;
|
||||
}
|
||||
}
|
||||
|
||||
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
||||
// we assume that the yl's have been multiplied with the appropriate scale factor
|
||||
@ -1656,6 +1763,97 @@ kernel void kernel_im2col_f16(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_upscale_f32(
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
constant int32_t & sf,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i1 = tgpig.x;
|
||||
|
||||
const int64_t i03 = i3;
|
||||
const int64_t i02 = i2;
|
||||
const int64_t i01 = i1/sf;
|
||||
|
||||
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
||||
|
||||
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||||
dst_ptr[i0] = src0_ptr[i0/sf];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_pad_f32(
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i1 = tgpig.x;
|
||||
|
||||
const int64_t i03 = i3;
|
||||
const int64_t i02 = i2;
|
||||
const int64_t i01 = i1;
|
||||
|
||||
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
||||
|
||||
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
||||
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||||
if (i0 < ne00) {
|
||||
dst_ptr[i0] = src0_ptr[i0];
|
||||
} else {
|
||||
dst_ptr[i0] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||||
dst_ptr[i0] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// bitonic sort implementation following the CUDA kernels as reference
|
||||
typedef void (argsort_t)(
|
||||
device const float * x,
|
||||
@ -1708,6 +1906,14 @@ kernel void kernel_argsort_f32_i32(
|
||||
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
|
||||
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
||||
|
||||
kernel void kernel_leaky_relu_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant float & slope,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f16_f16(
|
||||
device const half * src0,
|
||||
device half * dst,
|
||||
@ -3315,10 +3521,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
||||
const half d = xb->d;
|
||||
const half min = xb->dmin;
|
||||
const float d = xb->d;
|
||||
const float min = xb->dmin;
|
||||
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
||||
half dl, ml;
|
||||
float dl, ml;
|
||||
uint8_t sc = xb->scales[il];
|
||||
|
||||
#if QK_K == 256
|
||||
@ -3388,10 +3594,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
||||
q = q + (il/4) * 32 + 16 * (il&1);
|
||||
il = il & 3;
|
||||
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
||||
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
||||
const half min = xb->dmin;
|
||||
const half dl = d * sc[0];
|
||||
const half ml = min * sc[1];
|
||||
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
||||
const float min = xb->dmin;
|
||||
const float dl = d * sc[0];
|
||||
const float ml = min * sc[1];
|
||||
#else
|
||||
q = q + 16 * (il&1);
|
||||
device const uint8_t * s = xb->scales;
|
||||
@ -3418,13 +3624,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
||||
uint8_t ul = 1 << (il/2);
|
||||
il = il & 3;
|
||||
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
||||
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
||||
const half min = xb->dmin;
|
||||
const half dl = d * sc[0];
|
||||
const half ml = min * sc[1];
|
||||
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
||||
const float min = xb->dmin;
|
||||
const float dl = d * sc[0];
|
||||
const float ml = min * sc[1];
|
||||
|
||||
const ushort mask = il<2 ? 0x0F : 0xF0;
|
||||
const half qh_val = il<2 ? 16.h : 256.h;
|
||||
const float qh_val = il<2 ? 16.f : 256.f;
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
||||
}
|
||||
|
181
ggml.c
181
ggml.c
@ -1395,7 +1395,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) {
|
||||
inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
|
||||
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
|
||||
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
|
||||
inline static void ggml_vec_leaky_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.1f*x[i]; }
|
||||
inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
|
||||
|
||||
static const float GELU_COEF_A = 0.044715f;
|
||||
static const float GELU_QUICK_COEF = -1.702f;
|
||||
@ -1623,7 +1623,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"POOL_1D",
|
||||
"POOL_2D",
|
||||
"UPSCALE",
|
||||
"PAD",
|
||||
"ARGSORT",
|
||||
"LEAKY_RELU",
|
||||
|
||||
"FLASH_ATTN",
|
||||
"FLASH_FF",
|
||||
@ -1650,7 +1652,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"CROSS_ENTROPY_LOSS_BACK",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
|
||||
static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@ -1707,7 +1709,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"pool_1d(x)",
|
||||
"pool_2d(x)",
|
||||
"upscale(x)",
|
||||
"pad(x)",
|
||||
"argsort(x)",
|
||||
"leaky_relu(x)",
|
||||
|
||||
"flash_attn(x)",
|
||||
"flash_ff(x)",
|
||||
@ -1734,7 +1738,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"cross_entropy_loss_back(x,y)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
|
||||
static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@ -1750,10 +1754,9 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
||||
"GELU",
|
||||
"GELU_QUICK",
|
||||
"SILU",
|
||||
"LEAKY",
|
||||
};
|
||||
|
||||
static_assert(GGML_UNARY_OP_COUNT == 11, "GGML_UNARY_OP_COUNT != 11");
|
||||
static_assert(GGML_UNARY_OP_COUNT == 10, "GGML_UNARY_OP_COUNT != 10");
|
||||
|
||||
|
||||
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
||||
@ -3830,12 +3833,25 @@ struct ggml_tensor * ggml_relu_inplace(
|
||||
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
|
||||
}
|
||||
|
||||
// ggml_leaky
|
||||
// ggml_leaky_relu
|
||||
|
||||
struct ggml_tensor * ggml_leaky(
|
||||
struct ggml_tensor * ggml_leaky_relu(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary(ctx, a, GGML_UNARY_OP_LEAKY);
|
||||
struct ggml_tensor * a, float negative_slope, bool inplace) {
|
||||
bool is_node = false;
|
||||
|
||||
if (!inplace && (a->grad)) {
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
|
||||
|
||||
result->op = GGML_OP_LEAKY_RELU;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_gelu
|
||||
@ -4022,8 +4038,9 @@ static struct ggml_tensor * ggml_group_norm_impl(
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_GROUP_NORM;
|
||||
result->op_params[0] = n_groups;
|
||||
|
||||
result->op = GGML_OP_GROUP_NORM;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
result->src[1] = NULL; // TODO: maybe store epsilon here?
|
||||
@ -5523,6 +5540,30 @@ static struct ggml_tensor * ggml_upscale_impl(
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_pad(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int p0, int p1, int p2, int p3) {
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
GGML_ASSERT(false); // TODO: implement backward
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
|
||||
a->ne[0] + p0,
|
||||
a->ne[1] + p1,
|
||||
a->ne[2] + p2,
|
||||
a->ne[3] + p3);
|
||||
|
||||
result->op = GGML_OP_PAD;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_upscale(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@ -7718,8 +7759,10 @@ static void ggml_compute_forward_mul_f32(
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
// TODO: OpenCL kernel support broadcast
|
||||
#ifdef GGML_USE_CLBLAST
|
||||
if (src1->backend == GGML_BACKEND_GPU) {
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
if (ith == 0) {
|
||||
ggml_cl_mul(src0, src1, dst);
|
||||
}
|
||||
@ -8985,10 +9028,9 @@ static void ggml_compute_forward_silu(
|
||||
} break;
|
||||
}
|
||||
}
|
||||
// ggml_compute_forward_leaky_relu
|
||||
|
||||
// ggml_compute_forward_leaky
|
||||
|
||||
static void ggml_compute_forward_leaky_f32(
|
||||
static void ggml_compute_forward_leaky_relu_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
@ -9002,24 +9044,27 @@ static void ggml_compute_forward_leaky_f32(
|
||||
const int n = ggml_nrows(src0);
|
||||
const int nc = src0->ne[0];
|
||||
|
||||
float negative_slope;
|
||||
memcpy(&negative_slope, dst->op_params, sizeof(float));
|
||||
|
||||
assert(dst->nb[0] == sizeof(float));
|
||||
assert(src0->nb[0] == sizeof(float));
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
ggml_vec_leaky_f32(nc,
|
||||
ggml_vec_leaky_relu_f32(nc,
|
||||
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
||||
(float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_leaky(
|
||||
static void ggml_compute_forward_leaky_relu(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_leaky_f32(params, src0, dst);
|
||||
ggml_compute_forward_leaky_relu_f32(params, src0, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
@ -12158,6 +12203,7 @@ static void ggml_compute_forward_upscale_f32(
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
@ -12165,16 +12211,17 @@ static void ggml_compute_forward_upscale_f32(
|
||||
|
||||
// TODO: optimize
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = ith; i02 < ne02; i02++) {
|
||||
for (int m = 0; m < dst->ne[1]; m++) {
|
||||
int i01 = m / scale_factor;
|
||||
for (int n = 0; n < dst->ne[0]; n++) {
|
||||
int i00 = n / scale_factor;
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
const int64_t i03 = i3;
|
||||
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
||||
const int64_t i02 = i2;
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
const int64_t i01 = i1 / scale_factor;
|
||||
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
||||
const int64_t i00 = i0 / scale_factor;
|
||||
|
||||
const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
float * y = (float *)((char *) dst->data + n * dst->nb[0] + m * dst->nb[1] + i02 * dst->nb[2] + i03 * dst->nb[3]);
|
||||
float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
|
||||
*y = *x;
|
||||
}
|
||||
@ -12199,6 +12246,64 @@ static void ggml_compute_forward_upscale(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_pad
|
||||
|
||||
static void ggml_compute_forward_pad_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT( dst->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
float * dst_ptr = (float *) dst->data;
|
||||
|
||||
// TODO: optimize
|
||||
|
||||
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
||||
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
|
||||
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||
|
||||
const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
|
||||
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
||||
dst_ptr[dst_idx] = *src_ptr;
|
||||
} else {
|
||||
dst_ptr[dst_idx] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_pad(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_pad_f32(params, src0, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_argsort
|
||||
|
||||
static void ggml_compute_forward_argsort_f32(
|
||||
@ -13406,10 +13511,6 @@ static void ggml_compute_forward_unary(
|
||||
{
|
||||
ggml_compute_forward_silu(params, src0, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_LEAKY:
|
||||
{
|
||||
ggml_compute_forward_leaky(params, src0, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
@ -14191,10 +14292,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_upscale(params, tensor->src[0], tensor);
|
||||
} break;
|
||||
case GGML_OP_PAD:
|
||||
{
|
||||
ggml_compute_forward_pad(params, tensor->src[0], tensor);
|
||||
} break;
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
ggml_compute_forward_argsort(params, tensor->src[0], tensor);
|
||||
} break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
{
|
||||
ggml_compute_forward_leaky_relu(params, tensor->src[0], tensor);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN:
|
||||
{
|
||||
const int32_t t = ggml_get_op_params_i32(tensor, 0);
|
||||
@ -15187,10 +15296,18 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_PAD:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN:
|
||||
{
|
||||
struct ggml_tensor * flash_grad = NULL;
|
||||
@ -15796,6 +15913,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
{
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
@ -15808,7 +15926,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_LEAKY:
|
||||
{
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
@ -15927,6 +16044,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_PAD:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
|
20
ggml.h
20
ggml.h
@ -423,7 +423,9 @@ extern "C" {
|
||||
GGML_OP_POOL_1D,
|
||||
GGML_OP_POOL_2D,
|
||||
GGML_OP_UPSCALE, // nearest interpolate
|
||||
GGML_OP_PAD,
|
||||
GGML_OP_ARGSORT,
|
||||
GGML_OP_LEAKY_RELU,
|
||||
|
||||
GGML_OP_FLASH_ATTN,
|
||||
GGML_OP_FLASH_FF,
|
||||
@ -463,7 +465,6 @@ extern "C" {
|
||||
GGML_UNARY_OP_GELU,
|
||||
GGML_UNARY_OP_GELU_QUICK,
|
||||
GGML_UNARY_OP_SILU,
|
||||
GGML_UNARY_OP_LEAKY,
|
||||
|
||||
GGML_UNARY_OP_COUNT,
|
||||
};
|
||||
@ -793,6 +794,9 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
// dst = a
|
||||
// view(dst, nb1, nb2, nb3, offset) += b
|
||||
// return dst
|
||||
GGML_API struct ggml_tensor * ggml_acc(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@ -957,15 +961,14 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_leaky(
|
||||
GGML_API struct ggml_tensor * ggml_leaky_relu(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
struct ggml_tensor * a, float negative_slope, bool inplace);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_relu_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// TODO: double-check this computation is correct
|
||||
GGML_API struct ggml_tensor * ggml_gelu(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
@ -1551,6 +1554,15 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
int scale_factor);
|
||||
|
||||
// pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
|
||||
GGML_API struct ggml_tensor * ggml_pad(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int p0,
|
||||
int p1,
|
||||
int p2,
|
||||
int p3);
|
||||
|
||||
// sort rows
|
||||
enum ggml_sort_order {
|
||||
GGML_SORT_ASC,
|
||||
|
@ -234,6 +234,11 @@ static bool ggml_is_view_op(enum ggml_op op) {
|
||||
return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
|
||||
}
|
||||
|
||||
enum test_mode {
|
||||
MODE_TEST,
|
||||
MODE_PERF,
|
||||
};
|
||||
|
||||
struct test_case {
|
||||
virtual ~test_case() {}
|
||||
|
||||
@ -268,7 +273,58 @@ struct test_case {
|
||||
return size;
|
||||
}
|
||||
|
||||
ggml_cgraph * gf = nullptr;
|
||||
|
||||
static const int sentinel_size = 1024;
|
||||
|
||||
test_mode mode;
|
||||
|
||||
std::vector<ggml_tensor *> sentinels;
|
||||
|
||||
void add_sentinel(ggml_context * ctx) {
|
||||
if (mode == MODE_PERF) {
|
||||
return;
|
||||
}
|
||||
ggml_tensor * sentinel = ::ggml_new_tensor_1d(ctx, GGML_TYPE_F32, sentinel_size);
|
||||
ggml_format_name(sentinel, "sent_%zu", sentinels.size());
|
||||
sentinels.push_back(sentinel);
|
||||
}
|
||||
|
||||
// hijack ggml_new_tensor to add sentinels after each tensor to check for overflows in the backend
|
||||
|
||||
ggml_tensor * ggml_new_tensor(ggml_context * ctx, ggml_type type, int n_dims, const int64_t * ne) {
|
||||
ggml_tensor * t = ::ggml_new_tensor(ctx, type, n_dims, ne);
|
||||
add_sentinel(ctx);
|
||||
return t;
|
||||
}
|
||||
|
||||
ggml_tensor * ggml_new_tensor_1d(ggml_context * ctx, ggml_type type, int64_t ne0) {
|
||||
ggml_tensor * t = ::ggml_new_tensor_1d(ctx, type, ne0);
|
||||
add_sentinel(ctx);
|
||||
return t;
|
||||
}
|
||||
|
||||
ggml_tensor * ggml_new_tensor_2d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1) {
|
||||
ggml_tensor * t = ::ggml_new_tensor_2d(ctx, type, ne0, ne1);
|
||||
add_sentinel(ctx);
|
||||
return t;
|
||||
}
|
||||
|
||||
ggml_tensor * ggml_new_tensor_3d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2) {
|
||||
ggml_tensor * t = ::ggml_new_tensor_3d(ctx, type, ne0, ne1, ne2);
|
||||
add_sentinel(ctx);
|
||||
return t;
|
||||
}
|
||||
|
||||
ggml_tensor * ggml_new_tensor_4d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
|
||||
ggml_tensor * t = ::ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
|
||||
add_sentinel(ctx);
|
||||
return t;
|
||||
}
|
||||
|
||||
bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name) {
|
||||
mode = MODE_TEST;
|
||||
|
||||
ggml_init_params params = {
|
||||
/* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
|
||||
/* .mem_base = */ NULL,
|
||||
@ -276,6 +332,11 @@ struct test_case {
|
||||
};
|
||||
ggml_context * ctx = ggml_init(params);
|
||||
|
||||
gf = ggml_new_graph(ctx);
|
||||
|
||||
// pre-graph sentinel
|
||||
add_sentinel(ctx);
|
||||
|
||||
ggml_tensor * out = build_graph(ctx);
|
||||
|
||||
if (op_name != nullptr && op_desc(out) != op_name) {
|
||||
@ -296,13 +357,20 @@ struct test_case {
|
||||
}
|
||||
}
|
||||
|
||||
// post-graph sentinel
|
||||
add_sentinel(ctx);
|
||||
|
||||
// allocate
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1);
|
||||
|
||||
// build graph
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx);
|
||||
ggml_build_forward_expand(gf, out);
|
||||
|
||||
// add sentinels as graph nodes so that they are checked in the callback
|
||||
for (ggml_tensor * sentinel : sentinels) {
|
||||
gf->nodes[gf->n_nodes++] = sentinel;
|
||||
}
|
||||
|
||||
// randomize tensors
|
||||
initialize_tensors(ctx);
|
||||
|
||||
@ -318,9 +386,24 @@ struct test_case {
|
||||
};
|
||||
|
||||
auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool {
|
||||
callback_userdata * ud = (callback_userdata *) user_data;
|
||||
|
||||
if (t1->op == GGML_OP_NONE) {
|
||||
// sentinels must be unchanged
|
||||
std::vector<uint8_t> t1_data(ggml_nbytes(t1));
|
||||
std::vector<uint8_t> t2_data(ggml_nbytes(t2));
|
||||
ggml_backend_tensor_get(t1, t1_data.data(), 0, ggml_nbytes(t1));
|
||||
ggml_backend_tensor_get(t2, t2_data.data(), 0, ggml_nbytes(t2));
|
||||
|
||||
if (memcmp(t1_data.data(), t2_data.data(), ggml_nbytes(t1)) != 0) {
|
||||
printf("sentinel mismatch: %s ", t1->name);
|
||||
ud->ok = false;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> f1 = tensor_to_float(t1);
|
||||
std::vector<float> f2 = tensor_to_float(t2);
|
||||
callback_userdata * ud = (callback_userdata *) user_data;
|
||||
|
||||
for (size_t i = 0; i < f1.size(); i++) {
|
||||
// check for nans
|
||||
@ -349,9 +432,10 @@ struct test_case {
|
||||
if (err > ud->max_err) {
|
||||
printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
|
||||
//for (int i = 0; i < f1.size(); i++) {
|
||||
// printf("(%f, %f) ", f1[i], f2[i]);
|
||||
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
|
||||
//}
|
||||
//printf("\n");
|
||||
//exit(1);
|
||||
ud->ok = false;
|
||||
}
|
||||
return true;
|
||||
@ -375,6 +459,8 @@ struct test_case {
|
||||
}
|
||||
|
||||
bool eval_perf(ggml_backend_t backend, const char * op_name) {
|
||||
mode = MODE_PERF;
|
||||
|
||||
static const size_t graph_nodes = 8192;
|
||||
|
||||
ggml_init_params params = {
|
||||
@ -1135,6 +1221,118 @@ struct test_sum_rows : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_UPSCALE
|
||||
struct test_upscale : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const int32_t scale_factor;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR3(type, ne, scale_factor);
|
||||
}
|
||||
|
||||
test_upscale(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {512, 512, 3, 1},
|
||||
int32_t scale_factor = 2)
|
||||
: type(type), ne(ne), scale_factor(scale_factor) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_tensor * out = ggml_upscale(ctx, a, scale_factor);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_GROUP_NORM
|
||||
struct test_group_norm : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const int32_t num_groups;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR3(type, ne, num_groups);
|
||||
}
|
||||
|
||||
test_group_norm(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {64, 64, 320, 1},
|
||||
int32_t num_groups = 32)
|
||||
: type(type), ne(ne), num_groups(num_groups) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_tensor * out = ggml_group_norm(ctx, a, num_groups);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_ACC
|
||||
struct test_acc : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne_a;
|
||||
const std::array<int64_t, 4> ne_b;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR3(type, ne_a, ne_b);
|
||||
}
|
||||
|
||||
test_acc(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne_a = {1024, 577, 1, 1},
|
||||
std::array<int64_t, 4> ne_b = {1024, 576, 1, 1})
|
||||
: type(type), ne_a(ne_a), ne_b(ne_b) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
|
||||
ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_PAD
|
||||
struct test_pad : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne_a;
|
||||
const int pad_0;
|
||||
const int pad_1;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
|
||||
}
|
||||
|
||||
test_pad(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne_a = {512, 512, 1, 1},
|
||||
int pad_0 = 1, int pad_1 = 1)
|
||||
: type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
ggml_tensor * out = ggml_pad(ctx, a, pad_0, pad_1, 0, 0);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_LEAKY_RELU
|
||||
struct test_leaky_relu : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne_a;
|
||||
const float negative_slope;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR3(type, ne_a, negative_slope);
|
||||
}
|
||||
|
||||
test_leaky_relu(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
|
||||
float negative_slope = 0.1f)
|
||||
: type(type), ne_a(ne_a), negative_slope(negative_slope) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
ggml_tensor * out = ggml_leaky_relu(ctx, a, negative_slope, true);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// Mixtral MOE
|
||||
struct test_moe : public test_case {
|
||||
const int n_experts;
|
||||
@ -1219,11 +1417,6 @@ struct test_moe : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
enum test_mode {
|
||||
MODE_TEST,
|
||||
MODE_PERF,
|
||||
};
|
||||
|
||||
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
|
||||
std::vector<std::unique_ptr<test_case>> test_cases;
|
||||
|
||||
@ -1372,12 +1565,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, {10, 10, 10, 10}));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, {2, 1, 1, 1}));
|
||||
test_cases.emplace_back(new test_sum_rows());
|
||||
test_cases.emplace_back(new test_upscale());
|
||||
test_cases.emplace_back(new test_group_norm());
|
||||
test_cases.emplace_back(new test_acc());
|
||||
test_cases.emplace_back(new test_pad());
|
||||
test_cases.emplace_back(new test_leaky_relu());
|
||||
|
||||
#if !defined(__SANITIZE_THREAD__)
|
||||
// FIXME: these tests use too much memory with thread sanitizer
|
||||
test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 14336));
|
||||
test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024));
|
||||
//test_cases.emplace_back(new test_moe(8, 2, 8, 4096, 14336));
|
||||
#endif
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user