diff --git a/src/llama.cpp b/src/llama.cpp index fc7600944..712c12f54 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -14756,7 +14756,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } } - for (int i = 0; i < n_seqs; ++i) { + for (int i = 0; i < n_tokens; ++i) { if (last_row[i] >= 0) { data[i] = last_row[i]; } @@ -14942,6 +14942,43 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) { return n_outputs_max; } +// make the outputs have the same order they had in the user-provided batch +static void llama_output_reorder(struct llama_context * ctx) { + std::vector & out_ids = ctx->sbatch.out_ids; + if (!out_ids.empty()) { + uint32_t n_vocab = ctx->model.hparams.n_vocab; + uint32_t n_embd = ctx->model.hparams.n_embd; + int32_t n_outputs = ctx->n_outputs; + GGML_ASSERT((size_t) n_outputs == out_ids.size()); + // TODO: is there something more efficient which also minimizes swaps? + // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) + for (int32_t i = 0; i < n_outputs - 1; ++i) { + int32_t j_min = i; + for (int32_t j = i + 1; j < n_outputs; ++j) { + if (out_ids[j] < out_ids[j_min]) { + j_min = j; + } + } + if (j_min == i) { continue; } + std::swap(out_ids[i], out_ids[j_min]); + if (ctx->logits_size > 0) { + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]); + } + } + if (ctx->embd_size > 0) { + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]); + } + } + } + std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1); + for (int32_t i = 0; i < n_outputs; ++i) { + ctx->output_ids[out_ids[i]] = i; + } + out_ids.clear(); + } +} static void llama_graph_compute( llama_context & lctx, @@ -15180,8 +15217,8 @@ static int llama_decode_internal( auto & embd_seq_out = lctx.embd_seq; embd_seq_out.clear(); - for (uint32_t i = 0; i < n_tokens; i++) { - const llama_seq_id seq_id = ubatch.seq_id[i][0]; + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { continue; } @@ -15631,44 +15668,6 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } } -// make the outputs have the same order they had in the user-provided batch -static void llama_reorder_outputs(struct llama_context * ctx) { - std::vector & out_ids = ctx->sbatch.out_ids; - if (!out_ids.empty()) { - uint32_t n_vocab = ctx->model.hparams.n_vocab; - uint32_t n_embd = ctx->model.hparams.n_embd; - int32_t n_outputs = ctx->n_outputs; - GGML_ASSERT((size_t) n_outputs == out_ids.size()); - // TODO: is there something more efficient which also minimizes swaps? - // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) - for (int32_t i = 0; i < n_outputs - 1; ++i) { - int32_t j_min = i; - for (int32_t j = i + 1; j < n_outputs; ++j) { - if (out_ids[j] < out_ids[j_min]) { - j_min = j; - } - } - if (j_min == i) { continue; } - std::swap(out_ids[i], out_ids[j_min]); - if (ctx->logits_size > 0) { - for (uint32_t k = 0; k < n_vocab; k++) { - std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]); - } - } - if (ctx->embd_size > 0) { - for (uint32_t k = 0; k < n_embd; k++) { - std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]); - } - } - } - std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1); - for (int32_t i = 0; i < n_outputs; ++i) { - ctx->output_ids[out_ids[i]] = i; - } - out_ids.clear(); - } -} - // // quantization // @@ -17855,7 +17854,7 @@ struct llama_data_write { } void write_output_ids(struct llama_context * ctx) { - llama_reorder_outputs(ctx); + llama_output_reorder(ctx); const uint32_t n_outputs = ctx->n_outputs; @@ -18891,7 +18890,7 @@ float * llama_get_logits(struct llama_context * ctx) { // reorder logits for backward compatibility // TODO: maybe deprecate this - llama_reorder_outputs(ctx); + llama_output_reorder(ctx); return ctx->logits; } @@ -18939,7 +18938,7 @@ float * llama_get_embeddings(struct llama_context * ctx) { // reorder embeddings for backward compatibility // TODO: maybe deprecate this - llama_reorder_outputs(ctx); + llama_output_reorder(ctx); return ctx->embd; }