From b54afce9f42652e59c81e978e7cd5f9861a7d4e9 Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Sat, 9 Mar 2024 13:03:46 -0600 Subject: [PATCH] mostly style fixes; fix KQ_mask comment --- examples/gritlm/gritlm.cpp | 68 ++++++++++++++++++++------------------ llama.cpp | 2 +- 2 files changed, 36 insertions(+), 34 deletions(-) diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 2f98d45f4..9b75d4f82 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -4,7 +4,9 @@ #include #include -static float dot_product(const std::vector& v1, const std::vector& v2) { +// #define GRIT_DEBUG + +static float dot_product(const std::vector & v1, const std::vector & v2) { float dot = 0.0f; for (uint64_t i = 0; i < v1.size(); ++i) { dot += v1[i] * v2[i]; @@ -12,22 +14,22 @@ static float dot_product(const std::vector& v1, const std::vector& return dot; } -static float norm(const std::vector& v) { +static float norm(const std::vector & v) { return std::sqrt(dot_product(v, v)); } -static float cosine_similarity(const std::vector& v1, const std::vector& v2) { +static float cosine_similarity(const std::vector & v1, const std::vector & v2) { return dot_product(v1, v2) / (norm(v1) * norm(v2)); } -static void normalize(const std::vector& in, float* out) { +static void normalize(const std::vector & in, float * out) { float inorm = norm(in); for (uint64_t i = 0; i < in.size(); i++) { out[i] = in[i] / inorm; } } -static std::vector> encode(llama_context* ctx, const std::vector& sentences, const std::string& instruction) { +static std::vector> encode(llama_context * ctx, const std::vector & sentences, const std::string & instruction) { auto result = std::vector>{}; auto mdl = llama_get_model(ctx); @@ -40,21 +42,21 @@ static std::vector> encode(llama_context* ctx, const std::vec std::vector inputs = llama_tokenize(mdl, input_string, true, false); auto n_toks = (int32_t)inputs.size(); - // testing with and without EOS - unexpected embeddings in both cases - GritLM seems to have EOS = "" - // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L116 + // GritLM seems to have embed EOS = "" + // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18 // inputs.push_back(llama_token_eos(mdl)); // we want to ignore instruction tokens for mean pooling std::vector inputs_instruct = llama_tokenize(mdl, instruction, true, false); auto n_inst = (int32_t)inputs_instruct.size(); - /* +#ifdef GRIT_DEBUG // debug tokens - should be matching as referenced in the GritLM sample std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) { std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str()); }); std::printf("\n"); - */ +#endif // add input to batch (this increments n_tokens) for (int32_t j = 0; j < n_toks; j++) { @@ -75,7 +77,7 @@ static std::vector> encode(llama_context* ctx, const std::vec // sum up all token embeddings for (int32_t k = n_inst; k < n_toks; k++) { - float* emb = llama_get_embeddings_ith(ctx, k); + float * emb = llama_get_embeddings_ith(ctx, k); for (uint64_t j = 0; j < n_embd; j++) { emb_unorm[j] += emb[j]; } @@ -91,24 +93,24 @@ static std::vector> encode(llama_context* ctx, const std::vec normalize(emb_unorm, emb_norm.data()); result.push_back(emb_norm); - /* +#ifdef GRIT_DEBUG // print out emb_norm std::printf("embedding %ld: ", i); for (uint64_t j = 0; j < n_embd; j++) { std::printf("%.5f ", emb_norm[j]); } std::printf("\n\n"); - */ +#endif } llama_batch_free(batch); return result; } -static std::string aggregate_pieces(const std::vector& pieces) { +static std::string aggregate_pieces(const std::vector & pieces) { // calculate total length required size_t length = 0; - for (const auto& str : pieces) { + for (const auto & str : pieces) { length += str.size(); } @@ -117,17 +119,18 @@ static std::string aggregate_pieces(const std::vector& pieces) { result.reserve(length); // append pieces - for (const auto& str : pieces) { + for (const auto & str : pieces) { result += str; } return result; } -static std::string generate(llama_context* ctx, const std::string& prompt, bool stream) { +static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) { std::vector pieces; - const llama_model* mdl = llama_get_model(ctx); + const llama_model * mdl = llama_get_model(ctx); + llama_token eos_token = llama_token_eos(mdl); llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); std::vector inputs = llama_tokenize(mdl, prompt, false, true); @@ -135,25 +138,24 @@ static std::string generate(llama_context* ctx, const std::string& prompt, bool while (true) { llama_batch_clear(bat); - - for (auto i = 0; i < inputs.size(); i++) + for (auto i = 0; i < inputs.size(); i++) { llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == inputs.size() - 1); - + } inputs.clear(); llama_decode(ctx, bat); - auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); auto candidates = std::vector(llama_n_vocab(mdl)); - for (auto token = 0; token < candidates.size(); token++) + for (auto token = 0; token < candidates.size(); token++) { candidates[token] = llama_token_data{ token, logits[token], 0.0f }; - + } auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false }; llama_token token = llama_sample_token_greedy(ctx, &candidates_p); - if (token == llama_token_eos(mdl)) + if (token == eos_token) { break; + } std::string piece = llama_token_to_piece(ctx, token); if (stream) { @@ -169,11 +171,11 @@ static std::string generate(llama_context* ctx, const std::string& prompt, bool return aggregate_pieces(pieces); } -static std::string gritlm_instruction(const std::string& instruction) { +static std::string gritlm_instruction(const std::string & instruction) { return !instruction.empty() ? "<|user|>\n" + instruction + "\n<|embed|>\n" : "<|embed|>\n"; } -int main(int argc, char* argv[]) +int main(int argc, char * argv[]) { gpt_params params; if (!gpt_params_parse(argc, argv, params)) { @@ -185,17 +187,17 @@ int main(int argc, char* argv[]) llama_backend_init(); - llama_model* mdl = llama_load_model_from_file(params.model.c_str(), mparams); + llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams); // create new context - set to embedding mode - llama_context* embd_ctx = llama_new_context_with_model(mdl, cparams); + llama_context * embd_ctx = llama_new_context_with_model(mdl, cparams); llama_set_embeddings(embd_ctx, true); // create new context - default mode is causal - llama_context* causal_ctx = llama_new_context_with_model(mdl, cparams); + llama_context * causal_ctx = llama_new_context_with_model(mdl, cparams); - // ### Embedding/Representation ### samples taken from here: - // https://github.com/ContextualAI/gritlm?tab=readme-ov-file#basic + // samples taken from here: https://github.com/ContextualAI/gritlm#basic + // Embedding/Representation { std::string instruction = "Given a scientific paper title, retrieve the paper's abstract"; @@ -224,8 +226,8 @@ int main(int argc, char* argv[]) std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1); } - // ### Generation ### - // # GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction + // Generation + // GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction { const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n"; std::string response = generate(causal_ctx, prompt, true); diff --git a/llama.cpp b/llama.cpp index 79171c749..991e1e673 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8061,7 +8061,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } else { - // with causal attention, the mask needs to match the kv cache size + // for models using the kv cache, the mask needs to match the kv cache size const int64_t n_tokens = batch.n_tokens; const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;