mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-10-29 22:20:15 +01:00
ggml : add backend registry / device interfaces to BLAS backend (#9752)
* ggml : add backend registry / device interfaces to BLAS backend * fix mmap usage when using host buffers
This commit is contained in:
parent
f1af42fa8c
commit
6374743747
@ -170,6 +170,7 @@ extern "C" {
|
||||
|
||||
// Functions that may be obtained using ggml_backend_reg_get_proc_address
|
||||
typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(const float *);
|
||||
typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t, int);
|
||||
|
||||
//
|
||||
// Backend registry
|
||||
|
@ -17,6 +17,8 @@ GGML_API bool ggml_backend_is_blas(ggml_backend_t backend);
|
||||
// for openblas and blis, this will also set the number of threads used for blas operations
|
||||
GGML_API void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
|
||||
|
||||
GGML_API ggml_backend_reg_t ggml_backend_blas_reg(void);
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
@ -190,22 +190,24 @@ if (GGML_BLAS)
|
||||
# see https://gitlab.kitware.com/cmake/cmake/-/issues/20268
|
||||
find_package(PkgConfig REQUIRED)
|
||||
if (${GGML_BLAS_VENDOR} MATCHES "Generic")
|
||||
pkg_check_modules(DepBLAS REQUIRED blas)
|
||||
pkg_check_modules(DepBLAS blas)
|
||||
elseif (${GGML_BLAS_VENDOR} MATCHES "OpenBLAS")
|
||||
# As of openblas v0.3.22, the 64-bit is named openblas64.pc
|
||||
pkg_check_modules(DepBLAS openblas64)
|
||||
if (NOT DepBLAS_FOUND)
|
||||
pkg_check_modules(DepBLAS REQUIRED openblas)
|
||||
pkg_check_modules(DepBLAS openblas)
|
||||
endif()
|
||||
elseif (${GGML_BLAS_VENDOR} MATCHES "FLAME")
|
||||
pkg_check_modules(DepBLAS REQUIRED blis)
|
||||
add_compile_definitions(GGML_BLAS_USE_BLIS)
|
||||
pkg_check_modules(DepBLAS blis)
|
||||
elseif (${GGML_BLAS_VENDOR} MATCHES "ATLAS")
|
||||
pkg_check_modules(DepBLAS REQUIRED blas-atlas)
|
||||
pkg_check_modules(DepBLAS blas-atlas)
|
||||
elseif (${GGML_BLAS_VENDOR} MATCHES "FlexiBLAS")
|
||||
pkg_check_modules(DepBLAS REQUIRED flexiblas_api)
|
||||
pkg_check_modules(DepBLAS flexiblas_api)
|
||||
elseif (${GGML_BLAS_VENDOR} MATCHES "Intel")
|
||||
add_compile_definitions(GGML_BLAS_USE_MKL)
|
||||
# all Intel* libraries share the same include path
|
||||
pkg_check_modules(DepBLAS REQUIRED mkl-sdl)
|
||||
pkg_check_modules(DepBLAS mkl-sdl)
|
||||
elseif (${GGML_BLAS_VENDOR} MATCHES "NVHPC")
|
||||
# this doesn't provide pkg-config
|
||||
# suggest to assign BLAS_INCLUDE_DIRS on your own
|
||||
|
@ -88,6 +88,7 @@ extern "C" {
|
||||
|
||||
void (*free)(ggml_backend_t backend);
|
||||
|
||||
// Will be moved to the device interface
|
||||
// buffer allocation
|
||||
ggml_backend_buffer_type_t (*get_default_buffer_type)(ggml_backend_t backend);
|
||||
|
||||
@ -112,17 +113,9 @@ extern "C" {
|
||||
|
||||
// IMPORTANT: these functions have been moved to the device interface and will be removed from the backend interface
|
||||
// new backends should implement the device interface instead
|
||||
|
||||
// These functions are being moved to the device interface
|
||||
// check if the backend can compute an operation
|
||||
bool (*supports_op) (ggml_backend_t backend, const struct ggml_tensor * op);
|
||||
|
||||
// check if the backend can use tensors allocated in a buffer type
|
||||
bool (*supports_buft)(ggml_backend_t backend, ggml_backend_buffer_type_t buft);
|
||||
|
||||
// check if the backend wants to run an operation, even if the weights are allocated in a CPU buffer
|
||||
// these should be expensive operations with large batch sizes that may benefit from running on this backend
|
||||
// even if the weight has to be copied from the CPU temporarily
|
||||
bool (*offload_op) (ggml_backend_t backend, const struct ggml_tensor * op);
|
||||
|
||||
// (optional) event synchronization
|
||||
@ -184,9 +177,8 @@ extern "C" {
|
||||
// check if the backend can use tensors allocated in a buffer type
|
||||
bool (*supports_buft)(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft);
|
||||
|
||||
// check if the backend wants to run an operation, even if the weights are allocated in a CPU buffer
|
||||
// these should be expensive operations with large batch sizes that may benefit from running on this backend
|
||||
// even if the weight has to be copied from the CPU temporarily
|
||||
// (optional) check if the backend wants to run an operation, even if the weights are allocated in an incompatible buffer
|
||||
// these should be expensive operations that may benefit from running on this backend instead of the CPU backend
|
||||
bool (*offload_op)(ggml_backend_dev_t dev, const struct ggml_tensor * op);
|
||||
|
||||
// (optional) event synchronization
|
||||
|
@ -500,7 +500,11 @@ bool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buff
|
||||
}
|
||||
|
||||
bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op) {
|
||||
return device->iface.offload_op(device, op);
|
||||
if (device->iface.offload_op != NULL) {
|
||||
return device->iface.offload_op(device, op);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Backend (reg)
|
||||
@ -534,6 +538,10 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
|
||||
#include "ggml-metal.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_BLAS
|
||||
#include "ggml-blas.h"
|
||||
#endif
|
||||
|
||||
struct ggml_backend_registry {
|
||||
std::vector<ggml_backend_reg_t> backends;
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
@ -545,10 +553,13 @@ struct ggml_backend_registry {
|
||||
#ifdef GGML_USE_METAL
|
||||
register_backend(ggml_backend_metal_reg());
|
||||
#endif
|
||||
|
||||
register_backend(ggml_backend_cpu_reg());
|
||||
#ifdef GGML_USE_BLAS
|
||||
register_backend(ggml_backend_blas_reg());
|
||||
#endif
|
||||
|
||||
// TODO: sycl, vulkan, kompute, cann
|
||||
|
||||
register_backend(ggml_backend_cpu_reg());
|
||||
}
|
||||
|
||||
void register_backend(ggml_backend_reg_t reg) {
|
||||
@ -1229,16 +1240,22 @@ static ggml_backend_dev_t ggml_backend_cpu_reg_get_device(ggml_backend_reg_t reg
|
||||
};
|
||||
|
||||
return &ggml_backend_cpu_device;
|
||||
}
|
||||
|
||||
static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
||||
if (strcmp(name, "ggml_backend_set_n_threads") == 0) {
|
||||
return (void *)ggml_backend_cpu_set_n_threads;
|
||||
}
|
||||
return NULL;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
GGML_UNUSED(index);
|
||||
}
|
||||
|
||||
static const struct ggml_backend_reg_i ggml_backend_cpu_reg_i = {
|
||||
/* .get_name = */ ggml_backend_cpu_reg_get_name,
|
||||
/* .get_device_count = */ ggml_backend_cpu_reg_get_device_count,
|
||||
/* .get_device = */ ggml_backend_cpu_reg_get_device,
|
||||
/* .get_proc_address = */ NULL,
|
||||
/* .get_proc_address = */ ggml_backend_cpu_get_proc_address,
|
||||
};
|
||||
|
||||
ggml_backend_reg_t ggml_backend_cpu_reg(void) {
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include <future>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
|
||||
#if defined(GGML_USE_ACCELERATE)
|
||||
# include <Accelerate/Accelerate.h>
|
||||
@ -26,30 +27,6 @@ struct ggml_backend_blas_context {
|
||||
#endif
|
||||
};
|
||||
|
||||
// helper function to determine if it is better to use BLAS or not
|
||||
// for large matrices, BLAS is faster
|
||||
static bool ggml_backend_blas_use_blas(const struct ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
|
||||
// TODO: find the optimal values for these
|
||||
if (ggml_is_contiguous(src0) &&
|
||||
ggml_is_contiguous(src1) &&
|
||||
src1->type == GGML_TYPE_F32 &&
|
||||
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
|
||||
|
||||
/*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
@ -235,7 +212,7 @@ static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct g
|
||||
|
||||
// backend interface
|
||||
|
||||
static const char * ggml_backend_blas_name(ggml_backend_t backend) {
|
||||
static const char * ggml_backend_blas_get_name(ggml_backend_t backend) {
|
||||
return "BLAS";
|
||||
|
||||
GGML_UNUSED(backend);
|
||||
@ -285,29 +262,8 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend,
|
||||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
static bool ggml_backend_blas_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
|
||||
return (op->op == GGML_OP_MUL_MAT && ggml_backend_blas_use_blas(op)) ||
|
||||
(op->op == GGML_OP_OUT_PROD && op->src[0]->type == GGML_TYPE_F32 &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
ggml_is_matrix(src0) &&
|
||||
ggml_is_matrix(src1) &&
|
||||
ggml_is_contiguous(src0) &&
|
||||
(ggml_is_contiguous(src1) || ggml_is_transposed(src1)));
|
||||
|
||||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
static bool ggml_backend_blas_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
||||
return ggml_backend_buft_is_host(buft);
|
||||
|
||||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
static struct ggml_backend_i blas_backend_i = {
|
||||
/* .get_name = */ ggml_backend_blas_name,
|
||||
/* .get_name = */ ggml_backend_blas_get_name,
|
||||
/* .free = */ ggml_backend_blas_free,
|
||||
/* .get_default_buffer_type = */ ggml_backend_blas_get_default_buffer_type,
|
||||
/* .set_tensor_async = */ NULL,
|
||||
@ -319,8 +275,8 @@ static struct ggml_backend_i blas_backend_i = {
|
||||
/* .graph_plan_update = */ NULL,
|
||||
/* .graph_plan_compute = */ NULL,
|
||||
/* .graph_compute = */ ggml_backend_blas_graph_compute,
|
||||
/* .supports_op = */ ggml_backend_blas_supports_op,
|
||||
/* .supports_buft = */ ggml_backend_blas_supports_buft,
|
||||
/* .supports_op = */ NULL,
|
||||
/* .supports_buft = */ NULL,
|
||||
/* .offload_op = */ NULL,
|
||||
/* .event_record = */ NULL,
|
||||
/* .event_wait = */ NULL,
|
||||
@ -337,7 +293,7 @@ ggml_backend_t ggml_backend_blas_init(void) {
|
||||
ggml_backend_t backend = new ggml_backend {
|
||||
/* .guid = */ ggml_backend_blas_guid(),
|
||||
/* .interface = */ blas_backend_i,
|
||||
/* .device = */ nullptr,
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
|
||||
/* .context = */ ctx,
|
||||
};
|
||||
|
||||
@ -364,3 +320,203 @@ void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads)
|
||||
ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context;
|
||||
ctx->n_threads = n_threads;
|
||||
}
|
||||
|
||||
// device interface
|
||||
|
||||
static const char * ggml_backend_blas_device_get_name(ggml_backend_dev_t dev) {
|
||||
return "BLAS";
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t dev) {
|
||||
#if defined(GGML_USE_ACCELERATE)
|
||||
return "Accelerate";
|
||||
#elif defined(GGML_BLAS_USE_MKL)
|
||||
return "MKL";
|
||||
#elif defined(GGML_BLAS_USE_BLIS)
|
||||
return "BLIS";
|
||||
#elif defined(GGML_BLAS_USE_NVPL)
|
||||
return "NVPL";
|
||||
#elif defined(OPENBLAS_VERSION)
|
||||
return "OpenBLAS";
|
||||
#else
|
||||
return "BLAS";
|
||||
#endif
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
// TODO
|
||||
*free = 0;
|
||||
*total = 0;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static enum ggml_backend_dev_type ggml_backend_blas_device_get_type(ggml_backend_dev_t dev) {
|
||||
return GGML_BACKEND_DEVICE_TYPE_CPU;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static void ggml_backend_blas_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
||||
props->name = ggml_backend_blas_device_get_name(dev);
|
||||
props->description = ggml_backend_blas_device_get_description(dev);
|
||||
props->type = ggml_backend_blas_device_get_type(dev);
|
||||
ggml_backend_blas_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
props->caps = {
|
||||
/* .async = */ false,
|
||||
/* .host_buffer = */ false,
|
||||
/* .buffer_from_host_ptr = */ true,
|
||||
/* .events = */ false,
|
||||
};
|
||||
}
|
||||
|
||||
static ggml_backend_t ggml_backend_blas_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||
return ggml_backend_blas_init();
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
GGML_UNUSED(params);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_blas_device_get_buffer_type(ggml_backend_dev_t dev) {
|
||||
return ggml_backend_cpu_buffer_type();
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
||||
return ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
GGML_UNUSED(max_tensor_size);
|
||||
}
|
||||
|
||||
static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
return true;
|
||||
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
// BLAS usually is only faster for large matrices
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
|
||||
const int64_t ne0 = op->ne[0];
|
||||
const int64_t ne1 = op->ne[1];
|
||||
|
||||
// TODO: find the optimal value
|
||||
const int64_t min_batch = 32;
|
||||
|
||||
return (ggml_is_contiguous(src0) &&
|
||||
ggml_is_contiguous(src1) &&
|
||||
src1->type == GGML_TYPE_F32 &&
|
||||
(ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch));
|
||||
}
|
||||
|
||||
case GGML_OP_OUT_PROD:
|
||||
return (op->src[0]->type == GGML_TYPE_F32 &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
ggml_is_matrix(src0) &&
|
||||
ggml_is_matrix(src1) &&
|
||||
ggml_is_contiguous(src0) &&
|
||||
(ggml_is_contiguous(src1) || ggml_is_transposed(src1)));
|
||||
|
||||
default:
|
||||
return false;
|
||||
|
||||
}
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static bool ggml_backend_blas_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||
return ggml_backend_buft_is_host(buft);
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static const struct ggml_backend_device_i ggml_backend_blas_device_i = {
|
||||
/* .get_name = */ ggml_backend_blas_device_get_name,
|
||||
/* .get_description = */ ggml_backend_blas_device_get_description,
|
||||
/* .get_memory = */ ggml_backend_blas_device_get_memory,
|
||||
/* .get_type = */ ggml_backend_blas_device_get_type,
|
||||
/* .get_props = */ ggml_backend_blas_device_get_props,
|
||||
/* .init_backend = */ ggml_backend_blas_device_init,
|
||||
/* .get_buffer_type = */ ggml_backend_blas_device_get_buffer_type,
|
||||
/* .get_host_buffer_type = */ NULL,
|
||||
/* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_ptr,
|
||||
/* .supports_op = */ ggml_backend_blas_device_supports_op,
|
||||
/* .supports_buft = */ ggml_backend_blas_device_supports_buft,
|
||||
/* .offload_op = */ NULL,
|
||||
/* .event_new = */ NULL,
|
||||
/* .event_free = */ NULL,
|
||||
/* .event_synchronize = */ NULL,
|
||||
};
|
||||
|
||||
// backend reg interface
|
||||
|
||||
static const char * ggml_backend_blas_reg_get_name(ggml_backend_reg_t reg) {
|
||||
return "BLAS";
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_blas_reg_get_device_count(ggml_backend_reg_t reg) {
|
||||
return 1;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
}
|
||||
|
||||
static ggml_backend_dev_t ggml_backend_blas_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
||||
GGML_ASSERT(index == 0);
|
||||
|
||||
static ggml_backend_device ggml_backend_blas_device = {
|
||||
/* .iface = */ ggml_backend_blas_device_i,
|
||||
/* .reg = */ reg,
|
||||
/* .context = */ nullptr,
|
||||
};
|
||||
|
||||
return &ggml_backend_blas_device;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
GGML_UNUSED(index);
|
||||
}
|
||||
|
||||
static void * ggml_backend_blas_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
||||
if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
|
||||
return (void *)ggml_backend_blas_set_n_threads;
|
||||
}
|
||||
return NULL;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
GGML_UNUSED(name);
|
||||
}
|
||||
|
||||
static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
|
||||
/* .get_name = */ ggml_backend_blas_reg_get_name,
|
||||
/* .get_device_count = */ ggml_backend_blas_reg_get_device_count,
|
||||
/* .get_device = */ ggml_backend_blas_reg_get_device,
|
||||
/* .get_proc_address = */ ggml_backend_blas_get_proc_address,
|
||||
};
|
||||
|
||||
ggml_backend_reg_t ggml_backend_blas_reg(void) {
|
||||
static struct ggml_backend_reg ggml_backend_blas_reg = {
|
||||
/* .iface = */ ggml_backend_blas_reg_i,
|
||||
/* .context = */ NULL,
|
||||
};
|
||||
|
||||
return &ggml_backend_blas_reg;
|
||||
}
|
||||
|
@ -22,10 +22,6 @@
|
||||
# include "ggml-cann.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_BLAS
|
||||
# include "ggml-blas.h"
|
||||
#endif
|
||||
|
||||
// TODO: replace with ggml API call
|
||||
#define QK_K 256
|
||||
|
||||
@ -3288,9 +3284,8 @@ struct llama_context {
|
||||
std::unordered_map<struct llama_lora_adapter *, float> lora_adapters;
|
||||
|
||||
std::vector<ggml_backend_t> backends;
|
||||
#ifdef GGML_USE_BLAS
|
||||
ggml_backend_t backend_blas = nullptr;
|
||||
#endif
|
||||
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
|
||||
|
||||
ggml_backend_t backend_cpu = nullptr;
|
||||
|
||||
ggml_threadpool_t threadpool = nullptr;
|
||||
@ -8908,7 +8903,8 @@ static bool llm_load_tensors(
|
||||
bufs.reserve(n_max_backend_buffer);
|
||||
|
||||
// check if this backend device supports buffer_from_host_ptr
|
||||
ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
|
||||
// when using a host buffer as the CPU bakcend buffer, use the CPU device to prioritize using buffer_from_host_ptr over the host buffer
|
||||
ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft == llama_default_buffer_type_cpu(model, true) ? ggml_backend_cpu_buffer_type() : buft);
|
||||
bool buffer_from_host_ptr_supported = false;
|
||||
if (dev) {
|
||||
ggml_backend_dev_props props;
|
||||
@ -17048,17 +17044,19 @@ static void llama_graph_compute(
|
||||
int n_threads,
|
||||
ggml_threadpool * threadpool) {
|
||||
if (lctx.backend_cpu != nullptr) {
|
||||
ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
|
||||
ggml_backend_cpu_set_threadpool(lctx.backend_cpu, threadpool);
|
||||
ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
|
||||
}
|
||||
#ifdef GGML_USE_BLAS
|
||||
if (lctx.backend_blas != nullptr) {
|
||||
ggml_backend_blas_set_n_threads(lctx.backend_blas, n_threads);
|
||||
}
|
||||
#endif
|
||||
|
||||
ggml_backend_sched_graph_compute_async(lctx.sched, gf);
|
||||
// set the number of threads for all the backends
|
||||
for (const auto & set_n_threads_fn : lctx.set_n_threads_fns) {
|
||||
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
|
||||
}
|
||||
|
||||
auto err = ggml_backend_sched_graph_compute_async(lctx.sched, gf);
|
||||
if (err != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, err);
|
||||
}
|
||||
|
||||
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
|
||||
}
|
||||
@ -19110,9 +19108,16 @@ struct llama_model * llama_load_model_from_file(
|
||||
// TODO: rework API to give user more control over device selection
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||
// skip the CPU backend since it is handled separately
|
||||
if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU_FULL) {
|
||||
model->devices.push_back(dev);
|
||||
switch (ggml_backend_dev_type(dev)) {
|
||||
case GGML_BACKEND_DEVICE_TYPE_CPU:
|
||||
case GGML_BACKEND_DEVICE_TYPE_CPU_FULL:
|
||||
// skip CPU backends since they are `handled separately
|
||||
break;
|
||||
|
||||
case GGML_BACKEND_DEVICE_TYPE_GPU:
|
||||
case GGML_BACKEND_DEVICE_TYPE_GPU_FULL:
|
||||
model->devices.push_back(dev);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@ -19407,14 +19412,19 @@ struct llama_context * llama_new_context_with_model(
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_BLAS
|
||||
ctx->backend_blas = ggml_backend_blas_init();
|
||||
if (ctx->backend_blas == nullptr) {
|
||||
LLAMA_LOG_WARN("%s: failed to initialize BLAS backend\n", __func__);
|
||||
} else {
|
||||
ctx->backends.push_back(ctx->backend_blas);
|
||||
// add other backends (such as BLAS)
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
|
||||
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
ctx->backends.push_back(backend);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
ctx->backend_cpu = ggml_backend_cpu_init();
|
||||
if (ctx->backend_cpu == nullptr) {
|
||||
@ -19424,6 +19434,18 @@ struct llama_context * llama_new_context_with_model(
|
||||
}
|
||||
ctx->backends.push_back(ctx->backend_cpu);
|
||||
|
||||
// create a list of the set_n_threads functions in the backends
|
||||
for (auto * backend : ctx->backends) {
|
||||
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
|
||||
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
|
||||
if (reg) {
|
||||
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
||||
if (ggml_backend_set_n_threads_fn) {
|
||||
ctx->set_n_threads_fns.emplace_back(backend, ggml_backend_set_n_threads_fn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
||||
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
llama_free(ctx);
|
||||
|
@ -3820,9 +3820,11 @@ int main(int argc, char ** argv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_backend_is_cpu(backend)) {
|
||||
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
|
||||
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
||||
if (ggml_backend_set_n_threads_fn) {
|
||||
// TODO: better value for n_threads
|
||||
ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency() / 2);
|
||||
ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
printf(" Device description: %s\n", ggml_backend_dev_description(dev));
|
||||
|
Loading…
Reference in New Issue
Block a user