From 6b90566052611e1134c2debe987869666e59f427 Mon Sep 17 00:00:00 2001 From: Theia Vogel Date: Sat, 9 Mar 2024 20:22:37 -0800 Subject: [PATCH] control vector api and implementation --- common/common.cpp | 217 ++++++++++++++++++++++++++++++++++++++++++++++ common/common.h | 12 +++ llama.cpp | 121 ++++++++++++++++++++++++++ llama.h | 14 +++ 4 files changed, 364 insertions(+) diff --git a/common/common.cpp b/common/common.cpp index 2f38ac632..6a4ec30dd 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -562,6 +562,35 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.lora_base = argv[i]; + } else if (arg == "--control-vector") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.control_vectors.push_back(std::make_tuple(argv[i], 1.0f)); + } else if (arg == "--control-vector-scaled") { + if (++i >= argc) { + invalid_param = true; + break; + } + const char * control_vector = argv[i]; + if (++i >= argc) { + invalid_param = true; + break; + } + params.control_vectors.push_back(std::make_tuple(control_vector, std::stof(argv[i]))); + } else if (arg == "--control-vector-layer-range") { + if (++i >= argc) { + invalid_param = true; + break; + } + int32_t start = std::stoi(argv[i]); + if (++i >= argc) { + invalid_param = true; + break; + } + int32_t end = std::stoi(argv[i]); + params.control_vector_layer_range = std::make_tuple(start, end); } else if (arg == "--mmproj") { if (++i >= argc) { invalid_param = true; @@ -1087,6 +1116,12 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n"); printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); + printf(" --control-vector FNAME\n"); + printf(" add a control vector\n"); + printf(" --control-vector-scaled FNAME S\n"); + printf(" add a control vector with user defined scaling S\n"); + printf(" --control-vector-layer-range START END\n"); + printf(" layer range to apply the control vector(s) to, start and end inclusive\n"); printf(" -m FNAME, --model FNAME\n"); printf(" model path (default: %s)\n", params.model.c_str()); printf(" -md FNAME, --model-draft FNAME\n"); @@ -1351,6 +1386,35 @@ std::tuple llama_init_from_gpt_par return std::make_tuple(nullptr, nullptr); } + if (!params.control_vectors.empty()) { + int32_t layer_start, layer_end; + std::tie(layer_start, layer_end) = params.control_vector_layer_range; + + if (layer_start == 0) layer_start = 1; + if (layer_end == 0) layer_end = 31; + + std::vector control_vector; + int n_embd; + std::tie(control_vector, n_embd) = llama_control_vector_load(params.control_vectors); + if (n_embd == -1) { + llama_free(lctx); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); + } + + int err = llama_control_vector_apply(lctx, + control_vector.data(), + control_vector.size(), + n_embd, + layer_start, + layer_end); + if (err) { + llama_free(lctx); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); + } + } + for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]); float lora_scale = std::get<1>(params.lora_adapter[i]); @@ -1867,3 +1931,156 @@ void llama_embd_normalize(const float * inp, float * out, int n) { } } +// +// Control vector utils +// + +static std::tuple, int> llama_control_vector_load_one(const std::string & path, float strength) { + int n_tensors; + size_t n_bytes = 0; + uint32_t max_direction_layer = 0; + int n_embd = -1; + + // calculate size of ctx needed for tensors, ensure tensors are f32, and find max layer + { + struct ggml_init_params meta_params = { + /* .mem_size = */ ggml_tensor_overhead() * 128 + ggml_graph_overhead(), + /* .mem_buffer = */ nullptr, + /* .no_alloc = */ true, + }; + ggml_context * meta_ctx = ggml_init(meta_params); + struct gguf_init_params meta_gguf_params = { + /* .no_alloc = */ true, + /* .ctx = */ &meta_ctx, + }; + struct gguf_context * meta_ctx_gguf = gguf_init_from_file(path.c_str(), meta_gguf_params); + if (!meta_ctx_gguf) { + fprintf(stderr, "%s: failed to load control vector from %s\n", __func__, path.c_str()); + ggml_free(meta_ctx); + return std::make_tuple(std::vector(), -1); + } + + n_tensors = gguf_get_n_tensors(meta_ctx_gguf); + for (int i = 0; i < n_tensors; i++) { + std::string name = gguf_get_tensor_name(meta_ctx_gguf, i); + + // split on '.' + size_t dotpos = name.find('.'); + if (dotpos != std::string::npos && name.substr(0, dotpos) == "direction") { + try { + uint32_t layer = std::stoi(name.substr(dotpos + 1)); + if (layer == 0) { + fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, path.c_str()); + ggml_free(meta_ctx); + gguf_free(meta_ctx_gguf); + return std::make_tuple(std::vector(), -1); + } + if (layer > max_direction_layer) { + max_direction_layer = layer; + } + } catch (...) { + fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, path.c_str()); + ggml_free(meta_ctx); + gguf_free(meta_ctx_gguf); + return std::make_tuple(std::vector(), -1); + } + } + + struct ggml_tensor * tensor_meta = ggml_get_tensor(meta_ctx, name.c_str()); + if (tensor_meta->type != GGML_TYPE_F32 || ggml_n_dims(tensor_meta) != 1) { + fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, path.c_str()); + ggml_free(meta_ctx); + gguf_free(meta_ctx_gguf); + return std::make_tuple(std::vector(), -1); + } + if (n_embd == -1) { + n_embd = ggml_nelements(tensor_meta); + } else if (ggml_nelements(tensor_meta) != n_embd) { + fprintf(stderr, "%s: direction tensor sizes mismatched in %s\n", __func__, path.c_str()); + ggml_free(meta_ctx); + gguf_free(meta_ctx_gguf); + return std::make_tuple(std::vector(), -1); + } + n_bytes += ggml_nbytes(tensor_meta); + } + ggml_free(meta_ctx); + gguf_free(meta_ctx_gguf); + } + + if (n_tensors == 0) { + fprintf(stderr, "%s: no direction tensors found in %s\n", __func__, path.c_str()); + return std::make_tuple(std::vector(), -1); + } + + // load and scale tensors into final control vector context + struct ggml_init_params ggml_params = { + /* .mem_size = */ ggml_tensor_overhead() * n_tensors + n_bytes, + /* .mem_buffer = */ nullptr, + /* .no_alloc = */ false, + }; + struct ggml_context * ctx = ggml_init(ggml_params); + + struct gguf_init_params params = { + /*.no_alloc = */ false, + /*.ctx = */ &ctx, + }; + struct gguf_context * ctx_gguf = gguf_init_from_file(path.c_str(), params); + if (!ctx_gguf) { + fprintf(stderr, "%s: failed to load control vector from %s\n", __func__, path.c_str()); + ggml_free(ctx); + return std::make_tuple(std::vector(), -1); + } + + std::vector vector; + for (uint32_t i = 1; i < max_direction_layer; i++) { + std::string name = "direction." + std::to_string(i); + ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str()); + if (tensor) { + const float * data = (const float *) tensor->data; + for (int i = 0; i < n_embd; i++) { + vector.push_back(data[i] * strength); + } + } else { + vector.insert(vector.end(), n_embd, 0.); // as a filler + } + } + + return std::make_tuple(vector, n_embd); +} + +std::tuple, int> llama_control_vector_load(const std::vector> & vectors) { + std::vector vector; + int n_embd = -1; + + for (const auto& pair : vectors) { + std::string path; + float strength; + std::tie(path, strength) = pair; + + std::vector v; + int v_n_embd; + std::tie(v, v_n_embd) = llama_control_vector_load_one(path, strength); + + if (v_n_embd == -1) { + return std::make_tuple(std::vector(), -1); + } + if (n_embd != -1 && (n_embd != v_n_embd || v.size() != vector.size())) { + fprintf(stderr, "%s: control vector in %s does not match previous vector dimensions\n", __func__, path.c_str()); + return std::make_tuple(std::vector(), -1); + } + + if (n_embd == -1) { + vector = std::move(v); + n_embd = v_n_embd; + } else { + for (size_t i = 0; i < vector.size(); i++) { + vector[i] += v[i]; + } + } + } + + if (n_embd == -1) { + fprintf(stderr, "%s: no vectors passed\n", __func__); + } + return std::make_tuple(vector, n_embd); +} diff --git a/common/common.h b/common/common.h index f8d82b871..2ea867553 100644 --- a/common/common.h +++ b/common/common.h @@ -102,6 +102,9 @@ struct gpt_params { std::vector> lora_adapter; // lora adapter path with user defined scale std::string lora_base = ""; // base model path for the lora adapter + std::vector> control_vectors; // control vector with user defined scale + std::tuple control_vector_layer_range; // layer range for control vector + int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line // (which is more convenient to use for plotting) @@ -267,3 +270,12 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40 void llama_embd_normalize(const float * inp, float * out, int n); +// +// Control vector utils +// + +// Load control vectors from a tuple of {path, strength}, scale each by strength, and add them together. +// Returns a tuple of {concatenated vector data (n_emnd x n_layer), n_embd} +// On error, returns a tuple of {empty, -1} +std::tuple, int> llama_control_vector_load( + const std::vector> & vectors); diff --git a/llama.cpp b/llama.cpp index ad7b7b7d4..91e524518 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1885,6 +1885,31 @@ struct llama_kv_cache { } }; +struct llama_control_vector { + std::vector tensors; // per layer + std::vector ctxs; + std::vector bufs; + + int32_t layer_start = 0; + int32_t layer_end = 0; + + ggml_tensor * tensor_for(int il) const { + if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) { + return nullptr; + } + return tensors[il]; + } + + ~llama_control_vector() { + for (struct ggml_context * ctx : ctxs) { + ggml_free(ctx); + } + for (ggml_backend_buffer_t buf : bufs) { + ggml_backend_buffer_free(buf); + } + } +}; + struct llama_vocab { using id = int32_t; using token = std::string; @@ -2093,6 +2118,9 @@ struct llama_context { struct ggml_tensor * inp_s_mask; // F32 [kv_size] struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch] + // control vectors + struct llama_control_vector cvec; + #ifdef GGML_USE_MPI ggml_mpi_context * ctx_mpi = NULL; #endif @@ -5772,6 +5800,12 @@ struct llm_build_context { } cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx0, cur, layer_dir); + } cb(cur, "l_out", il); // input for next layer @@ -13188,6 +13222,93 @@ int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const } } +static bool llama_control_vector_init(struct llama_control_vector & cvec, const llama_model & model) { + GGML_ASSERT(cvec.tensors.empty()); + GGML_ASSERT(cvec.ctxs.empty()); + GGML_ASSERT(cvec.bufs.empty()); + + // count layer buffer types + std::map buft_layer_count; + for (int64_t i = 0; i < model.hparams.n_layer; i++) { + buft_layer_count[model.buft_layer[i].buft]++; + } + + // allocate contexts + std::map ctx_map; + for (auto & it : buft_layer_count) { + int n_layers = it.second; + struct ggml_init_params params = { + /*.mem_size =*/ n_layers * ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__); + return 1; + } + ctx_map[it.first] = ctx; + } + + // make tensors + cvec.tensors.push_back(nullptr); // there's never a tensor for layer 0 + for (size_t il = 1; il < model.hparams.n_layer; il++) { + struct ggml_context * ctx = ctx_map.at(model.buft_layer[il].buft); + ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_embd); + cvec.tensors.push_back(tensor); + } + + // allocate tensors / buffers and zero + for (auto it : ctx_map) { + ggml_backend_buffer_type_t buft = it.first; + ggml_context * ctx = it.second; + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + LLAMA_LOG_ERROR("%s: failed to allocate buffer for control vector\n", __func__); + return false; + } + ggml_backend_buffer_clear(buf, 0); + cvec.ctxs.push_back(ctx); + cvec.bufs.push_back(buf); + } + + return true; +} + +int32_t llama_control_vector_apply(struct llama_context * lctx, float * data, size_t len, int n_embd, int32_t il_start, int32_t il_end) { + const llama_model & model = lctx->model; + llama_control_vector & cvec = lctx->cvec; + + if (n_embd != (int) model.hparams.n_embd) { + LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__); + return 1; + } + + if (cvec.tensors.empty()) { + if (!llama_control_vector_init(cvec, model)) { + return 1; + } + } + + cvec.layer_start = il_start; + cvec.layer_end = il_end; + + for (size_t il = 1; il < model.hparams.n_layer; il++) { + if (il >= cvec.tensors.size() || cvec.tensors[il] == nullptr) { + continue; + } + size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present + if (off + n_embd <= len) { + ggml_backend_tensor_set(cvec.tensors[il], + data + off, + 0, + n_embd * ggml_element_size(cvec.tensors[il])); + } + } + + return 0; +} + struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max) { struct llama_kv_cache_view result = { /*.n_cells = */ 0, diff --git a/llama.h b/llama.h index 446899da6..cb946b752 100644 --- a/llama.h +++ b/llama.h @@ -437,6 +437,20 @@ extern "C" { const char * path_base_model, int32_t n_threads); + // Apply a loaded control vector to a llama_context, or if data is NULL, clear + // the currently loaded vector. + // n_embd should be the size of a single layer's control, and data should point + // to an n_embd x n_layers buffer starting from layer 1. + // il_start and il_end are the layer range the vector should apply to (both inclusive) + // See llama_control_vector_load in common to load a control vector. + LLAMA_API int32_t llama_control_vector_apply( + struct llama_context * lctx, + float * data, + size_t len, + int n_embd, + int32_t il_start, + int32_t il_end); + // // KV cache //