diff --git a/ggml/src/ggml-cuda/argmax.cu b/ggml/src/ggml-cuda/argmax.cu index aab04eca7..5340eedc0 100644 --- a/ggml/src/ggml-cuda/argmax.cu +++ b/ggml/src/ggml-cuda/argmax.cu @@ -1,57 +1,69 @@ -#include "common.cuh" -#include "argmax.cuh" -#include "sum.cuh" - +#include #include -static __global__ void argmax_f32( - const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) { +#include "argmax.cuh" +#include "common.cuh" +#include "sum.cuh" - int argmax_thread = 0; - const int64_t row0 = (int64_t)blockIdx.x*WARP_SIZE; +static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) { + const int64_t row = blockIdx.x; -#pragma unroll - for (int64_t row1 = 0; row1 < WARP_SIZE; ++row1) { - const int64_t row = row0 + row1; + float maxval = -FLT_MAX; + int argmax = -1; + const float * rowx = x + row * ncols; - if (row >= nrows) { - break; + for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) { + const float val = rowx[col]; + if (val > maxval) { + maxval = val; + argmax = col; } - - float maxval = -FLT_MAX; - int argmax = -1; - - for (int32_t col = threadIdx.x; col < ncols; col += WARP_SIZE) { - const float val = x[row*ncols + col]; - const int bigger = val > maxval; - const int not_bigger = bigger ^ 0x00000001; - - maxval = maxval*not_bigger + val*bigger; - argmax = argmax*not_bigger + col*bigger; - } - -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE); - const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE); - const int bigger = val > maxval; - const int not_bigger = bigger ^ 0x00000001; - - maxval = maxval*not_bigger + val*bigger; - argmax = argmax*not_bigger + col*bigger; - } - - const int store = row1 == threadIdx.x; - argmax_thread += store*argmax; } - const int row = row0 + threadIdx.x; - - if (row >= nrows) { - return; +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); + const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); + if (val > maxval) { + maxval = val; + argmax = col; + } } - dst[row] = argmax_thread; + const int n_warps = blockDim.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + if (n_warps > 1) { + constexpr int max_warps = 1024 / WARP_SIZE; + __shared__ float shared_maxval[max_warps]; + __shared__ int shared_argmax[max_warps]; + if (lane_id == 0) { + shared_maxval[warp_id] = maxval; + shared_argmax[warp_id] = argmax; + } + + __syncthreads(); + + if (warp_id == 0) { + if (lane_id < n_warps) { + maxval = shared_maxval[lane_id]; + argmax = shared_argmax[lane_id]; + } +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); + const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); + if (val > maxval) { + maxval = val; + argmax = col; + } + } + } + } + + if (warp_id == 0 && lane_id == 0) { + dst[row] = argmax; + } } void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -70,10 +82,10 @@ void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { cudaStream_t stream = ctx.stream(); - const int64_t num_blocks = (nrows + WARP_SIZE - 1) / WARP_SIZE; - - const dim3 blocks_dim(WARP_SIZE, 1, 1); + const int64_t num_blocks = nrows; + const int64_t num_threads = std::min(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE); + const dim3 blocks_dim(num_threads, 1, 1); const dim3 blocks_num(num_blocks, 1, 1); - argmax_f32<<>>(src0_d, dst_d, ne00, nrows); + argmax_f32<<>>(src0_d, dst_d, ne00); } diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index e146c691c..b0dd16066 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -180,8 +180,8 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) { return __reduce_add_sync(0xffffffff, x); #else #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - x += __shfl_xor_sync(0xffffffff, x, mask, 32); + for (int offset = 16; offset > 0; offset >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, offset, 32); } return x; #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE @@ -189,17 +189,17 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) { static __device__ __forceinline__ float warp_reduce_sum(float x) { #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - x += __shfl_xor_sync(0xffffffff, x, mask, 32); + for (int offset = 16; offset > 0; offset >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, offset, 32); } return x; } static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32); - a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32); + for (int offset = 16; offset > 0; offset >>= 1) { + a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32); + a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32); } return a; } @@ -209,16 +209,16 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32); + for (int offset = 16; offset > 0; offset >>= 1) { + const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32); reinterpret_cast(a.x) += __low2half(a_other); reinterpret_cast(a.y) += __high2half(a_other); } return a; #else #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); + for (int offset = 16; offset > 0; offset >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32)); } return a; #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) @@ -231,8 +231,8 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + for (int offset = 16; offset > 0; offset >>= 1) { + x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32)); } return x; } @@ -275,8 +275,8 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + for (int offset = 16; offset > 0; offset >>= 1) { + x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32)); } return x; #else diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 45408ce86..1702e4ce2 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -69,8 +69,8 @@ static __global__ void quantize_mmq_q8_1( // Exchange max. abs. value between vals_per_scale/4 threads. #pragma unroll - for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) { - amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE)); + for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE)); } float sum; @@ -79,8 +79,8 @@ static __global__ void quantize_mmq_q8_1( // Exchange calculate sum across vals_per_sum/4 threads. #pragma unroll - for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) { - sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE); + for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) { + sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE); } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 719d75c70..78e7874de 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2255,6 +2255,7 @@ struct ggml_tensor * ggml_argmax( struct ggml_context * ctx, struct ggml_tensor * a) { GGML_ASSERT(ggml_is_matrix(a)); + GGML_ASSERT(a->ne[0] <= INT32_MAX); struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]); @@ -4138,6 +4139,7 @@ struct ggml_tensor * ggml_argsort( struct ggml_context * ctx, struct ggml_tensor * a, enum ggml_sort_order order) { + GGML_ASSERT(a->ne[0] <= INT32_MAX); struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne); ggml_set_op_params_i32(result, 0, (int32_t) order); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 37342c156..b2b570524 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1154,6 +1154,26 @@ struct test_argmax : public test_case { return out; } + void initialize_tensors(ggml_context * ctx) override { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_F32) { + // initialize with unique values to avoid ties + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float)); + } + } else { + init_tensor_uniform(t); + } + } + } + double max_nmse_err() override { return 0.0; } @@ -3440,6 +3460,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); test_cases.emplace_back(new test_argmax()); + test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1})); + test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1})); + test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1})); + test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1})); + test_cases.emplace_back(new test_count_equal()); for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1 @@ -3830,6 +3855,10 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, 1.0f, 0.0f)); + test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1})); + test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1})); + test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1})); + for (int bs : {1, 512}) { for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32}) {