diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 53379253a..90b6c56ed 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -119,10 +119,10 @@ bool llama_kv_cache_init( struct llama_kv_cache_slot_info llama_kv_cache_find_slot( struct llama_kv_cache & cache, - const struct llama_ubatch & batch) { - const uint32_t n_tokens = batch.n_tokens; - const uint32_t n_seqs = batch.n_seqs; - const uint32_t n_seq_tokens = batch.n_seq_tokens; + const struct llama_ubatch & ubatch) { + const uint32_t n_tokens = ubatch.n_tokens; + const uint32_t n_seqs = ubatch.n_seqs; + const uint32_t n_seq_tokens = ubatch.n_seq_tokens; if (cache.recurrent) { // For recurrent state architectures (like Mamba or RWKV), @@ -130,16 +130,16 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // A slot should be always be contiguous. // can only process batches with an equal number of new tokens in each sequence - GGML_ASSERT(batch.equal_seqs); + GGML_ASSERT(ubatch.equal_seqs); int32_t min = cache.size - 1; int32_t max = 0; // everything should fit if all seq_ids are smaller than the max for (uint32_t s = 0; s < n_seqs; ++s) { - const uint32_t n_seq_id = batch.n_seq_id[s]; + const uint32_t n_seq_id = ubatch.n_seq_id[s]; for (uint32_t j = 0; j < n_seq_id; ++j) { - const llama_seq_id seq_id = batch.seq_id[s][j]; + const llama_seq_id seq_id = ubatch.seq_id[s][j]; if (seq_id < 0 || (uint32_t) seq_id >= cache.size) { // too big seq_id @@ -198,7 +198,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // find usable cell range for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = batch.seq_id[s][0]; + const llama_seq_id seq_id = ubatch.seq_id[s][0]; llama_kv_cell & seq_meta = cache.cells[seq_id]; bool has_cell = false; if (seq_meta.tail >= 0) { @@ -237,7 +237,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // gather and re-order for (uint32_t s = 0; s < n_seqs; ++s) { int32_t dst_id = s + min; - int32_t src_id = cache.cells[batch.seq_id[s][0]].tail; + int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail; if (dst_id != src_id) { llama_kv_cell & dst_cell = cache.cells[dst_id]; llama_kv_cell & src_cell = cache.cells[src_id]; @@ -258,7 +258,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // update the pos of the used seqs for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; int32_t cell_id = s + min; llama_kv_cell & cell = cache.cells[cell_id]; @@ -266,12 +266,12 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // What should happen when the pos backtracks or skips a value? // Clearing the state mid-batch would require special-casing which isn't done. LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", - __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens); + __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); } cell.pos = last_pos; cell.seq_id.clear(); - for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) { - const llama_seq_id seq_id = batch.seq_id[s][j]; + for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; cell.seq_id.insert(seq_id); cache.cells[seq_id].tail = cell_id; } @@ -325,10 +325,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( for (uint32_t s = 0; s < n_seqs; s++) { for (uint32_t i = 0; i < n_seq_tokens; ++i) { uint32_t k = s*n_seq_tokens + i; - cache.cells[cache.head + k].pos = batch.pos[k]; + cache.cells[cache.head + k].pos = ubatch.pos[k]; - for (int32_t j = 0; j < batch.n_seq_id[s]; j++) { - cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]); + for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) { + cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]); } } } diff --git a/src/llama.cpp b/src/llama.cpp index 7337c34ce..60728e5bb 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2540,21 +2540,21 @@ static struct ggml_tensor * llm_build_inp_embd( struct ggml_context * ctx, struct llama_context & lctx, const llama_hparams & hparams, - const llama_ubatch & batch, + const llama_ubatch & ubatch, struct ggml_tensor * tok_embd, const llm_build_cb & cb) { const int64_t n_embd = hparams.n_embd; struct ggml_tensor * inpL; - if (batch.token) { - lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens); + if (ubatch.token) { + lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch.n_tokens); cb(lctx.inp_tokens, "inp_tokens", -1); ggml_set_input(lctx.inp_tokens); inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens); } else { - lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens); + lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens); inpL = lctx.inp_embd; ggml_set_input(lctx.inp_embd); } @@ -3149,7 +3149,7 @@ static struct ggml_tensor * llm_build_copy_mask_state( static struct ggml_tensor * llm_build_mamba( struct ggml_context * ctx, struct llama_context & lctx, - const llama_ubatch & batch, + const llama_ubatch & ubatch, struct ggml_cgraph * graph, struct ggml_tensor * cur, struct ggml_tensor * state_copy, @@ -3165,17 +3165,17 @@ static struct ggml_tensor * llm_build_mamba( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t dt_rank = hparams.ssm_dt_rank; - const int64_t n_seqs = batch.n_seqs; + const int64_t n_seqs = ubatch.n_seqs; // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; // Use the same RMS norm as the final layer norm const float norm_rms_eps = hparams.f_norm_rms_eps; - const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; GGML_ASSERT(n_seqs != 0); - GGML_ASSERT(batch.equal_seqs); - GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs); + GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); struct ggml_tensor * conv_states_all = kv.k_l[il]; struct ggml_tensor * ssm_states_all = kv.v_l[il];