From 45681a57dd92c79db71f5111910c6775a967a9c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Wed, 26 Jun 2024 15:03:01 +0200 Subject: [PATCH] llama : add inference support and model types for T5 and FLAN-T5 model families llama : add new API functions to support encoder-decoder models: llama_encode(), llama_model_has_encoder(), llama_model_decoder_start_token() common, llama-cli : use new API functions to support encoder-decoder models convert-hf : handle shared token embeddings tensors in T5Model convert-hf : handle SentencePiece BPE tokenizer in T5Model (for Pile-T5 models) convert-hf : add MT5ForConditionalGeneration and UMT5ForConditionalGeneration to architectures supported by T5Model --- common/common.cpp | 19 +- convert-hf-to-gguf.py | 46 ++- examples/main/main.cpp | 21 +- llama.cpp | 806 ++++++++++++++++++++++++++++++++++++++++- llama.h | 15 + 5 files changed, 892 insertions(+), 15 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index c76d0e2c3..d74dfa7a0 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2061,7 +2061,24 @@ std::tuple llama_init_from_gpt_par if (params.warmup) { LOG("warming up the model with an empty run\n"); - std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; + std::vector tmp; + llama_token bos = llama_token_bos(model); + llama_token eos = llama_token_eos(model); + // some models (e.g. T5) don't have a BOS token + if (bos != -1) { + tmp.push_back(bos); + } + tmp.push_back(eos); + + if (llama_model_has_encoder(model)) { + llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); + llama_token decoder_start_token_id = llama_model_decoder_start_token(model); + if (decoder_start_token_id == -1) { + decoder_start_token_id = bos; + } + tmp.clear(); + tmp.push_back(decoder_start_token_id); + } llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_kv_cache_clear(lctx); llama_synchronize(lctx); diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index c26fad930..5eb4e589e 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2775,11 +2775,17 @@ class DeepseekV2Model(Model): raise ValueError(f"Unprocessed experts: {experts}") -@Model.register("T5ForConditionalGeneration") @Model.register("T5WithLMHeadModel") +@Model.register("T5ForConditionalGeneration") +@Model.register("MT5ForConditionalGeneration") +@Model.register("UMT5ForConditionalGeneration") class T5Model(Model): model_arch = gguf.MODEL_ARCH.T5 + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.shared_token_embeddings_found = False + def set_vocab(self): # to avoid TypeError: Descriptors cannot be created directly # exception when importing sentencepiece_model_pb2 @@ -2787,17 +2793,29 @@ class T5Model(Model): from sentencepiece import SentencePieceProcessor from sentencepiece import sentencepiece_model_pb2 as model - tokenizer_path = self.dir_model / 'spiece.model' + tokenizer_path = self.dir_model / 'tokenizer.model' + + # many older models use spiece.model tokenizer model filename + if not tokenizer_path.is_file(): + tokenizer_path = self.dir_model / 'spiece.model' if not tokenizer_path.is_file(): raise FileNotFoundError(f"File not found: {tokenizer_path}") sentencepiece_model = model.ModelProto() sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) + + # some models like Pile-T5 family use BPE tokenizer instead of Unigram + if sentencepiece_model.trainer_spec.model_type == 2: # BPE + # assure the tokenizer model file name is correct + assert tokenizer_path.name == 'tokenizer.model' + return self._set_vocab_sentencepiece() + else: + assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM + add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap - assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM tokenizer = SentencePieceProcessor() tokenizer.LoadFromFile(str(tokenizer_path)) @@ -2867,7 +2885,10 @@ class T5Model(Model): def set_gguf_parameters(self): self.gguf_writer.add_name("T5") - self.gguf_writer.add_context_length(self.hparams["n_positions"]) + if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None: + logger.warning("Couldn't find context length in config.json, assuming default value of 512") + n_ctx = 512 + self.gguf_writer.add_context_length(n_ctx) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) self.gguf_writer.add_block_count(self.hparams["num_layers"]) @@ -2883,12 +2904,17 @@ class T5Model(Model): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused - # Sometimes T5 and Flan-T5 based models contain "encoder.embed_tokens.weight" tensor or - # "decoder.embed_tokens.weight" tensors that are duplicates of "shared.weight" tensor - # To prevent errors caused by an unnecessary unmapped tensor, skip both of them and use only "shared.weight". - if name == "decoder.embed_tokens.weight" or name == "encoder.embed_tokens.weight": - logger.debug(f"Skipping tensor {name!r} in safetensors so that convert can end normally.") - return [] + # T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight", + # "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored + # in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder + # and decoder and ignore the remaining ones. + if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]: + if not self.shared_token_embeddings_found: + name = "shared.weight" + self.shared_token_embeddings_found = True + else: + logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.") + return [] return [(self.map_tensor_name(name), data_torch)] diff --git a/examples/main/main.cpp b/examples/main/main.cpp index cfaf6a6e8..db35c6eb4 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -255,7 +255,9 @@ int main(int argc, char ** argv) { } const bool add_bos = llama_should_add_bos_token(model); - GGML_ASSERT(llama_add_eos_token(model) != 1); + if (!llama_model_has_encoder(model)) { + GGML_ASSERT(llama_add_eos_token(model) != 1); + } LOG("add_bos: %d\n", add_bos); std::vector embd_inp; @@ -517,6 +519,23 @@ int main(int argc, char ** argv) { exit(1); } + if (llama_model_has_encoder(model)) { + int enc_input_size = embd_inp.size(); + llama_token * enc_input_buf = embd_inp.data(); + + if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) { + LOG_TEE("%s : failed to eval\n", __func__); + return 1; + } + + llama_token decoder_start_token_id = llama_model_decoder_start_token(model); + if (decoder_start_token_id == -1) { + decoder_start_token_id = llama_token_bos(model); + } + embd_inp.clear(); + embd_inp.push_back(decoder_start_token_id); + } + while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict if (!embd.empty()) { diff --git a/llama.cpp b/llama.cpp index 78a21008f..efc40c17f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1999,12 +1999,18 @@ enum e_model { MODEL_17M, MODEL_22M, MODEL_33M, + MODEL_60M, MODEL_70M, + MODEL_80M, MODEL_109M, MODEL_137M, MODEL_160M, + MODEL_220M, + MODEL_250M, MODEL_335M, MODEL_410M, + MODEL_770M, + MODEL_780M, MODEL_0_5B, MODEL_1B, MODEL_1_4B, @@ -2015,6 +2021,7 @@ enum e_model { MODEL_6_9B, MODEL_7B, MODEL_8B, + MODEL_11B, MODEL_12B, MODEL_13B, MODEL_14B, @@ -2062,6 +2069,8 @@ struct llama_hparams { uint32_t n_expert = 0; uint32_t n_expert_used = 0; uint32_t n_vocab_type = 0; // for BERT-style token types + uint32_t n_rel_attn_bkts = 0; + int32_t dec_start_token_id = -1; uint32_t n_layer_dense_lead = 0; uint32_t n_lora_q = 0; @@ -2112,6 +2121,8 @@ struct llama_hparams { if (this->n_expert != other.n_expert) return true; if (this->n_expert_used != other.n_expert_used) return true; + if (this->n_rel_attn_bkts != other.n_rel_attn_bkts) return true; + if (this->dec_start_token_id != other.dec_start_token_id) return true; if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true; if (this->n_lora_q != other.n_lora_q) return true; if (this->n_lora_kv != other.n_lora_kv) return true; @@ -2215,6 +2226,8 @@ struct llama_layer { struct ggml_tensor * attn_kv_a_norm; struct ggml_tensor * attn_sub_norm; struct ggml_tensor * ffn_sub_norm; + struct ggml_tensor * cross_attn_norm; + struct ggml_tensor * enc_attn_norm; // attention struct ggml_tensor * wq; @@ -2226,6 +2239,14 @@ struct llama_layer { struct ggml_tensor * wq_b; struct ggml_tensor * wkv_a_mqa; struct ggml_tensor * wkv_b; + struct ggml_tensor * cross_wq; + struct ggml_tensor * cross_wk; + struct ggml_tensor * cross_wv; + struct ggml_tensor * cross_wo; + struct ggml_tensor * enc_wq; + struct ggml_tensor * enc_wk; + struct ggml_tensor * enc_wv; + struct ggml_tensor * enc_wo; // attention bias struct ggml_tensor * bq; @@ -2234,17 +2255,26 @@ struct llama_layer { struct ggml_tensor * bo; struct ggml_tensor * bqkv; + // relative position bias + struct ggml_tensor * rel_attn_b; + struct ggml_tensor * enc_rel_attn_b; + struct ggml_tensor * cross_rel_attn_b; + // normalization struct ggml_tensor * ffn_norm; struct ggml_tensor * ffn_norm_b; struct ggml_tensor * layer_out_norm; struct ggml_tensor * layer_out_norm_b; struct ggml_tensor * ffn_norm_exps; + struct ggml_tensor * enc_ffn_norm; // ff struct ggml_tensor * ffn_gate; // w1 struct ggml_tensor * ffn_down; // w2 struct ggml_tensor * ffn_up; // w3 + struct ggml_tensor * enc_ffn_gate; + struct ggml_tensor * enc_ffn_down; + struct ggml_tensor * enc_ffn_up; // ff MoE struct ggml_tensor * ffn_gate_inp; @@ -2470,6 +2500,7 @@ struct llama_model { struct ggml_tensor * output_norm_b; struct ggml_tensor * output; struct ggml_tensor * output_b; + struct ggml_tensor * enc_output_norm; std::vector layers; @@ -2598,6 +2629,12 @@ struct llama_context { // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE std::map> embd_seq; + // whether we are computing encoder output or decoder output + bool is_encoding = false; + // output of the encoder part of the encoder-decoder models + std::vector encoder_output; + std::vector > encoder_output_seq_ids; + // memory buffers used to evaluate the model std::vector buf_compute_meta; ggml_backend_sched_t sched = nullptr; @@ -2617,6 +2654,9 @@ struct llama_context { struct ggml_tensor * inp_s_copy; // I32 [kv_size] struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] + struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] + struct ggml_tensor * inp_enc_output; // F32 [n_embd, n_enc_outputs] + struct ggml_tensor * inp_cross_KQ_mask; // F32 [n_enc_outputs, n_batch] // control vectors struct llama_control_vector cvec; @@ -4220,12 +4260,18 @@ static const char * llama_model_type_name(e_model type) { case MODEL_17M: return "17M"; case MODEL_22M: return "22M"; case MODEL_33M: return "33M"; + case MODEL_60M: return "60M"; case MODEL_70M: return "70M"; + case MODEL_80M: return "80M"; case MODEL_109M: return "109M"; case MODEL_137M: return "137M"; case MODEL_160M: return "160M"; + case MODEL_220M: return "220M"; + case MODEL_250M: return "250M"; case MODEL_335M: return "335M"; case MODEL_410M: return "410M"; + case MODEL_770M: return "770M"; + case MODEL_780M: return "780M"; case MODEL_0_5B: return "0.5B"; case MODEL_1B: return "1B"; case MODEL_1_4B: return "1.4B"; @@ -4236,6 +4282,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_6_9B: return "6.9B"; case MODEL_7B: return "7B"; case MODEL_8B: return "8B"; + case MODEL_11B: return "11B"; case MODEL_12B: return "12B"; case MODEL_13B: return "13B"; case MODEL_14B: return "14B"; @@ -4831,6 +4878,37 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_T5: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + uint32_t decoder_start_token_id; + if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, decoder_start_token_id, false)) { + hparams.dec_start_token_id = decoder_start_token_id; + } + + switch (hparams.n_layer) { + case 6: model.type = e_model::MODEL_60M; break; // t5-small + case 8: model.type = e_model::MODEL_80M; break; // flan-t5-small + case 12: + switch (hparams.n_ff) { + case 3072: model.type = e_model::MODEL_220M; break; // t5-base + case 2048: model.type = e_model::MODEL_250M; break; // flan-t5-base + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 24: + switch (hparams.n_ff) { + case 4096: model.type = e_model::MODEL_770M; break; // t5-large + case 2816: model.type = e_model::MODEL_780M; break; // flan-t5-large + case 16384: model.type = e_model::MODEL_3B; break; // t5-3b + case 5120: model.type = e_model::MODEL_3B; break; // flan-t5-xl + case 65536: model.type = e_model::MODEL_11B; break; // t5-11b + case 10240: model.type = e_model::MODEL_11B; break; // flan-t5-xxl + default: model.type = e_model::MODEL_UNKNOWN; + } break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -6857,6 +6935,63 @@ static bool llm_load_tensors( layer.ffn_up_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}); } } break; + case LLM_ARCH_T5: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.enc_output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.enc_attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}); + layer.enc_rel_attn_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {hparams.n_head, hparams.n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.enc_wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + layer.enc_wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.enc_wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.enc_wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); + + layer.enc_ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}); + layer.enc_ffn_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.enc_ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.enc_ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}); + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}); + layer.rel_attn_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {hparams.n_head, hparams.n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); + + layer.cross_attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM, "weight", i), {n_embd}); + // this tensor seems to be unused in HF transformers implementation + layer.cross_rel_attn_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {hparams.n_head, hparams.n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.cross_wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + layer.cross_wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.cross_wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.cross_wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -7582,6 +7717,7 @@ struct llm_build_context { const int32_t n_tokens; const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) const int32_t n_outputs; + const int32_t n_enc_outputs; const int32_t kv_head; // index of where we store new KV data in the cache const int32_t n_ctx_orig; @@ -7631,6 +7767,7 @@ struct llm_build_context { n_tokens (batch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), n_outputs (worst_case ? n_tokens : lctx.n_outputs), + n_enc_outputs (worst_case ? n_tokens : lctx.encoder_output.size() / hparams.n_embd), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), @@ -7661,6 +7798,9 @@ struct llm_build_context { lctx.inp_s_copy = nullptr; lctx.inp_s_mask = nullptr; lctx.inp_s_seq = nullptr; + lctx.inp_pos_bucket = nullptr; + lctx.inp_enc_output = nullptr; + lctx.inp_cross_KQ_mask = nullptr; } void free() { @@ -7903,6 +8043,55 @@ struct llm_build_context { return gf; } + struct ggml_tensor * llm_build_inp_rel_pos_bucket( + bool causal) { + + if (causal) { + lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens); + } else { + lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens); + } + + ggml_set_input(lctx.inp_pos_bucket); + cb(lctx.inp_pos_bucket, "pos_bucket", -1); + + return lctx.inp_pos_bucket; + } + + struct ggml_tensor * llm_build_rel_pos_bias( + struct ggml_tensor * pos_bucket, + struct ggml_tensor * rel_attn_b) { + + struct ggml_tensor * pos_bucket_1d = ggml_view_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1], 0); + cb(pos_bucket_1d, "pos_bucket_1d", -1); + + struct ggml_tensor * pos_bias = ggml_get_rows(ctx0, rel_attn_b, pos_bucket_1d); + cb(pos_bias, "pos_bias", -1); + + pos_bias = ggml_view_3d(ctx0, pos_bias, pos_bias->ne[0], lctx.inp_pos_bucket->ne[0], lctx.inp_pos_bucket->ne[1], ggml_element_size(pos_bias) * pos_bias->ne[0], ggml_element_size(pos_bias) * pos_bias->ne[0] * lctx.inp_pos_bucket->ne[0], 0); + cb(pos_bias, "pos_bias", -1); + + pos_bias = ggml_permute(ctx0, pos_bias, 2, 0, 1, 3); + cb(pos_bias, "pos_bias", -1); + + return pos_bias; + } + + struct ggml_tensor * llm_build_inp_enc_output() { + const int64_t n_embd = hparams.n_embd; + lctx.inp_enc_output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc_outputs); + ggml_set_input(lctx.inp_enc_output); + cb(lctx.inp_enc_output, "enc_output", -1); + return lctx.inp_enc_output; + } + + struct ggml_tensor * llm_build_inp_cross_KQ_mask() { + lctx.inp_cross_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc_outputs, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + ggml_set_input(lctx.inp_cross_KQ_mask); + cb(lctx.inp_cross_KQ_mask, "enc_mask", -1); + return lctx.inp_cross_KQ_mask; + } + struct ggml_cgraph * build_llama() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -12061,6 +12250,320 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_t5() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + if (lctx.is_encoding) { + struct ggml_tensor * enc_pos_buckets = llm_build_inp_rel_pos_bucket(false); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * enc_KQ_mask = build_inp_KQ_mask(false); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].enc_attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].enc_wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].enc_wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].enc_wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + cb(kq, "kq", il); + + struct ggml_tensor * rel_attn_b = model.layers[il].enc_rel_attn_b ? model.layers[il].enc_rel_attn_b : model.layers[0].enc_rel_attn_b; + struct ggml_tensor * pos_bias = llm_build_rel_pos_bias(enc_pos_buckets, rel_attn_b); + struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias); + cb(kq_b, "kq_b", il); + + kq = ggml_soft_max_ext(ctx0, kq_b, enc_KQ_mask, 1.0f, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens))); + cb(v, "v", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); + cb(cur, "kqv_merged_cont", il); + + ggml_build_forward_expand(gf, cur); + + cur = ggml_mul_mat(ctx0, model.layers[il].enc_wo, cur); + cb(cur, "kqv_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].enc_ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // T5 uses relu, flan-T5 uses gelu-gated + cur = llm_build_ffn(ctx0, cur, + model.layers[il].enc_ffn_up, NULL, + model.layers[il].enc_ffn_gate, NULL, + model.layers[il].enc_ffn_down, NULL, + NULL, + model.layers[il].enc_ffn_gate ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].enc_ffn_gate ? LLM_FFN_PAR : LLM_FFN_SEQ, + cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx0, cur, layer_dir); + } + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cb(cur, "result_embd", -1); + + cur = llm_build_norm(ctx0, cur, hparams, + model.enc_output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + } else { + struct ggml_tensor * enc_output = llm_build_inp_enc_output(); + struct ggml_tensor * dec_pos_buckets = llm_build_inp_rel_pos_bucket(true); + + struct ggml_tensor * dec_KQ_mask = build_inp_KQ_mask(); + struct ggml_tensor * cross_KQ_mask = llm_build_inp_cross_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); + + struct ggml_tensor * k = + ggml_view_3d(ctx0, kv_self.k_l[il], + n_embd_head_k, n_kv, n_head_kv, + ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), + 0); + cb(k, "k", il); + + struct ggml_tensor * v = + ggml_view_3d(ctx0, kv_self.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv_self.v_l[il])*n_ctx, + ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + + struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + cb(kq, "kq", il); + + struct ggml_tensor * rel_attn_b = model.layers[il].rel_attn_b ? model.layers[il].rel_attn_b : model.layers[0].rel_attn_b; + struct ggml_tensor * pos_bias = llm_build_rel_pos_bias(dec_pos_buckets, rel_attn_b); + struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias); + cb(kq_b, "kq_b", il); + + kq = ggml_soft_max_ext(ctx0, kq_b, dec_KQ_mask, 1.0f, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); + cb(cur, "kqv_merged_cont", il); + + ggml_build_forward_expand(gf, cur); + + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + cb(cur, "kqv_out", il); + } + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "cross_inp", il); + + struct ggml_tensor * inpCA = cur; + + // norm + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].cross_attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "cross_attn_norm", il); + + // cross-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].cross_wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_wk, enc_output); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_wv, enc_output); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_enc_outputs); + + struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + cb(kq, "kq", il); + + kq = ggml_soft_max_ext(ctx0, kq, cross_KQ_mask, 1.0f, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_enc_outputs))); + cb(v, "v", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_enc_outputs, n_embd_head, n_head_kv), kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); + cb(cur, "kqv_merged_cont", il); + + ggml_build_forward_expand(gf, cur); + + cur = ggml_mul_mat(ctx0, model.layers[il].cross_wo, cur); + cb(cur, "kqv_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // T5 uses relu, flan-T5 uses gelu-gated + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + NULL, + model.layers[il].enc_ffn_gate ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].enc_ffn_gate ? LLM_FFN_PAR : LLM_FFN_SEQ, + cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx0, cur, layer_dir); + } + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cb(cur, "result_embd", -1); + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + } + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -12288,6 +12791,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_bitnet(); } break; + case LLM_ARCH_T5: + { + result = llm.build_t5(); + } break; default: GGML_ASSERT(false); } @@ -12326,6 +12833,30 @@ static void llama_set_s_copy(llama_context & lctx) { } } +static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t num_buckets, bool bidirectional) { + // TODO move to hparams if a T5 variant appears that uses a different value + const int64_t max_distance = 128; + + if (bidirectional) { + num_buckets >>= 1; + } + + const int64_t max_exact = num_buckets >> 1; + + int32_t relative_position = x - y; + int32_t relative_bucket = 0; + if (bidirectional) { + relative_bucket += (relative_position > 0) * num_buckets; + relative_position = abs(relative_position); + } else { + relative_position = -std::min(relative_position, 0); + } + int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (num_buckets - max_exact) / log(1.0 * max_distance / max_exact)); + relative_position_if_large = std::min(relative_position_if_large, num_buckets - 1); + relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large); + return relative_bucket; +} + static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // // set input data @@ -12391,7 +12922,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { if (lctx.inp_KQ_mask) { // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. - if (cparams.causal_attn) { + if (cparams.causal_attn && !lctx.is_encoding) { const int64_t n_kv = kv_self.n; const int64_t n_tokens = batch.n_tokens; @@ -12431,7 +12962,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } else { // when using kv cache, the mask needs to match the kv cache size const int64_t n_tokens = batch.n_tokens; - const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens; + const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); @@ -12595,6 +13126,67 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } + + if (lctx.inp_pos_bucket) { + const int64_t n_tokens = batch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer)); + + int32_t * data = (int32_t *) lctx.inp_pos_bucket->data; + + if (!lctx.is_encoding) { + const int64_t n_kv = kv_self.n; + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_kv; ++i) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding); + } + } + } + } else { + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_tokens; ++i) { + data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(batch.pos[i], batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding); + } + } + } + } + } + + if (!lctx.is_encoding && lctx.inp_enc_output) { + ggml_backend_tensor_set(lctx.inp_enc_output, lctx.encoder_output.data(), 0, lctx.encoder_output.size() * ggml_element_size(lctx.inp_enc_output)); + } + + if (!lctx.is_encoding && lctx.inp_cross_KQ_mask) { + const int64_t n_encoder_output = lctx.encoder_output.size() / hparams.n_embd; + const int64_t n_tokens = batch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cross_KQ_mask->buffer)); + + float * data = (float *) lctx.inp_cross_KQ_mask->data; + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_encoder_output; ++i) { + float f = -INFINITY; + for (int s = 0; s < batch.n_seq_id[j]; ++s) { + const llama_seq_id seq_id = batch.seq_id[j][s]; + if (lctx.encoder_output_seq_ids[i].find(seq_id) != lctx.encoder_output_seq_ids[i].end()) { + f = 0.0f; + } + } + data[h*(n_encoder_output*n_tokens) + j*n_encoder_output + i] = f; + } + } + + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_encoder_output; ++j) { + data[h*(n_encoder_output*n_tokens) + i*n_encoder_output + j] = -INFINITY; + } + } + } + } } // Make sure enough space is available for outputs. @@ -12611,7 +13203,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) { // TODO: use a per-batch flag for logits presence instead const bool has_logits = !cparams.embeddings; - const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); + const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)); const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0; const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0; @@ -12703,6 +13295,7 @@ static int llama_decode_internal( llama_context & lctx, llama_batch batch_all) { // TODO: rename back to batch + lctx.is_encoding = false; const uint32_t n_tokens_all = batch_all.n_tokens; if (n_tokens_all == 0) { @@ -12998,6 +13591,191 @@ static int llama_decode_internal( return 0; } +// encode a batch of tokens by evaluating the encoder part of the transformer +// +// - lctx: llama context +// - batch: batch to evaluate +// +// return 0 on success +// return positive int on warning +// return negative int on error +// +static int llama_encode_internal( + llama_context & lctx, + llama_batch batch_all) { // TODO: rename back to batch + + lctx.is_encoding = true; + const uint32_t n_tokens_all = batch_all.n_tokens; + + if (n_tokens_all == 0) { + LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__); + return -1; + } + + const auto & model = lctx.model; + const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; + + GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT + + GGML_ASSERT(n_tokens_all <= cparams.n_batch); + + GGML_ASSERT(cparams.n_ubatch >= n_tokens_all && "encoder requires n_ubatch >= n_tokens"); + + if (lctx.t_compute_start_us == 0) { + lctx.t_compute_start_us = ggml_time_us(); + } + lctx.n_queued_tokens += n_tokens_all; + + const int64_t n_embd = hparams.n_embd; + + uint32_t n_outputs = 0; + uint32_t n_outputs_prev = 0; + + const auto n_ubatch = cparams.n_ubatch; + + std::vector pos; + std::vector n_seq_id; + std::vector seq_id_arr; + std::vector> seq_id; + + n_outputs = n_tokens_all; + + // reserve output buffer + if (llama_output_reserve(lctx, n_outputs) < n_outputs) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs); + return -2; + }; + + for (uint32_t i = 0; i < n_outputs; ++i) { + lctx.output_ids[i] = i; + } + + lctx.inp_enc_output = NULL; + + for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) { + const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); + llama_batch u_batch = { + /* .n_tokens = */ (int32_t) n_tokens, + /* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr, + /* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr, + /* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr, + /* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr, + /* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr, + /* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr, + /* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1, + /* .all_pos_1 = */ batch_all.all_pos_1, + /* .all_seq_id = */ batch_all.all_seq_id, + }; + + // count the outputs in this u_batch + { + int32_t n_outputs_new = 0; + + n_outputs_new = n_tokens; + + // needs to happen before the graph is built + lctx.n_outputs = n_outputs_new; + } + + int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; + GGML_ASSERT(n_threads > 0); + + // helpers for smoother batch API transition + // after deprecating the llama_eval calls, these will be removed + if (u_batch.pos == nullptr) { + pos.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + pos[i] = u_batch.all_pos_0 + i*u_batch.all_pos_1; + } + + u_batch.pos = pos.data(); + } + + if (u_batch.seq_id == nullptr) { + n_seq_id.resize(n_tokens); + seq_id.resize(n_tokens); + seq_id_arr.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + n_seq_id[i] = 1; + seq_id[i].resize(1); + seq_id[i][0] = u_batch.all_seq_id; + seq_id_arr[i] = seq_id[i].data(); + } + + u_batch.n_seq_id = n_seq_id.data(); + u_batch.seq_id = seq_id_arr.data(); + } + + ggml_backend_sched_reset(lctx.sched); + ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); + + ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false); + + // the output is always the last tensor in the graph + struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2]; + + // token or sequence embeddings + embd = gf->nodes[gf->n_nodes - 1]; + + GGML_ASSERT(strcmp(embd->name, "result_norm") == 0); + + // for big prompts, if BLAS is enabled, it is better to use only one thread + // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance + // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well + // we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering + // with the BLAS calls. need a better solution + // MoE Special Case: This logic applies when hparams.n_expert == 0, i.e. the model is NOT an MoE model. When an MoE is + // being processed then Accelerate/BLAS will not be involved, so capping would limit performance. + if (n_tokens >= 32 && hparams.n_expert == 0 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) { + n_threads = std::min(4, n_threads); + } + + ggml_backend_sched_alloc_graph(lctx.sched, gf); + + llama_set_inputs(lctx, u_batch); + + llama_graph_compute(lctx, gf, n_threads); + + // extract embeddings + if (embd) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd); + GGML_ASSERT(backend_embd != nullptr); + + // extract token embeddings + GGML_ASSERT(lctx.embd != nullptr); + const int32_t n_outputs_new = lctx.n_outputs; + lctx.encoder_output.resize((n_outputs_prev + n_outputs_new)*n_embd); + float * embd_out = lctx.encoder_output.data() + n_outputs_prev*n_embd; + + if (n_outputs_new) { + GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs); + GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t) lctx.embd_size); + ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float)); + } + + // extract output embeddings mask + lctx.encoder_output_seq_ids.resize(n_outputs_prev + n_outputs_new); + for (int i = 0; i < n_outputs_new; i++) { + for (int s = 0; s < u_batch.n_seq_id[i]; s++) { + llama_seq_id seq_id = u_batch.seq_id[i][s]; + lctx.encoder_output_seq_ids[i].insert(seq_id); + } + } + } + n_outputs_prev += lctx.n_outputs; + } + + // set to total number of outputs in the batch, for use in llama_get_logits_ith + lctx.n_outputs = n_outputs; + + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(lctx.sched); + + return 0; +} + // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { @@ -17646,6 +18424,17 @@ struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const ch return it->second; } +bool llama_model_has_encoder(const struct llama_model * model) { + switch (model->arch) { + case LLM_ARCH_T5: return true; + default: return false; + } +} + +llama_token llama_model_decoder_start_token(const struct llama_model * model) { + return model->hparams.dec_start_token_id; +} + uint32_t llama_model_quantize( const char * fname_inp, const char * fname_out, @@ -18992,6 +19781,17 @@ void llama_batch_free(struct llama_batch batch) { if (batch.logits) free(batch.logits); } +int32_t llama_encode( + struct llama_context * ctx, + struct llama_batch batch) { + const int ret = llama_encode_internal(*ctx, batch); + if (ret < 0) { + LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); + } + + return ret; +} + int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) { diff --git a/llama.h b/llama.h index 88eecb0ed..8e606b366 100644 --- a/llama.h +++ b/llama.h @@ -483,6 +483,13 @@ extern "C" { // Get a llama model tensor LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name); + // Returns true if the model contains an encoder that requires llama_encode() call + LLAMA_API bool llama_model_has_encoder(const struct llama_model * model); + + // For encoder-decoder models, this function returns id of the token that must be provided + // to the decoder to start generating output sequence. For other models, it returns -1. + LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, @@ -768,6 +775,14 @@ extern "C" { // Frees a batch of tokens allocated with llama_batch_init() LLAMA_API void llama_batch_free(struct llama_batch batch); + // Processes a batch of tokens with the ecoder part of the encoder-decoder model. + // Stores the encoder output internally for later use by the decoder cross-attention layers. + // 0 - success + // < 0 - error + LLAMA_API int32_t llama_encode( + struct llama_context * ctx, + struct llama_batch batch); + // Positive return values does not mean a fatal error, but rather a warning. // 0 - success // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)