mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 22:08:46 +01:00
Support calling mlock() on loaded model data on Linux and macOS (#453)
* Support calling mlock() on loaded model data on Linux and macOS This is enabled by a new --mlock command line option. Using mlock() disables swapping and memory compression for the model data. Doing so can be useful on systems where the model takes up a large fraction of system RAM. In my experience, macOS is quite eager to start compressing llama.cpp's memory, which then makes it halt for a few seconds while it decompresses, even with a model that uses "only" 25GB out of 32GB. Of course, this comes at the cost of forcing the system to swap or compress other processes' memory instead, so it needs to be used with care and shouldn't be enabled by default. In theory it should be possible to support this on Windows as well using VirtualLock(), but I'm not much of a Windows user. * Update llama.cpp --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
8d4a855c24
commit
563cdc391d
76
ggml.c
76
ggml.c
@ -1,5 +1,5 @@
|
|||||||
// Defines CLOCK_MONOTONIC on Linux
|
// Defines CLOCK_MONOTONIC and asprintf on Linux
|
||||||
#define _POSIX_C_SOURCE 199309L
|
#define _GNU_SOURCE
|
||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
@ -10,6 +10,7 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
#include <errno.h>
|
||||||
#include <time.h>
|
#include <time.h>
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
@ -31,7 +32,6 @@
|
|||||||
#else
|
#else
|
||||||
// ref: https://github.com/ggerganov/whisper.cpp/issues/168
|
// ref: https://github.com/ggerganov/whisper.cpp/issues/168
|
||||||
#include <windows.h>
|
#include <windows.h>
|
||||||
#include <errno.h>
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
typedef volatile LONG atomic_int;
|
typedef volatile LONG atomic_int;
|
||||||
@ -83,6 +83,17 @@ typedef void* thread_ret_t;
|
|||||||
#define static_assert(cond, msg) _Static_assert(cond, msg)
|
#define static_assert(cond, msg) _Static_assert(cond, msg)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#define GGML_MLOCK_SUPPORT 0
|
||||||
|
|
||||||
|
#ifdef __has_include
|
||||||
|
#if __has_include(<sys/mman.h>)
|
||||||
|
#undef GGML_MLOCK_SUPPORT
|
||||||
|
#define GGML_MLOCK_SUPPORT 1
|
||||||
|
#include <sys/mman.h>
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
/*#define GGML_PERF*/
|
/*#define GGML_PERF*/
|
||||||
#define GGML_DEBUG 0
|
#define GGML_DEBUG 0
|
||||||
#define GGML_GELU_FP16
|
#define GGML_GELU_FP16
|
||||||
@ -2344,6 +2355,7 @@ struct ggml_context {
|
|||||||
size_t mem_size;
|
size_t mem_size;
|
||||||
void * mem_buffer;
|
void * mem_buffer;
|
||||||
bool mem_buffer_owned;
|
bool mem_buffer_owned;
|
||||||
|
bool mem_buffer_mlocked;
|
||||||
|
|
||||||
int n_objects;
|
int n_objects;
|
||||||
|
|
||||||
@ -2619,16 +2631,19 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
*ctx = (struct ggml_context) {
|
*ctx = (struct ggml_context) {
|
||||||
/*.mem_size =*/ params.mem_size,
|
/*.mem_size =*/ params.mem_size,
|
||||||
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
|
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
|
||||||
/*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
|
/*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
|
||||||
/*.n_objects =*/ 0,
|
/*.mem_buffer_mlocked =*/ false,
|
||||||
/*.objects_begin =*/ NULL,
|
/*.n_objects =*/ 0,
|
||||||
/*.objects_end =*/ NULL,
|
/*.objects_begin =*/ NULL,
|
||||||
/*.scratch =*/ { 0, 0, NULL, },
|
/*.objects_end =*/ NULL,
|
||||||
/*.scratch_save =*/ { 0, 0, NULL, },
|
/*.scratch =*/ { 0, 0, NULL, },
|
||||||
|
/*.scratch_save =*/ { 0, 0, NULL, },
|
||||||
};
|
};
|
||||||
|
|
||||||
|
GGML_ASSERT(ctx->mem_buffer != NULL); // check for allocation failure
|
||||||
|
|
||||||
ggml_assert_aligned(ctx->mem_buffer);
|
ggml_assert_aligned(ctx->mem_buffer);
|
||||||
|
|
||||||
GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
|
GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
|
||||||
@ -2651,6 +2666,14 @@ void ggml_free(struct ggml_context * ctx) {
|
|||||||
GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
|
GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
|
||||||
__func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
|
__func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
|
||||||
|
|
||||||
|
#if GGML_MLOCK_SUPPORT
|
||||||
|
if (ctx->mem_buffer_mlocked) {
|
||||||
|
if (munlock(ctx->mem_buffer, ctx->mem_size)) {
|
||||||
|
fprintf(stderr, "%s: failed to munlock buffer: %s\n", __func__, strerror(errno));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
if (ctx->mem_buffer_owned) {
|
if (ctx->mem_buffer_owned) {
|
||||||
free(ctx->mem_buffer);
|
free(ctx->mem_buffer);
|
||||||
}
|
}
|
||||||
@ -2679,6 +2702,37 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ggml_mlock_supported(void) {
|
||||||
|
return GGML_MLOCK_SUPPORT;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if GGML_MLOCK_SUPPORT
|
||||||
|
#ifdef __APPLE__
|
||||||
|
#define MLOCK_SUGGESTION "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or\n" \
|
||||||
|
"decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l)."
|
||||||
|
#else
|
||||||
|
#define MLOCK_SUGGESTION "Try increasing RLIMIT_MLOCK (ulimit -l)."
|
||||||
|
#endif
|
||||||
|
bool ggml_mlock(struct ggml_context * ctx, char ** err_p) {
|
||||||
|
if (ctx->mem_buffer_mlocked) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (mlock(ctx->mem_buffer, ctx->mem_size)) {
|
||||||
|
int ret = asprintf(err_p, "failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION,
|
||||||
|
ctx->mem_size, strerror(errno));
|
||||||
|
GGML_ASSERT(ret >= 0);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ctx->mem_buffer_mlocked = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
#else // GGML_MLOCK_SUPPORT
|
||||||
|
bool ggml_mlock(struct ggml_context * ctx, char ** err_p) {
|
||||||
|
*err_p = strdup("can't mlock because it's not supported on this system");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif // GGML_MLOCK_SUPPORT
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
struct ggml_tensor * ggml_new_tensor_impl(
|
struct ggml_tensor * ggml_new_tensor_impl(
|
||||||
|
3
ggml.h
3
ggml.h
@ -343,6 +343,9 @@ size_t ggml_used_mem(const struct ggml_context * ctx);
|
|||||||
|
|
||||||
size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
|
size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
|
||||||
|
|
||||||
|
bool ggml_mlock_supported(void);
|
||||||
|
bool ggml_mlock(struct ggml_context * ctx, char ** err_p);
|
||||||
|
|
||||||
struct ggml_tensor * ggml_new_tensor(
|
struct ggml_tensor * ggml_new_tensor(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
enum ggml_type type,
|
enum ggml_type type,
|
||||||
|
14
llama.cpp
14
llama.cpp
@ -115,6 +115,7 @@ struct llama_context_params llama_context_default_params() {
|
|||||||
/*.f16_kv =*/ false,
|
/*.f16_kv =*/ false,
|
||||||
/*.logits_all =*/ false,
|
/*.logits_all =*/ false,
|
||||||
/*.vocab_only =*/ false,
|
/*.vocab_only =*/ false,
|
||||||
|
/*.use_mlock =*/ false,
|
||||||
/*.embedding =*/ false,
|
/*.embedding =*/ false,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1428,12 +1429,23 @@ struct llama_context * llama_init_from_file(
|
|||||||
|
|
||||||
ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||||
|
|
||||||
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only)) {
|
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory,
|
||||||
|
params.vocab_only)) {
|
||||||
fprintf(stderr, "%s: failed to load model\n", __func__);
|
fprintf(stderr, "%s: failed to load model\n", __func__);
|
||||||
delete ctx;
|
delete ctx;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (params.use_mlock) {
|
||||||
|
char *err;
|
||||||
|
if (!ggml_mlock(ctx->model.ctx, &err)) {
|
||||||
|
fprintf(stderr, "%s\n", err);
|
||||||
|
free(err);
|
||||||
|
delete ctx;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// reserve memory for context buffers
|
// reserve memory for context buffers
|
||||||
{
|
{
|
||||||
const auto & hparams = ctx->model.hparams;
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
1
llama.h
1
llama.h
@ -53,6 +53,7 @@ extern "C" {
|
|||||||
bool f16_kv; // use fp16 for KV cache
|
bool f16_kv; // use fp16 for KV cache
|
||||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||||
bool vocab_only; // only load the vocabulary, no weights
|
bool vocab_only; // only load the vocabulary, no weights
|
||||||
|
bool use_mlock; // force system to keep model in RAM
|
||||||
bool embedding; // embedding mode only
|
bool embedding; // embedding mode only
|
||||||
};
|
};
|
||||||
|
|
||||||
|
1
main.cpp
1
main.cpp
@ -199,6 +199,7 @@ int main(int argc, char ** argv) {
|
|||||||
lparams.seed = params.seed;
|
lparams.seed = params.seed;
|
||||||
lparams.f16_kv = params.memory_f16;
|
lparams.f16_kv = params.memory_f16;
|
||||||
lparams.logits_all = params.perplexity;
|
lparams.logits_all = params.perplexity;
|
||||||
|
lparams.use_mlock = params.use_mlock;
|
||||||
lparams.embedding = params.embedding;
|
lparams.embedding = params.embedding;
|
||||||
|
|
||||||
ctx = llama_init_from_file(params.model.c_str(), lparams);
|
ctx = llama_init_from_file(params.model.c_str(), lparams);
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
@ -127,6 +129,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||||||
params.instruct = true;
|
params.instruct = true;
|
||||||
} else if (arg == "--color") {
|
} else if (arg == "--color") {
|
||||||
params.use_color = true;
|
params.use_color = true;
|
||||||
|
} else if (arg == "--mlock") {
|
||||||
|
params.use_mlock = true;
|
||||||
} else if (arg == "-r" || arg == "--reverse-prompt") {
|
} else if (arg == "-r" || arg == "--reverse-prompt") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
@ -194,6 +198,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||||||
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");
|
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");
|
||||||
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||||
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
|
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
|
||||||
|
if (ggml_mlock_supported()) {
|
||||||
|
fprintf(stderr, " --mlock force system to keep model in RAM rather than swapping or compressing\n");
|
||||||
|
}
|
||||||
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
||||||
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
|
1
utils.h
1
utils.h
@ -46,6 +46,7 @@ struct gpt_params {
|
|||||||
bool instruct = false; // instruction mode (used for Alpaca models)
|
bool instruct = false; // instruction mode (used for Alpaca models)
|
||||||
bool ignore_eos = false; // do not stop generating after eos
|
bool ignore_eos = false; // do not stop generating after eos
|
||||||
bool perplexity = false; // compute perplexity over the prompt
|
bool perplexity = false; // compute perplexity over the prompt
|
||||||
|
bool use_mlock = false; // use mlock to keep model in memory
|
||||||
};
|
};
|
||||||
|
|
||||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
||||||
|
Loading…
Reference in New Issue
Block a user