llama : add inference support and model types for T5 and FLAN-T5 model families

llama : add new API functions to support encoder-decoder models: llama_encode(), llama_model_has_encoder(), llama_model_decoder_start_token()

common, llama-cli : use new API functions to support encoder-decoder models

convert-hf : handle shared token embeddings tensors in T5Model

convert-hf : handle SentencePiece BPE tokenizer in T5Model (for Pile-T5 models)

convert-hf : add MT5ForConditionalGeneration and UMT5ForConditionalGeneration to architectures supported by T5Model
This commit is contained in:
Stanisław Szymczyk 2024-06-26 15:03:01 +02:00
parent 6fcbf68235
commit 45681a57dd
5 changed files with 892 additions and 15 deletions

View File

@ -2061,7 +2061,24 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
if (params.warmup) { if (params.warmup) {
LOG("warming up the model with an empty run\n"); LOG("warming up the model with an empty run\n");
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), }; std::vector<llama_token> tmp;
llama_token bos = llama_token_bos(model);
llama_token eos = llama_token_eos(model);
// some models (e.g. T5) don't have a BOS token
if (bos != -1) {
tmp.push_back(bos);
}
tmp.push_back(eos);
if (llama_model_has_encoder(model)) {
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == -1) {
decoder_start_token_id = bos;
}
tmp.clear();
tmp.push_back(decoder_start_token_id);
}
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_clear(lctx); llama_kv_cache_clear(lctx);
llama_synchronize(lctx); llama_synchronize(lctx);

View File

@ -2775,11 +2775,17 @@ class DeepseekV2Model(Model):
raise ValueError(f"Unprocessed experts: {experts}") raise ValueError(f"Unprocessed experts: {experts}")
@Model.register("T5ForConditionalGeneration")
@Model.register("T5WithLMHeadModel") @Model.register("T5WithLMHeadModel")
@Model.register("T5ForConditionalGeneration")
@Model.register("MT5ForConditionalGeneration")
@Model.register("UMT5ForConditionalGeneration")
class T5Model(Model): class T5Model(Model):
model_arch = gguf.MODEL_ARCH.T5 model_arch = gguf.MODEL_ARCH.T5
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.shared_token_embeddings_found = False
def set_vocab(self): def set_vocab(self):
# to avoid TypeError: Descriptors cannot be created directly # to avoid TypeError: Descriptors cannot be created directly
# exception when importing sentencepiece_model_pb2 # exception when importing sentencepiece_model_pb2
@ -2787,6 +2793,10 @@ class T5Model(Model):
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
from sentencepiece import sentencepiece_model_pb2 as model from sentencepiece import sentencepiece_model_pb2 as model
tokenizer_path = self.dir_model / 'tokenizer.model'
# many older models use spiece.model tokenizer model filename
if not tokenizer_path.is_file():
tokenizer_path = self.dir_model / 'spiece.model' tokenizer_path = self.dir_model / 'spiece.model'
if not tokenizer_path.is_file(): if not tokenizer_path.is_file():
@ -2794,10 +2804,18 @@ class T5Model(Model):
sentencepiece_model = model.ModelProto() sentencepiece_model = model.ModelProto()
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
if sentencepiece_model.trainer_spec.model_type == 2: # BPE
# assure the tokenizer model file name is correct
assert tokenizer_path.name == 'tokenizer.model'
return self._set_vocab_sentencepiece()
else:
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
tokenizer = SentencePieceProcessor() tokenizer = SentencePieceProcessor()
tokenizer.LoadFromFile(str(tokenizer_path)) tokenizer.LoadFromFile(str(tokenizer_path))
@ -2867,7 +2885,10 @@ class T5Model(Model):
def set_gguf_parameters(self): def set_gguf_parameters(self):
self.gguf_writer.add_name("T5") self.gguf_writer.add_name("T5")
self.gguf_writer.add_context_length(self.hparams["n_positions"]) if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
n_ctx = 512
self.gguf_writer.add_context_length(n_ctx)
self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
self.gguf_writer.add_block_count(self.hparams["num_layers"]) self.gguf_writer.add_block_count(self.hparams["num_layers"])
@ -2883,11 +2904,16 @@ class T5Model(Model):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused del bid # unused
# Sometimes T5 and Flan-T5 based models contain "encoder.embed_tokens.weight" tensor or # T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight",
# "decoder.embed_tokens.weight" tensors that are duplicates of "shared.weight" tensor # "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored
# To prevent errors caused by an unnecessary unmapped tensor, skip both of them and use only "shared.weight". # in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder
if name == "decoder.embed_tokens.weight" or name == "encoder.embed_tokens.weight": # and decoder and ignore the remaining ones.
logger.debug(f"Skipping tensor {name!r} in safetensors so that convert can end normally.") if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]:
if not self.shared_token_embeddings_found:
name = "shared.weight"
self.shared_token_embeddings_found = True
else:
logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.")
return [] return []
return [(self.map_tensor_name(name), data_torch)] return [(self.map_tensor_name(name), data_torch)]

View File

@ -255,7 +255,9 @@ int main(int argc, char ** argv) {
} }
const bool add_bos = llama_should_add_bos_token(model); const bool add_bos = llama_should_add_bos_token(model);
if (!llama_model_has_encoder(model)) {
GGML_ASSERT(llama_add_eos_token(model) != 1); GGML_ASSERT(llama_add_eos_token(model) != 1);
}
LOG("add_bos: %d\n", add_bos); LOG("add_bos: %d\n", add_bos);
std::vector<llama_token> embd_inp; std::vector<llama_token> embd_inp;
@ -517,6 +519,23 @@ int main(int argc, char ** argv) {
exit(1); exit(1);
} }
if (llama_model_has_encoder(model)) {
int enc_input_size = embd_inp.size();
llama_token * enc_input_buf = embd_inp.data();
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == -1) {
decoder_start_token_id = llama_token_bos(model);
}
embd_inp.clear();
embd_inp.push_back(decoder_start_token_id);
}
while ((n_remain != 0 && !is_antiprompt) || params.interactive) { while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict // predict
if (!embd.empty()) { if (!embd.empty()) {

806
llama.cpp

File diff suppressed because it is too large Load Diff

15
llama.h
View File

@ -483,6 +483,13 @@ extern "C" {
// Get a llama model tensor // Get a llama model tensor
LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name); LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
// Returns true if the model contains an encoder that requires llama_encode() call
LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
// For encoder-decoder models, this function returns id of the token that must be provided
// to the decoder to start generating output sequence. For other models, it returns -1.
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
// Returns 0 on success // Returns 0 on success
LLAMA_API uint32_t llama_model_quantize( LLAMA_API uint32_t llama_model_quantize(
const char * fname_inp, const char * fname_inp,
@ -768,6 +775,14 @@ extern "C" {
// Frees a batch of tokens allocated with llama_batch_init() // Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch); LLAMA_API void llama_batch_free(struct llama_batch batch);
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
// Stores the encoder output internally for later use by the decoder cross-attention layers.
// 0 - success
// < 0 - error
LLAMA_API int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch);
// Positive return values does not mean a fatal error, but rather a warning. // Positive return values does not mean a fatal error, but rather a warning.
// 0 - success // 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)