mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
musa: enable building fat binaries, enable unified memory, and disable Flash Attention on QY1 (MTT S80) (#9526)
* mtgpu: add mp_21 support Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * mtgpu: disable flash attention on qy1 (MTT S80); disable q3_k and mul_mat_batched_cublas Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * mtgpu: enable unified memory Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * mtgpu: map cublasOperation_t to mublasOperation_t (sync code to latest) Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> --------- Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
This commit is contained in:
parent
912c331d3d
commit
c35e586ea5
2
Makefile
2
Makefile
@ -611,7 +611,7 @@ ifdef GGML_CUDA
|
|||||||
|
|
||||||
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include
|
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include
|
||||||
MK_LDFLAGS += -lmusa -lmublas -lmusart -lpthread -ldl -lrt -L$(CUDA_PATH)/lib -L/usr/lib64
|
MK_LDFLAGS += -lmusa -lmublas -lmusart -lpthread -ldl -lrt -L$(CUDA_PATH)/lib -L/usr/lib64
|
||||||
MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_22
|
MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_21 --cuda-gpu-arch=mp_22
|
||||||
else
|
else
|
||||||
ifneq ('', '$(wildcard /opt/cuda)')
|
ifneq ('', '$(wildcard /opt/cuda)')
|
||||||
CUDA_PATH ?= /opt/cuda
|
CUDA_PATH ?= /opt/cuda
|
||||||
|
@ -364,7 +364,7 @@ if (GGML_CUDA)
|
|||||||
if (GGML_MUSA)
|
if (GGML_MUSA)
|
||||||
set_source_files_properties(${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX)
|
set_source_files_properties(${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX)
|
||||||
foreach(SOURCE ${GGML_SOURCES_CUDA})
|
foreach(SOURCE ${GGML_SOURCES_CUDA})
|
||||||
set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_22")
|
set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_21 --cuda-gpu-arch=mp_22")
|
||||||
endforeach()
|
endforeach()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -136,7 +136,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
|
|||||||
return res;
|
return res;
|
||||||
#else
|
#else
|
||||||
|
|
||||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
#if !defined(GGML_USE_HIPBLAS)
|
||||||
cudaError_t err;
|
cudaError_t err;
|
||||||
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
|
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
|
||||||
{
|
{
|
||||||
@ -149,7 +149,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
|
|||||||
return err;
|
return err;
|
||||||
#else
|
#else
|
||||||
return cudaMalloc(ptr, size);
|
return cudaMalloc(ptr, size);
|
||||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
#endif // !defined(GGML_USE_HIPBLAS)
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@ -2830,6 +2830,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||||||
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
|
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
#ifdef GGML_USE_MUSA
|
||||||
|
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
|
||||||
|
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif // GGML_USE_MUSA
|
||||||
switch (a->type) {
|
switch (a->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
@ -2853,6 +2859,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
|
#ifdef GGML_USE_MUSA
|
||||||
|
if (a->type == GGML_TYPE_Q3_K) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif // GGML_USE_MUSA
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
@ -2978,6 +2989,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||||||
case GGML_OP_RWKV_WKV:
|
case GGML_OP_RWKV_WKV:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_FLASH_ATTN_EXT: {
|
case GGML_OP_FLASH_ATTN_EXT: {
|
||||||
|
#ifndef FLASH_ATTN_AVAILABLE
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
|
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -50,6 +50,8 @@
|
|||||||
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
|
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
|
||||||
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
||||||
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
|
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
|
||||||
|
#define CC_QY1 210
|
||||||
|
#define CC_QY2 220
|
||||||
|
|
||||||
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
||||||
|
|
||||||
@ -134,6 +136,10 @@ typedef float2 dfloat2;
|
|||||||
#define INT8_MMA_AVAILABLE
|
#define INT8_MMA_AVAILABLE
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
|
||||||
|
|
||||||
|
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
|
||||||
|
#define FLASH_ATTN_AVAILABLE
|
||||||
|
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
|
||||||
|
|
||||||
static constexpr bool fast_fp16_available(const int cc) {
|
static constexpr bool fast_fp16_available(const int cc) {
|
||||||
return cc >= CC_PASCAL && cc != 610;
|
return cc >= CC_PASCAL && cc != 610;
|
||||||
}
|
}
|
||||||
|
@ -44,13 +44,17 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||||||
const int ne1,
|
const int ne1,
|
||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
|
#ifndef FLASH_ATTN_AVAILABLE
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
#endif // FLASH_ATTN_AVAILABLE
|
||||||
// Skip unused kernel variants for faster compilation:
|
// Skip unused kernel variants for faster compilation:
|
||||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||||
|
2
ggml/src/ggml-cuda/vendors/musa.h
vendored
2
ggml/src/ggml-cuda/vendors/musa.h
vendored
@ -26,6 +26,7 @@
|
|||||||
#define cublasSetStream mublasSetStream
|
#define cublasSetStream mublasSetStream
|
||||||
#define cublasSgemm mublasSgemm
|
#define cublasSgemm mublasSgemm
|
||||||
#define cublasStatus_t mublasStatus_t
|
#define cublasStatus_t mublasStatus_t
|
||||||
|
#define cublasOperation_t mublasOperation_t
|
||||||
#define cublasGetStatusString mublasStatus_to_string
|
#define cublasGetStatusString mublasStatus_to_string
|
||||||
#define cudaDataType_t musaDataType_t
|
#define cudaDataType_t musaDataType_t
|
||||||
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
||||||
@ -56,6 +57,7 @@
|
|||||||
#define cudaLaunchHostFunc musaLaunchHostFunc
|
#define cudaLaunchHostFunc musaLaunchHostFunc
|
||||||
#define cudaMalloc musaMalloc
|
#define cudaMalloc musaMalloc
|
||||||
#define cudaMallocHost musaMallocHost
|
#define cudaMallocHost musaMallocHost
|
||||||
|
#define cudaMallocManaged musaMallocManaged
|
||||||
#define cudaMemcpy musaMemcpy
|
#define cudaMemcpy musaMemcpy
|
||||||
#define cudaMemcpyAsync musaMemcpyAsync
|
#define cudaMemcpyAsync musaMemcpyAsync
|
||||||
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
|
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
|
||||||
|
Loading…
Reference in New Issue
Block a user