diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e0ea890b1..e26260a35 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -116,6 +116,7 @@ #include "ggml.h" #include "ggml-backend-impl.h" +#define CC_PASCAL 600 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products #define CC_VOLTA 700 #define CC_OFFSET_AMD 1000000 @@ -556,11 +557,12 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; struct cuda_device_capabilities { int cc; // compute capability + size_t smpb; // max. shared memory per block bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory }; -static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} }; +static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, false, 0} }; static void * g_scratch_buffer = nullptr; static size_t g_scratch_size = 0; // disabled by default @@ -593,6 +595,19 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { return a; } +static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { +#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) + (void) a; + bad_arch(); +#else +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); + } + return a; +#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +} + static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { @@ -601,6 +616,19 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) + (void) x; + bad_arch(); +#else +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +} + static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; GGML_UNUSED(a); @@ -5385,75 +5413,233 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; } -static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) { +template +static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL + const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template; + const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2; + const int tid = threadIdx.x; const int rowx = blockIdx.x; const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension - const int block_size = blockDim.x; + const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; - __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE]; + extern __shared__ half data_soft_max_f16[]; + half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication + // (shared memory) buffer to cache values between iterations: + half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data); + // if the buffer is larger than max. shared memory per block, use dst as temp. buffer instead + // in that case col_smem == col_data must be enforced to avoid race conditions - float max_val = -INFINITY; + half2 max_val = make_half2(-INFINITY, -INFINITY); - for (int col = tid; col < ncols; col += block_size) { - const int ix = rowx*ncols + col; - const int iy = rowy*ncols + col; - max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f)); +#pragma unroll + for (int col0 = 0; col0 < ncols_smem; col0 += block_size) { + const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id; + const int col_smem = vals_smem ? col0 + tid : col_data; + + const int ix = rowx*ncols_data + col_data; + const int iy = rowy*ncols_data + col_data; + + half2 val; + if (need_check && col_data + 0 >= ncols_data) { + val.x = -INFINITY; + } else { + val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f); + } + if (need_check && col_data + WARP_SIZE >= ncols_data) { + val.y = -INFINITY; + } else { + val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f); + } + if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) { + vals[col_smem] = val; + } + max_val = __hmax2(max_val, val); } // find the max value in the block max_val = warp_reduce_max(max_val); if (block_size > WARP_SIZE) { if (warp_id == 0) { - buf[lane_id] = -INFINITY; + buf_iw[lane_id] = -INFINITY; } __syncthreads(); if (lane_id == 0) { - buf[warp_id] = max_val; + buf_iw[warp_id] = __hmax(max_val.x, max_val.y); } __syncthreads(); - max_val = buf[lane_id]; + max_val = __half2half2(buf_iw[lane_id]); max_val = warp_reduce_max(max_val); + } else { + max_val = __half2half2(__hmax(max_val.x, max_val.y)); } - float tmp = 0.f; + half2 tmp = make_half2(0.0f, 0.0f); // partial sums + +#pragma unroll + for (int col0 = 0; col0 < ncols_smem; col0 += block_size) { + const int col_smem = vals_smem ? col0 + tid : 2*col0 + 2*warp_id*WARP_SIZE + lane_id; + + if (ncols_template == 0 && col_smem >= (vals_smem ? ncols_smem : ncols_data)) { + break; + } + + const half2 val = h2exp(vals[col_smem] - max_val); - for (int col = tid; col < ncols; col += block_size) { - const int ix = rowx*ncols + col; - const int iy = rowy*ncols + col; - const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val); tmp += val; - dst[ix] = val; + vals[col_smem] = val; } // find the sum of exps in the block tmp = warp_reduce_sum(tmp); if (block_size > WARP_SIZE) { if (warp_id == 0) { - buf[lane_id] = 0.f; + buf_iw[lane_id] = 0.0f; } __syncthreads(); if (lane_id == 0) { - buf[warp_id] = tmp; + buf_iw[warp_id] = tmp.x + tmp.y; } __syncthreads(); - tmp = buf[lane_id]; + tmp = __half2half2(buf_iw[lane_id]); + tmp = warp_reduce_sum(tmp); + } else { + tmp = __half2half2(tmp.x + tmp.y); + } + + const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp; + +#pragma unroll + for (int col0 = 0; col0 < ncols_smem; col0 += block_size) { + const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id; + const int col_smem = vals_smem ? col0 + tid : col_data; + + const int idst = rowx*ncols_data + col_data; + const half2 result = vals[col_smem] * inv_sum; + + if (need_check && col_data + 0 >= ncols_data) { + return; + } + dst[idst] = result.x; + + if (need_check && col_data + WARP_SIZE >= ncols_data) { + return; + } + + dst[idst + WARP_SIZE] = result.y; + } +#else + (void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale; + bad_arch(); +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +} + +template +static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { + const int ncols = ncols_template == 0 ? ncols_par : ncols_template; + + const int tid = threadIdx.x; + const int rowx = blockIdx.x; + const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension + + const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + + extern __shared__ float data_soft_max_f32[]; + float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication + // shared memory buffer to cache values between iterations: + float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols; + + float max_val = -INFINITY; + +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + break; + } + + const int ix = rowx*ncols + col; + const int iy = rowy*ncols + col; + + const float val = x[ix]*scale + (y ? y[iy] : 0.0f); + vals[col] = val; + max_val = max(max_val, val); + } + + // find the max value in the block + max_val = warp_reduce_max(max_val); + if (block_size > WARP_SIZE) { + if (warp_id == 0) { + buf_iw[lane_id] = -INFINITY; + } + __syncthreads(); + + if (lane_id == 0) { + buf_iw[warp_id] = max_val; + } + __syncthreads(); + + max_val = buf_iw[lane_id]; + max_val = warp_reduce_max(max_val); + } + + float tmp = 0.0f; // partial sum + +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + break; + } + + const float val = expf(vals[col] - max_val); + tmp += val; + vals[col] = val; + } + + // find the sum of exps in the block + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + if (warp_id == 0) { + buf_iw[lane_id] = 0.0f; + } + __syncthreads(); + + if (lane_id == 0) { + buf_iw[warp_id] = tmp; + } + __syncthreads(); + + tmp = buf_iw[lane_id]; tmp = warp_reduce_sum(tmp); } - const float inv_tmp = 1.f / tmp; + const float inv_sum = 1.0f / tmp; - for (int col = tid; col < ncols; col += block_size) { - const int i = rowx*ncols + col; - dst[i] *= inv_tmp; +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + return; + } + + const int idst = rowx*ncols + col; + dst[idst] = vals[col] * inv_sum; } } @@ -6752,12 +6938,90 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); } +static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { + int nth = WARP_SIZE; + while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; + const dim3 block_dims(nth, 1, 1); + const dim3 block_nums(nrows_x, 1, 1); + const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half); + static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); + if (shmem <= g_device_caps[g_main_device].smpb) { + switch (ncols_x) { + case 32: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 64: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 128: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 256: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 512: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 1024: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 2048: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 4096: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + default: + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + } + } else { + const size_t shmem_low = WARP_SIZE*sizeof(half); + soft_max_f16<<>>(x, y, dst, ncols_x, nrows_y, scale); + } +} + static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); const dim3 block_nums(nrows_x, 1, 1); - soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); + static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); + if (shmem < g_device_caps[g_main_device].smpb) { + switch (ncols_x) { + case 32: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 64: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 128: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 256: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 512: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 1024: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 2048: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + case 4096: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + default: + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + break; + } + } else { + const size_t shmem_low = WARP_SIZE*sizeof(float); + soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); + } } static void im2col_f32_f16_cuda(const float* x, half* dst, @@ -7072,6 +7336,7 @@ void ggml_init_cublas() { #else g_device_caps[id].cc = 100*prop.major + 10*prop.minor; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + g_device_caps[id].smpb = prop.sharedMemPerBlock; } for (int id = 0; id < g_device_count; ++id) { g_tensor_split[id] /= total_vram; @@ -8087,7 +8352,21 @@ static void ggml_cuda_op_soft_max( float scale = 1.0f; memcpy(&scale, dst->op_params, sizeof(float)); - soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + const bool use_f16_soft_max = false; +#else +#ifdef GGML_CUDA_F16 + const bool use_f16_soft_max = true; +#else + const bool use_f16_soft_max = false; +#endif // GGML_CUDA_F16 +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + + if (use_f16_soft_max) { + soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + } else { + soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); + } (void) dst; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b79de7a7d..7a60d7743 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -450,7 +450,7 @@ struct test_case { double err = nmse(f1.data(), f2.data(), f1.size()); if (err > ud->max_err) { - printf("[%s] NMSE = %f ", ggml_op_desc(t1), err); + printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); //for (int i = 0; i < (int) f1.size(); i++) { // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); //} @@ -1449,6 +1449,7 @@ struct test_moe : public test_case { static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) { std::vector> test_cases; + std::default_random_engine rng(0); const ggml_type all_types[] = { GGML_TYPE_F32, GGML_TYPE_F16, @@ -1583,7 +1584,19 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 1}, 5)); test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5)); - test_cases.emplace_back(new test_soft_max()); + std::uniform_int_distribution<> dist_ne1(1, 50); + int exponent = 1; + while (exponent < (1 << 17)) { + std::uniform_int_distribution<> dist_ne0(exponent, 2*exponent); + + for (int n = 0; n < 10; ++n) { + int64_t ne0 = dist_ne0(rng); + int64_t ne1 = dist_ne1(rng); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1})); + } + + exponent <<= 1; + } for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B