musa: enable building fat binaries, enable unified memory, and disable Flash Attention on QY1 (MTT S80) (#9526)

* mtgpu: add mp_21 support

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* mtgpu: disable flash attention on qy1 (MTT S80); disable q3_k and mul_mat_batched_cublas

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* mtgpu: enable unified memory

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* mtgpu: map cublasOperation_t to mublasOperation_t (sync code to latest)

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

---------

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
This commit is contained in:
R0CKSTAR 2024-09-22 22:55:49 +08:00 committed by GitHub
parent 912c331d3d
commit c35e586ea5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 31 additions and 5 deletions

View File

@ -611,7 +611,7 @@ ifdef GGML_CUDA
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include
MK_LDFLAGS += -lmusa -lmublas -lmusart -lpthread -ldl -lrt -L$(CUDA_PATH)/lib -L/usr/lib64 MK_LDFLAGS += -lmusa -lmublas -lmusart -lpthread -ldl -lrt -L$(CUDA_PATH)/lib -L/usr/lib64
MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_22 MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_21 --cuda-gpu-arch=mp_22
else else
ifneq ('', '$(wildcard /opt/cuda)') ifneq ('', '$(wildcard /opt/cuda)')
CUDA_PATH ?= /opt/cuda CUDA_PATH ?= /opt/cuda

View File

@ -364,7 +364,7 @@ if (GGML_CUDA)
if (GGML_MUSA) if (GGML_MUSA)
set_source_files_properties(${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX) set_source_files_properties(${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX)
foreach(SOURCE ${GGML_SOURCES_CUDA}) foreach(SOURCE ${GGML_SOURCES_CUDA})
set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_22") set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_21 --cuda-gpu-arch=mp_22")
endforeach() endforeach()
endif() endif()

View File

@ -136,7 +136,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
return res; return res;
#else #else
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) #if !defined(GGML_USE_HIPBLAS)
cudaError_t err; cudaError_t err;
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
{ {
@ -149,7 +149,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
return err; return err;
#else #else
return cudaMalloc(ptr, size); return cudaMalloc(ptr, size);
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) #endif // !defined(GGML_USE_HIPBLAS)
#endif #endif
} }
@ -2830,6 +2830,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) { if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
return false; return false;
} }
#ifdef GGML_USE_MUSA
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
return false;
}
#endif // GGML_USE_MUSA
switch (a->type) { switch (a->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
case GGML_TYPE_F16: case GGML_TYPE_F16:
@ -2853,6 +2859,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
#ifdef GGML_USE_MUSA
if (a->type == GGML_TYPE_Q3_K) {
return false;
}
#endif // GGML_USE_MUSA
return true; return true;
default: default:
return false; return false;
@ -2978,6 +2989,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_RWKV_WKV: case GGML_OP_RWKV_WKV:
return true; return true;
case GGML_OP_FLASH_ATTN_EXT: { case GGML_OP_FLASH_ATTN_EXT: {
#ifndef FLASH_ATTN_AVAILABLE
return false;
#endif
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
return true; return true;
} }

View File

@ -50,6 +50,8 @@
#define CC_RDNA1 (CC_OFFSET_AMD + 1010) #define CC_RDNA1 (CC_OFFSET_AMD + 1010)
#define CC_RDNA2 (CC_OFFSET_AMD + 1030) #define CC_RDNA2 (CC_OFFSET_AMD + 1030)
#define CC_RDNA3 (CC_OFFSET_AMD + 1100) #define CC_RDNA3 (CC_OFFSET_AMD + 1100)
#define CC_QY1 210
#define CC_QY2 220
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
@ -134,6 +136,10 @@ typedef float2 dfloat2;
#define INT8_MMA_AVAILABLE #define INT8_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
#define FLASH_ATTN_AVAILABLE
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
static constexpr bool fast_fp16_available(const int cc) { static constexpr bool fast_fp16_available(const int cc) {
return cc >= CC_PASCAL && cc != 610; return cc >= CC_PASCAL && cc != 610;
} }

View File

@ -44,6 +44,10 @@ static __global__ void flash_attn_tile_ext_f32(
const int ne1, const int ne1,
const int ne2, const int ne2,
const int ne3) { const int ne3) {
#ifndef FLASH_ATTN_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation: // Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) { if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE; NO_DEVICE_CODE;

View File

@ -26,6 +26,7 @@
#define cublasSetStream mublasSetStream #define cublasSetStream mublasSetStream
#define cublasSgemm mublasSgemm #define cublasSgemm mublasSgemm
#define cublasStatus_t mublasStatus_t #define cublasStatus_t mublasStatus_t
#define cublasOperation_t mublasOperation_t
#define cublasGetStatusString mublasStatus_to_string #define cublasGetStatusString mublasStatus_to_string
#define cudaDataType_t musaDataType_t #define cudaDataType_t musaDataType_t
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
@ -56,6 +57,7 @@
#define cudaLaunchHostFunc musaLaunchHostFunc #define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaMalloc musaMalloc #define cudaMalloc musaMalloc
#define cudaMallocHost musaMallocHost #define cudaMallocHost musaMallocHost
#define cudaMallocManaged musaMallocManaged
#define cudaMemcpy musaMemcpy #define cudaMemcpy musaMemcpy
#define cudaMemcpyAsync musaMemcpyAsync #define cudaMemcpyAsync musaMemcpyAsync
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync #define cudaMemcpyPeerAsync musaMemcpyPeerAsync