diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 9aede7fad..536657526 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -178,25 +178,27 @@ int main(int argc, char ** argv) { float * out = emb + p * n_embd; batch_decode(ctx, batch, out, s, n_embd); - // print the first part of the embeddings + // print the first part of the embeddings or for a single prompt, the full embedding fprintf(stdout, "\n"); for (int j = 0; j < n_prompts; j++) { fprintf(stdout, "embedding %d: ", j); - for (int i = 0; i < std::min(16, n_embd); i++) { + for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) { fprintf(stdout, "%9.6f ", emb[j * n_embd + i]); } fprintf(stdout, "\n"); } // print cosine similarity matrix - fprintf(stdout, "\n"); - printf("cosine similarity matrix:\n\n"); - for (int i = 0; i < n_prompts; i++) { - for (int j = 0; j < n_prompts; j++) { - float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); - fprintf(stdout, "%6.2f ", sim); - } + if (n_prompts > 1) { fprintf(stdout, "\n"); + printf("cosine similarity matrix:\n\n"); + for (int i = 0; i < n_prompts; i++) { + for (int j = 0; j < n_prompts; j++) { + float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); + fprintf(stdout, "%6.2f ", sim); + } + fprintf(stdout, "\n"); + } } // clean up