mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-06 00:20:34 +01:00
CUDA: non-contiguous (RMS) norm support (#11659)
* CUDA: non-contiguous (RMS) norm support --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
3ec9fd4b77
commit
fd08255d0d
@ -38,6 +38,7 @@
|
|||||||
#include "ggml-cuda/upscale.cuh"
|
#include "ggml-cuda/upscale.cuh"
|
||||||
#include "ggml-cuda/wkv6.cuh"
|
#include "ggml-cuda/wkv6.cuh"
|
||||||
#include "ggml-cuda/gla.cuh"
|
#include "ggml-cuda/gla.cuh"
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
@ -3139,6 +3140,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
break;
|
break;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
|
return true;
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
case GGML_OP_RMS_NORM_BACK:
|
||||||
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
|
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
|
||||||
break;
|
break;
|
||||||
@ -3181,7 +3183,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
|
return true;
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
|
return ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
case GGML_OP_ARANGE:
|
case GGML_OP_ARANGE:
|
||||||
|
@ -1,12 +1,20 @@
|
|||||||
#include "norm.cuh"
|
#include "norm.cuh"
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
template <int block_size>
|
template <int block_size>
|
||||||
static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
static __global__ void norm_f32(
|
||||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||||
const int tid = threadIdx.x;
|
const int64_t stride_sample, const float eps) {
|
||||||
|
const int nrows = gridDim.x;
|
||||||
|
const int nchannels = gridDim.y;
|
||||||
|
|
||||||
x += int64_t(row)*ncols;
|
const int row = blockIdx.x;
|
||||||
dst += int64_t(row)*ncols;
|
const int channel = blockIdx.y;
|
||||||
|
const int sample = blockIdx.z;
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
|
||||||
|
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
||||||
|
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
||||||
|
|
||||||
float2 mean_var = make_float2(0.0f, 0.0f);
|
float2 mean_var = make_float2(0.0f, 0.0f);
|
||||||
|
|
||||||
@ -97,12 +105,19 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int block_size>
|
template <int block_size>
|
||||||
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
static __global__ void rms_norm_f32(
|
||||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||||
const int tid = threadIdx.x;
|
const int64_t stride_sample, const float eps) {
|
||||||
|
const int nrows = gridDim.x;
|
||||||
|
const int nchannels = gridDim.y;
|
||||||
|
|
||||||
x += int64_t(row)*ncols;
|
const int row = blockIdx.x;
|
||||||
dst += int64_t(row)*ncols;
|
const int channel = blockIdx.y;
|
||||||
|
const int sample = blockIdx.z;
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
|
||||||
|
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
||||||
|
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
||||||
|
|
||||||
float tmp = 0.0f; // partial sum for thread in warp
|
float tmp = 0.0f; // partial sum for thread in warp
|
||||||
|
|
||||||
@ -186,13 +201,16 @@ static __global__ void rms_norm_back_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
static void norm_f32_cuda(
|
||||||
|
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||||
|
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||||
|
const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||||
if (ncols < 1024) {
|
if (ncols < 1024) {
|
||||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||||
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||||
} else {
|
} else {
|
||||||
const dim3 block_dims(1024, 1, 1);
|
const dim3 block_dims(1024, 1, 1);
|
||||||
norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -207,13 +225,16 @@ static void group_norm_f32_cuda(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
static void rms_norm_f32_cuda(
|
||||||
|
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||||
|
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||||
|
const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||||
if (ncols < 1024) {
|
if (ncols < 1024) {
|
||||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||||
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||||
} else {
|
} else {
|
||||||
const dim3 block_dims(1024, 1, 1);
|
const dim3 block_dims(1024, 1, 1);
|
||||||
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -229,23 +250,26 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
|
|||||||
|
|
||||||
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const float * src0_d = (const float *)src0->data;
|
const float * src0_d = (const float *) src0->data;
|
||||||
float * dst_d = (float *)dst->data;
|
float * dst_d = (float *) dst->data;
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||||
const int64_t nrows = ggml_nrows(src0);
|
|
||||||
|
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, dst->op_params, sizeof(float));
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
GGML_ASSERT(eps >= 0.0f);
|
GGML_ASSERT(eps >= 0.0f);
|
||||||
|
|
||||||
norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
|
const size_t ts0 = ggml_type_size(src0->type);
|
||||||
|
GGML_ASSERT(nb00 == ts0);
|
||||||
|
const int64_t s01 = nb01 / ts0;
|
||||||
|
const int64_t s02 = nb02 / ts0;
|
||||||
|
const int64_t s03 = nb03 / ts0;
|
||||||
|
|
||||||
|
norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
@ -254,8 +278,6 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|||||||
float * dst_d = (float *)dst->data;
|
float * dst_d = (float *)dst->data;
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
@ -271,23 +293,26 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|||||||
|
|
||||||
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const float * src0_d = (const float *)src0->data;
|
const float * src0_d = (const float *) src0->data;
|
||||||
float * dst_d = (float *)dst->data;
|
float * dst_d = (float *) dst->data;
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||||
const int64_t nrows = ggml_nrows(src0);
|
|
||||||
|
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, dst->op_params, sizeof(float));
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
GGML_ASSERT(eps >= 0.0f);
|
GGML_ASSERT(eps >= 0.0f);
|
||||||
|
|
||||||
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
|
const size_t ts0 = ggml_type_size(src0->type);
|
||||||
|
GGML_ASSERT(nb00 == ts0);
|
||||||
|
const int64_t s01 = nb01 / ts0;
|
||||||
|
const int64_t s02 = nb02 / ts0;
|
||||||
|
const int64_t s03 = nb03 / ts0;
|
||||||
|
|
||||||
|
rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
@ -1206,10 +1206,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
return has_simdgroup_reduction;
|
return has_simdgroup_reduction;
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
|
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||||
case GGML_OP_ARGMAX:
|
case GGML_OP_ARGMAX:
|
||||||
case GGML_OP_NORM:
|
|
||||||
return true;
|
return true;
|
||||||
|
case GGML_OP_NORM:
|
||||||
|
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
{
|
{
|
||||||
const int mode = ((const int32_t *) op->op_params)[2];
|
const int mode = ((const int32_t *) op->op_params)[2];
|
||||||
|
@ -8182,9 +8182,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||||||
case GGML_OP_VIEW:
|
case GGML_OP_VIEW:
|
||||||
case GGML_OP_PERMUTE:
|
case GGML_OP_PERMUTE:
|
||||||
case GGML_OP_TRANSPOSE:
|
case GGML_OP_TRANSPOSE:
|
||||||
|
return true;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
|
return ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
|
@ -4610,7 +4610,8 @@ struct llm_build_context {
|
|||||||
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
|
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
|
||||||
cb(k_pe, "k_pe", il);
|
cb(k_pe, "k_pe", il);
|
||||||
|
|
||||||
kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
|
// TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
|
||||||
|
kv_compressed = ggml_cont(ctx0, kv_compressed);
|
||||||
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
|
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
|
||||||
model.layers[il].attn_kv_a_norm, NULL,
|
model.layers[il].attn_kv_a_norm, NULL,
|
||||||
LLM_NORM_RMS, cb, il);
|
LLM_NORM_RMS, cb, il);
|
||||||
@ -6464,7 +6465,8 @@ struct llm_build_context {
|
|||||||
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
|
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
|
||||||
cb(k_pe, "k_pe", il);
|
cb(k_pe, "k_pe", il);
|
||||||
|
|
||||||
kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
|
// TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
|
||||||
|
kv_compressed = ggml_cont(ctx0, kv_compressed);
|
||||||
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
|
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
|
||||||
model.layers[il].attn_kv_a_norm, NULL,
|
model.layers[il].attn_kv_a_norm, NULL,
|
||||||
LLM_NORM_RMS, cb, il);
|
LLM_NORM_RMS, cb, il);
|
||||||
|
@ -1674,21 +1674,28 @@ struct test_silu_back : public test_case {
|
|||||||
struct test_norm : public test_case {
|
struct test_norm : public test_case {
|
||||||
const ggml_type type;
|
const ggml_type type;
|
||||||
const std::array<int64_t, 4> ne;
|
const std::array<int64_t, 4> ne;
|
||||||
float eps;
|
const bool v; // whether a is a non-contiguous view
|
||||||
|
const float eps;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR3(type, ne, eps);
|
return VARS_TO_STR4(type, ne, v, eps);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_norm(ggml_type type = GGML_TYPE_F32,
|
test_norm(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne = {64, 5, 4, 3},
|
std::array<int64_t, 4> ne = {64, 5, 4, 3},
|
||||||
|
bool v = false,
|
||||||
float eps = 1e-6f)
|
float eps = 1e-6f)
|
||||||
: type(type), ne(ne), eps(eps) {}
|
: type(type), ne(ne), v(v), eps(eps) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
ggml_set_name(a, "a");
|
ggml_set_name(a, "a");
|
||||||
|
|
||||||
|
if (v) {
|
||||||
|
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
|
||||||
|
ggml_set_name(a, "view of a");
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor * out = ggml_norm(ctx, a, eps);
|
ggml_tensor * out = ggml_norm(ctx, a, eps);
|
||||||
ggml_set_name(out, "out");
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
@ -1700,22 +1707,29 @@ struct test_norm : public test_case {
|
|||||||
struct test_rms_norm : public test_case {
|
struct test_rms_norm : public test_case {
|
||||||
const ggml_type type;
|
const ggml_type type;
|
||||||
const std::array<int64_t, 4> ne;
|
const std::array<int64_t, 4> ne;
|
||||||
float eps;
|
const bool v; // whether a is a non-contiguous view
|
||||||
|
const float eps;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR3(type, ne, eps);
|
return VARS_TO_STR4(type, ne, v, eps);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_rms_norm(ggml_type type = GGML_TYPE_F32,
|
test_rms_norm(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne = {64, 5, 4, 3},
|
std::array<int64_t, 4> ne = {64, 5, 4, 3},
|
||||||
|
bool v = false,
|
||||||
float eps = 1e-6f)
|
float eps = 1e-6f)
|
||||||
: type(type), ne(ne), eps(eps) {}
|
: type(type), ne(ne), v(v), eps(eps) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
ggml_set_param(ctx, a);
|
ggml_set_param(ctx, a);
|
||||||
ggml_set_name(a, "a");
|
ggml_set_name(a, "a");
|
||||||
|
|
||||||
|
if (v) {
|
||||||
|
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
|
||||||
|
ggml_set_name(a, "view of a");
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor * out = ggml_rms_norm(ctx, a, eps);
|
ggml_tensor * out = ggml_rms_norm(ctx, a, eps);
|
||||||
ggml_set_name(out, "out");
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
@ -1741,7 +1755,7 @@ struct test_rms_norm : public test_case {
|
|||||||
struct test_rms_norm_back : public test_case {
|
struct test_rms_norm_back : public test_case {
|
||||||
const ggml_type type;
|
const ggml_type type;
|
||||||
const std::array<int64_t, 4> ne;
|
const std::array<int64_t, 4> ne;
|
||||||
float eps;
|
const float eps;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR3(type, ne, eps);
|
return VARS_TO_STR3(type, ne, eps);
|
||||||
@ -2919,7 +2933,7 @@ struct test_group_norm : public test_case {
|
|||||||
const float eps;
|
const float eps;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR3(type, ne, num_groups);
|
return VARS_TO_STR4(type, ne, num_groups, eps);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_group_norm(ggml_type type = GGML_TYPE_F32,
|
test_group_norm(ggml_type type = GGML_TYPE_F32,
|
||||||
@ -3964,9 +3978,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||||||
test_cases.emplace_back(new test_scale());
|
test_cases.emplace_back(new test_scale());
|
||||||
test_cases.emplace_back(new test_silu_back());
|
test_cases.emplace_back(new test_silu_back());
|
||||||
|
|
||||||
for (float eps : {0.0f, 1e-7f, 1e-4f, 1e-1f}) {
|
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
|
||||||
test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
for (bool v : {false, true}) {
|
||||||
test_cases.emplace_back(new test_rms_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
|
||||||
|
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
|
||||||
|
}
|
||||||
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user