mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-27 12:33:06 +01:00
llama : apply suggestions
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
7b7db0bbee
commit
1fb5d4fdee
@ -3018,7 +3018,7 @@ struct llama_sbatch {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::sort(ids.begin(), ids.end(),
|
std::sort(ids.begin(), ids.end(),
|
||||||
[batch](size_t a, size_t b) {
|
[&batch](size_t a, size_t b) {
|
||||||
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
|
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
|
||||||
int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
|
int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
|
||||||
// sort by seq_id, then by pos
|
// sort by seq_id, then by pos
|
||||||
@ -3050,7 +3050,6 @@ struct llama_sbatch {
|
|||||||
if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
|
if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
|
||||||
for (size_t i = 0; i < n_tokens; ++i) {
|
for (size_t i = 0; i < n_tokens; ++i) {
|
||||||
const size_t bi = ids[i];
|
const size_t bi = ids[i];
|
||||||
const size_t s_len = seq.size();
|
|
||||||
const int32_t n_seqs = batch.n_seq_id[bi];
|
const int32_t n_seqs = batch.n_seq_id[bi];
|
||||||
llama_seq_id * seq_ids = batch.seq_id[bi];
|
llama_seq_id * seq_ids = batch.seq_id[bi];
|
||||||
if (last_seq != nullptr) {
|
if (last_seq != nullptr) {
|
||||||
@ -3067,7 +3066,7 @@ struct llama_sbatch {
|
|||||||
}
|
}
|
||||||
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id};
|
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id};
|
||||||
seq.push_back(new_seq);
|
seq.push_back(new_seq);
|
||||||
last_seq = &seq[s_len];
|
last_seq = &seq.back();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
|
llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
|
||||||
@ -15089,8 +15088,8 @@ static int llama_decode_internal(
|
|||||||
|
|
||||||
while (lctx.sbatch.n_tokens > 0) {
|
while (lctx.sbatch.n_tokens > 0) {
|
||||||
// For now, only use equal splits for recurrent model architectures
|
// For now, only use equal splits for recurrent model architectures
|
||||||
llama_ubatch u_batch = kv_self.recurrent ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch);
|
llama_ubatch ubatch = kv_self.recurrent ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch);
|
||||||
const uint32_t n_tokens = u_batch.n_tokens;
|
const uint32_t n_tokens = ubatch.n_tokens;
|
||||||
|
|
||||||
// count the outputs in this u_batch
|
// count the outputs in this u_batch
|
||||||
{
|
{
|
||||||
@ -15099,9 +15098,9 @@ static int llama_decode_internal(
|
|||||||
if (n_outputs == n_tokens_all) {
|
if (n_outputs == n_tokens_all) {
|
||||||
n_outputs_new = n_tokens;
|
n_outputs_new = n_tokens;
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(u_batch.output);
|
GGML_ASSERT(ubatch.output);
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
n_outputs_new += (int32_t) (u_batch.output[i] != 0);
|
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -15122,7 +15121,7 @@ static int llama_decode_internal(
|
|||||||
kv_self.head = 0;
|
kv_self.head = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!llama_kv_cache_find_slot(kv_self, u_batch)) {
|
if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -15141,7 +15140,7 @@ static int llama_decode_internal(
|
|||||||
ggml_backend_sched_reset(lctx.sched);
|
ggml_backend_sched_reset(lctx.sched);
|
||||||
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||||
|
|
||||||
ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
|
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
|
||||||
|
|
||||||
// the output is always the last tensor in the graph
|
// the output is always the last tensor in the graph
|
||||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||||
@ -15166,7 +15165,7 @@ static int llama_decode_internal(
|
|||||||
|
|
||||||
ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
||||||
|
|
||||||
llama_set_inputs(lctx, u_batch);
|
llama_set_inputs(lctx, ubatch);
|
||||||
|
|
||||||
llama_graph_compute(lctx, gf, n_threads);
|
llama_graph_compute(lctx, gf, n_threads);
|
||||||
|
|
||||||
@ -15229,7 +15228,7 @@ static int llama_decode_internal(
|
|||||||
embd_seq_out.clear();
|
embd_seq_out.clear();
|
||||||
|
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
const llama_seq_id seq_id = u_batch.seq_id[i][0];
|
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user