llama : pass KV cache type through API

This commit is contained in:
Georgi Gerganov 2023-12-05 15:40:23 +02:00
parent b881f630ca
commit 3ce30e07c9
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 59 additions and 12 deletions

View File

@ -500,6 +500,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.dump_kv_cache = true; params.dump_kv_cache = true;
} else if (arg == "-nkvo" || arg == "--no-kv-offload") { } else if (arg == "-nkvo" || arg == "--no-kv-offload") {
params.no_kv_offload = true; params.no_kv_offload = true;
} else if (arg == "-ctk" || arg == "--cache-type-k") {
params.cache_type_k = argv[++i];
} else if (arg == "-ctv" || arg == "--cache-type-v") {
params.cache_type_v = argv[++i];
} else if (arg == "--multiline-input") { } else if (arg == "--multiline-input") {
params.multiline_input = true; params.multiline_input = true;
} else if (arg == "--simple-io") { } else if (arg == "--simple-io") {
@ -844,6 +848,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" verbose print of the KV cache\n"); printf(" verbose print of the KV cache\n");
printf(" -nkvo, --no-kv-offload\n"); printf(" -nkvo, --no-kv-offload\n");
printf(" disable KV offload\n"); printf(" disable KV offload\n");
printf(" -ctk TYPE, --cache-type-k TYPE\n");
printf(" KV cache data type for K (default: %s)\n", params.cache_type_k.c_str());
printf(" -ctv TYPE, --cache-type-v TYPE\n");
printf(" KV cache data type for V (default: %s)\n", params.cache_type_v.c_str());
printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n"); printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n"); printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
@ -908,6 +916,29 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
return mparams; return mparams;
} }
static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "f16") {
return GGML_TYPE_F16;
}
if (s == "q8_0") {
return GGML_TYPE_Q8_0;
}
if (s == "q4_0") {
return GGML_TYPE_Q4_0;
}
if (s == "q4_1") {
return GGML_TYPE_Q4_1;
}
if (s == "q5_0") {
return GGML_TYPE_Q5_0;
}
if (s == "q5_1") {
return GGML_TYPE_Q5_1;
}
throw std::runtime_error("Invalid cache type: " + s);
}
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto cparams = llama_context_default_params(); auto cparams = llama_context_default_params();
@ -930,6 +961,9 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_orig_ctx = params.yarn_orig_ctx; cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.offload_kqv = !params.no_kv_offload; cparams.offload_kqv = !params.no_kv_offload;
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
return cparams; return cparams;
} }

View File

@ -125,9 +125,12 @@ struct gpt_params {
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
bool no_kv_offload = false; // disable KV offloading bool no_kv_offload = false; // disable KV offloading
std::string cache_type_k = "f16"; // KV cache data type for the K
std::string cache_type_v = "f16"; // KV cache data type for the V
// multimodal models (see examples/llava) // multimodal models (see examples/llava)
std::string mmproj = ""; // path to multimodal projector std::string mmproj = ""; // path to multimodal projector
std::string image = ""; // path to an image file std::string image = ""; // path to an image file
}; };
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params); bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params);

View File

@ -8580,6 +8580,8 @@ struct llama_context_params llama_context_default_params() {
/*.yarn_beta_fast =*/ 32.0f, /*.yarn_beta_fast =*/ 32.0f,
/*.yarn_beta_slow =*/ 1.0f, /*.yarn_beta_slow =*/ 1.0f,
/*.yarn_orig_ctx =*/ 0, /*.yarn_orig_ctx =*/ 0,
/*.type_k =*/ GGML_TYPE_F16,
/*.type_v =*/ GGML_TYPE_F16,
/*.mul_mat_q =*/ true, /*.mul_mat_q =*/ true,
/*.f16_kv =*/ true, /*.f16_kv =*/ true,
/*.logits_all =*/ false, /*.logits_all =*/ false,
@ -8737,31 +8739,36 @@ struct llama_context * llama_new_context_with_model(
//const ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; //const ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
// TODO: move as params const ggml_type type_k = params.type_k;
const ggml_type k_type = GGML_TYPE_Q8_0; const ggml_type type_v = params.type_v;
const ggml_type v_type = GGML_TYPE_F16;
GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(k_type) == 0); GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(type_k) == 0);
GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(v_type) == 0); GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(type_v) == 0);
// reserve memory for context buffers // reserve memory for context buffers
if (!hparams.vocab_only) { if (!hparams.vocab_only) {
if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, k_type, v_type, cparams.n_ctx, model->n_gpu_layers, cparams.offload_kqv)) { if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, type_k, type_v, cparams.n_ctx, model->n_gpu_layers, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx); llama_free(ctx);
return nullptr; return nullptr;
} }
{ {
// const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v); size_t memory_size_k = 0;
size_t memory_size = 0; size_t memory_size_v = 0;
for (auto & k : ctx->kv_self.k_l) { for (auto & k : ctx->kv_self.k_l) {
memory_size += ggml_nbytes(k); memory_size_k += ggml_nbytes(k);
} }
for (auto & v : ctx->kv_self.v_l) { for (auto & v : ctx->kv_self.v_l) {
memory_size += ggml_nbytes(v); memory_size_v += ggml_nbytes(v);
} }
LLAMA_LOG_INFO("%s: kv self size = %7.2f MiB\n", __func__, memory_size / 1024.0 / 1024.0);
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
} }
// resized during inference // resized during inference

View File

@ -191,6 +191,9 @@ extern "C" {
float yarn_beta_slow; // YaRN high correction dim float yarn_beta_slow; // YaRN high correction dim
uint32_t yarn_orig_ctx; // YaRN original context size uint32_t yarn_orig_ctx; // YaRN original context size
ggml_type type_k; // data type for K cache
ggml_type type_v; // data type for V cache
// Keep the booleans together to avoid misalignment during copy-by-value. // Keep the booleans together to avoid misalignment during copy-by-value.
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true) bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
bool f16_kv; // use fp16 for KV cache, fp32 otherwise bool f16_kv; // use fp16 for KV cache, fp32 otherwise