mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 06:10:29 +01:00
llama : add inference support and model types for T5 and FLAN-T5 model families
llama : add new API functions to support encoder-decoder models: llama_encode(), llama_model_has_encoder(), llama_model_decoder_start_token() common, llama-cli : use new API functions to support encoder-decoder models convert-hf : handle shared token embeddings tensors in T5Model convert-hf : handle SentencePiece BPE tokenizer in T5Model (for Pile-T5 models) convert-hf : add MT5ForConditionalGeneration and UMT5ForConditionalGeneration to architectures supported by T5Model
This commit is contained in:
parent
6fcbf68235
commit
45681a57dd
@ -2061,7 +2061,24 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
||||
if (params.warmup) {
|
||||
LOG("warming up the model with an empty run\n");
|
||||
|
||||
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
|
||||
std::vector<llama_token> tmp;
|
||||
llama_token bos = llama_token_bos(model);
|
||||
llama_token eos = llama_token_eos(model);
|
||||
// some models (e.g. T5) don't have a BOS token
|
||||
if (bos != -1) {
|
||||
tmp.push_back(bos);
|
||||
}
|
||||
tmp.push_back(eos);
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
|
||||
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
|
||||
if (decoder_start_token_id == -1) {
|
||||
decoder_start_token_id = bos;
|
||||
}
|
||||
tmp.clear();
|
||||
tmp.push_back(decoder_start_token_id);
|
||||
}
|
||||
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
|
||||
llama_kv_cache_clear(lctx);
|
||||
llama_synchronize(lctx);
|
||||
|
@ -2775,11 +2775,17 @@ class DeepseekV2Model(Model):
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("T5ForConditionalGeneration")
|
||||
@Model.register("T5WithLMHeadModel")
|
||||
@Model.register("T5ForConditionalGeneration")
|
||||
@Model.register("MT5ForConditionalGeneration")
|
||||
@Model.register("UMT5ForConditionalGeneration")
|
||||
class T5Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.T5
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.shared_token_embeddings_found = False
|
||||
|
||||
def set_vocab(self):
|
||||
# to avoid TypeError: Descriptors cannot be created directly
|
||||
# exception when importing sentencepiece_model_pb2
|
||||
@ -2787,17 +2793,29 @@ class T5Model(Model):
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from sentencepiece import sentencepiece_model_pb2 as model
|
||||
|
||||
tokenizer_path = self.dir_model / 'spiece.model'
|
||||
tokenizer_path = self.dir_model / 'tokenizer.model'
|
||||
|
||||
# many older models use spiece.model tokenizer model filename
|
||||
if not tokenizer_path.is_file():
|
||||
tokenizer_path = self.dir_model / 'spiece.model'
|
||||
|
||||
if not tokenizer_path.is_file():
|
||||
raise FileNotFoundError(f"File not found: {tokenizer_path}")
|
||||
|
||||
sentencepiece_model = model.ModelProto()
|
||||
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
|
||||
|
||||
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
|
||||
if sentencepiece_model.trainer_spec.model_type == 2: # BPE
|
||||
# assure the tokenizer model file name is correct
|
||||
assert tokenizer_path.name == 'tokenizer.model'
|
||||
return self._set_vocab_sentencepiece()
|
||||
else:
|
||||
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
|
||||
|
||||
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
|
||||
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
|
||||
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
|
||||
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
|
||||
|
||||
tokenizer = SentencePieceProcessor()
|
||||
tokenizer.LoadFromFile(str(tokenizer_path))
|
||||
@ -2867,7 +2885,10 @@ class T5Model(Model):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_name("T5")
|
||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
|
||||
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
|
||||
n_ctx = 512
|
||||
self.gguf_writer.add_context_length(n_ctx)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
|
||||
self.gguf_writer.add_block_count(self.hparams["num_layers"])
|
||||
@ -2883,12 +2904,17 @@ class T5Model(Model):
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
# Sometimes T5 and Flan-T5 based models contain "encoder.embed_tokens.weight" tensor or
|
||||
# "decoder.embed_tokens.weight" tensors that are duplicates of "shared.weight" tensor
|
||||
# To prevent errors caused by an unnecessary unmapped tensor, skip both of them and use only "shared.weight".
|
||||
if name == "decoder.embed_tokens.weight" or name == "encoder.embed_tokens.weight":
|
||||
logger.debug(f"Skipping tensor {name!r} in safetensors so that convert can end normally.")
|
||||
return []
|
||||
# T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight",
|
||||
# "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored
|
||||
# in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder
|
||||
# and decoder and ignore the remaining ones.
|
||||
if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]:
|
||||
if not self.shared_token_embeddings_found:
|
||||
name = "shared.weight"
|
||||
self.shared_token_embeddings_found = True
|
||||
else:
|
||||
logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.")
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
@ -255,7 +255,9 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
const bool add_bos = llama_should_add_bos_token(model);
|
||||
GGML_ASSERT(llama_add_eos_token(model) != 1);
|
||||
if (!llama_model_has_encoder(model)) {
|
||||
GGML_ASSERT(llama_add_eos_token(model) != 1);
|
||||
}
|
||||
LOG("add_bos: %d\n", add_bos);
|
||||
|
||||
std::vector<llama_token> embd_inp;
|
||||
@ -517,6 +519,23 @@ int main(int argc, char ** argv) {
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
int enc_input_size = embd_inp.size();
|
||||
llama_token * enc_input_buf = embd_inp.data();
|
||||
|
||||
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) {
|
||||
LOG_TEE("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
|
||||
if (decoder_start_token_id == -1) {
|
||||
decoder_start_token_id = llama_token_bos(model);
|
||||
}
|
||||
embd_inp.clear();
|
||||
embd_inp.push_back(decoder_start_token_id);
|
||||
}
|
||||
|
||||
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
||||
// predict
|
||||
if (!embd.empty()) {
|
||||
|
15
llama.h
15
llama.h
@ -483,6 +483,13 @@ extern "C" {
|
||||
// Get a llama model tensor
|
||||
LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
|
||||
|
||||
// Returns true if the model contains an encoder that requires llama_encode() call
|
||||
LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
|
||||
|
||||
// For encoder-decoder models, this function returns id of the token that must be provided
|
||||
// to the decoder to start generating output sequence. For other models, it returns -1.
|
||||
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
|
||||
|
||||
// Returns 0 on success
|
||||
LLAMA_API uint32_t llama_model_quantize(
|
||||
const char * fname_inp,
|
||||
@ -768,6 +775,14 @@ extern "C" {
|
||||
// Frees a batch of tokens allocated with llama_batch_init()
|
||||
LLAMA_API void llama_batch_free(struct llama_batch batch);
|
||||
|
||||
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
|
||||
// Stores the encoder output internally for later use by the decoder cross-attention layers.
|
||||
// 0 - success
|
||||
// < 0 - error
|
||||
LLAMA_API int32_t llama_encode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
|
||||
// Positive return values does not mean a fatal error, but rather a warning.
|
||||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
|
Loading…
Reference in New Issue
Block a user