mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 05:48:47 +01:00
fallback to CPU buffer if host buffer alloc fails (#4610)
This commit is contained in:
parent
925e5584a0
commit
708e179e85
11
ggml-cuda.cu
11
ggml-cuda.cu
@ -6729,8 +6729,7 @@ void * ggml_cuda_host_malloc(size_t size) {
|
|||||||
void * ptr = nullptr;
|
void * ptr = nullptr;
|
||||||
cudaError_t err = cudaMallocHost((void **) &ptr, size);
|
cudaError_t err = cudaMallocHost((void **) &ptr, size);
|
||||||
if (err != cudaSuccess) {
|
if (err != cudaSuccess) {
|
||||||
// The allocation error can be bypassed. A null ptr will assigned out of this function.
|
// clear the error
|
||||||
// This can fixed the OOM error in WSL.
|
|
||||||
cudaGetLastError();
|
cudaGetLastError();
|
||||||
fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
|
fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
|
||||||
size/1024.0/1024.0, cudaGetErrorString(err));
|
size/1024.0/1024.0, cudaGetErrorString(err));
|
||||||
@ -9674,12 +9673,14 @@ ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
|
|||||||
// host buffer type
|
// host buffer type
|
||||||
|
|
||||||
static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
CUDA_CHECK(cudaFreeHost(buffer->context));
|
ggml_cuda_host_free(buffer->context);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||||
void * ptr;
|
void * ptr = ggml_cuda_host_malloc(size);
|
||||||
CUDA_CHECK(cudaMallocHost(&ptr, size));
|
if (ptr == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
// FIXME: this is a hack to avoid having to implement a new buffer type
|
// FIXME: this is a hack to avoid having to implement a new buffer type
|
||||||
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
||||||
|
16
llama.cpp
16
llama.cpp
@ -1177,21 +1177,27 @@ static std::string llama_token_to_piece(const struct llama_context * ctx, llama_
|
|||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_type_t llama_default_buffer_type(int n_gpu_layers) {
|
static ggml_backend_buffer_type_t llama_default_buffer_type(int n_gpu_layers) {
|
||||||
|
ggml_backend_buffer_type_t buft = nullptr;
|
||||||
|
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
if (n_gpu_layers > 0) {
|
if (n_gpu_layers > 0) {
|
||||||
return ggml_backend_metal_buffer_type();
|
buft = ggml_backend_metal_buffer_type();
|
||||||
}
|
}
|
||||||
#elif defined(GGML_USE_CUBLAS) && defined(LLAMA_GGML_BACKEND_CUDA_TEST)
|
#elif defined(GGML_USE_CUBLAS) && defined(LLAMA_GGML_BACKEND_CUDA_TEST)
|
||||||
if (n_gpu_layers > 0) {
|
if (n_gpu_layers > 0) {
|
||||||
return ggml_backend_cuda_buffer_type(0);
|
buft = ggml_backend_cuda_buffer_type(0);
|
||||||
}
|
}
|
||||||
#elif defined(GGML_USE_CUBLAS)
|
#elif defined(GGML_USE_CUBLAS)
|
||||||
return ggml_backend_cuda_host_buffer_type();
|
buft = ggml_backend_cuda_host_buffer_type();
|
||||||
#elif defined(GGML_USE_CPU_HBM)
|
#elif defined(GGML_USE_CPU_HBM)
|
||||||
return ggml_backend_cpu_hbm_buffer_type();
|
buft = ggml_backend_cpu_hbm_buffer_type();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return ggml_backend_cpu_buffer_type();
|
if (buft == nullptr) {
|
||||||
|
buft = ggml_backend_cpu_buffer_type();
|
||||||
|
}
|
||||||
|
|
||||||
|
return buft;
|
||||||
|
|
||||||
GGML_UNUSED(n_gpu_layers);
|
GGML_UNUSED(n_gpu_layers);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user