mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 21:37:19 +01:00
Direct I/O and Transparent HugePages
--direct-io for bypassing page cache (and using THP on Linux) Up to 3-6x faster uncached loading, fewer pageouts, no page cache pollution.
This commit is contained in:
parent
917dc8cfa6
commit
1b17ed7ab6
@ -1072,6 +1072,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
params.use_mmap = false;
|
||||
return true;
|
||||
}
|
||||
if (arg == "--direct-io") {
|
||||
params.use_direct_io = true;
|
||||
return true;
|
||||
}
|
||||
if (arg == "--numa") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
@ -1544,6 +1548,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||
if (llama_supports_mmap()) {
|
||||
printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
|
||||
}
|
||||
if (llama_supports_direct_io()) {
|
||||
printf(" --direct-io use direct I/O (potentially faster uncached loading, fewer pageouts, no page cache pollution)\n");
|
||||
}
|
||||
printf(" --numa TYPE attempt optimizations that help on some NUMA systems\n");
|
||||
printf(" - distribute: spread execution evenly over all nodes\n");
|
||||
printf(" - isolate: only spawn threads on CPUs on the node that execution started on\n");
|
||||
@ -1844,6 +1851,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
|
||||
mparams.split_mode = params.split_mode;
|
||||
mparams.tensor_split = params.tensor_split;
|
||||
mparams.use_mmap = params.use_mmap;
|
||||
mparams.use_direct_io = params.use_direct_io;
|
||||
mparams.use_mlock = params.use_mlock;
|
||||
mparams.check_tensors = params.check_tensors;
|
||||
if (params.kv_overrides.empty()) {
|
||||
@ -2706,6 +2714,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
||||
fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict);
|
||||
fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs);
|
||||
fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false");
|
||||
fprintf(stream, "direct-io: %s # default: false\n", params.use_direct_io ? "true" : "false");
|
||||
fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false");
|
||||
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
|
||||
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
|
||||
|
@ -160,6 +160,7 @@ struct gpt_params {
|
||||
bool instruct = false; // instruction mode (used for Alpaca models)
|
||||
bool logits_all = false; // return logits for all tokens in the batch
|
||||
bool use_mmap = true; // use mmap for faster loads
|
||||
bool use_direct_io = false; // use direct I/O
|
||||
bool use_mlock = false; // use mlock to keep model in memory
|
||||
bool verbose_prompt = false; // print prompt tokens before generation
|
||||
bool display_prompt = true; // print prompt before generation
|
||||
|
@ -38,6 +38,7 @@ options:
|
||||
-nkvo, --no-kv-offload <0|1> (default: 0)
|
||||
-fa, --flash-attn <0|1> (default: 0)
|
||||
-mmp, --mmap <0|1> (default: 1)
|
||||
-dio, --direct-io <0|1> (default: 0)
|
||||
--numa <distribute|isolate|numactl> (default: disabled)
|
||||
-embd, --embeddings <0|1> (default: 0)
|
||||
-ts, --tensor-split <ts0/ts1/..> (default: 0)
|
||||
|
@ -184,6 +184,7 @@ struct cmd_params {
|
||||
std::vector<bool> flash_attn;
|
||||
std::vector<std::vector<float>> tensor_split;
|
||||
std::vector<bool> use_mmap;
|
||||
std::vector<bool> use_direct_io;
|
||||
std::vector<bool> embeddings;
|
||||
ggml_numa_strategy numa;
|
||||
int reps;
|
||||
@ -208,6 +209,7 @@ static const cmd_params cmd_params_defaults = {
|
||||
/* flash_attn */ {false},
|
||||
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
|
||||
/* use_mmap */ {true},
|
||||
/* use_direct_io */ {false},
|
||||
/* embeddings */ {false},
|
||||
/* numa */ GGML_NUMA_STRATEGY_DISABLED,
|
||||
/* reps */ 5,
|
||||
@ -235,6 +237,7 @@ static void print_usage(int /* argc */, char ** argv) {
|
||||
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
|
||||
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
|
||||
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
|
||||
printf(" -dio, --direct-io <0|1> (default: %s)\n", join(cmd_params_defaults.use_direct_io, ",").c_str());
|
||||
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
|
||||
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
|
||||
printf(" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n");
|
||||
@ -444,6 +447,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
}
|
||||
auto p = split<bool>(argv[i], split_delim);
|
||||
params.use_mmap.insert(params.use_mmap.end(), p.begin(), p.end());
|
||||
} else if (arg == "-dio" || arg == "--direct-io") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
auto p = split<bool>(argv[i], split_delim);
|
||||
params.use_direct_io.insert(params.use_direct_io.end(), p.begin(), p.end());
|
||||
} else if (arg == "-embd" || arg == "--embeddings") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
@ -525,6 +535,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
|
||||
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
|
||||
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
|
||||
if (params.use_direct_io.empty()){ params.use_direct_io = cmd_params_defaults.use_direct_io; }
|
||||
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
|
||||
if (params.n_threads.empty()) { params.n_threads = cmd_params_defaults.n_threads; }
|
||||
|
||||
@ -547,6 +558,7 @@ struct cmd_params_instance {
|
||||
bool flash_attn;
|
||||
std::vector<float> tensor_split;
|
||||
bool use_mmap;
|
||||
bool use_direct_io;
|
||||
bool embeddings;
|
||||
|
||||
llama_model_params to_llama_mparams() const {
|
||||
@ -557,6 +569,7 @@ struct cmd_params_instance {
|
||||
mparams.main_gpu = main_gpu;
|
||||
mparams.tensor_split = tensor_split.data();
|
||||
mparams.use_mmap = use_mmap;
|
||||
mparams.use_direct_io = use_direct_io;
|
||||
|
||||
return mparams;
|
||||
}
|
||||
@ -567,6 +580,7 @@ struct cmd_params_instance {
|
||||
split_mode == other.split_mode &&
|
||||
main_gpu == other.main_gpu &&
|
||||
use_mmap == other.use_mmap &&
|
||||
use_direct_io == other.use_direct_io &&
|
||||
tensor_split == other.tensor_split;
|
||||
}
|
||||
|
||||
@ -596,6 +610,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
for (const auto & mg : params.main_gpu)
|
||||
for (const auto & ts : params.tensor_split)
|
||||
for (const auto & mmp : params.use_mmap)
|
||||
for (const auto & dio : params.use_direct_io)
|
||||
for (const auto & embd : params.embeddings)
|
||||
for (const auto & nb : params.n_batch)
|
||||
for (const auto & nub : params.n_ubatch)
|
||||
@ -624,6 +639,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .flash_attn = */ fa,
|
||||
/* .tensor_split = */ ts,
|
||||
/* .use_mmap = */ mmp,
|
||||
/* .use_direct_io= */ dio,
|
||||
/* .embeddings = */ embd,
|
||||
};
|
||||
instances.push_back(instance);
|
||||
@ -649,6 +665,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .flash_attn = */ fa,
|
||||
/* .tensor_split = */ ts,
|
||||
/* .use_mmap = */ mmp,
|
||||
/* .use_direct_io= */ dio,
|
||||
/* .embeddings = */ embd,
|
||||
};
|
||||
instances.push_back(instance);
|
||||
@ -674,6 +691,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .flash_attn = */ fa,
|
||||
/* .tensor_split = */ ts,
|
||||
/* .use_mmap = */ mmp,
|
||||
/* .use_direct_io= */ dio,
|
||||
/* .embeddings = */ embd,
|
||||
};
|
||||
instances.push_back(instance);
|
||||
@ -712,6 +730,7 @@ struct test {
|
||||
bool flash_attn;
|
||||
std::vector<float> tensor_split;
|
||||
bool use_mmap;
|
||||
bool use_direct_io;
|
||||
bool embeddings;
|
||||
int n_prompt;
|
||||
int n_gen;
|
||||
@ -737,6 +756,7 @@ struct test {
|
||||
flash_attn = inst.flash_attn;
|
||||
tensor_split = inst.tensor_split;
|
||||
use_mmap = inst.use_mmap;
|
||||
use_direct_io = inst.use_direct_io;
|
||||
embeddings = inst.embeddings;
|
||||
n_prompt = inst.n_prompt;
|
||||
n_gen = inst.n_gen;
|
||||
@ -810,7 +830,7 @@ struct test {
|
||||
"n_threads", "type_k", "type_v",
|
||||
"n_gpu_layers", "split_mode",
|
||||
"main_gpu", "no_kv_offload", "flash_attn",
|
||||
"tensor_split", "use_mmap", "embeddings",
|
||||
"tensor_split", "use_mmap", "use_direct_io", "embeddings",
|
||||
"n_prompt", "n_gen", "test_time",
|
||||
"avg_ns", "stddev_ns",
|
||||
"avg_ts", "stddev_ts"
|
||||
@ -831,7 +851,7 @@ struct test {
|
||||
}
|
||||
if (field == "cuda" || field == "opencl" || field == "vulkan" || field == "kompute" || field == "metal" ||
|
||||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
|
||||
field == "flash_attn" || field == "use_mmap" || field == "embeddings") {
|
||||
field == "flash_attn" || field == "use_mmap" || field == "use_direct_io" || field == "embeddings") {
|
||||
return BOOL;
|
||||
}
|
||||
if (field == "avg_ts" || field == "stddev_ts") {
|
||||
@ -866,7 +886,7 @@ struct test {
|
||||
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
|
||||
std::to_string(n_gpu_layers), split_mode_str(split_mode),
|
||||
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
|
||||
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
|
||||
tensor_split_str, std::to_string(use_mmap), std::to_string(use_direct_io), std::to_string(embeddings),
|
||||
std::to_string(n_prompt), std::to_string(n_gen), test_time,
|
||||
std::to_string(avg_ns()), std::to_string(stdev_ns()),
|
||||
std::to_string(avg_ts()), std::to_string(stdev_ts())
|
||||
@ -1042,6 +1062,9 @@ struct markdown_printer : public printer {
|
||||
if (field == "use_mmap") {
|
||||
return "mmap";
|
||||
}
|
||||
if (field == "use_direct_io") {
|
||||
return "direct_io";
|
||||
}
|
||||
if (field == "embeddings") {
|
||||
return "embd";
|
||||
}
|
||||
@ -1094,6 +1117,9 @@ struct markdown_printer : public printer {
|
||||
if (params.use_mmap.size() > 1 || params.use_mmap != cmd_params_defaults.use_mmap) {
|
||||
fields.emplace_back("use_mmap");
|
||||
}
|
||||
if (params.use_direct_io.size() > 1 || params.use_direct_io != cmd_params_defaults.use_direct_io) {
|
||||
fields.emplace_back("use_direct_io");
|
||||
}
|
||||
if (params.embeddings.size() > 1 || params.embeddings != cmd_params_defaults.embeddings) {
|
||||
fields.emplace_back("embeddings");
|
||||
}
|
||||
|
@ -282,6 +282,10 @@ These options help improve the performance and memory usage of the LLaMA models.
|
||||
|
||||
- `--no-mmap`: Do not memory-map the model. By default, models are mapped into memory, which allows the system to load only the necessary parts of the model as needed. However, if the model is larger than your total amount of RAM or if your system is low on available memory, using mmap might increase the risk of pageouts, negatively impacting performance. Disabling mmap results in slower load times but may reduce pageouts if you're not using `--mlock`. Note that if the model is larger than the total amount of RAM, turning off mmap would prevent the model from loading at all.
|
||||
|
||||
### Direct I/O
|
||||
|
||||
- `--direct-io`: Use direct I/O. Potentially faster uncached loading, fewer pageouts, no page cache pollution. You may benefit from this option if you load a model for the first time (or after some time), load several different models consecutively, or simply want to keep the page cache clean. The faster your storage device is, the greater the gain you can expect. The effect may be greater on Linux due to Transparent HugePage support.
|
||||
|
||||
### NUMA support
|
||||
|
||||
- `--numa distribute`: Pin an equal proportion of the threads to the cores on each NUMA node. This will spread the load amongst all cores on the system, utilitizing all memory channels at the expense of potentially requiring memory to travel over the slow links between nodes.
|
||||
|
@ -34,6 +34,7 @@ The project is under active development, and we are [looking for feedback and co
|
||||
- `-ub N`, `--ubatch-size N`: Physical maximum batch size. Default: `512`
|
||||
- `--mlock`: Lock the model in memory, preventing it from being swapped out when memory-mapped.
|
||||
- `--no-mmap`: Do not memory-map the model. By default, models are mapped into memory, which allows the system to load only the necessary parts of the model as needed.
|
||||
- `--direct-io`: Use direct I/O. Potentially faster uncached loading, fewer pageouts, no page cache pollution.
|
||||
- `--numa STRATEGY`: Attempt one of the below optimization strategies that may help on some NUMA systems
|
||||
- `--numa distribute`: Spread execution evenly over all nodes
|
||||
- `--numa isolate`: Only spawn threads on CPUs on the node that execution started on
|
||||
|
@ -2352,6 +2352,9 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
|
||||
if (llama_supports_mmap()) {
|
||||
printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
|
||||
}
|
||||
if (llama_supports_direct_io()) {
|
||||
printf(" --direct-io use direct I/O (potentially faster uncached loading, fewer pageouts, no page cache pollution)\n");
|
||||
}
|
||||
printf(" --numa TYPE attempt optimizations that help on some NUMA systems\n");
|
||||
printf(" - distribute: spread execution evenly over all nodes\n");
|
||||
printf(" - isolate: only spawn threads on CPUs on the node that execution started on\n");
|
||||
@ -2754,6 +2757,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
|
||||
params.use_mlock = true;
|
||||
} else if (arg == "--no-mmap") {
|
||||
params.use_mmap = false;
|
||||
} else if (arg == "--direct-io") {
|
||||
params.use_direct_io = true;
|
||||
} else if (arg == "--numa") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
268
llama.cpp
268
llama.cpp
@ -1197,11 +1197,13 @@ struct no_init {
|
||||
};
|
||||
|
||||
struct llama_file {
|
||||
char * name;
|
||||
// use FILE * so we don't have to re-open the file to mmap
|
||||
FILE * fp;
|
||||
size_t size;
|
||||
|
||||
llama_file(const char * fname, const char * mode) {
|
||||
name = strdup(fname);
|
||||
fp = ggml_fopen(fname, mode);
|
||||
if (fp == NULL) {
|
||||
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
|
||||
@ -1265,10 +1267,105 @@ struct llama_file {
|
||||
write_raw(&val, sizeof(val));
|
||||
}
|
||||
|
||||
size_t read_direct(void * ptr, size_t len, size_t offset) const {
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(__APPLE__)
|
||||
int page_size = sysconf(_SC_PAGESIZE);
|
||||
GGML_ASSERT((uintptr_t) ptr % page_size == 0);
|
||||
GGML_ASSERT(len % page_size == 0);
|
||||
GGML_ASSERT(offset % page_size == 0);
|
||||
#ifdef __APPLE__
|
||||
int fd = open(name, O_RDONLY);
|
||||
if (fd == -1) {
|
||||
throw std::runtime_error(format("failed to open %s: %s", name, strerror(errno)));
|
||||
}
|
||||
if (fcntl(fd, F_NOCACHE, 1) == -1) {
|
||||
throw std::runtime_error(format("failed to enable direct I/O: %s", strerror(errno)));
|
||||
}
|
||||
#else
|
||||
int fd = open(name, O_RDONLY | O_DIRECT);
|
||||
if (fd == -1) {
|
||||
throw std::runtime_error(format("failed to open %s for direct I/O: %s", name, strerror(errno)));
|
||||
}
|
||||
#endif
|
||||
size_t bytes_read = 0;
|
||||
while (len > 0) {
|
||||
ssize_t count = pread(fd, ptr, std::min(len, (size_t) INT_MAX & ~(page_size - 1)), offset);
|
||||
if (count == -1) {
|
||||
throw std::runtime_error(format("direct read error: %s", strerror(errno)));
|
||||
}
|
||||
if (count == 0) { // EOF
|
||||
break;
|
||||
}
|
||||
ptr = (char *) ptr + count;
|
||||
offset += count;
|
||||
len -= count;
|
||||
bytes_read += count;
|
||||
}
|
||||
|
||||
close(fd);
|
||||
|
||||
return bytes_read;
|
||||
}
|
||||
|
||||
static constexpr bool DIRECT_IO_SUPPORTED = true;
|
||||
#elif defined(_WIN32)
|
||||
SYSTEM_INFO siSysInfo;
|
||||
GetSystemInfo(&siSysInfo);
|
||||
DWORD dwPageSize = siSysInfo.dwPageSize;
|
||||
GGML_ASSERT((uintptr_t) ptr % dwPageSize == 0);
|
||||
GGML_ASSERT(len % dwPageSize == 0);
|
||||
GGML_ASSERT(offset % dwPageSize == 0);
|
||||
|
||||
HANDLE hFile = ReOpenFile((HANDLE) _get_osfhandle(_fileno(fp)), GENERIC_READ, FILE_SHARE_READ, FILE_FLAG_NO_BUFFERING);
|
||||
if (hFile == INVALID_HANDLE_VALUE) {
|
||||
throw std::runtime_error(format("failed to open %s: %s", name, llama_format_win_err(GetLastError()).c_str()));
|
||||
}
|
||||
|
||||
size_t bytes_read = 0;
|
||||
while (len > 0) {
|
||||
OVERLAPPED oOverlap = {0};
|
||||
oOverlap.OffsetHigh = offset >> 32;
|
||||
oOverlap.Offset = offset;
|
||||
DWORD nBytesToRead = std::min(len, (size_t) 0xFFFFFFFF & ~(dwPageSize - 1));
|
||||
DWORD count = 0;
|
||||
if (!ReadFile(hFile, ptr, nBytesToRead, &count, &oOverlap)) {
|
||||
if (GetLastError() == ERROR_HANDLE_EOF) {
|
||||
bytes_read += count;
|
||||
break;
|
||||
}
|
||||
throw std::runtime_error(format("direct read error: %s", llama_format_win_err(GetLastError()).c_str()));
|
||||
}
|
||||
bytes_read += count;
|
||||
if (count < nBytesToRead) { // EOF
|
||||
break;
|
||||
}
|
||||
ptr = (char *) ptr + count;
|
||||
offset += count;
|
||||
len -= count;
|
||||
}
|
||||
|
||||
CloseHandle(hFile);
|
||||
|
||||
return bytes_read;
|
||||
}
|
||||
|
||||
static constexpr bool DIRECT_IO_SUPPORTED = true;
|
||||
#else
|
||||
GGML_UNUSED(ptr);
|
||||
GGML_UNUSED(len);
|
||||
GGML_UNUSED(offset);
|
||||
|
||||
throw std::runtime_error("direct I/O not supported");
|
||||
}
|
||||
|
||||
static constexpr bool DIRECT_IO_SUPPORTED = false;
|
||||
#endif
|
||||
|
||||
~llama_file() {
|
||||
if (fp) {
|
||||
std::fclose(fp);
|
||||
}
|
||||
free(name);
|
||||
}
|
||||
};
|
||||
using llama_files = std::vector<std::unique_ptr<llama_file>>;
|
||||
@ -1279,6 +1376,23 @@ struct llama_mmap {
|
||||
|
||||
llama_mmap(const llama_mmap &) = delete;
|
||||
|
||||
static void align_to_next_page(size_t * ptr, size_t page_size) {
|
||||
size_t offset_in_page = *ptr & (page_size - 1);
|
||||
size_t offset_to_page = offset_in_page == 0 ? 0 : page_size - offset_in_page;
|
||||
*ptr += offset_to_page;
|
||||
}
|
||||
|
||||
static void align_to_previous_page(size_t * ptr, size_t page_size) {
|
||||
*ptr = *ptr & ~(page_size - 1);
|
||||
}
|
||||
|
||||
virtual void populate(size_t first, size_t last) const {
|
||||
GGML_UNUSED(first);
|
||||
GGML_UNUSED(last);
|
||||
|
||||
// either already populated or populated dynamically
|
||||
}
|
||||
|
||||
#ifdef _POSIX_MAPPED_FILES
|
||||
static constexpr bool SUPPORTED = true;
|
||||
|
||||
@ -1324,26 +1438,16 @@ struct llama_mmap {
|
||||
mapped_fragments.emplace_back(0, file->size);
|
||||
}
|
||||
|
||||
static void align_range(size_t * first, size_t * last, size_t page_size) {
|
||||
// align first to the next page
|
||||
size_t offset_in_page = *first & (page_size - 1);
|
||||
size_t offset_to_page = offset_in_page == 0 ? 0 : page_size - offset_in_page;
|
||||
*first += offset_to_page;
|
||||
|
||||
// align last to the previous page
|
||||
*last = *last & ~(page_size - 1);
|
||||
|
||||
if (*last <= *first) {
|
||||
*last = *first;
|
||||
}
|
||||
}
|
||||
|
||||
// partially unmap the file in the range [first, last)
|
||||
void unmap_fragment(size_t first, size_t last) {
|
||||
// note: this function must not be called multiple times with overlapping ranges
|
||||
// otherwise, there is a risk of invalidating addresses that have been repurposed for other mappings
|
||||
int page_size = sysconf(_SC_PAGESIZE);
|
||||
align_range(&first, &last, page_size);
|
||||
align_to_next_page(&first, page_size);
|
||||
align_to_previous_page(&last, page_size);
|
||||
if (last <= first) {
|
||||
last = first;
|
||||
}
|
||||
size_t len = last - first;
|
||||
|
||||
if (len == 0) {
|
||||
@ -1384,7 +1488,7 @@ struct llama_mmap {
|
||||
mapped_fragments = std::move(new_mapped_fragments);
|
||||
}
|
||||
|
||||
~llama_mmap() {
|
||||
virtual ~llama_mmap() {
|
||||
for (const auto & frag : mapped_fragments) {
|
||||
if (munmap((char *) addr + frag.first, frag.second - frag.first)) {
|
||||
LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
|
||||
@ -1447,7 +1551,7 @@ struct llama_mmap {
|
||||
GGML_UNUSED(last);
|
||||
}
|
||||
|
||||
~llama_mmap() {
|
||||
virtual ~llama_mmap() {
|
||||
if (!UnmapViewOfFile(addr)) {
|
||||
LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
@ -1470,8 +1574,99 @@ struct llama_mmap {
|
||||
|
||||
throw std::runtime_error("mmap not supported");
|
||||
}
|
||||
|
||||
virtual ~llama_mmap() = default;
|
||||
#endif
|
||||
|
||||
protected:
|
||||
llama_mmap() {}
|
||||
};
|
||||
|
||||
struct llama_anonymous_mmap : llama_mmap {
|
||||
llama_file * file;
|
||||
|
||||
llama_anonymous_mmap(const llama_anonymous_mmap &) = delete;
|
||||
|
||||
#ifdef _POSIX_MAPPED_FILES
|
||||
#ifndef MAP_ANONYMOUS
|
||||
#define MAP_ANONYMOUS MAP_ANON
|
||||
#endif
|
||||
llama_anonymous_mmap(struct llama_file * file) {
|
||||
this->file = file;
|
||||
size = file->size;
|
||||
addr = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_ANONYMOUS | MAP_PRIVATE, -1, 0);
|
||||
if (addr == MAP_FAILED) { // NOLINT
|
||||
throw std::runtime_error(format("mmap(.., MAP_ANONYMOUS) failed: %s", strerror(errno)));
|
||||
}
|
||||
#ifdef __linux__
|
||||
// THP is enabled by default for anonymous memory mappings on madvise
|
||||
if (madvise(addr, size, MADV_HUGEPAGE)) {
|
||||
LLAMA_LOG_WARN("warning: madvise(.., MADV_HUGEPAGE) failed: %s\n", strerror(errno));
|
||||
}
|
||||
#endif
|
||||
mapped_fragments.emplace_back(0, size);
|
||||
}
|
||||
|
||||
void populate(size_t first, size_t last) const {
|
||||
int page_size = sysconf(_SC_PAGESIZE);
|
||||
|
||||
align_to_previous_page(&first, page_size);
|
||||
align_to_next_page(&last, page_size);
|
||||
|
||||
size_t bytes_read = file->read_direct((char *) addr + first, last - first, first);
|
||||
if (bytes_read != std::min(last, file->size) - first) {
|
||||
throw std::runtime_error("unexpectedly reached end of file");
|
||||
}
|
||||
}
|
||||
#elif defined(_WIN32)
|
||||
llama_anonymous_mmap(struct llama_file * file) {
|
||||
this->file = file;
|
||||
size = file->size;
|
||||
|
||||
HANDLE hMapping = CreateFileMapping(INVALID_HANDLE_VALUE, NULL, PAGE_READWRITE, size >> 32, size, NULL);
|
||||
if (hMapping == NULL) {
|
||||
throw std::runtime_error(format("CreateFileMapping failed: %s", llama_format_win_err(GetLastError()).c_str()));
|
||||
}
|
||||
|
||||
addr = MapViewOfFile(hMapping, FILE_MAP_ALL_ACCESS, 0, 0, size);
|
||||
DWORD dwError = GetLastError();
|
||||
|
||||
CloseHandle(hMapping);
|
||||
|
||||
if (addr == NULL) {
|
||||
throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(dwError).c_str()));
|
||||
}
|
||||
}
|
||||
|
||||
void populate(size_t first, size_t last) const {
|
||||
SYSTEM_INFO siSysInfo;
|
||||
GetSystemInfo(&siSysInfo);
|
||||
DWORD dwPageSize = siSysInfo.dwPageSize;
|
||||
|
||||
align_to_previous_page(&first, dwPageSize);
|
||||
align_to_next_page(&last, dwPageSize);
|
||||
|
||||
size_t bytes_read = file->read_direct((char *) addr + first, last - first, first);
|
||||
if (bytes_read != std::min(last, file->size) - first) {
|
||||
throw std::runtime_error("unexpectedly reached end of file");
|
||||
}
|
||||
}
|
||||
#else
|
||||
llama_anonymous_mmap(struct llama_file * file) {
|
||||
GGML_UNUSED(file);
|
||||
|
||||
throw std::runtime_error("mmap not supported");
|
||||
}
|
||||
|
||||
void populate(size_t first, size_t last) const {
|
||||
GGML_UNUSED(first);
|
||||
GGML_UNUSED(last);
|
||||
|
||||
throw std::runtime_error("mmap not supported");
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
using llama_mmaps = std::vector<std::unique_ptr<llama_mmap>>;
|
||||
|
||||
// Represents some region of memory being locked using mlock or VirtualLock;
|
||||
@ -3020,6 +3215,7 @@ struct llama_model_loader {
|
||||
size_t n_bytes = 0;
|
||||
|
||||
bool use_mmap = false;
|
||||
bool use_direct_io = false;
|
||||
bool check_tensors;
|
||||
|
||||
llama_files files;
|
||||
@ -3054,7 +3250,7 @@ struct llama_model_loader {
|
||||
std::string arch_name;
|
||||
LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
|
||||
|
||||
llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
|
||||
llama_model_loader(const std::string & fname, bool use_mmap, bool use_direct_io, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
|
||||
int trace = 0;
|
||||
if (getenv("LLAMA_TRACE")) {
|
||||
trace = atoi(getenv("LLAMA_TRACE"));
|
||||
@ -3267,7 +3463,15 @@ struct llama_model_loader {
|
||||
use_mmap = false;
|
||||
}
|
||||
|
||||
this->use_mmap = use_mmap;
|
||||
if (!llama_file::DIRECT_IO_SUPPORTED && use_direct_io) {
|
||||
LLAMA_LOG_WARN("%s: direct I/O is not supported on this platform\n", __func__);
|
||||
use_direct_io = false;
|
||||
}
|
||||
|
||||
// either file or anonymous mappings
|
||||
this->use_mmap = use_mmap || use_direct_io;
|
||||
this->use_direct_io = use_direct_io;
|
||||
|
||||
this->check_tensors = check_tensors;
|
||||
}
|
||||
|
||||
@ -3463,12 +3667,12 @@ struct llama_model_loader {
|
||||
}
|
||||
}
|
||||
|
||||
void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr) {
|
||||
void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr, bool anonymous = false) {
|
||||
if (use_mmap) {
|
||||
mappings.reserve(files.size());
|
||||
mmaps_used.reserve(files.size());
|
||||
for (const auto & file : files) {
|
||||
std::unique_ptr<llama_mmap> mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, ggml_is_numa()));
|
||||
std::unique_ptr<llama_mmap> mapping(anonymous ? new llama_anonymous_mmap(file.get()) : new llama_mmap(file.get(), prefetch ? -1 : 0, ggml_is_numa()));
|
||||
mmaps_used.emplace_back(mapping->size, 0);
|
||||
if (mlock_mmaps) {
|
||||
std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());
|
||||
@ -3546,6 +3750,15 @@ struct llama_model_loader {
|
||||
void * progress_callback_user_data) {
|
||||
GGML_ASSERT(size_data != 0 && "call init_mappings() first");
|
||||
|
||||
if (use_mmap) {
|
||||
for (uint32_t idx = 0; idx < files.size(); idx++) {
|
||||
void * addr = nullptr;
|
||||
size_t first, last;
|
||||
get_mapping_range(&first, &last, &addr, idx, ctx);
|
||||
mappings.at(idx)->populate(first, last);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<no_init<uint8_t>> read_buf;
|
||||
std::vector<std::future<std::pair<ggml_tensor *, bool>>> validation_result;
|
||||
|
||||
@ -6005,7 +6218,7 @@ static bool llm_load_tensors(
|
||||
|
||||
ml.done_getting_tensors();
|
||||
|
||||
ml.init_mappings(true, use_mlock ? &model.mlock_mmaps : nullptr);
|
||||
ml.init_mappings(true, use_mlock ? &model.mlock_mmaps : nullptr, /* anonymous */ ml.use_direct_io);
|
||||
model.mappings.reserve(ml.mappings.size());
|
||||
|
||||
// create the backend buffers
|
||||
@ -6148,7 +6361,7 @@ static bool llm_load_tensors(
|
||||
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
|
||||
static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
|
||||
try {
|
||||
llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
|
||||
llama_model_loader ml(fname, params.use_mmap, params.use_direct_io, params.check_tensors, params.kv_overrides);
|
||||
|
||||
model.hparams.vocab_only = params.vocab_only;
|
||||
|
||||
@ -14536,7 +14749,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||
auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
|
||||
kv_overrides = v->data();
|
||||
}
|
||||
llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides);
|
||||
llama_model_loader ml(fname_inp, use_mmap, /* use_direct_io */ false, /*check_tensors*/ true, kv_overrides);
|
||||
ml.init_mappings(false); // no prefetching
|
||||
|
||||
llama_model model;
|
||||
@ -14899,7 +15112,7 @@ static int llama_apply_lora_from_file_internal(
|
||||
std::unique_ptr<llama_model_loader> ml;
|
||||
if (path_base_model) {
|
||||
LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
|
||||
ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr));
|
||||
ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*use_direct_io*/ false, /*check_tensors*/ false, /*kv_overrides*/ nullptr));
|
||||
ml->init_mappings(/*prefetch*/ false); // no prefetching
|
||||
}
|
||||
|
||||
@ -15158,6 +15371,7 @@ struct llama_model_params llama_model_default_params() {
|
||||
/*.kv_overrides =*/ nullptr,
|
||||
/*.vocab_only =*/ false,
|
||||
/*.use_mmap =*/ true,
|
||||
/*.use_direct_io =*/ false,
|
||||
/*.use_mlock =*/ false,
|
||||
/*.check_tensors =*/ false,
|
||||
};
|
||||
@ -15246,6 +15460,10 @@ bool llama_supports_mlock(void) {
|
||||
return llama_mlock::SUPPORTED;
|
||||
}
|
||||
|
||||
bool llama_supports_direct_io(void) {
|
||||
return llama_file::DIRECT_IO_SUPPORTED;
|
||||
}
|
||||
|
||||
bool llama_supports_gpu_offload(void) {
|
||||
#if defined(GGML_USE_CUDA) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \
|
||||
defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_RPC)
|
||||
|
2
llama.h
2
llama.h
@ -260,6 +260,7 @@ extern "C" {
|
||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
||||
bool vocab_only; // only load the vocabulary, no weights
|
||||
bool use_mmap; // use mmap if possible
|
||||
bool use_direct_io; // use direct I/O if possible
|
||||
bool use_mlock; // force system to keep model in RAM
|
||||
bool check_tensors; // validate model tensor data
|
||||
};
|
||||
@ -409,6 +410,7 @@ extern "C" {
|
||||
LLAMA_API size_t llama_max_devices(void);
|
||||
|
||||
LLAMA_API bool llama_supports_mmap (void);
|
||||
LLAMA_API bool llama_supports_direct_io (void);
|
||||
LLAMA_API bool llama_supports_mlock (void);
|
||||
LLAMA_API bool llama_supports_gpu_offload(void);
|
||||
|
||||
|
@ -12,7 +12,7 @@ logger = logging.getLogger("run-with-preset")
|
||||
|
||||
CLI_ARGS_MAIN_PERPLEXITY = [
|
||||
"batch-size", "cfg-negative-prompt", "cfg-scale", "chunks", "color", "ctx-size", "escape",
|
||||
"export", "file", "frequency-penalty", "grammar", "grammar-file", "hellaswag",
|
||||
"direct-io", "export", "file", "frequency-penalty", "grammar", "grammar-file", "hellaswag",
|
||||
"hellaswag-tasks", "ignore-eos", "in-prefix", "in-prefix-bos", "in-suffix", "instruct",
|
||||
"interactive", "interactive-first", "keep", "logdir", "logit-bias", "lora", "lora-base",
|
||||
"low-vram", "main-gpu", "memory-f32", "mirostat", "mirostat-ent", "mirostat-lr", "mlock",
|
||||
@ -30,7 +30,7 @@ CLI_ARGS_LLAMA_BENCH = [
|
||||
]
|
||||
|
||||
CLI_ARGS_SERVER = [
|
||||
"alias", "batch-size", "ctx-size", "embedding", "host", "memory-f32", "lora", "lora-base",
|
||||
"alias", "batch-size", "ctx-size", "direct-io", "embedding", "host", "memory-f32", "lora", "lora-base",
|
||||
"low-vram", "main-gpu", "mlock", "model", "n-gpu-layers", "n-probs", "no-mmap", "no-mul-mat-q",
|
||||
"numa", "path", "port", "rope-freq-base", "timeout", "rope-freq-scale", "tensor-split",
|
||||
"threads", "verbose"
|
||||
|
Loading…
x
Reference in New Issue
Block a user