mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 13:58:46 +01:00
llama : fix empty batch causing llama_batch_allocr to crash (#9966)
* llama : fix empty batch cause llama_batch_allocr to crash * move batch_allocr inside decode/encode_internal * fix build * add GGML_ASSERT * Apply suggestions from code review Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
19d900a756
commit
c8c07d658a
128
src/llama.cpp
128
src/llama.cpp
@ -5177,6 +5177,57 @@ struct llama_model_loader {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// temporary allocate memory for the input batch if needed
|
||||||
|
static const llama_seq_id batch_default_seq_id = 0;
|
||||||
|
struct llama_batch_allocr {
|
||||||
|
std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
|
||||||
|
std::vector<llama_pos> pos;
|
||||||
|
std::vector<int32_t> n_seq_id;
|
||||||
|
std::vector<llama_seq_id *> seq_id;
|
||||||
|
std::vector<int8_t> logits;
|
||||||
|
struct llama_batch batch;
|
||||||
|
// optionally fulfill the batch returned by llama_batch_get_one
|
||||||
|
llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
|
||||||
|
batch = in_batch;
|
||||||
|
GGML_ASSERT(batch.n_tokens > 0);
|
||||||
|
if (!batch.pos) {
|
||||||
|
// determine the last position in KV cache
|
||||||
|
llama_pos last_pos = -1;
|
||||||
|
for (const auto & cell : ctx.kv_self.cells) {
|
||||||
|
if (cell.has_seq_id(batch_default_seq_id)) {
|
||||||
|
last_pos = std::max(last_pos, cell.pos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
last_pos++; // next position
|
||||||
|
pos.resize(batch.n_tokens);
|
||||||
|
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||||
|
pos[i] = i+last_pos;
|
||||||
|
}
|
||||||
|
batch.pos = pos.data();
|
||||||
|
}
|
||||||
|
if (!batch.n_seq_id) {
|
||||||
|
n_seq_id.resize(batch.n_tokens);
|
||||||
|
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||||
|
n_seq_id[i] = seq_id_0.size();
|
||||||
|
}
|
||||||
|
batch.n_seq_id = n_seq_id.data();
|
||||||
|
}
|
||||||
|
if (!batch.seq_id) {
|
||||||
|
seq_id.resize(batch.n_tokens + 1);
|
||||||
|
seq_id[batch.n_tokens] = NULL;
|
||||||
|
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||||
|
seq_id[i] = seq_id_0.data();
|
||||||
|
}
|
||||||
|
batch.seq_id = seq_id.data();
|
||||||
|
}
|
||||||
|
if (!batch.logits) {
|
||||||
|
logits.resize(batch.n_tokens);
|
||||||
|
logits[logits.size() - 1] = true;
|
||||||
|
batch.logits = logits.data();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
|
bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
|
||||||
uint32_t tmp;
|
uint32_t tmp;
|
||||||
@ -17095,16 +17146,20 @@ static void llama_graph_compute(
|
|||||||
//
|
//
|
||||||
static int llama_decode_internal(
|
static int llama_decode_internal(
|
||||||
llama_context & lctx,
|
llama_context & lctx,
|
||||||
llama_batch batch) {
|
llama_batch inp_batch) {
|
||||||
|
|
||||||
lctx.is_encoding = false;
|
lctx.is_encoding = false;
|
||||||
const uint32_t n_tokens_all = batch.n_tokens;
|
|
||||||
|
|
||||||
if (n_tokens_all == 0) {
|
if (inp_batch.n_tokens == 0) {
|
||||||
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// temporary allocate memory for the input batch if needed
|
||||||
|
llama_batch_allocr batch_allocr(lctx, inp_batch);
|
||||||
|
const llama_batch & batch = batch_allocr.batch;
|
||||||
|
const uint32_t n_tokens_all = batch.n_tokens;
|
||||||
|
|
||||||
const auto & model = lctx.model;
|
const auto & model = lctx.model;
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
const auto & cparams = lctx.cparams;
|
const auto & cparams = lctx.cparams;
|
||||||
@ -17409,17 +17464,20 @@ static int llama_decode_internal(
|
|||||||
//
|
//
|
||||||
static int llama_encode_internal(
|
static int llama_encode_internal(
|
||||||
llama_context & lctx,
|
llama_context & lctx,
|
||||||
llama_batch batch) {
|
llama_batch inp_batch) {
|
||||||
|
|
||||||
lctx.is_encoding = true;
|
lctx.is_encoding = true;
|
||||||
|
|
||||||
const uint32_t n_tokens = batch.n_tokens;
|
if (inp_batch.n_tokens == 0) {
|
||||||
|
|
||||||
if (n_tokens == 0) {
|
|
||||||
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// temporary allocate memory for the input batch if needed
|
||||||
|
llama_batch_allocr batch_allocr(lctx, inp_batch);
|
||||||
|
const llama_batch & batch = batch_allocr.batch;
|
||||||
|
const uint32_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
const auto & model = lctx.model;
|
const auto & model = lctx.model;
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
const auto & cparams = lctx.cparams;
|
const auto & cparams = lctx.cparams;
|
||||||
@ -21090,61 +21148,10 @@ void llama_batch_free(struct llama_batch batch) {
|
|||||||
if (batch.logits) free(batch.logits);
|
if (batch.logits) free(batch.logits);
|
||||||
}
|
}
|
||||||
|
|
||||||
// temporary allocate memory for the input batch if needed
|
|
||||||
static const llama_seq_id batch_default_seq_id = 0;
|
|
||||||
struct llama_batch_allocr {
|
|
||||||
std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
|
|
||||||
std::vector<llama_pos> pos;
|
|
||||||
std::vector<int32_t> n_seq_id;
|
|
||||||
std::vector<llama_seq_id *> seq_id;
|
|
||||||
std::vector<int8_t> logits;
|
|
||||||
struct llama_batch batch;
|
|
||||||
// optionally fulfill the batch returned by llama_batch_get_one
|
|
||||||
llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) {
|
|
||||||
batch = in_batch;
|
|
||||||
if (!batch.pos) {
|
|
||||||
// determine the last position in KV cache
|
|
||||||
llama_pos last_pos = -1;
|
|
||||||
for (const auto & cell : ctx->kv_self.cells) {
|
|
||||||
if (cell.has_seq_id(batch_default_seq_id)) {
|
|
||||||
last_pos = std::max(last_pos, cell.pos);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
last_pos++; // next position
|
|
||||||
pos.resize(batch.n_tokens);
|
|
||||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
|
||||||
pos[i] = i+last_pos;
|
|
||||||
}
|
|
||||||
batch.pos = pos.data();
|
|
||||||
}
|
|
||||||
if (!batch.n_seq_id) {
|
|
||||||
n_seq_id.resize(batch.n_tokens);
|
|
||||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
|
||||||
n_seq_id[i] = seq_id_0.size();
|
|
||||||
}
|
|
||||||
batch.n_seq_id = n_seq_id.data();
|
|
||||||
}
|
|
||||||
if (!batch.seq_id) {
|
|
||||||
seq_id.resize(batch.n_tokens + 1);
|
|
||||||
seq_id[batch.n_tokens] = NULL;
|
|
||||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
|
||||||
seq_id[i] = seq_id_0.data();
|
|
||||||
}
|
|
||||||
batch.seq_id = seq_id.data();
|
|
||||||
}
|
|
||||||
if (!batch.logits) {
|
|
||||||
logits.resize(batch.n_tokens);
|
|
||||||
logits[logits.size() - 1] = true;
|
|
||||||
batch.logits = logits.data();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
int32_t llama_encode(
|
int32_t llama_encode(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch) {
|
struct llama_batch batch) {
|
||||||
llama_batch_allocr batch_allocr(ctx, batch);
|
const int ret = llama_encode_internal(*ctx, batch);
|
||||||
const int ret = llama_encode_internal(*ctx, batch_allocr.batch);
|
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
|
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
|
||||||
}
|
}
|
||||||
@ -21155,8 +21162,7 @@ int32_t llama_encode(
|
|||||||
int32_t llama_decode(
|
int32_t llama_decode(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch) {
|
struct llama_batch batch) {
|
||||||
llama_batch_allocr batch_allocr(ctx, batch);
|
const int ret = llama_decode_internal(*ctx, batch);
|
||||||
const int ret = llama_decode_internal(*ctx, batch_allocr.batch);
|
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user