mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 22:08:55 +01:00
llama : batch
ggml-ci
This commit is contained in:
parent
a7df0714db
commit
7035c79fb5
@ -1 +1,306 @@
|
||||
#include "llama-batch.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
|
||||
llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
|
||||
// clear empty sequences
|
||||
// the previous ubatch is assumed to be gone,
|
||||
// so nothing should refer to values in these sequences anymore.
|
||||
for (size_t i = seq.size(); i-- > 0;) {
|
||||
if (seq[i].length == 0) {
|
||||
seq.pop_back();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
ubatch_token.resize(!has_embd ? n_ubatch : 0);
|
||||
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
|
||||
ubatch_pos.resize(n_ubatch);
|
||||
ubatch_n_seq_id.resize(n_ubatch);
|
||||
ubatch_seq_id.resize(n_ubatch);
|
||||
ubatch_output.resize(n_ubatch);
|
||||
llama_ubatch ubatch = {
|
||||
/*equal_seqs =*/ true,
|
||||
/*n_tokens =*/ 0,
|
||||
/*n_seq_tokens =*/ 0,
|
||||
/*n_seqs =*/ 0,
|
||||
/*token =*/ !has_embd ? ubatch_token.data() : nullptr,
|
||||
/*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
|
||||
/*pos =*/ ubatch_pos.data(),
|
||||
/*n_seq_id =*/ ubatch_n_seq_id.data(),
|
||||
/*seq_id =*/ ubatch_seq_id.data(),
|
||||
/*output =*/ ubatch_output.data(),
|
||||
};
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
|
||||
GGML_ASSERT(batch != nullptr);
|
||||
GGML_ASSERT(length <= seq.length);
|
||||
// Can only add sequences of equal lengths to a batch,
|
||||
// otherwise it isn't clear to which sequence a token belongs
|
||||
GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
|
||||
GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
|
||||
// NOTE: loops are separated for cache-friendliness
|
||||
if (batch->token) {
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
ubatch.token = batch->token + seq.offset;
|
||||
}
|
||||
} else {
|
||||
ubatch.token = nullptr;
|
||||
}
|
||||
if (batch->embd) {
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
memcpy(
|
||||
ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
|
||||
batch->embd + (n_embd * ids[seq.offset + i]),
|
||||
n_embd * sizeof(float)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
ubatch.embd = batch->embd + (n_embd * seq.offset);
|
||||
}
|
||||
} else {
|
||||
ubatch.embd = nullptr;
|
||||
}
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
ubatch.pos = batch->pos + seq.offset;
|
||||
}
|
||||
if (ubatch.equal_seqs) {
|
||||
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
|
||||
if (seq.seq_id) {
|
||||
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
if (batch->n_seq_id) {
|
||||
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
|
||||
} else {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
|
||||
}
|
||||
}
|
||||
if (batch->seq_id) {
|
||||
ubatch.seq_id = batch->seq_id + seq.offset;
|
||||
}
|
||||
}
|
||||
if (logits_all) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.output[ubatch.n_tokens + i] = 1;
|
||||
out_ids.push_back(ids[seq.offset + i]);
|
||||
}
|
||||
} else if (batch->logits) {
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
size_t id = ids[seq.offset + i];
|
||||
int8_t is_output = batch->logits[id];
|
||||
ubatch.output[ubatch.n_tokens + i] = is_output;
|
||||
if (is_output) { out_ids.push_back(id); }
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
ubatch.output = batch->logits + seq.offset;
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// only get last output
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
size_t id = ids[seq.offset + i];
|
||||
int8_t is_last = id == ids.size() - 1;
|
||||
ubatch.output[ubatch.n_tokens + i] = is_last;
|
||||
if (is_last) { out_ids.push_back(id); }
|
||||
}
|
||||
}
|
||||
if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
|
||||
ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
|
||||
}
|
||||
ubatch.n_tokens += length;
|
||||
ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
|
||||
seq.offset += length;
|
||||
seq.length -= length;
|
||||
n_tokens -= length;
|
||||
GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
|
||||
}
|
||||
|
||||
llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
ubatch.equal_seqs = false;
|
||||
if (!seq.empty()) {
|
||||
llama_sbatch_seq & s = seq[0];
|
||||
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
|
||||
add_seq_to_ubatch(ubatch, s, length);
|
||||
}
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
if (!seq.empty()) {
|
||||
size_t length = 0;
|
||||
size_t n_tokens_in_ubatch = 0;
|
||||
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
|
||||
// smallest first, because it's easier to split this way;
|
||||
// starting from the end to pop in constant time.
|
||||
for (size_t i = seq.size(); i-- > 0;) {
|
||||
llama_sbatch_seq & s = seq[i];
|
||||
GGML_ASSERT(s.length > 0);
|
||||
if (length == 0) {
|
||||
length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||
}
|
||||
add_seq_to_ubatch(ubatch, s, length);
|
||||
n_tokens_in_ubatch += length;
|
||||
// shared prompts can't be mixed with any of their sequences,
|
||||
// so it's safer to compute them in their own ubatch
|
||||
if (s.n_seq_id > 1) { break; }
|
||||
// stop when there isn't enough space for another sequence
|
||||
if (length + n_tokens_in_ubatch > n_ubatch) { break; }
|
||||
}
|
||||
}
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
if (!seq.empty()) {
|
||||
llama_sbatch_seq & s = seq[seq.size() - 1];
|
||||
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
|
||||
add_seq_to_ubatch(ubatch, s, length);
|
||||
}
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
|
||||
GGML_ASSERT(batch.n_tokens >= 0);
|
||||
this->batch = &batch;
|
||||
this->n_embd = n_embd;
|
||||
this->logits_all = logits_all;
|
||||
|
||||
n_tokens = batch.n_tokens;
|
||||
ids.resize(n_tokens);
|
||||
out_ids.clear();
|
||||
// TODO: reserve out_ids and seq
|
||||
|
||||
for (size_t i = 0; i < n_tokens; ++i) {
|
||||
ids[i] = i;
|
||||
}
|
||||
if (simple_split) {
|
||||
seq.resize(1);
|
||||
llama_sbatch_seq & s = seq[0];
|
||||
s.n_seq_id = 0;
|
||||
s.seq_id = nullptr;
|
||||
s.offset = 0;
|
||||
s.length = n_tokens;
|
||||
return;
|
||||
}
|
||||
std::sort(ids.begin(), ids.end(),
|
||||
[&batch](size_t a, size_t b) {
|
||||
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
|
||||
int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
|
||||
// sort by seq_id, then by pos
|
||||
if (n_seq_a == n_seq_b) {
|
||||
if (batch.seq_id) {
|
||||
for (int32_t i = 0; i < n_seq_a; ++i) {
|
||||
llama_seq_id seq_id_a = batch.seq_id[a][i];
|
||||
llama_seq_id seq_id_b = batch.seq_id[b][i];
|
||||
// smaller seq_ids go first
|
||||
if (seq_id_a != seq_id_b) {
|
||||
return seq_id_a < seq_id_b;
|
||||
}
|
||||
}
|
||||
}
|
||||
// when all else is equal, sort by pos
|
||||
if (batch.pos) {
|
||||
return batch.pos[a] < batch.pos[b];
|
||||
}
|
||||
// no pos, sort by id
|
||||
return a < b;
|
||||
}
|
||||
// shared prompts go first
|
||||
return n_seq_a > n_seq_b;
|
||||
}
|
||||
);
|
||||
// init seq
|
||||
llama_sbatch_seq * last_seq = nullptr;
|
||||
|
||||
for (size_t i = 0; i < n_tokens; ++i) {
|
||||
const size_t bi = ids[i];
|
||||
const int32_t n_seqs = batch.n_seq_id[bi];
|
||||
llama_seq_id * seq_ids = batch.seq_id[bi];
|
||||
if (last_seq != nullptr) {
|
||||
bool same = n_seqs == last_seq->n_seq_id;
|
||||
for (int32_t j = 0; same && j < n_seqs; ++j) {
|
||||
if (seq_ids[j] != last_seq->seq_id[j]) {
|
||||
same = false;
|
||||
}
|
||||
}
|
||||
if (same) {
|
||||
last_seq->length += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
|
||||
seq.push_back(new_seq);
|
||||
last_seq = &seq.back();
|
||||
}
|
||||
// keep shared prompts first at the end, then sort by length descending.
|
||||
std::sort(seq.begin(), seq.end(),
|
||||
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
|
||||
if (a.n_seq_id == b.n_seq_id) {
|
||||
return a.length > b.length;
|
||||
}
|
||||
return a.n_seq_id < b.n_seq_id;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
|
||||
batch = in_batch;
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
if (!batch.pos) {
|
||||
pos.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
pos[i] = i + p0;
|
||||
}
|
||||
batch.pos = pos.data();
|
||||
}
|
||||
if (!batch.n_seq_id) {
|
||||
n_seq_id.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
n_seq_id[i] = seq_id_0.size();
|
||||
}
|
||||
batch.n_seq_id = n_seq_id.data();
|
||||
}
|
||||
if (!batch.seq_id) {
|
||||
seq_id.resize(batch.n_tokens + 1);
|
||||
seq_id[batch.n_tokens] = NULL;
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
seq_id[i] = seq_id_0.data();
|
||||
}
|
||||
batch.seq_id = seq_id.data();
|
||||
}
|
||||
if (!batch.logits) {
|
||||
logits.resize(batch.n_tokens);
|
||||
logits[logits.size() - 1] = true;
|
||||
batch.logits = logits.data();
|
||||
}
|
||||
}
|
||||
|
@ -2,9 +2,8 @@
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
|
||||
// very similar to llama_batch,
|
||||
// but has more metadata about sequences
|
||||
@ -58,277 +57,33 @@ struct llama_sbatch {
|
||||
std::vector<llama_seq_id *> ubatch_seq_id;
|
||||
std::vector<int8_t> ubatch_output;
|
||||
|
||||
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false) {
|
||||
// clear empty sequences
|
||||
// the previous ubatch is assumed to be gone,
|
||||
// so nothing should refer to values in these sequences anymore.
|
||||
for (size_t i = seq.size(); i-- > 0;) {
|
||||
if (seq[i].length == 0) {
|
||||
seq.pop_back();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
ubatch_token.resize(!has_embd ? n_ubatch : 0);
|
||||
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
|
||||
ubatch_pos.resize(n_ubatch);
|
||||
ubatch_n_seq_id.resize(n_ubatch);
|
||||
ubatch_seq_id.resize(n_ubatch);
|
||||
ubatch_output.resize(n_ubatch);
|
||||
llama_ubatch ubatch = {
|
||||
/*equal_seqs =*/ true,
|
||||
/*n_tokens =*/ 0,
|
||||
/*n_seq_tokens =*/ 0,
|
||||
/*n_seqs =*/ 0,
|
||||
/*token =*/ !has_embd ? ubatch_token.data() : nullptr,
|
||||
/*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
|
||||
/*pos =*/ ubatch_pos.data(),
|
||||
/*n_seq_id =*/ ubatch_n_seq_id.data(),
|
||||
/*seq_id =*/ ubatch_seq_id.data(),
|
||||
/*output =*/ ubatch_output.data(),
|
||||
};
|
||||
return ubatch;
|
||||
}
|
||||
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
|
||||
|
||||
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
|
||||
GGML_ASSERT(batch != nullptr);
|
||||
GGML_ASSERT(length <= seq.length);
|
||||
// Can only add sequences of equal lengths to a batch,
|
||||
// otherwise it isn't clear to which sequence a token belongs
|
||||
GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
|
||||
GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
|
||||
// NOTE: loops are separated for cache-friendliness
|
||||
if (batch->token) {
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
ubatch.token = batch->token + seq.offset;
|
||||
}
|
||||
} else {
|
||||
ubatch.token = nullptr;
|
||||
}
|
||||
if (batch->embd) {
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
memcpy(
|
||||
ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
|
||||
batch->embd + (n_embd * ids[seq.offset + i]),
|
||||
n_embd * sizeof(float)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
ubatch.embd = batch->embd + (n_embd * seq.offset);
|
||||
}
|
||||
} else {
|
||||
ubatch.embd = nullptr;
|
||||
}
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
ubatch.pos = batch->pos + seq.offset;
|
||||
}
|
||||
if (ubatch.equal_seqs) {
|
||||
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
|
||||
if (seq.seq_id) {
|
||||
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
if (batch->n_seq_id) {
|
||||
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
|
||||
} else {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
|
||||
}
|
||||
}
|
||||
if (batch->seq_id) {
|
||||
ubatch.seq_id = batch->seq_id + seq.offset;
|
||||
}
|
||||
}
|
||||
if (logits_all) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.output[ubatch.n_tokens + i] = 1;
|
||||
out_ids.push_back(ids[seq.offset + i]);
|
||||
}
|
||||
} else if (batch->logits) {
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
size_t id = ids[seq.offset + i];
|
||||
int8_t is_output = batch->logits[id];
|
||||
ubatch.output[ubatch.n_tokens + i] = is_output;
|
||||
if (is_output) { out_ids.push_back(id); }
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
ubatch.output = batch->logits + seq.offset;
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// only get last output
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
size_t id = ids[seq.offset + i];
|
||||
int8_t is_last = id == ids.size() - 1;
|
||||
ubatch.output[ubatch.n_tokens + i] = is_last;
|
||||
if (is_last) { out_ids.push_back(id); }
|
||||
}
|
||||
}
|
||||
if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
|
||||
ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
|
||||
}
|
||||
ubatch.n_tokens += length;
|
||||
ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
|
||||
seq.offset += length;
|
||||
seq.length -= length;
|
||||
n_tokens -= length;
|
||||
GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
|
||||
}
|
||||
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
|
||||
|
||||
// simple split, unknown number of sequences of unequal lengths
|
||||
llama_ubatch split_simple(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
ubatch.equal_seqs = false;
|
||||
if (!seq.empty()) {
|
||||
llama_sbatch_seq & s = seq[0];
|
||||
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
|
||||
add_seq_to_ubatch(ubatch, s, length);
|
||||
}
|
||||
return ubatch;
|
||||
}
|
||||
llama_ubatch split_simple(size_t n_ubatch);
|
||||
|
||||
// make batches of equal-length sequences
|
||||
llama_ubatch split_equal(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
if (!seq.empty()) {
|
||||
size_t length = 0;
|
||||
size_t n_tokens_in_ubatch = 0;
|
||||
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
|
||||
// smallest first, because it's easier to split this way;
|
||||
// starting from the end to pop in constant time.
|
||||
for (size_t i = seq.size(); i-- > 0;) {
|
||||
llama_sbatch_seq & s = seq[i];
|
||||
GGML_ASSERT(s.length > 0);
|
||||
if (length == 0) {
|
||||
length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||
}
|
||||
add_seq_to_ubatch(ubatch, s, length);
|
||||
n_tokens_in_ubatch += length;
|
||||
// shared prompts can't be mixed with any of their sequences,
|
||||
// so it's safer to compute them in their own ubatch
|
||||
if (s.n_seq_id > 1) { break; }
|
||||
// stop when there isn't enough space for another sequence
|
||||
if (length + n_tokens_in_ubatch > n_ubatch) { break; }
|
||||
}
|
||||
}
|
||||
return ubatch;
|
||||
}
|
||||
llama_ubatch split_equal(size_t n_ubatch);
|
||||
|
||||
// sequence-wise split
|
||||
llama_ubatch split_seq(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
if (!seq.empty()) {
|
||||
llama_sbatch_seq & s = seq[seq.size() - 1];
|
||||
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
|
||||
add_seq_to_ubatch(ubatch, s, length);
|
||||
}
|
||||
return ubatch;
|
||||
}
|
||||
llama_ubatch split_seq(size_t n_ubatch);
|
||||
|
||||
void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) {
|
||||
GGML_ASSERT(batch.n_tokens >= 0);
|
||||
this->batch = &batch;
|
||||
this->n_embd = n_embd;
|
||||
this->logits_all = logits_all;
|
||||
|
||||
n_tokens = batch.n_tokens;
|
||||
ids.resize(n_tokens);
|
||||
out_ids.clear();
|
||||
// TODO: reserve out_ids and seq
|
||||
|
||||
for (size_t i = 0; i < n_tokens; ++i) {
|
||||
ids[i] = i;
|
||||
}
|
||||
if (simple_split) {
|
||||
seq.resize(1);
|
||||
llama_sbatch_seq & s = seq[0];
|
||||
s.n_seq_id = 0;
|
||||
s.seq_id = nullptr;
|
||||
s.offset = 0;
|
||||
s.length = n_tokens;
|
||||
return;
|
||||
}
|
||||
std::sort(ids.begin(), ids.end(),
|
||||
[&batch](size_t a, size_t b) {
|
||||
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
|
||||
int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
|
||||
// sort by seq_id, then by pos
|
||||
if (n_seq_a == n_seq_b) {
|
||||
if (batch.seq_id) {
|
||||
for (int32_t i = 0; i < n_seq_a; ++i) {
|
||||
llama_seq_id seq_id_a = batch.seq_id[a][i];
|
||||
llama_seq_id seq_id_b = batch.seq_id[b][i];
|
||||
// smaller seq_ids go first
|
||||
if (seq_id_a != seq_id_b) {
|
||||
return seq_id_a < seq_id_b;
|
||||
}
|
||||
}
|
||||
}
|
||||
// when all else is equal, sort by pos
|
||||
if (batch.pos) {
|
||||
return batch.pos[a] < batch.pos[b];
|
||||
}
|
||||
// no pos, sort by id
|
||||
return a < b;
|
||||
}
|
||||
// shared prompts go first
|
||||
return n_seq_a > n_seq_b;
|
||||
}
|
||||
);
|
||||
// init seq
|
||||
llama_sbatch_seq * last_seq = nullptr;
|
||||
|
||||
for (size_t i = 0; i < n_tokens; ++i) {
|
||||
const size_t bi = ids[i];
|
||||
const int32_t n_seqs = batch.n_seq_id[bi];
|
||||
llama_seq_id * seq_ids = batch.seq_id[bi];
|
||||
if (last_seq != nullptr) {
|
||||
bool same = n_seqs == last_seq->n_seq_id;
|
||||
for (int32_t j = 0; same && j < n_seqs; ++j) {
|
||||
if (seq_ids[j] != last_seq->seq_id[j]) {
|
||||
same = false;
|
||||
}
|
||||
}
|
||||
if (same) {
|
||||
last_seq->length += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
|
||||
seq.push_back(new_seq);
|
||||
last_seq = &seq.back();
|
||||
}
|
||||
// keep shared prompts first at the end, then sort by length descending.
|
||||
std::sort(seq.begin(), seq.end(),
|
||||
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
|
||||
if (a.n_seq_id == b.n_seq_id) {
|
||||
return a.length > b.length;
|
||||
}
|
||||
return a.n_seq_id < b.n_seq_id;
|
||||
}
|
||||
);
|
||||
}
|
||||
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
|
||||
};
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
struct llama_batch_allocr {
|
||||
struct llama_batch batch;
|
||||
|
||||
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id;
|
||||
std::vector<int8_t> logits;
|
||||
|
||||
// optionally fulfill the batch returned by llama_batch_get_one
|
||||
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
|
||||
};
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include "llama-context.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <stdexcept>
|
||||
|
||||
// deprecated
|
||||
|
@ -57,13 +57,24 @@ struct llama_kv_cache {
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
size_t total_size() {
|
||||
size_t total_size() const {
|
||||
size_t size = 0;
|
||||
for (auto & buf : bufs) {
|
||||
for (const auto & buf : bufs) {
|
||||
size += ggml_backend_buffer_get_size(buf.get());
|
||||
}
|
||||
|
||||
return size;
|
||||
}
|
||||
|
||||
// TODO: better data structures to reduce the cost of this operation
|
||||
llama_pos max_pos() const {
|
||||
llama_pos max_pos = -1;
|
||||
for (const auto & cell : cells) {
|
||||
max_pos = std::max(max_pos, cell.pos);
|
||||
}
|
||||
|
||||
return max_pos;
|
||||
}
|
||||
};
|
||||
|
||||
// a structure holds information about the slot found in llama_kv_cache_find_slot
|
||||
|
@ -1293,57 +1293,6 @@ struct llama_model_loader {
|
||||
}
|
||||
};
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
static const llama_seq_id batch_default_seq_id = 0;
|
||||
struct llama_batch_allocr {
|
||||
std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id;
|
||||
std::vector<int8_t> logits;
|
||||
struct llama_batch batch;
|
||||
// optionally fulfill the batch returned by llama_batch_get_one
|
||||
llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
|
||||
batch = in_batch;
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
if (!batch.pos) {
|
||||
// determine the last position in KV cache
|
||||
llama_pos last_pos = -1;
|
||||
for (const auto & cell : ctx.kv_self.cells) {
|
||||
if (cell.has_seq_id(batch_default_seq_id)) {
|
||||
last_pos = std::max(last_pos, cell.pos);
|
||||
}
|
||||
}
|
||||
last_pos++; // next position
|
||||
pos.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
pos[i] = i+last_pos;
|
||||
}
|
||||
batch.pos = pos.data();
|
||||
}
|
||||
if (!batch.n_seq_id) {
|
||||
n_seq_id.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
n_seq_id[i] = seq_id_0.size();
|
||||
}
|
||||
batch.n_seq_id = n_seq_id.data();
|
||||
}
|
||||
if (!batch.seq_id) {
|
||||
seq_id.resize(batch.n_tokens + 1);
|
||||
seq_id[batch.n_tokens] = NULL;
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
seq_id[i] = seq_id_0.data();
|
||||
}
|
||||
batch.seq_id = seq_id.data();
|
||||
}
|
||||
if (!batch.logits) {
|
||||
logits.resize(batch.n_tokens);
|
||||
logits[logits.size() - 1] = true;
|
||||
batch.logits = logits.data();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
|
||||
uint32_t tmp;
|
||||
@ -14005,7 +13954,8 @@ static int llama_decode_internal(
|
||||
}
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
llama_batch_allocr batch_allocr(lctx, inp_batch);
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
const uint32_t n_tokens_all = batch.n_tokens;
|
||||
|
||||
@ -14339,7 +14289,8 @@ static int llama_encode_internal(
|
||||
}
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
llama_batch_allocr batch_allocr(lctx, inp_batch);
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
const uint32_t n_tokens = batch.n_tokens;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user