From fd08255d0dea6625596c0367ee0a11d195f36762 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 4 Feb 2025 22:21:42 +0100 Subject: [PATCH] CUDA: non-contiguous (RMS) norm support (#11659) * CUDA: non-contiguous (RMS) norm support --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-cuda/ggml-cuda.cu | 4 ++ ggml/src/ggml-cuda/norm.cu | 89 ++++++++++++++++++---------- ggml/src/ggml-metal/ggml-metal.m | 5 +- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 + src/llama.cpp | 6 +- tests/test-backend-ops.cpp | 38 ++++++++---- 6 files changed, 97 insertions(+), 47 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index bda10aec1..70a598099 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -38,6 +38,7 @@ #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv6.cuh" #include "ggml-cuda/gla.cuh" +#include "ggml.h" #include #include @@ -3139,6 +3140,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: + return true; case GGML_OP_RMS_NORM_BACK: return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0; break; @@ -3181,7 +3183,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: + return true; case GGML_OP_GROUP_NORM: + return ggml_is_contiguous(op->src[0]); case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_ARANGE: diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index d991ec972..f127616ed 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -1,12 +1,20 @@ #include "norm.cuh" +#include template -static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) { - const int row = blockIdx.x*blockDim.y + threadIdx.y; - const int tid = threadIdx.x; +static __global__ void norm_f32( + const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps) { + const int nrows = gridDim.x; + const int nchannels = gridDim.y; - x += int64_t(row)*ncols; - dst += int64_t(row)*ncols; + const int row = blockIdx.x; + const int channel = blockIdx.y; + const int sample = blockIdx.z; + const int tid = threadIdx.x; + + x += sample*stride_sample + channel*stride_channel + row*stride_row; + dst += ((sample*nchannels + channel)*nrows + row)*ncols; float2 mean_var = make_float2(0.0f, 0.0f); @@ -97,12 +105,19 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr } template -static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { - const int row = blockIdx.x*blockDim.y + threadIdx.y; - const int tid = threadIdx.x; +static __global__ void rms_norm_f32( + const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps) { + const int nrows = gridDim.x; + const int nchannels = gridDim.y; - x += int64_t(row)*ncols; - dst += int64_t(row)*ncols; + const int row = blockIdx.x; + const int channel = blockIdx.y; + const int sample = blockIdx.z; + const int tid = threadIdx.x; + + x += sample*stride_sample + channel*stride_channel + row*stride_row; + dst += ((sample*nchannels + channel)*nrows + row)*ncols; float tmp = 0.0f; // partial sum for thread in warp @@ -186,13 +201,16 @@ static __global__ void rms_norm_back_f32( } } -static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { +static void norm_f32_cuda( + const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - norm_f32<<>>(x, dst, ncols, eps); + norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - norm_f32<1024><<>>(x, dst, ncols, eps); + norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } @@ -207,13 +225,16 @@ static void group_norm_f32_cuda( } } -static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { +static void rms_norm_f32_cuda( + const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols, eps); + rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024><<>>(x, dst, ncols, eps); + rms_norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } @@ -229,23 +250,26 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); + GGML_TENSOR_UNARY_OP_LOCALS; float eps; memcpy(&eps, dst->op_params, sizeof(float)); GGML_ASSERT(eps >= 0.0f); - norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(nb00 == ts0); + const int64_t s01 = nb01 / ts0; + const int64_t s02 = nb02 / ts0; + const int64_t s03 = nb03 / ts0; + + norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream); } void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -254,8 +278,6 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -271,23 +293,26 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); + GGML_TENSOR_UNARY_OP_LOCALS; float eps; memcpy(&eps, dst->op_params, sizeof(float)); GGML_ASSERT(eps >= 0.0f); - rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(nb00 == ts0); + const int64_t s01 = nb01 / ts0; + const int64_t s02 = nb02 / ts0; + const int64_t s03 = nb03 / ts0; + + rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream); } void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 9605914ff..0a264be37 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1206,10 +1206,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_GROUP_NORM: return has_simdgroup_reduction; case GGML_OP_RMS_NORM: - return has_simdgroup_reduction && (op->ne[0] % 4 == 0); + return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); case GGML_OP_ARGMAX: - case GGML_OP_NORM: return true; + case GGML_OP_NORM: + return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_ROPE: { const int mode = ((const int32_t *) op->op_params)[2]; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 9ca3959ab..48ac489a6 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8182,9 +8182,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: + return true; case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: + return ggml_is_contiguous(op->src[0]); case GGML_OP_ADD: case GGML_OP_ACC: case GGML_OP_MUL: diff --git a/src/llama.cpp b/src/llama.cpp index 5760017e0..aae3c69b5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4610,7 +4610,8 @@ struct llm_build_context { ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); cb(k_pe, "k_pe", il); - kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm + // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont + kv_compressed = ggml_cont(ctx0, kv_compressed); kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, model.layers[il].attn_kv_a_norm, NULL, LLM_NORM_RMS, cb, il); @@ -6464,7 +6465,8 @@ struct llm_build_context { ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); cb(k_pe, "k_pe", il); - kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm + // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont + kv_compressed = ggml_cont(ctx0, kv_compressed); kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, model.layers[il].attn_kv_a_norm, NULL, LLM_NORM_RMS, cb, il); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 4c5c4dd9c..1bfd41254 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1674,21 +1674,28 @@ struct test_silu_back : public test_case { struct test_norm : public test_case { const ggml_type type; const std::array ne; - float eps; + const bool v; // whether a is a non-contiguous view + const float eps; std::string vars() override { - return VARS_TO_STR3(type, ne, eps); + return VARS_TO_STR4(type, ne, v, eps); } test_norm(ggml_type type = GGML_TYPE_F32, std::array ne = {64, 5, 4, 3}, + bool v = false, float eps = 1e-6f) - : type(type), ne(ne), eps(eps) {} + : type(type), ne(ne), v(v), eps(eps) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_name(a, "a"); + if (v) { + a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view of a"); + } + ggml_tensor * out = ggml_norm(ctx, a, eps); ggml_set_name(out, "out"); @@ -1700,22 +1707,29 @@ struct test_norm : public test_case { struct test_rms_norm : public test_case { const ggml_type type; const std::array ne; - float eps; + const bool v; // whether a is a non-contiguous view + const float eps; std::string vars() override { - return VARS_TO_STR3(type, ne, eps); + return VARS_TO_STR4(type, ne, v, eps); } test_rms_norm(ggml_type type = GGML_TYPE_F32, std::array ne = {64, 5, 4, 3}, + bool v = false, float eps = 1e-6f) - : type(type), ne(ne), eps(eps) {} + : type(type), ne(ne), v(v), eps(eps) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_param(ctx, a); ggml_set_name(a, "a"); + if (v) { + a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view of a"); + } + ggml_tensor * out = ggml_rms_norm(ctx, a, eps); ggml_set_name(out, "out"); @@ -1741,7 +1755,7 @@ struct test_rms_norm : public test_case { struct test_rms_norm_back : public test_case { const ggml_type type; const std::array ne; - float eps; + const float eps; std::string vars() override { return VARS_TO_STR3(type, ne, eps); @@ -2919,7 +2933,7 @@ struct test_group_norm : public test_case { const float eps; std::string vars() override { - return VARS_TO_STR3(type, ne, num_groups); + return VARS_TO_STR4(type, ne, num_groups, eps); } test_group_norm(ggml_type type = GGML_TYPE_F32, @@ -3964,9 +3978,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_scale()); test_cases.emplace_back(new test_silu_back()); - for (float eps : {0.0f, 1e-7f, 1e-4f, 1e-1f}) { - test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps)); - test_cases.emplace_back(new test_rms_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps)); + for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) { + for (bool v : {false, true}) { + test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, v, eps)); + test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps)); + } test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); }