cuda : improve cuda pool efficiency using virtual memory (#4606)

* cuda : improve cuda pool efficiency using virtual memory

* fix mixtral

* fix cmake build

* check for vmm support, disable for hip

ggml-ci

* fix hip build

* clarify granularity

* move all caps to g_device_caps

* refactor error checking

* add cuda_pool_alloc, refactor most pool allocations

ggml-ci

* fix hip build

* CUBLAS_TF32_TENSOR_OP_MATH is not a macro

* more hip crap

* llama : fix msvc warnings

* ggml : fix msvc warnings

* minor

* minor

* cuda : fallback to CPU on host buffer alloc fail

* Update ggml-cuda.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml-cuda.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* ensure allocations are always aligned

* act_size -> actual_size

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
slaren 2023-12-24 14:34:22 +01:00 committed by GitHub
parent 708e179e85
commit 5bf3953d7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 328 additions and 208 deletions

View File

@ -302,6 +302,8 @@ if (LLAMA_CUBLAS)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
endif()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver)
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
# 52 == lowest CUDA 12 standard
# 60 == f16 CUDA intrinsics

View File

@ -367,17 +367,15 @@ endif # LLAMA_BLIS
ifdef LLAMA_CUBLAS
MK_CPPFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include -I/usr/local/cuda/targets/aarch64-linux/include
MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib -L/usr/lib/wsl/lib
OBJS += ggml-cuda.o
MK_NVCCFLAGS = -use_fast_math
ifndef JETSON_EOL_MODULE_DETECT
MK_NVCCFLAGS += --forward-unknown-to-host-compiler
endif # JETSON_EOL_MODULE_DETECT
ifdef LLAMA_DEBUG
MK_NVCCFLAGS += -lineinfo
endif
endif # LLAMA_DEBUG
ifdef LLAMA_CUDA_NVCC
NVCC = $(LLAMA_CUDA_NVCC)
else

View File

@ -297,7 +297,7 @@ static void ggml_backend_registry_init(void) {
void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
GGML_ASSERT(ggml_backend_registry_count < GGML_MAX_BACKENDS_REG);
int id = ggml_backend_registry_count;
size_t id = ggml_backend_registry_count;
ggml_backend_registry[id] = (struct ggml_backend_reg) {
/* .name = */ {0},
@ -330,6 +330,8 @@ size_t ggml_backend_reg_find_by_name(const char * name) {
return i;
}
}
// not found
return SIZE_MAX;
}
@ -340,15 +342,15 @@ ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str)
const char * params = strchr(backend_str, ':');
char backend_name[128];
if (params == NULL) {
strcpy(backend_name, backend_str);
snprintf(backend_name, sizeof(backend_name), "%s", backend_str);
params = "";
} else {
strncpy(backend_name, backend_str, params - backend_str);
backend_name[params - backend_str] = '\0';
snprintf(backend_name, sizeof(backend_name), "%.*s", (int)(params - backend_str), backend_str);
params++;
}
size_t backend_i = ggml_backend_reg_find_by_name(backend_name);
if (backend_i == SIZE_MAX) {
fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name);
return NULL;
@ -396,18 +398,12 @@ static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
}
static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
memcpy((char *)tensor->data + offset, data, size);
GGML_UNUSED(buffer);
}
static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
memcpy(data, (const char *)tensor->data + offset, size);
GGML_UNUSED(buffer);

View File

@ -86,17 +86,28 @@
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#define __trap abort
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
#else
#include <cuda_runtime.h>
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
// CUDA 10.2 does not have these macro definitions.
#ifndef CUBLAS_TF32_TENSOR_OP_MATH
#if CUDART_VERSION < 11020
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define cublasComputeType_t cudaDataType_t
#endif
#endif // CUDART_VERSION < 11020
#endif // defined(GGML_USE_HIPBLAS)
#include "ggml-cuda.h"
@ -200,45 +211,45 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
#define CUDA_CHECK(err) \
do { \
cudaError_t err_ = (err); \
if (err_ != cudaSuccess) { \
int id; \
cudaGetDevice(&id); \
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
cudaGetErrorString(err_)); \
fprintf(stderr, "current device: %d\n", id); \
GGML_ASSERT(!"CUDA error"); \
} \
} while (0)
#if CUDART_VERSION >= 12000
#define CUBLAS_CHECK(err) \
do { \
cublasStatus_t err_ = (err); \
if (err_ != CUBLAS_STATUS_SUCCESS) { \
int id; \
cudaGetDevice(&id); \
fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
fprintf(stderr, "current device: %d\n", id); \
GGML_ASSERT(!"cuBLAS error"); \
} \
} while (0)
static const char * cublas_get_error_str(const cublasStatus_t err) {
return cublasGetStatusString(err);
}
#else
#define CUBLAS_CHECK(err) \
do { \
cublasStatus_t err_ = (err); \
if (err_ != CUBLAS_STATUS_SUCCESS) { \
int id; \
cudaGetDevice(&id); \
fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
fprintf(stderr, "current device: %d\n", id); \
GGML_ASSERT(!"cuBLAS error"); \
} \
} while (0)
#endif // CUDART_VERSION >= 11
static const char * cublas_get_error_str(const cublasStatus_t err) {
switch (err) {
case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
default: return "unknown error";
}
}
#endif // CUDART_VERSION >= 12000
[[noreturn]]
static void ggml_cuda_error(const char * stmt, const char * func, const char * file, const int line, const char * msg) {
fprintf(stderr, "CUDA error: %s: %s\n", stmt, msg);
fprintf(stderr, " in function %s at %s:%d\n", func, file, line);
GGML_ASSERT(!"CUDA error");
}
#define CUDA_CHECK(err) do { auto err_ = (err); if (err_ != cudaSuccess) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cudaGetErrorString(err_)); } while (0)
#define CUBLAS_CHECK(err) do { auto err_ = (err); if (err_ != CUBLAS_STATUS_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cublas_get_error_str(err_)); } while (0)
#if !defined(GGML_USE_HIPBLAS)
static const char * cu_get_error_str(CUresult err) {
const char * err_str;
cuGetErrorString(err, &err_str);
return err_str;
}
#define CU_CHECK(err) do { auto err_ = (err); if (err_ != CUDA_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cu_get_error_str(err_)); } while (0)
#endif
#if CUDART_VERSION >= 11100
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
@ -516,9 +527,17 @@ inline cudaError_t ggml_cuda_set_device(const int device) {
static int g_device_count = -1;
static int g_main_device = 0;
static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
struct cuda_device_capabilities {
int cc; // compute capability
bool vmm; // virtual memory support
size_t vmm_granularity; // granularity of virtual memory
};
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
static void * g_scratch_buffer = nullptr;
static size_t g_scratch_size = 0; // disabled by default
static size_t g_scratch_offset = 0;
@ -5875,7 +5894,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
int id;
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id];
const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
@ -5920,7 +5939,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
int id;
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id];
const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
@ -5965,7 +5984,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
int id;
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id];
const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
@ -6010,7 +6029,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
int id;
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id];
const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
@ -6055,7 +6074,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
int id;
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id];
const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
@ -6100,7 +6119,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
int id;
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id];
const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
@ -6147,7 +6166,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
int id;
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id];
const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
@ -6193,7 +6212,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
int id;
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id];
const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
@ -6238,7 +6257,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
int id;
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id];
const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
@ -6283,7 +6302,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
int id;
CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id];
const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) {
@ -6543,21 +6562,24 @@ struct scoped_spin_lock {
scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
};
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
// #define DEBUG_CUDA_MALLOC
struct cuda_buffer {
void * ptr = nullptr;
size_t size = 0;
};
static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
static size_t g_cuda_pool_size[GGML_CUDA_MAX_DEVICES] = {0};
static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));
#ifdef DEBUG_CUDA_MALLOC
int nnz = 0;
size_t max_size = 0, tot_size = 0;
size_t max_size = 0;
#endif
size_t best_diff = 1ull << 36;
int ibest = -1;
@ -6566,7 +6588,6 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
if (b.ptr != nullptr) {
#ifdef DEBUG_CUDA_MALLOC
++nnz;
tot_size += b.size;
if (b.size > max_size) max_size = b.size;
#endif
if (b.size >= size) {
@ -6593,19 +6614,20 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
b.size = 0;
return ptr;
}
#ifdef DEBUG_CUDA_MALLOC
fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz,
(uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024));
#endif
void * ptr;
size_t look_ahead_size = (size_t) (1.05 * size);
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
*actual_size = look_ahead_size;
g_cuda_pool_size[id] += look_ahead_size;
#ifdef DEBUG_CUDA_MALLOC
fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
(uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
#endif
return ptr;
}
static void ggml_cuda_pool_free(void * ptr, size_t size) {
static void ggml_cuda_pool_free_leg(void * ptr, size_t size) {
scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));
@ -6620,8 +6642,152 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
}
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
CUDA_CHECK(cudaFree(ptr));
g_cuda_pool_size[id] -= size;
}
#if !defined(GGML_USE_HIPBLAS)
// pool with virtual memory
static std::vector<CUmemGenericAllocationHandle> g_cuda_pool_handles[GGML_CUDA_MAX_DEVICES];
static CUdeviceptr g_cuda_pool_addr[GGML_CUDA_MAX_DEVICES] = {0};
static size_t g_cuda_pool_used[GGML_CUDA_MAX_DEVICES] = {0};
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 36; // 64 GB
static void * ggml_cuda_pool_malloc_vmm(size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));
// round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
const size_t alignment = 128;
size = alignment * ((size + alignment - 1) / alignment);
size_t avail = g_cuda_pool_size[id] - g_cuda_pool_used[id];
if (size > avail) {
// round up to the next multiple of the granularity
size_t reserve_size = size - avail;
const size_t granularity = g_device_caps[id].vmm_granularity;
reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);
GGML_ASSERT(g_cuda_pool_size[id] + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
// allocate more physical memory
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = id;
CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
// reserve virtual address space (if not already reserved)
if (g_cuda_pool_addr[id] == 0) {
CU_CHECK(cuMemAddressReserve(&g_cuda_pool_addr[id], CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
}
// map at the end of the pool
CU_CHECK(cuMemMap(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, 0, handle, 0));
// set access
CUmemAccessDesc access = {};
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access.location.id = id;
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CU_CHECK(cuMemSetAccess(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, &access, 1));
// add to the pool
g_cuda_pool_handles[id].push_back(handle);
g_cuda_pool_size[id] += reserve_size;
//printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
// id, (unsigned long long) (g_cuda_pool_size[id]/1024/1024),
// (unsigned long long) (reserve_size/1024/1024));
}
GGML_ASSERT(g_cuda_pool_addr[id] != 0);
void * ptr = (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]);
*actual_size = size;
g_cuda_pool_used[id] += size;
#ifdef DEBUG_CUDA_MALLOC
printf("cuda pool[%d]: allocated %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr);
#endif
return ptr;
}
static void ggml_cuda_pool_free_vmm(void * ptr, size_t size) {
scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));
#ifdef DEBUG_CUDA_MALLOC
printf("cuda pool[%d]: freed %llu bytes at %llx\n", id, (unsigned long long) size, ptr);
#endif
g_cuda_pool_used[id] -= size;
// all deallocations must be in reverse order of the allocations
GGML_ASSERT(ptr == (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]));
}
static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
int id;
CUDA_CHECK(cudaGetDevice(&id));
if (g_device_caps[id].vmm) {
return ggml_cuda_pool_malloc_vmm(size, actual_size);
} else {
return ggml_cuda_pool_malloc_leg(size, actual_size);
}
}
static void ggml_cuda_pool_free(void * ptr, size_t size) {
int id;
CUDA_CHECK(cudaGetDevice(&id));
if (g_device_caps[id].vmm) {
ggml_cuda_pool_free_vmm(ptr, size);
} else {
ggml_cuda_pool_free_leg(ptr, size);
}
}
#else
#define ggml_cuda_pool_malloc ggml_cuda_pool_malloc_leg
#define ggml_cuda_pool_free ggml_cuda_pool_free_leg
#endif // !defined(GGML_USE_HIPBLAS)
template<typename T>
struct cuda_pool_alloc {
T * ptr = nullptr;
size_t actual_size = 0;
// size is in number of elements
T * alloc(size_t size) {
GGML_ASSERT(ptr == nullptr);
ptr = (T *) ggml_cuda_pool_malloc(size * sizeof(T), &this->actual_size);
return ptr;
}
cuda_pool_alloc(size_t size) {
alloc(size);
}
~cuda_pool_alloc() {
if (ptr != nullptr) {
ggml_cuda_pool_free(ptr, actual_size);
}
}
T * get() {
return ptr;
}
cuda_pool_alloc() = default;
cuda_pool_alloc(const cuda_pool_alloc &) = delete;
cuda_pool_alloc(cuda_pool_alloc &&) = delete;
cuda_pool_alloc& operator=(const cuda_pool_alloc &) = delete;
cuda_pool_alloc& operator=(cuda_pool_alloc &&) = delete;
};
static bool g_cublas_loaded = false;
bool ggml_cublas_loaded(void) {
@ -6660,16 +6826,33 @@ void ggml_init_cublas() {
#endif
fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
for (int id = 0; id < g_device_count; ++id) {
int device_vmm = 0;
#if !defined(GGML_USE_HIPBLAS)
CUdevice device;
CU_CHECK(cuDeviceGet(&device, id));
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
if (device_vmm) {
CUmemAllocationProp alloc_prop = {};
alloc_prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
alloc_prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
alloc_prop.location.id = id;
CU_CHECK(cuMemGetAllocationGranularity(&g_device_caps[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
}
#endif // !defined(GGML_USE_HIPBLAS)
g_device_caps[id].vmm = !!device_vmm;
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
fprintf(stderr, " Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor);
fprintf(stderr, " Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
g_tensor_split[id] = total_vram;
total_vram += prop.totalGlobalMem;
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
g_compute_capabilities[id] = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
g_device_caps[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
#else
g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}
for (int id = 0; id < g_device_count; ++id) {
@ -7178,11 +7361,11 @@ static int64_t get_row_rounding(ggml_type type) {
int64_t max_compute_capability = INT_MIN;
for (int64_t id = 0; id < g_device_count; ++id) {
if (g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
if (min_compute_capability > g_compute_capabilities[id]) {
min_compute_capability = g_compute_capabilities[id];
if (min_compute_capability > g_device_caps[id].cc) {
min_compute_capability = g_device_caps[id].cc;
}
if (max_compute_capability < g_compute_capabilities[id]) {
max_compute_capability = g_compute_capabilities[id];
if (max_compute_capability < g_device_caps[id].cc) {
max_compute_capability = g_device_caps[id].cc;
}
}
}
@ -7297,8 +7480,8 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
#ifdef GGML_CUDA_F16
size_t ash;
dfloat * src1_dfloat = nullptr; // dfloat == half
cuda_pool_alloc<half> src1_dfloat_a;
half * src1_dfloat = nullptr; // dfloat == half
bool src1_convert_f16 =
src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
@ -7306,7 +7489,7 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
if (src1_convert_f16) {
src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
src1_dfloat = src1_dfloat_a.alloc(ne00);
ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00,
ne00, 1, sizeof(float), 0, 0,
ne00, 1, sizeof(half), 0, 0, stream);
@ -7354,12 +7537,6 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
break;
}
#ifdef GGML_CUDA_F16
if (src1_convert_f16) {
ggml_cuda_pool_free(src1_dfloat, ash);
}
#endif // GGML_CUDA_F16
(void) src1;
(void) dst;
(void) src1_ddq_i;
@ -7390,33 +7567,30 @@ inline void ggml_cuda_op_mul_mat_cublas(
// ldc == nrows of the matrix that cuBLAS writes into
int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
const int compute_capability = g_compute_capabilities[id];
const int compute_capability = g_device_caps[id].cc;
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
half * src0_as_f16 = nullptr;
size_t src0_as = 0;
cuda_pool_alloc<half> src0_as_f16;
if (src0->type != GGML_TYPE_F16) {
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
size_t ne = row_diff*ne00;
src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
src0_as_f16.alloc(ne);
to_fp16_cuda(src0_dd_i, src0_as_f16.get(), ne, stream);
}
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
half * src1_as_f16 = nullptr;
size_t src1_as = 0;
cuda_pool_alloc<half> src1_as_f16;
if (src1->type != GGML_TYPE_F16) {
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
size_t ne = src1_ncols*ne10;
src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
src1_as_f16.alloc(ne);
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
}
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
cuda_pool_alloc<half> dst_f16(row_diff*src1_ncols);
const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;
@ -7427,34 +7601,23 @@ inline void ggml_cuda_op_mul_mat_cublas(
row_diff, src1_ncols, ne10,
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
src1_ptr, CUDA_R_16F, ne10,
&beta_f16, dst_f16, CUDA_R_16F, ldc,
&beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
ggml_cuda_pool_free(dst_f16, dst_as);
if (src0_as != 0) {
ggml_cuda_pool_free(src0_as_f16, src0_as);
}
if (src1_as != 0) {
ggml_cuda_pool_free(src1_as_f16, src1_as);
}
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
}
else {
float * src0_ddq_as_f32 = nullptr;
size_t src0_as = 0;
cuda_pool_alloc<float> src0_ddq_as_f32;
if (src0->type != GGML_TYPE_F32) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
GGML_ASSERT(to_fp32_cuda != nullptr);
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
src0_ddq_as_f32.alloc(row_diff*ne00);
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
}
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
const float alpha = 1.0f;
const float beta = 0.0f;
@ -7466,10 +7629,6 @@ inline void ggml_cuda_op_mul_mat_cublas(
&alpha, src0_ddf_i, ne00,
src1_ddf_i, ne10,
&beta, dst_dd_i, ldc));
if (src0_as != 0) {
ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
}
}
(void) dst;
@ -7761,18 +7920,17 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
float * src1_ddf = nullptr;
float * dst_ddf = nullptr;
// as = actual size
size_t src0_asf = 0;
size_t src1_asf = 0;
size_t dst_asf = 0;
cuda_pool_alloc<float> src0_f;
cuda_pool_alloc<float> src1_f;
cuda_pool_alloc<float> dst_f;
ggml_cuda_set_device(g_main_device);
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
if (src0_on_device) {
src0_ddf = (float *) src0_extra->data_device[g_main_device];
} else {
src0_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_asf);
src0_ddf = src0_f.alloc(ggml_nelements(src0));
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
}
@ -7780,14 +7938,14 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
if (src1_on_device) {
src1_ddf = (float *) src1_extra->data_device[g_main_device];
} else {
src1_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf);
src1_ddf = src1_f.alloc(ggml_nelements(src1));
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream));
}
}
if (dst_on_device) {
dst_ddf = (float *) dst_extra->data_device[g_main_device];
} else {
dst_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(dst), &dst_asf);
dst_ddf = dst_f.alloc(ggml_nelements(dst));
}
// do the computation
@ -7799,16 +7957,6 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
}
if (src0_asf > 0) {
ggml_cuda_pool_free(src0_ddf, src0_asf);
}
if (src1_asf > 0) {
ggml_cuda_pool_free(src1_ddf, src1_asf);
}
if (dst_asf > 0) {
ggml_cuda_pool_free(dst_ddf, dst_asf);
}
if (dst->backend == GGML_BACKEND_CPU) {
CUDA_CHECK(cudaDeviceSynchronize());
}
@ -8122,17 +8270,17 @@ static void ggml_cuda_op_mul_mat(
CUDA_CHECK(ggml_cuda_set_device(id));
// free buffers again when done
if (src0_as[id] > 0) {
ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
}
if (src1_asf[id] > 0) {
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
if (dst_as[id] > 0) {
ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
}
if (src1_asq[id] > 0) {
ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
}
if (dst_as[id] > 0) {
ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
if (src1_asf[id] > 0) {
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
}
if (src0_as[id] > 0) {
ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
}
}
@ -8385,14 +8533,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
size_t src1_as = 0;
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
cuda_pool_alloc<half> src1_as_f16(ne1);
to_fp16_cuda(src1_ddf, src1_as_f16.get(), ne1, main_stream);
size_t dst_as = 0;
half * dst_f16 = nullptr;
char * dst_t = nullptr;
cuda_pool_alloc<half> dst_f16;
char * dst_t;
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
cudaDataType_t cu_data_type = CUDA_R_16F;
@ -8411,8 +8556,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
const void * beta = &beta_f16;
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
dst_t = (char *) dst_f16;
dst_t = (char *) dst_f16.alloc(ne);
nbd2 /= sizeof(float) / sizeof(half);
nbd3 /= sizeof(float) / sizeof(half);
@ -8460,7 +8604,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
(const char *) src1_as_f16.get(), CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC
ne12*ne13,
cu_compute_type,
@ -8469,19 +8613,13 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
// use cublasGemmBatchedEx
const int ne23 = ne12*ne13;
const void ** ptrs_src = nullptr;
void ** ptrs_dst = nullptr;
size_t ptrs_src_s = 0;
size_t ptrs_dst_s = 0;
ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
cuda_pool_alloc<const void *> ptrs_src(2*ne23);
cuda_pool_alloc< void *> ptrs_dst(1*ne23);
dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
src0_as_f16, src1_as_f16, dst_t,
ptrs_src, ptrs_dst,
src0_as_f16, src1_as_f16.get(), dst_t,
ptrs_src.get(), ptrs_dst.get(),
ne12, ne13,
ne23,
nb02, nb03,
@ -8493,30 +8631,19 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUBLAS_CHECK(
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
beta, ( void **) (ptrs_dst + 0*ne23), cu_data_type, ne01,
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
ne23,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
if (ptrs_src_s != 0) {
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
}
if (ptrs_dst_s != 0) {
ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
}
}
#endif
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
ggml_cuda_pool_free(dst_f16, dst_as);
to_fp32_cuda(dst_f16.get(), dst_ddf, ne, main_stream);
}
ggml_cuda_pool_free(src1_as_f16, src1_as);
}
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@ -8529,8 +8656,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
int64_t min_compute_capability = INT_MAX;
for (int64_t id = 0; id < g_device_count; ++id) {
if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
min_compute_capability = g_compute_capabilities[id];
if (min_compute_capability > g_device_caps[id].cc && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
min_compute_capability = g_device_caps[id].cc;
}
}
@ -8843,12 +8970,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
}
} else {
size_t as_src1, as_dst;
char * src1_contiguous = (char *) ggml_cuda_pool_malloc(sizeof(float)*ggml_nelements(src1), &as_src1);
char * dst_contiguous = (char *) ggml_cuda_pool_malloc(sizeof(float)*ggml_nelements(dst), &as_dst);
cuda_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
cuda_pool_alloc<char> dst_contiguous(sizeof(float)*ggml_nelements(dst));
src1_row_extra.data_device[g_main_device] = src1_contiguous;
dst_row_extra.data_device[g_main_device] = dst_contiguous;
src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_CPU ?
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
@ -8868,7 +8994,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
GGML_ASSERT(row_id >= 0 && row_id < n_as);
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous + num_src1_rows*nb11, src1_original + i01*nb11,
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11,
nb11, src1_kind, stream));
num_src1_rows++;
}
@ -8900,14 +9026,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
GGML_ASSERT(row_id >= 0 && row_id < n_as);
CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous + num_src1_rows*nb1,
CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1,
nb1, dst_kind, stream));
num_src1_rows++;
}
}
ggml_cuda_pool_free(src1_contiguous, as_src1);
ggml_cuda_pool_free(dst_contiguous, as_dst);
}
if (dst->backend == GGML_BACKEND_CPU) {
@ -9678,8 +9801,10 @@ static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buff
static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
void * ptr = ggml_cuda_host_malloc(size);
if (ptr == nullptr) {
return nullptr;
// fallback to cpu buffer
return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
}
// FIXME: this is a hack to avoid having to implement a new buffer type

2
ggml.c
View File

@ -19351,7 +19351,7 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data;
}
gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n);
free(data);
free((void *)data);
} else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) {
GGML_ASSERT(false && "nested arrays not supported");
} else {

2
ggml.h
View File

@ -255,6 +255,8 @@
#define GGML_UNREACHABLE() GGML_ASSERT(!"statement should not be reached")
#elif defined(__GNUC__)
#define GGML_UNREACHABLE() __builtin_unreachable()
#elif defined(_MSC_VER)
#define GGML_UNREACHABLE() __assume(0)
#else
#define GGML_UNREACHABLE() ((void) 0)
#endif

View File

@ -1281,7 +1281,7 @@ struct llama_hparams {
if (this->rope_finetuned != other.rope_finetuned) return true;
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
const float EPSILON = 1e-9;
const float EPSILON = 1e-9f;
if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true;
@ -10300,7 +10300,7 @@ int llama_token_to_piece(const struct llama_model * model, llama_token token, ch
std::string result = model->vocab.id_to_token[token].text;
llama_unescape_whitespace(result);
if (length < (int) result.length()) {
return -result.length();
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
@ -10330,7 +10330,7 @@ int llama_token_to_piece(const struct llama_model * model, llama_token token, ch
std::string result = model->vocab.id_to_token[token].text;
result = llama_decode_text(result);
if (length < (int) result.length()) {
return -result.length();
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();

View File

@ -883,9 +883,6 @@ int main(int argc, const char ** argv) {
srand(seed);
const int nargs = 1;
int64_t ne2[4];
ne2[0] = 1;
for (int ndims = 1; ndims <= 2; ++ndims) {
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);