mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 13:27:21 +01:00
embedding : evaluate prompt in batches (#2713)
This commit is contained in:
parent
1123f7fbdf
commit
519c981f8b
@ -72,23 +72,30 @@ int main(int argc, char ** argv) {
|
|||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.embedding){
|
if (embd_inp.size() > (size_t)params.n_ctx) {
|
||||||
if (embd_inp.size() > 0) {
|
fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
|
||||||
if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads)) {
|
__func__, embd_inp.size(), params.n_ctx);
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
return 1;
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const int n_embd = llama_n_embd(ctx);
|
|
||||||
const auto embeddings = llama_get_embeddings(ctx);
|
|
||||||
|
|
||||||
for (int i = 0; i < n_embd; i++) {
|
|
||||||
printf("%f ", embeddings[i]);
|
|
||||||
}
|
|
||||||
printf("\n");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
while (!embd_inp.empty()) {
|
||||||
|
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
|
||||||
|
if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past, params.n_threads)) {
|
||||||
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
n_past += n_tokens;
|
||||||
|
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int n_embd = llama_n_embd(ctx);
|
||||||
|
const auto embeddings = llama_get_embeddings(ctx);
|
||||||
|
|
||||||
|
for (int i = 0; i < n_embd; i++) {
|
||||||
|
printf("%f ", embeddings[i]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user