llama : rename missed batch params/vars to ubatch

This commit renames the `batch` parameter to `ubatch` in the
`llama_kv_cache_find_slot`, `llm_build_inp_embd`, and
`llm_build_mamba` functions.

The motivation for this is that this should have been done as part of
Commit 19d900a7565b8f6b0a708836a57d26966cb9efe2 ("llama : rename batch
to ubatch (#9950)") but for some reason I missed these functions in
that commit and only noticed them now (sorry).
This commit is contained in:
Daniel Bevenius 2024-10-26 15:07:03 +02:00
parent cc2983d375
commit c76851eeb0

View File

@ -3591,10 +3591,10 @@ static bool llama_kv_cache_init(
// to the first cell of the slot. // to the first cell of the slot.
static bool llama_kv_cache_find_slot( static bool llama_kv_cache_find_slot(
struct llama_kv_cache & cache, struct llama_kv_cache & cache,
const struct llama_ubatch & batch) { const struct llama_ubatch & ubatch) {
const uint32_t n_tokens = batch.n_tokens; const uint32_t n_tokens = ubatch.n_tokens;
const uint32_t n_seqs = batch.n_seqs; const uint32_t n_seqs = ubatch.n_seqs;
const uint32_t n_seq_tokens = batch.n_seq_tokens; const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
if (cache.recurrent) { if (cache.recurrent) {
// For recurrent state architectures (like Mamba or RWKV), // For recurrent state architectures (like Mamba or RWKV),
@ -3602,16 +3602,16 @@ static bool llama_kv_cache_find_slot(
// A slot should be always be contiguous. // A slot should be always be contiguous.
// can only process batches with an equal number of new tokens in each sequence // can only process batches with an equal number of new tokens in each sequence
GGML_ASSERT(batch.equal_seqs); GGML_ASSERT(ubatch.equal_seqs);
int32_t min = cache.size - 1; int32_t min = cache.size - 1;
int32_t max = 0; int32_t max = 0;
// everything should fit if all seq_ids are smaller than the max // everything should fit if all seq_ids are smaller than the max
for (uint32_t s = 0; s < n_seqs; ++s) { for (uint32_t s = 0; s < n_seqs; ++s) {
const uint32_t n_seq_id = batch.n_seq_id[s]; const uint32_t n_seq_id = ubatch.n_seq_id[s];
for (uint32_t j = 0; j < n_seq_id; ++j) { for (uint32_t j = 0; j < n_seq_id; ++j) {
const llama_seq_id seq_id = batch.seq_id[s][j]; const llama_seq_id seq_id = ubatch.seq_id[s][j];
if (seq_id < 0 || (uint32_t) seq_id >= cache.size) { if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
// too big seq_id // too big seq_id
@ -3670,7 +3670,7 @@ static bool llama_kv_cache_find_slot(
// find usable cell range // find usable cell range
for (uint32_t s = 0; s < n_seqs; ++s) { for (uint32_t s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = batch.seq_id[s][0]; const llama_seq_id seq_id = ubatch.seq_id[s][0];
llama_kv_cell & seq_meta = cache.cells[seq_id]; llama_kv_cell & seq_meta = cache.cells[seq_id];
bool has_cell = false; bool has_cell = false;
if (seq_meta.tail >= 0) { if (seq_meta.tail >= 0) {
@ -3709,7 +3709,7 @@ static bool llama_kv_cache_find_slot(
// gather and re-order // gather and re-order
for (uint32_t s = 0; s < n_seqs; ++s) { for (uint32_t s = 0; s < n_seqs; ++s) {
int32_t dst_id = s + min; int32_t dst_id = s + min;
int32_t src_id = cache.cells[batch.seq_id[s][0]].tail; int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail;
if (dst_id != src_id) { if (dst_id != src_id) {
llama_kv_cell & dst_cell = cache.cells[dst_id]; llama_kv_cell & dst_cell = cache.cells[dst_id];
llama_kv_cell & src_cell = cache.cells[src_id]; llama_kv_cell & src_cell = cache.cells[src_id];
@ -3730,7 +3730,7 @@ static bool llama_kv_cache_find_slot(
// update the pos of the used seqs // update the pos of the used seqs
for (uint32_t s = 0; s < n_seqs; ++s) { for (uint32_t s = 0; s < n_seqs; ++s) {
const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1]; const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
int32_t cell_id = s + min; int32_t cell_id = s + min;
llama_kv_cell & cell = cache.cells[cell_id]; llama_kv_cell & cell = cache.cells[cell_id];
@ -3738,12 +3738,12 @@ static bool llama_kv_cache_find_slot(
// What should happen when the pos backtracks or skips a value? // What should happen when the pos backtracks or skips a value?
// Clearing the state mid-batch would require special-casing which isn't done. // Clearing the state mid-batch would require special-casing which isn't done.
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
__func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens); __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
} }
cell.pos = last_pos; cell.pos = last_pos;
cell.seq_id.clear(); cell.seq_id.clear();
for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) { for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
const llama_seq_id seq_id = batch.seq_id[s][j]; const llama_seq_id seq_id = ubatch.seq_id[s][j];
cell.seq_id.insert(seq_id); cell.seq_id.insert(seq_id);
cache.cells[seq_id].tail = cell_id; cache.cells[seq_id].tail = cell_id;
} }
@ -3795,10 +3795,10 @@ static bool llama_kv_cache_find_slot(
for (uint32_t s = 0; s < n_seqs; s++) { for (uint32_t s = 0; s < n_seqs; s++) {
for (uint32_t i = 0; i < n_seq_tokens; ++i) { for (uint32_t i = 0; i < n_seq_tokens; ++i) {
uint32_t k = s*n_seq_tokens + i; uint32_t k = s*n_seq_tokens + i;
cache.cells[cache.head + k].pos = batch.pos[k]; cache.cells[cache.head + k].pos = ubatch.pos[k];
for (int32_t j = 0; j < batch.n_seq_id[s]; j++) { for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]); cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]);
} }
} }
} }
@ -9178,21 +9178,21 @@ static struct ggml_tensor * llm_build_inp_embd(
struct ggml_context * ctx, struct ggml_context * ctx,
struct llama_context & lctx, struct llama_context & lctx,
const llama_hparams & hparams, const llama_hparams & hparams,
const llama_ubatch & batch, const llama_ubatch & ubatch,
struct ggml_tensor * tok_embd, struct ggml_tensor * tok_embd,
const llm_build_cb & cb) { const llm_build_cb & cb) {
const int64_t n_embd = hparams.n_embd; const int64_t n_embd = hparams.n_embd;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
if (batch.token) { if (ubatch.token) {
lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens); lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch.n_tokens);
cb(lctx.inp_tokens, "inp_tokens", -1); cb(lctx.inp_tokens, "inp_tokens", -1);
ggml_set_input(lctx.inp_tokens); ggml_set_input(lctx.inp_tokens);
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens); inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
} else { } else {
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens); lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
inpL = lctx.inp_embd; inpL = lctx.inp_embd;
ggml_set_input(lctx.inp_embd); ggml_set_input(lctx.inp_embd);
} }
@ -9766,7 +9766,7 @@ static struct ggml_tensor * llm_build_copy_mask_state(
static struct ggml_tensor * llm_build_mamba( static struct ggml_tensor * llm_build_mamba(
struct ggml_context * ctx, struct ggml_context * ctx,
struct llama_context & lctx, struct llama_context & lctx,
const llama_ubatch & batch, const llama_ubatch & ubatch,
struct ggml_cgraph * graph, struct ggml_cgraph * graph,
struct ggml_tensor * cur, struct ggml_tensor * cur,
struct ggml_tensor * state_copy, struct ggml_tensor * state_copy,
@ -9782,17 +9782,17 @@ static struct ggml_tensor * llm_build_mamba(
const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state; const int64_t d_state = hparams.ssm_d_state;
const int64_t dt_rank = hparams.ssm_dt_rank; const int64_t dt_rank = hparams.ssm_dt_rank;
const int64_t n_seqs = batch.n_seqs; const int64_t n_seqs = ubatch.n_seqs;
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
// Use the same RMS norm as the final layer norm // Use the same RMS norm as the final layer norm
const float norm_rms_eps = hparams.f_norm_rms_eps; const float norm_rms_eps = hparams.f_norm_rms_eps;
const int64_t n_seq_tokens = batch.n_seq_tokens; const int64_t n_seq_tokens = ubatch.n_seq_tokens;
GGML_ASSERT(n_seqs != 0); GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(batch.equal_seqs); GGML_ASSERT(ubatch.equal_seqs);
GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
struct ggml_tensor * conv_states_all = kv.k_l[il]; struct ggml_tensor * conv_states_all = kv.k_l[il];
struct ggml_tensor * ssm_states_all = kv.v_l[il]; struct ggml_tensor * ssm_states_all = kv.v_l[il];
@ -20440,10 +20440,10 @@ struct llama_data_read {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); llama_ubatch ubatch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
batch.n_tokens = cell_count; ubatch.n_tokens = cell_count;
batch.n_seq_tokens = cell_count; ubatch.n_seq_tokens = cell_count;
batch.n_seqs = 1; ubatch.n_seqs = 1;
for (uint32_t i = 0; i < cell_count; ++i) { for (uint32_t i = 0; i < cell_count; ++i) {
llama_pos pos; llama_pos pos;
@ -20457,11 +20457,11 @@ struct llama_data_read {
return false; return false;
} }
batch.pos[i] = pos; ubatch.pos[i] = pos;
} }
batch.n_seq_id[0] = 1; ubatch.n_seq_id[0] = 1;
batch.seq_id[0] = &dest_seq_id; ubatch.seq_id[0] = &dest_seq_id;
if (!llama_kv_cache_find_slot(kv_self, batch)) { if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false; return false;
} }
@ -20469,8 +20469,8 @@ struct llama_data_read {
// DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
// Assume that this is one contiguous block of cells // Assume that this is one contiguous block of cells
GGML_ASSERT(kv_self.head + cell_count <= kv_self.size); GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]); GGML_ASSERT(kv_self.cells[kv_self.head].pos == ubatch.pos[0]);
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == ubatch.pos[cell_count - 1]);
GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
} else { } else {