llama : check all graph nodes when searching for result_embd_pooled (#8956)

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
This commit is contained in:
fairydreaming 2024-08-11 10:35:26 +02:00 committed by GitHub
parent 7c5bfd57f8
commit 33309f661a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14722,12 +14722,15 @@ static int llama_decode_internal(
res = nullptr; res = nullptr;
embd = nullptr; embd = nullptr;
} else if (cparams.embeddings) { } else if (cparams.embeddings) {
res = nullptr; // do not extract logits for embedding case res = nullptr; // do not extract logits for embedding case
embd = gf->nodes[gf->n_nodes - 1]; embd = nullptr;
if (strcmp(embd->name, "result_embd_pooled") != 0) { for (int i = gf->n_nodes - 1; i >= 0; --i) {
embd = gf->nodes[gf->n_nodes - 2]; if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
embd = gf->nodes[i];
break;
}
} }
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor"); GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
} else { } else {
embd = nullptr; // do not extract embeddings when not needed embd = nullptr; // do not extract embeddings when not needed
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor"); GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");