From 45aab64e93d11aa08ba1c722fc0c65c1ff458364 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 9 Jan 2025 16:44:49 +0200 Subject: [PATCH] hparams : move vocab params to llama_vocab (#11159) ggml-ci --- src/llama-context.cpp | 9 +++++---- src/llama-hparams.h | 2 -- src/llama-model.cpp | 21 +++++++++------------ src/llama-vocab.cpp | 8 ++++++++ src/llama-vocab.h | 2 +- src/llama.cpp | 8 ++------ 6 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4b195eaca..e20482516 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -469,11 +469,12 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) { size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) { const auto & cparams = lctx.cparams; const auto & hparams = lctx.model.hparams; + const auto & vocab = lctx.model.vocab; const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max); const auto n_batch = cparams.n_batch; - const auto n_vocab = hparams.n_vocab; + const auto n_vocab = vocab.n_vocab(); const auto n_embd = hparams.n_embd; // TODO: use a per-batch flag for logits presence instead @@ -540,7 +541,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) { void llama_output_reorder(struct llama_context & ctx) { std::vector & out_ids = ctx.sbatch.out_ids; if (!out_ids.empty()) { - const uint32_t n_vocab = ctx.model.hparams.n_vocab; + const uint32_t n_vocab = ctx.model.vocab.n_vocab(); const uint32_t n_embd = ctx.model.hparams.n_embd; const int32_t n_outputs = ctx.n_outputs; @@ -724,7 +725,7 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); } - return ctx->logits + j*ctx->model.hparams.n_vocab; + return ctx->logits + j*ctx->model.vocab.n_vocab(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG @@ -884,7 +885,7 @@ struct llama_data_write { } void write_logits(const struct llama_context * ctx) { - const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab); + const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.vocab.n_vocab()); write(&logits_size, sizeof(logits_size)); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 3542bef49..1fe454103 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -30,7 +30,6 @@ struct llama_hparams { bool use_par_res; bool swin_norm; - uint32_t n_vocab = 0; uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; uint32_t n_embd_features = 0; @@ -41,7 +40,6 @@ struct llama_hparams { uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_expert = 0; uint32_t n_expert_used = 0; - uint32_t n_vocab_type = 0; // for BERT-style token types uint32_t n_rel_attn_bkts = 0; // for WavTokenizer diff --git a/src/llama-model.cpp b/src/llama-model.cpp index fd3c96606..b81ed9437 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -402,9 +402,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { // get general kv ml.get_key(LLM_KV_GENERAL_NAME, name, false); - // get hparams kv - ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, false); - // everything past this point is not vocab-related if (hparams.vocab_only) { return; @@ -500,6 +497,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.n_embd_head_v = 0; } + // for differentiating model types + uint32_t n_vocab = 0; + ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); + // arch-specific KVs switch (arch) { case LLM_ARCH_LLAMA: @@ -519,7 +520,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 26: type = LLM_TYPE_3B; break; case 28: type = LLM_TYPE_3B; break; // Llama 3.2 3B // granite uses a vocab with len 49152 - case 32: type = hparams.n_vocab == 49152 ? LLM_TYPE_3B : (hparams.n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break; + case 32: type = n_vocab == 49152 ? LLM_TYPE_3B : (n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break; case 36: type = LLM_TYPE_8B; break; // granite case 40: type = LLM_TYPE_13B; break; case 48: type = LLM_TYPE_34B; break; @@ -621,7 +622,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { @@ -644,7 +644,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); hparams.f_max_alibi_bias = 8.0f; @@ -658,7 +657,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); if (hparams.n_layer == 12 && hparams.n_embd == 768) { @@ -1369,8 +1367,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_embd_head_v = hparams.n_embd_head_v; const int64_t n_ff = hparams.n_ff(); const int64_t n_embd_gqa = n_embd_v_gqa; - const int64_t n_vocab = hparams.n_vocab; - const int64_t n_vocab_type = hparams.n_vocab_type; + const int64_t n_vocab = vocab.n_vocab(); + const int64_t n_token_types = vocab.n_token_types(); const int64_t n_rot = hparams.n_rot; const int64_t n_expert = hparams.n_expert; const int64_t n_expert_used = hparams.n_expert_used; @@ -1815,7 +1813,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_NOMIC_BERT: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); if (arch == LLM_ARCH_BERT) { pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); @@ -1869,7 +1867,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_JINA_BERT_V2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings - type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}, 0); // token_type_embeddings + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); //LayerNorm bias @@ -3553,7 +3551,6 @@ void llama_model::print_info() const { // hparams LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); - LLAMA_LOG_INFO("%s: n_vocab (hp) = %u\n", __func__, hparams.n_vocab); LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); if (!hparams.vocab_only) { diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 2240ff8f6..c5524a116 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1205,6 +1205,7 @@ struct fragment_buffer_variant { struct llama_vocab::impl { uint32_t n_vocab = 0; + uint32_t n_token_types = 0; // for BERT-style token types std::unordered_map token_to_id; std::vector id_to_token; @@ -1286,6 +1287,7 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) { struct gguf_context * ctx = ml.meta.get(); auto & n_vocab = pimpl->n_vocab; + auto & n_token_types = pimpl->n_token_types; auto & id_to_token = pimpl->id_to_token; auto & token_to_id = pimpl->token_to_id; auto & special_eog_ids = pimpl->special_eog_ids; @@ -1300,6 +1302,8 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) { ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); + ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, n_token_types, false); + if (tokenizer_model == "no_vocab" || tokenizer_model == "none") { type = LLAMA_VOCAB_TYPE_NONE; @@ -2013,6 +2017,10 @@ uint32_t llama_vocab::n_vocab() const { return (uint32_t) pimpl->id_to_token.size(); } +uint32_t llama_vocab::n_token_types() const { + return (uint32_t) pimpl->n_token_types; +} + std::string llama_vocab::type_name() const{ switch (type) { case LLAMA_VOCAB_TYPE_NONE: return "no vocab"; diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 84bd7c440..710464f21 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -24,8 +24,8 @@ struct llama_vocab { enum llama_vocab_type get_type() const; enum llama_vocab_pre_type get_pre_type() const; - // TODO: how to deduplicate with llama_hparams.n_vocab ? uint32_t n_vocab() const; + uint32_t n_token_types() const; std::string type_name() const; diff --git a/src/llama.cpp b/src/llama.cpp index 56d15e76d..76506abc1 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -65,11 +65,6 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam model.load_stats(ml); model.print_info(); - if (model.vocab.get_type() != LLAMA_VOCAB_TYPE_NONE && - model.hparams.n_vocab != model.vocab.n_vocab()) { - throw std::runtime_error("vocab size mismatch"); - } - if (params.vocab_only) { LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); return 0; @@ -8467,6 +8462,7 @@ static int llama_decode_impl( const uint32_t n_tokens_all = batch.n_tokens; const auto & model = lctx.model; + const auto & vocab = model.vocab; const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; @@ -8494,7 +8490,7 @@ static int llama_decode_impl( llama_kv_slot_restorer kv_slot_restorer(kv_self); const int64_t n_embd = hparams.n_embd; - const int64_t n_vocab = hparams.n_vocab; + const int64_t n_vocab = vocab.n_vocab(); uint32_t n_outputs = 0; uint32_t n_outputs_prev = 0;