mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 13:28:50 +01:00
CUDA: faster softmax via shared memory + fp16 math (#4742)
This commit is contained in:
parent
1fc2f265ff
commit
8f900abfc0
333
ggml-cuda.cu
333
ggml-cuda.cu
@ -116,6 +116,7 @@
|
|||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "ggml-backend-impl.h"
|
#include "ggml-backend-impl.h"
|
||||||
|
|
||||||
|
#define CC_PASCAL 600
|
||||||
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
||||||
#define CC_VOLTA 700
|
#define CC_VOLTA 700
|
||||||
#define CC_OFFSET_AMD 1000000
|
#define CC_OFFSET_AMD 1000000
|
||||||
@ -556,11 +557,12 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
|
|||||||
|
|
||||||
struct cuda_device_capabilities {
|
struct cuda_device_capabilities {
|
||||||
int cc; // compute capability
|
int cc; // compute capability
|
||||||
|
size_t smpb; // max. shared memory per block
|
||||||
bool vmm; // virtual memory support
|
bool vmm; // virtual memory support
|
||||||
size_t vmm_granularity; // granularity of virtual memory
|
size_t vmm_granularity; // granularity of virtual memory
|
||||||
};
|
};
|
||||||
|
|
||||||
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
|
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, false, 0} };
|
||||||
|
|
||||||
static void * g_scratch_buffer = nullptr;
|
static void * g_scratch_buffer = nullptr;
|
||||||
static size_t g_scratch_size = 0; // disabled by default
|
static size_t g_scratch_size = 0; // disabled by default
|
||||||
@ -593,6 +595,19 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
|||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||||
|
#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
(void) a;
|
||||||
|
bad_arch();
|
||||||
|
#else
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
||||||
|
}
|
||||||
|
return a;
|
||||||
|
#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
@ -601,6 +616,19 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||||
|
#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
(void) x;
|
||||||
|
bad_arch();
|
||||||
|
#else
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
|
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
|
||||||
return b;
|
return b;
|
||||||
GGML_UNUSED(a);
|
GGML_UNUSED(a);
|
||||||
@ -5385,75 +5413,233 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
|||||||
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
|
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
|
||||||
|
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
|
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
|
||||||
|
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
const int rowx = blockIdx.x;
|
const int rowx = blockIdx.x;
|
||||||
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
||||||
|
|
||||||
const int block_size = blockDim.x;
|
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
||||||
|
|
||||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
|
||||||
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
|
extern __shared__ half data_soft_max_f16[];
|
||||||
|
half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
|
||||||
|
// (shared memory) buffer to cache values between iterations:
|
||||||
|
half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data);
|
||||||
|
// if the buffer is larger than max. shared memory per block, use dst as temp. buffer instead
|
||||||
|
// in that case col_smem == col_data must be enforced to avoid race conditions
|
||||||
|
|
||||||
float max_val = -INFINITY;
|
half2 max_val = make_half2(-INFINITY, -INFINITY);
|
||||||
|
|
||||||
for (int col = tid; col < ncols; col += block_size) {
|
#pragma unroll
|
||||||
const int ix = rowx*ncols + col;
|
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
||||||
const int iy = rowy*ncols + col;
|
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
|
||||||
max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
|
const int col_smem = vals_smem ? col0 + tid : col_data;
|
||||||
|
|
||||||
|
const int ix = rowx*ncols_data + col_data;
|
||||||
|
const int iy = rowy*ncols_data + col_data;
|
||||||
|
|
||||||
|
half2 val;
|
||||||
|
if (need_check && col_data + 0 >= ncols_data) {
|
||||||
|
val.x = -INFINITY;
|
||||||
|
} else {
|
||||||
|
val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
|
||||||
|
}
|
||||||
|
if (need_check && col_data + WARP_SIZE >= ncols_data) {
|
||||||
|
val.y = -INFINITY;
|
||||||
|
} else {
|
||||||
|
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
|
||||||
|
}
|
||||||
|
if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
|
||||||
|
vals[col_smem] = val;
|
||||||
|
}
|
||||||
|
max_val = __hmax2(max_val, val);
|
||||||
}
|
}
|
||||||
|
|
||||||
// find the max value in the block
|
// find the max value in the block
|
||||||
max_val = warp_reduce_max(max_val);
|
max_val = warp_reduce_max(max_val);
|
||||||
if (block_size > WARP_SIZE) {
|
if (block_size > WARP_SIZE) {
|
||||||
if (warp_id == 0) {
|
if (warp_id == 0) {
|
||||||
buf[lane_id] = -INFINITY;
|
buf_iw[lane_id] = -INFINITY;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
if (lane_id == 0) {
|
if (lane_id == 0) {
|
||||||
buf[warp_id] = max_val;
|
buf_iw[warp_id] = __hmax(max_val.x, max_val.y);
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
max_val = buf[lane_id];
|
max_val = __half2half2(buf_iw[lane_id]);
|
||||||
max_val = warp_reduce_max(max_val);
|
max_val = warp_reduce_max(max_val);
|
||||||
|
} else {
|
||||||
|
max_val = __half2half2(__hmax(max_val.x, max_val.y));
|
||||||
}
|
}
|
||||||
|
|
||||||
float tmp = 0.f;
|
half2 tmp = make_half2(0.0f, 0.0f); // partial sums
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
||||||
|
const int col_smem = vals_smem ? col0 + tid : 2*col0 + 2*warp_id*WARP_SIZE + lane_id;
|
||||||
|
|
||||||
|
if (ncols_template == 0 && col_smem >= (vals_smem ? ncols_smem : ncols_data)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const half2 val = h2exp(vals[col_smem] - max_val);
|
||||||
|
|
||||||
for (int col = tid; col < ncols; col += block_size) {
|
|
||||||
const int ix = rowx*ncols + col;
|
|
||||||
const int iy = rowy*ncols + col;
|
|
||||||
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
|
|
||||||
tmp += val;
|
tmp += val;
|
||||||
dst[ix] = val;
|
vals[col_smem] = val;
|
||||||
}
|
}
|
||||||
|
|
||||||
// find the sum of exps in the block
|
// find the sum of exps in the block
|
||||||
tmp = warp_reduce_sum(tmp);
|
tmp = warp_reduce_sum(tmp);
|
||||||
if (block_size > WARP_SIZE) {
|
if (block_size > WARP_SIZE) {
|
||||||
if (warp_id == 0) {
|
if (warp_id == 0) {
|
||||||
buf[lane_id] = 0.f;
|
buf_iw[lane_id] = 0.0f;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
if (lane_id == 0) {
|
if (lane_id == 0) {
|
||||||
buf[warp_id] = tmp;
|
buf_iw[warp_id] = tmp.x + tmp.y;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
tmp = buf[lane_id];
|
tmp = __half2half2(buf_iw[lane_id]);
|
||||||
|
tmp = warp_reduce_sum(tmp);
|
||||||
|
} else {
|
||||||
|
tmp = __half2half2(tmp.x + tmp.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
||||||
|
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
|
||||||
|
const int col_smem = vals_smem ? col0 + tid : col_data;
|
||||||
|
|
||||||
|
const int idst = rowx*ncols_data + col_data;
|
||||||
|
const half2 result = vals[col_smem] * inv_sum;
|
||||||
|
|
||||||
|
if (need_check && col_data + 0 >= ncols_data) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dst[idst] = result.x;
|
||||||
|
|
||||||
|
if (need_check && col_data + WARP_SIZE >= ncols_data) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[idst + WARP_SIZE] = result.y;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
(void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
|
||||||
|
bad_arch();
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||||
|
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
||||||
|
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||||
|
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int rowx = blockIdx.x;
|
||||||
|
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
||||||
|
|
||||||
|
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
||||||
|
|
||||||
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
|
||||||
|
extern __shared__ float data_soft_max_f32[];
|
||||||
|
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
||||||
|
// shared memory buffer to cache values between iterations:
|
||||||
|
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
|
||||||
|
|
||||||
|
float max_val = -INFINITY;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||||
|
const int col = col0 + tid;
|
||||||
|
|
||||||
|
if (ncols_template == 0 && col >= ncols) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ix = rowx*ncols + col;
|
||||||
|
const int iy = rowy*ncols + col;
|
||||||
|
|
||||||
|
const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
|
||||||
|
vals[col] = val;
|
||||||
|
max_val = max(max_val, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the max value in the block
|
||||||
|
max_val = warp_reduce_max(max_val);
|
||||||
|
if (block_size > WARP_SIZE) {
|
||||||
|
if (warp_id == 0) {
|
||||||
|
buf_iw[lane_id] = -INFINITY;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (lane_id == 0) {
|
||||||
|
buf_iw[warp_id] = max_val;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
max_val = buf_iw[lane_id];
|
||||||
|
max_val = warp_reduce_max(max_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
float tmp = 0.0f; // partial sum
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||||
|
const int col = col0 + tid;
|
||||||
|
|
||||||
|
if (ncols_template == 0 && col >= ncols) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float val = expf(vals[col] - max_val);
|
||||||
|
tmp += val;
|
||||||
|
vals[col] = val;
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the sum of exps in the block
|
||||||
|
tmp = warp_reduce_sum(tmp);
|
||||||
|
if (block_size > WARP_SIZE) {
|
||||||
|
if (warp_id == 0) {
|
||||||
|
buf_iw[lane_id] = 0.0f;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (lane_id == 0) {
|
||||||
|
buf_iw[warp_id] = tmp;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
tmp = buf_iw[lane_id];
|
||||||
tmp = warp_reduce_sum(tmp);
|
tmp = warp_reduce_sum(tmp);
|
||||||
}
|
}
|
||||||
|
|
||||||
const float inv_tmp = 1.f / tmp;
|
const float inv_sum = 1.0f / tmp;
|
||||||
|
|
||||||
for (int col = tid; col < ncols; col += block_size) {
|
#pragma unroll
|
||||||
const int i = rowx*ncols + col;
|
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||||
dst[i] *= inv_tmp;
|
const int col = col0 + tid;
|
||||||
|
|
||||||
|
if (ncols_template == 0 && col >= ncols) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int idst = rowx*ncols + col;
|
||||||
|
dst[idst] = vals[col] * inv_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -6752,12 +6938,90 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
|
|||||||
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
||||||
|
int nth = WARP_SIZE;
|
||||||
|
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||||
|
const dim3 block_dims(nth, 1, 1);
|
||||||
|
const dim3 block_nums(nrows_x, 1, 1);
|
||||||
|
const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half);
|
||||||
|
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||||
|
if (shmem <= g_device_caps[g_main_device].smpb) {
|
||||||
|
switch (ncols_x) {
|
||||||
|
case 32:
|
||||||
|
soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
case 64:
|
||||||
|
soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
case 512:
|
||||||
|
soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
case 1024:
|
||||||
|
soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
case 2048:
|
||||||
|
soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
case 4096:
|
||||||
|
soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const size_t shmem_low = WARP_SIZE*sizeof(half);
|
||||||
|
soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
||||||
int nth = WARP_SIZE;
|
int nth = WARP_SIZE;
|
||||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||||
const dim3 block_dims(nth, 1, 1);
|
const dim3 block_dims(nth, 1, 1);
|
||||||
const dim3 block_nums(nrows_x, 1, 1);
|
const dim3 block_nums(nrows_x, 1, 1);
|
||||||
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
||||||
|
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||||
|
if (shmem < g_device_caps[g_main_device].smpb) {
|
||||||
|
switch (ncols_x) {
|
||||||
|
case 32:
|
||||||
|
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
case 64:
|
||||||
|
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
case 512:
|
||||||
|
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
case 1024:
|
||||||
|
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
case 2048:
|
||||||
|
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
case 4096:
|
||||||
|
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
||||||
|
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void im2col_f32_f16_cuda(const float* x, half* dst,
|
static void im2col_f32_f16_cuda(const float* x, half* dst,
|
||||||
@ -7072,6 +7336,7 @@ void ggml_init_cublas() {
|
|||||||
#else
|
#else
|
||||||
g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
|
g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
|
||||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
g_device_caps[id].smpb = prop.sharedMemPerBlock;
|
||||||
}
|
}
|
||||||
for (int id = 0; id < g_device_count; ++id) {
|
for (int id = 0; id < g_device_count; ++id) {
|
||||||
g_tensor_split[id] /= total_vram;
|
g_tensor_split[id] /= total_vram;
|
||||||
@ -8087,7 +8352,21 @@ static void ggml_cuda_op_soft_max(
|
|||||||
float scale = 1.0f;
|
float scale = 1.0f;
|
||||||
memcpy(&scale, dst->op_params, sizeof(float));
|
memcpy(&scale, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
const bool use_f16_soft_max = false;
|
||||||
|
#else
|
||||||
|
#ifdef GGML_CUDA_F16
|
||||||
|
const bool use_f16_soft_max = true;
|
||||||
|
#else
|
||||||
|
const bool use_f16_soft_max = false;
|
||||||
|
#endif // GGML_CUDA_F16
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
|
||||||
|
if (use_f16_soft_max) {
|
||||||
|
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
||||||
|
} else {
|
||||||
|
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
||||||
|
}
|
||||||
|
|
||||||
(void) dst;
|
(void) dst;
|
||||||
}
|
}
|
||||||
|
@ -450,7 +450,7 @@ struct test_case {
|
|||||||
|
|
||||||
double err = nmse(f1.data(), f2.data(), f1.size());
|
double err = nmse(f1.data(), f2.data(), f1.size());
|
||||||
if (err > ud->max_err) {
|
if (err > ud->max_err) {
|
||||||
printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
|
printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
|
||||||
//for (int i = 0; i < (int) f1.size(); i++) {
|
//for (int i = 0; i < (int) f1.size(); i++) {
|
||||||
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
|
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
|
||||||
//}
|
//}
|
||||||
@ -1449,6 +1449,7 @@ struct test_moe : public test_case {
|
|||||||
|
|
||||||
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
|
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
|
||||||
std::vector<std::unique_ptr<test_case>> test_cases;
|
std::vector<std::unique_ptr<test_case>> test_cases;
|
||||||
|
std::default_random_engine rng(0);
|
||||||
|
|
||||||
const ggml_type all_types[] = {
|
const ggml_type all_types[] = {
|
||||||
GGML_TYPE_F32, GGML_TYPE_F16,
|
GGML_TYPE_F32, GGML_TYPE_F16,
|
||||||
@ -1583,7 +1584,19 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 1}, 5));
|
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 1}, 5));
|
||||||
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5));
|
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_soft_max());
|
std::uniform_int_distribution<> dist_ne1(1, 50);
|
||||||
|
int exponent = 1;
|
||||||
|
while (exponent < (1 << 17)) {
|
||||||
|
std::uniform_int_distribution<> dist_ne0(exponent, 2*exponent);
|
||||||
|
|
||||||
|
for (int n = 0; n < 10; ++n) {
|
||||||
|
int64_t ne0 = dist_ne0(rng);
|
||||||
|
int64_t ne1 = dist_ne1(rng);
|
||||||
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
exponent <<= 1;
|
||||||
|
}
|
||||||
|
|
||||||
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B
|
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B
|
||||||
|
Loading…
Reference in New Issue
Block a user