llama.cpp/ggml/src/ggml-cuda/common.cuh
Andreas Kieslinger 750cb3e246
CUDA: rename macros to avoid conflicts with WinAPI (#10736)
* Renames NVIDIA GPU-architecture flags to avoid name clashes with WinAPI. (e.g. CC_PASCAL, GPU architecture or WinAPI pascal compiler flag?)

* Reverts erroneous rename in SYCL-code.

* Renames GGML_CUDA_MIN_CC_DP4A to GGML_CUDA_CC_DP4A.

* Renames the rest of the compute capability macros for consistency.
2024-12-10 18:23:24 +01:00

700 lines
22 KiB
Plaintext

#pragma once
#include "ggml.h"
#include "ggml-cuda.h"
#include <cstdint>
#include <memory>
#if defined(GGML_USE_HIP)
#define GGML_COMMON_DECL_HIP
#define GGML_COMMON_IMPL_HIP
#else
#define GGML_COMMON_DECL_CUDA
#define GGML_COMMON_IMPL_CUDA
#if defined(GGML_USE_MUSA)
#define GGML_COMMON_DECL_MUSA
#define GGML_COMMON_IMPL_MUSA
#endif
#endif
#include "ggml-common.h"
#include <cstdio>
#include <array>
#include <cassert>
#include <cfloat>
#include <string>
#include <vector>
#if defined(GGML_USE_HIP)
#include "vendors/hip.h"
#elif defined(GGML_USE_MUSA)
#include "vendors/musa.h"
#else
#include "vendors/cuda.h"
#endif // defined(GGML_USE_HIP)
#define STRINGIZE_IMPL(...) #__VA_ARGS__
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
#define WARP_SIZE 32
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
#define GGML_CUDA_CC_PASCAL 600
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define GGML_CUDA_CC_VOLTA 700
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_OFFSET_AMD 1000000
// GCN/CNDA, wave size is 64
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 942) // MI300
// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 1010) // RX 5000
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA
#define GGML_CUDA_CC_QY1 210
#define GGML_CUDA_CC_QY2 220
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
#define GGML_CUDA_MAX_STREAMS 8
[[noreturn]]
void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
#define CUDA_CHECK_GEN(err, success, error_fn) \
do { \
auto err_ = (err); \
if (err_ != (success)) { \
ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_)); \
} \
} while (0)
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
static const char * cublas_get_error_str(const cublasStatus_t err) {
return cublasGetStatusString(err);
}
#else
static const char * cublas_get_error_str(const cublasStatus_t err) {
switch (err) {
case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
default: return "unknown error";
}
}
#endif // CUDART_VERSION >= 12000
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
#if !defined(GGML_USE_HIP)
static const char * cu_get_error_str(CUresult err) {
const char * err_str;
cuGetErrorString(err, &err_str);
return err_str;
}
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
#endif
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
#else
#define GGML_CUDA_ASSUME(x)
#endif // CUDART_VERSION >= 11100
#ifdef GGML_CUDA_F16
typedef half dfloat; // dequantize float
typedef half2 dfloat2;
#else
typedef float dfloat; // dequantize float
typedef float2 dfloat2;
#endif // GGML_CUDA_F16
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#define FP16_AVAILABLE
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
#define FAST_FP16_AVAILABLE
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#define FP16_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define INT8_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
#define FLASH_ATTN_AVAILABLE
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
static constexpr bool fast_fp16_available(const int cc) {
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
}
static constexpr bool fp16_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
}
static constexpr bool int8_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
}
[[noreturn]]
static __device__ void no_device_code(
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
file_name, line, function_name, arch);
GGML_UNUSED(arch_list);
#else
printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
file_name, line, function_name, arch, arch_list);
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
__trap();
GGML_UNUSED(no_device_code); // suppress unused function warning
}
#ifdef __CUDA_ARCH__
#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
#else
#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
#endif // __CUDA_ARCH__
static __device__ __forceinline__ int warp_reduce_sum(int x) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
return __reduce_add_sync(0xffffffff, x);
#else
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
}
return x;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
}
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
}
return x;
}
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32);
}
return a;
}
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#ifdef FP16_AVAILABLE
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32);
reinterpret_cast<half&>(a.x) += __low2half(a_other);
reinterpret_cast<half&>(a.y) += __high2half(a_other);
}
return a;
#else
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32));
}
return a;
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#else
NO_DEVICE_CODE;
return a;
#endif // FP16_AVAILABLE
}
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
}
return x;
}
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
#ifdef FP16_AVAILABLE
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
return __float2half(fmaxf(__half2float(a), __half2float(b)));
#else
return __hmax(a, b);
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
#else
NO_DEVICE_CODE;
GGML_UNUSED(b);
return a;
#endif // FP16_AVAILABLE
}
static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#if CUDART_VERSION >= CUDART_HMAX
return __hmax2(a, b);
#else
half2 ret;
reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
return ret;
#endif // CUDART_VERSION >= CUDART_HMAX
#else
GGML_UNUSED(a);
GGML_UNUSED(b);
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
}
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
}
return x;
#else
GGML_UNUSED(x);
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
}
#if CUDART_VERSION < CUDART_HMASK
static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
return mask_low | mask_high;
}
#endif // CUDART_VERSION < CUDART_HMASK
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
c = __builtin_amdgcn_sdot4(a, b, c, false);
#elif defined(RDNA3)
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
#elif defined(__gfx1010__) || defined(__gfx900__)
int tmp1;
int tmp2;
asm("\n \
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
v_add3_u32 %0, %1, %2, %0 \n \
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
v_add3_u32 %0, %1, %2, %0 \n \
"
: "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
: "v"(a), "v"(b)
);
#else
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
#endif
return c;
#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
return __dp4a(a, b, c);
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
const int8_t * a8 = (const int8_t *) &a;
const int8_t * b8 = (const int8_t *) &b;
return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
}
// TODO: move to ggml-common.h
static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
static __device__ __forceinline__ float get_alibi_slope(
const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
) {
if (max_bias <= 0.0f) {
return 1.0f;
}
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
return powf(base, exph);
}
template <ggml_type type>
struct ggml_cuda_type_traits;
template<>
struct ggml_cuda_type_traits<GGML_TYPE_F16> {
static constexpr int qk = 1;
static constexpr int qr = 1;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
static constexpr int qk = QK4_0;
static constexpr int qr = QR4_0;
static constexpr int qi = QI4_0;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q4_1> {
static constexpr int qk = QK4_1;
static constexpr int qr = QR4_1;
static constexpr int qi = QI4_1;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q5_0> {
static constexpr int qk = QK5_0;
static constexpr int qr = QR5_0;
static constexpr int qi = QI5_0;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
static constexpr int qk = QK5_1;
static constexpr int qr = QR5_1;
static constexpr int qi = QI5_1;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
static constexpr int qk = QK8_0;
static constexpr int qr = QR8_0;
static constexpr int qi = QI8_0;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_K;
static constexpr int qi = QI2_K;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q3_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR3_K;
static constexpr int qi = QI3_K;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_K;
static constexpr int qi = QI4_K;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q5_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR5_K;
static constexpr int qi = QI5_K;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR6_K;
static constexpr int qi = QI6_K;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_XXS;
static constexpr int qi = QI2_XXS;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_XS;
static constexpr int qi = QI2_XS;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> {
static constexpr int qk = QK_K;
static constexpr int qr = QR2_S;
static constexpr int qi = QI2_S;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR3_XXS;
static constexpr int qi = QI3_XXS;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {
static constexpr int qk = QK_K;
static constexpr int qr = QR1_S;
static constexpr int qi = QI1_S;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {
static constexpr int qk = QK_K;
static constexpr int qr = QR1_M;
static constexpr int qi = QI1_M;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
static constexpr int qk = QK4_NL;
static constexpr int qr = QR4_NL;
static constexpr int qi = QI4_NL;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
static constexpr int qi = QI4_XS;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
static constexpr int qk = QK_K;
static constexpr int qr = QR3_S;
static constexpr int qi = QI3_S;
};
//////////////////////
struct ggml_cuda_device_info {
int device_count;
struct cuda_device_info {
int cc; // compute capability
int nsm; // number of streaming multiprocessors
size_t smpb; // max. shared memory per block
size_t smpbo; // max. shared memory per block (with opt-in)
bool vmm; // virtual memory support
size_t vmm_granularity; // granularity of virtual memory
size_t total_vram;
};
cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
};
const ggml_cuda_device_info & ggml_cuda_info();
void ggml_cuda_set_device(int device);
int ggml_cuda_get_device();
struct ggml_cuda_pool {
virtual ~ggml_cuda_pool() = default;
virtual void * alloc(size_t size, size_t * actual_size) = 0;
virtual void free(void * ptr, size_t size) = 0;
};
template<typename T>
struct ggml_cuda_pool_alloc {
ggml_cuda_pool * pool = nullptr;
T * ptr = nullptr;
size_t actual_size = 0;
ggml_cuda_pool_alloc() = default;
explicit ggml_cuda_pool_alloc(ggml_cuda_pool & pool) : pool(&pool) {
}
ggml_cuda_pool_alloc(ggml_cuda_pool & pool, size_t size) : pool(&pool) {
alloc(size);
}
~ggml_cuda_pool_alloc() {
if (ptr != nullptr) {
pool->free(ptr, actual_size);
}
}
// size is in number of elements
T * alloc(size_t size) {
GGML_ASSERT(pool != nullptr);
GGML_ASSERT(ptr == nullptr);
ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
return ptr;
}
T * alloc(ggml_cuda_pool & pool, size_t size) {
this->pool = &pool;
return alloc(size);
}
T * get() {
return ptr;
}
ggml_cuda_pool_alloc(const ggml_cuda_pool_alloc &) = delete;
ggml_cuda_pool_alloc(ggml_cuda_pool_alloc &&) = delete;
ggml_cuda_pool_alloc& operator=(const ggml_cuda_pool_alloc &) = delete;
ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete;
};
// backend interface
struct ggml_tensor_extra_gpu {
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
};
#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
#define USE_CUDA_GRAPH
#endif
struct ggml_graph_node_properties {
void * node_address;
ggml_op node_op;
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
void * src_address[GGML_MAX_SRC];
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
};
struct ggml_cuda_graph {
#ifdef USE_CUDA_GRAPH
~ggml_cuda_graph() {
if (instance != nullptr) {
CUDA_CHECK(cudaGraphExecDestroy(instance));
}
if (graph != nullptr) {
CUDA_CHECK(cudaGraphDestroy(graph));
}
}
cudaGraph_t graph = nullptr;
cudaGraphExec_t instance = nullptr;
size_t num_nodes = 0;
std::vector<cudaGraphNode_t> nodes;
std::vector<cudaKernelNodeParams> params;
bool disable_due_to_gpu_arch = false;
bool disable_due_to_too_many_updates = false;
bool disable_due_to_failed_graph_capture = false;
int number_consecutive_updates = 0;
std::vector<ggml_graph_node_properties> ggml_graph_properties;
std::vector<char **> updated_kernel_arg;
#endif
};
struct ggml_backend_cuda_context {
int device;
std::string name;
cudaEvent_t copy_event = nullptr;
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
std::unique_ptr<ggml_cuda_graph> cuda_graph;
explicit ggml_backend_cuda_context(int device) :
device(device),
name(GGML_CUDA_NAME + std::to_string(device)) {
}
~ggml_backend_cuda_context() {
if (copy_event != nullptr) {
CUDA_CHECK(cudaEventDestroy(copy_event));
}
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
if (streams[i][j] != nullptr) {
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
}
}
if (cublas_handles[i] != nullptr) {
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
}
}
}
cudaStream_t stream(int device, int stream) {
if (streams[device][stream] == nullptr) {
ggml_cuda_set_device(device);
CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking));
}
return streams[device][stream];
}
cudaStream_t stream() {
return stream(device, 0);
}
cublasHandle_t cublas_handle(int device) {
if (cublas_handles[device] == nullptr) {
ggml_cuda_set_device(device);
CUBLAS_CHECK(cublasCreate(&cublas_handles[device]));
CUBLAS_CHECK(cublasSetMathMode(cublas_handles[device], CUBLAS_TF32_TENSOR_OP_MATH));
}
return cublas_handles[device];
}
cublasHandle_t cublas_handle() {
return cublas_handle(device);
}
// pool
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
ggml_cuda_pool & pool(int device) {
if (pools[device] == nullptr) {
pools[device] = new_pool_for_device(device);
}
return *pools[device];
}
ggml_cuda_pool & pool() {
return pool(device);
}
};