Add support for encoder-only T5 models (#8900)

* gguf-py : add T5ENCODER model architecture

* common : call llama_decode() during warmup only if the model has decoder

* convert-hf : add T5EncoderModel

* llama : add llama_model_has_decoder() API function

* llama : split build_t5() into build_t5_encoder() and build_t5_decoder()

* llama : add support for LLM_ARCH_T5ENCODER

* llama-embedding : add support for LLAMA_POOLING_TYPE_NONE

* llama-embedding : add support for encoder-only models

---------

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
This commit is contained in:
fairydreaming 2024-08-10 11:43:26 +02:00 committed by GitHub
parent 911b437f22
commit 7c3f55c100
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 702 additions and 335 deletions

View File

@ -2156,7 +2156,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
tmp.clear(); tmp.clear();
tmp.push_back(decoder_start_token_id); tmp.push_back(decoder_start_token_id);
} }
if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
}
llama_kv_cache_clear(lctx); llama_kv_cache_clear(lctx);
llama_synchronize(lctx); llama_synchronize(lctx);
llama_reset_timings(lctx); llama_reset_timings(lctx);

View File

@ -3324,6 +3324,145 @@ class T5Model(Model):
return [(self.map_tensor_name(name), data_torch)] return [(self.map_tensor_name(name), data_torch)]
@Model.register("T5EncoderModel")
class T5EncoderModel(Model):
model_arch = gguf.MODEL_ARCH.T5ENCODER
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.shared_token_embeddings_found = False
def set_vocab(self):
# to avoid TypeError: Descriptors cannot be created directly
# exception when importing sentencepiece_model_pb2
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
from sentencepiece import SentencePieceProcessor
from sentencepiece import sentencepiece_model_pb2 as model
tokenizer_path = self.dir_model / 'tokenizer.model'
# many older models use spiece.model tokenizer model filename
if not tokenizer_path.is_file():
tokenizer_path = self.dir_model / 'spiece.model'
if not tokenizer_path.is_file():
raise FileNotFoundError(f"File not found: {tokenizer_path}")
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
if sentencepiece_model.trainer_spec.model_type == 2: # BPE
# assure the tokenizer model file name is correct
assert tokenizer_path.name == 'tokenizer.model'
return self._set_vocab_sentencepiece()
else:
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
tokenizer = SentencePieceProcessor()
tokenizer.LoadFromFile(str(tokenizer_path))
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
for token_id in range(tokenizer.vocab_size()):
piece = tokenizer.IdToPiece(token_id)
text = piece.encode("utf-8")
score = tokenizer.GetScore(token_id)
toktype = SentencePieceTokenTypes.NORMAL
if tokenizer.IsUnknown(token_id):
toktype = SentencePieceTokenTypes.UNKNOWN
elif tokenizer.IsControl(token_id):
toktype = SentencePieceTokenTypes.CONTROL
elif tokenizer.IsUnused(token_id):
toktype = SentencePieceTokenTypes.UNUSED
elif tokenizer.IsByte(token_id):
toktype = SentencePieceTokenTypes.BYTE
tokens[token_id] = text
scores[token_id] = score
toktypes[token_id] = toktype
added_tokens_file = self.dir_model / 'added_tokens.json'
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
added_tokens_json = json.load(f)
for key in added_tokens_json:
token_id = added_tokens_json[key]
if token_id >= vocab_size:
logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
continue
tokens[token_id] = key.encode("utf-8")
scores[token_id] = -1000.0
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
if vocab_size > len(tokens):
pad_count = vocab_size - len(tokens)
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
for i in range(1, pad_count + 1):
tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.UNUSED)
self.gguf_writer.add_tokenizer_model("t5")
self.gguf_writer.add_tokenizer_pre("default")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_add_space_prefix(add_prefix)
self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces)
if precompiled_charsmap:
self.gguf_writer.add_precompiled_charsmap(precompiled_charsmap)
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
self.gguf_writer.add_add_bos_token(False)
self.gguf_writer.add_add_eos_token(True)
def set_gguf_parameters(self):
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
n_ctx = 512
self.gguf_writer.add_context_length(n_ctx)
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
self.gguf_writer.add_block_count(self.hparams["num_layers"])
self.gguf_writer.add_head_count(self.hparams["num_heads"])
self.gguf_writer.add_key_length(self.hparams["d_kv"])
self.gguf_writer.add_value_length(self.hparams["d_kv"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_relative_attn_buckets_count(self.hparams["relative_attention_num_buckets"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_file_type(self.ftype)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
# T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight",
# "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored
# in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder
# and decoder and ignore the remaining ones.
if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]:
if not self.shared_token_embeddings_found:
name = "shared.weight"
self.shared_token_embeddings_found = True
else:
logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.")
return []
return [(self.map_tensor_name(name), data_torch)]
@Model.register("JAISLMHeadModel") @Model.register("JAISLMHeadModel")
class JaisModel(Model): class JaisModel(Model):
model_arch = gguf.MODEL_ARCH.JAIS model_arch = gguf.MODEL_ARCH.JAIS

View File

@ -31,25 +31,47 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
} }
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
const struct llama_model * model = llama_get_model(ctx);
// clear previous kv_cache values (irrelevant for embeddings) // clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx); llama_kv_cache_clear(ctx);
// run model // run model
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
// encoder-only model
if (llama_encode(ctx, batch) < 0) {
fprintf(stderr, "%s : failed to encode\n", __func__);
}
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
// decoder-only model
if (llama_decode(ctx, batch) < 0) { if (llama_decode(ctx, batch) < 0) {
fprintf(stderr, "%s : failed to decode\n", __func__); fprintf(stderr, "%s : failed to decode\n", __func__);
} }
}
for (int i = 0; i < batch.n_tokens; i++) { for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) { if (!batch.logits[i]) {
continue; continue;
} }
// try to get sequence embeddings - supported only when pooling_type is not NONE const float * embd = nullptr;
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); int embd_pos = 0;
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
float * out = output + batch.seq_id[i][0] * n_embd; if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
// try to get token embeddings
embd = llama_get_embeddings_ith(ctx, i);
embd_pos = i;
GGML_ASSERT(embd != NULL && "failed to get token embeddings");
} else {
// try to get sequence embeddings - supported only when pooling_type is not NONE
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
embd_pos = batch.seq_id[i][0];
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
}
float * out = output + embd_pos * n_embd;
llama_embd_normalize(embd, out, n_embd, embd_norm); llama_embd_normalize(embd, out, n_embd, embd_norm);
} }
} }
@ -93,8 +115,9 @@ int main(int argc, char ** argv) {
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__); if (llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
fprintf(stderr, "%s: error: computing embeddings in encoder-decoder models is not supported\n", __func__);
return 1; return 1;
} }
@ -153,13 +176,23 @@ int main(int argc, char ** argv) {
const int n_prompts = prompts.size(); const int n_prompts = prompts.size();
struct llama_batch batch = llama_batch_init(n_batch, 0, 1); struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
// count number of embeddings
int n_embd_count = 0;
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
for (int k = 0; k < n_prompts; k++) {
n_embd_count += inputs[k].size();
}
} else {
n_embd_count = n_prompts;
}
// allocate output // allocate output
const int n_embd = llama_n_embd(model); const int n_embd = llama_n_embd(model);
std::vector<float> embeddings(n_prompts * n_embd, 0); std::vector<float> embeddings(n_embd_count * n_embd, 0);
float * emb = embeddings.data(); float * emb = embeddings.data();
// break into batches // break into batches
int p = 0; // number of prompts processed already int e = 0; // number of embeddings already stored
int s = 0; // number of prompts in current batch int s = 0; // number of prompts in current batch
for (int k = 0; k < n_prompts; k++) { for (int k = 0; k < n_prompts; k++) {
// clamp to n_batch tokens // clamp to n_batch tokens
@ -169,11 +202,11 @@ int main(int argc, char ** argv) {
// encode if at capacity // encode if at capacity
if (batch.n_tokens + n_toks > n_batch) { if (batch.n_tokens + n_toks > n_batch) {
float * out = emb + p * n_embd; float * out = emb + e * n_embd;
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
llama_batch_clear(batch); e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
p += s;
s = 0; s = 0;
llama_batch_clear(batch);
} }
// add to batch // add to batch
@ -182,12 +215,34 @@ int main(int argc, char ** argv) {
} }
// final batch // final batch
float * out = emb + p * n_embd; float * out = emb + e * n_embd;
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
if (params.embd_out.empty()) { if (params.embd_out.empty()) {
// print the first part of the embeddings or for a single prompt, the full embedding
fprintf(stdout, "\n"); fprintf(stdout, "\n");
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
for (int j = 0; j < n_embd_count; j++) {
fprintf(stdout, "embedding %d: ", j);
for (int i = 0; i < std::min(3, n_embd); i++) {
if (params.embd_normalize == 0) {
fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
} else {
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
}
}
fprintf(stdout, " ... ");
for (int i = n_embd - 3; i < n_embd; i++) {
if (params.embd_normalize == 0) {
fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
} else {
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
}
}
fprintf(stdout, "\n");
}
} else {
// print the first part of the embeddings or for a single prompt, the full embedding
for (int j = 0; j < n_prompts; j++) { for (int j = 0; j < n_prompts; j++) {
fprintf(stdout, "embedding %d: ", j); fprintf(stdout, "embedding %d: ", j);
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) { for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
@ -218,6 +273,7 @@ int main(int argc, char ** argv) {
} }
} }
} }
}
if (params.embd_out == "json" || params.embd_out == "json+" || params.embd_out == "array") { if (params.embd_out == "json" || params.embd_out == "json+" || params.embd_out == "array") {
const bool notArray = params.embd_out != "array"; const bool notArray = params.embd_out != "array";
@ -233,23 +289,23 @@ int main(int argc, char ** argv) {
} }
fprintf(stdout, notArray ? "]\n }" : "]"); fprintf(stdout, notArray ? "]\n }" : "]");
j++; j++;
if (j < n_prompts) fprintf(stdout, notArray ? ",\n" : ","); else break; if (j < n_embd_count) fprintf(stdout, notArray ? ",\n" : ","); else break;
} }
fprintf(stdout, notArray ? "\n ]" : "]\n"); fprintf(stdout, notArray ? "\n ]" : "]\n");
if (params.embd_out == "json+" && n_prompts > 1) { if (params.embd_out == "json+" && n_prompts > 1) {
fprintf(stdout, ",\n \"cosineSimilarity\": [\n"); fprintf(stdout, ",\n \"cosineSimilarity\": [\n");
for (int i = 0;;) { // at least two iteration (n_prompts > 1) for (int i = 0;;) { // at least two iteration (n_embd_count > 1)
fprintf(stdout, " ["); fprintf(stdout, " [");
for (int j = 0;;) { // at least two iteration (n_prompts > 1) for (int j = 0;;) { // at least two iteration (n_embd_count > 1)
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
fprintf(stdout, "%6.2f", sim); fprintf(stdout, "%6.2f", sim);
j++; j++;
if (j < n_prompts) fprintf(stdout, ", "); else break; if (j < n_embd_count) fprintf(stdout, ", "); else break;
} }
fprintf(stdout, " ]"); fprintf(stdout, " ]");
i++; i++;
if (i < n_prompts) fprintf(stdout, ",\n"); else break; if (i < n_embd_count) fprintf(stdout, ",\n"); else break;
} }
fprintf(stdout, "\n ]"); fprintf(stdout, "\n ]");
} }

View File

@ -217,6 +217,7 @@ class MODEL_ARCH(IntEnum):
CHATGLM = auto() CHATGLM = auto()
BITNET = auto() BITNET = auto()
T5 = auto() T5 = auto()
T5ENCODER = auto()
JAIS = auto() JAIS = auto()
@ -344,6 +345,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5", MODEL_ARCH.T5: "t5",
MODEL_ARCH.T5ENCODER: "t5encoder",
MODEL_ARCH.JAIS: "jais", MODEL_ARCH.JAIS: "jais",
} }
@ -1036,6 +1038,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ENC_FFN_UP, MODEL_TENSOR.ENC_FFN_UP,
MODEL_TENSOR.ENC_OUTPUT_NORM, MODEL_TENSOR.ENC_OUTPUT_NORM,
], ],
MODEL_ARCH.T5ENCODER: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ENC_ATTN_NORM,
MODEL_TENSOR.ENC_ATTN_Q,
MODEL_TENSOR.ENC_ATTN_K,
MODEL_TENSOR.ENC_ATTN_V,
MODEL_TENSOR.ENC_ATTN_OUT,
MODEL_TENSOR.ENC_ATTN_REL_B,
MODEL_TENSOR.ENC_FFN_NORM,
MODEL_TENSOR.ENC_FFN_GATE,
MODEL_TENSOR.ENC_FFN_DOWN,
MODEL_TENSOR.ENC_FFN_UP,
MODEL_TENSOR.ENC_OUTPUT_NORM,
],
MODEL_ARCH.JAIS: [ MODEL_ARCH.JAIS: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,

View File

@ -504,6 +504,9 @@ extern "C" {
// Returns true if the model contains an encoder that requires llama_encode() call // Returns true if the model contains an encoder that requires llama_encode() call
LLAMA_API bool llama_model_has_encoder(const struct llama_model * model); LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
// Returns true if the model contains a decoder that requires llama_decode() call
LLAMA_API bool llama_model_has_decoder(const struct llama_model * model);
// For encoder-decoder models, this function returns id of the token that must be provided // For encoder-decoder models, this function returns id of the token that must be provided
// to the decoder to start generating output sequence. For other models, it returns -1. // to the decoder to start generating output sequence. For other models, it returns -1.
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model); LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);

View File

@ -208,6 +208,7 @@ enum llm_arch {
LLM_ARCH_CHATGLM, LLM_ARCH_CHATGLM,
LLM_ARCH_BITNET, LLM_ARCH_BITNET,
LLM_ARCH_T5, LLM_ARCH_T5,
LLM_ARCH_T5ENCODER,
LLM_ARCH_JAIS, LLM_ARCH_JAIS,
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
}; };
@ -252,6 +253,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5, "t5" },
{ LLM_ARCH_T5ENCODER, "t5encoder" },
{ LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_UNKNOWN, "(unknown)" }, { LLM_ARCH_UNKNOWN, "(unknown)" },
}; };
@ -1261,6 +1263,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" }, { LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
}, },
}, },
{
LLM_ARCH_T5ENCODER,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
{ LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" },
{ LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" },
{ LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" },
{ LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" },
{ LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" },
{ LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" },
{ LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" },
{ LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" },
{ LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" },
{ LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_JAIS, LLM_ARCH_JAIS,
{ {
@ -5187,6 +5207,12 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN; default: model.type = e_model::MODEL_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_T5ENCODER:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts);
model.type = e_model::MODEL_UNKNOWN;
} break;
case LLM_ARCH_JAIS: case LLM_ARCH_JAIS:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@ -7421,6 +7447,42 @@ static bool llm_load_tensors(
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff});
} }
} break; } break;
case LLM_ARCH_T5ENCODER:
{
const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
// output
{
model.output_norm_enc = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd});
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (model.output == NULL) {
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
}
}
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
layer.attn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd});
layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.wq_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa});
layer.wk_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
layer.wv_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
layer.wo_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
layer.ffn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd});
layer.ffn_gate_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
layer.ffn_down_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd});
layer.ffn_up_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff});
}
} break;
case LLM_ARCH_JAIS: case LLM_ARCH_JAIS:
{ {
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@ -13135,7 +13197,7 @@ struct llm_build_context {
return gf; return gf;
} }
struct ggml_cgraph * build_t5() { struct ggml_cgraph * build_t5_encoder() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
// mutable variable, needed during the last layer of the computation to skip unused tokens // mutable variable, needed during the last layer of the computation to skip unused tokens
@ -13150,7 +13212,7 @@ struct llm_build_context {
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
if (lctx.is_encoding) { GGML_ASSERT(lctx.is_encoding);
struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false); struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads) // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
@ -13261,7 +13323,28 @@ struct llm_build_context {
model.output_norm_enc, NULL, model.output_norm_enc, NULL,
LLM_NORM_RMS, cb, -1); LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1); cb(cur, "result_norm", -1);
} else {
ggml_build_forward_expand(gf, cur);
return gf;
}
struct ggml_cgraph * build_t5_decoder() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
// mutable variable, needed during the last layer of the computation to skip unused tokens
int32_t n_tokens = this->n_tokens;
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
GGML_ASSERT(!lctx.is_encoding);
GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first"); GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first");
struct ggml_tensor * embd_enc = llm_build_inp_embd_enc(); struct ggml_tensor * embd_enc = llm_build_inp_embd_enc();
@ -13445,7 +13528,6 @@ struct llm_build_context {
// lm_head // lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1); cb(cur, "result_output", -1);
}
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
@ -13898,7 +13980,15 @@ static struct ggml_cgraph * llama_build_graph(
} break; } break;
case LLM_ARCH_T5: case LLM_ARCH_T5:
{ {
result = llm.build_t5(); if (lctx.is_encoding) {
result = llm.build_t5_encoder();
} else {
result = llm.build_t5_decoder();
}
} break;
case LLM_ARCH_T5ENCODER:
{
result = llm.build_t5_encoder();
} break; } break;
case LLM_ARCH_JAIS: case LLM_ARCH_JAIS:
{ {
@ -14346,7 +14436,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
// TODO: use a per-batch flag for logits presence instead // TODO: use a per-batch flag for logits presence instead
const bool has_logits = !cparams.embeddings; const bool has_logits = !cparams.embeddings;
const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)); const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0; const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0; const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
@ -14829,9 +14919,24 @@ static int llama_encode_internal(
ggml_cgraph * gf = llama_build_graph(lctx, batch, false); ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
// the output embeddings after the final encoder normalization // the output embeddings after the final encoder normalization
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 1]; struct ggml_tensor * embd = nullptr;
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0); // there are two cases here
if (llama_model_has_decoder(&lctx.model)) {
// first case is an encoder-decoder T5 model where embeddings are passed to decoder
embd = gf->nodes[gf->n_nodes - 1];
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0 && "missing result_output tensor");
} else {
// second case is an encoder-only T5 model
if (cparams.embeddings) {
// only output embeddings if required
embd = gf->nodes[gf->n_nodes - 1];
if (strcmp(embd->name, "result_embd_pooled") != 0) {
embd = gf->nodes[gf->n_nodes - 2];
}
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
}
}
ggml_backend_sched_alloc_graph(lctx.sched, gf); ggml_backend_sched_alloc_graph(lctx.sched, gf);
@ -14844,9 +14949,7 @@ static int llama_encode_internal(
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd); ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
GGML_ASSERT(backend_embd != nullptr); GGML_ASSERT(backend_embd != nullptr);
// extract token embeddings if (llama_model_has_decoder(&lctx.model)) {
GGML_ASSERT(lctx.embd != nullptr);
lctx.embd_enc.resize(n_tokens*n_embd); lctx.embd_enc.resize(n_tokens*n_embd);
float * embd_out = lctx.embd_enc.data(); float * embd_out = lctx.embd_enc.data();
@ -14860,6 +14963,42 @@ static int llama_encode_internal(
lctx.seq_ids_enc[i].insert(seq_id); lctx.seq_ids_enc[i].insert(seq_id);
} }
} }
} else {
GGML_ASSERT(lctx.embd != nullptr);
switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
GGML_ASSERT(lctx.embd != nullptr);
float * embd_out = lctx.embd;
GGML_ASSERT(n_tokens*n_embd <= (int64_t) lctx.embd_size);
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
} break;
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
{
// extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq;
embd_seq_out.clear();
for (uint32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = batch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ABORT("unknown pooling type");
}
}
}
} }
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
@ -16567,6 +16706,8 @@ struct llama_context * llama_new_context_with_model(
ctx->sampling.rng = std::mt19937(params.seed); ctx->sampling.rng = std::mt19937(params.seed);
ctx->logits_all = params.logits_all; ctx->logits_all = params.logits_all;
// build worst-case graph for encoder if a model contains encoder
ctx->is_encoding = llama_model_has_encoder(model);
uint32_t kv_size = cparams.n_ctx; uint32_t kv_size = cparams.n_ctx;
ggml_type type_k = params.type_k; ggml_type type_k = params.type_k;
@ -16881,6 +17022,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA:
case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_T5: case LLM_ARCH_T5:
case LLM_ARCH_T5ENCODER:
case LLM_ARCH_JAIS: case LLM_ARCH_JAIS:
return LLAMA_ROPE_TYPE_NONE; return LLAMA_ROPE_TYPE_NONE;
@ -17029,10 +17171,18 @@ struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const ch
bool llama_model_has_encoder(const struct llama_model * model) { bool llama_model_has_encoder(const struct llama_model * model) {
switch (model->arch) { switch (model->arch) {
case LLM_ARCH_T5: return true; case LLM_ARCH_T5: return true;
case LLM_ARCH_T5ENCODER: return true;
default: return false; default: return false;
} }
} }
bool llama_model_has_decoder(const struct llama_model * model) {
switch (model->arch) {
case LLM_ARCH_T5ENCODER: return false;
default: return true;
}
}
llama_token llama_model_decoder_start_token(const struct llama_model * model) { llama_token llama_model_decoder_start_token(const struct llama_model * model) {
return model->hparams.dec_start_token_id; return model->hparams.dec_start_token_id;
} }