diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 5e343742d..971e0fcb2 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -205,6 +205,8 @@ class Model: return OrionModel if model_architecture == "InternLM2ForCausalLM": return InternLM2Model + if model_architecture == "BertModel": + return BertModel return Model def _is_model_safetensors(self) -> bool: @@ -258,6 +260,8 @@ class Model: return gguf.MODEL_ARCH.ORION if arch == "InternLM2ForCausalLM": return gguf.MODEL_ARCH.INTERNLM2 + if arch == "BertModel": + return gguf.MODEL_ARCH.BERT raise NotImplementedError(f'Architecture "{arch}" not supported!') @@ -1521,6 +1525,55 @@ in chat mode so that the conversation can end normally.") self.post_write_tensors(tensor_map, name, data_torch) +class BertModel(Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.block_count = self.hparams["num_hidden_layers"] + + def set_gguf_parameters(self): + # TODO(cebtenzzre): merge with parent class + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) + self.gguf_writer.add_file_type(self.ftype) + + def write_tensors(self): + tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + tensors = dict(self.get_tensors()) + for name, data_torch in tensors.items(): + # we are only using BERT for embeddings so we don't need the pooling layer + if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"): + continue # we don't need these + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + data = data_torch.squeeze().numpy() + n_dims = len(data.shape) + new_dtype: type[np.floating[Any]] + + if self.ftype == 1 and name.endswith(".weight") and n_dims == 2: + # if f16 desired, convert any float32 2-dim weight tensors to float16 + new_dtype = np.float16 + else: + # if f32 desired, convert any float16 to float32 + new_dtype = np.float32 + + print(f"{new_name}, n_dims = {n_dims}, {data_torch.dtype} --> {new_dtype}") + + if data.dtype != new_dtype: + data = data.astype(new_dtype) + + self.gguf_writer.add_tensor(new_name, data) + + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index ed8e26f83..4b42e488c 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -60,22 +60,23 @@ class Keys: SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" class Tokenizer: - MODEL = "tokenizer.ggml.model" - LIST = "tokenizer.ggml.tokens" - TOKEN_TYPE = "tokenizer.ggml.token_type" - SCORES = "tokenizer.ggml.scores" - MERGES = "tokenizer.ggml.merges" - BOS_ID = "tokenizer.ggml.bos_token_id" - EOS_ID = "tokenizer.ggml.eos_token_id" - UNK_ID = "tokenizer.ggml.unknown_token_id" - SEP_ID = "tokenizer.ggml.seperator_token_id" - PAD_ID = "tokenizer.ggml.padding_token_id" - ADD_BOS = "tokenizer.ggml.add_bos_token" - ADD_EOS = "tokenizer.ggml.add_eos_token" - ADD_PREFIX = "tokenizer.ggml.add_space_prefix" - HF_JSON = "tokenizer.huggingface.json" - RWKV = "tokenizer.rwkv.world" - CHAT_TEMPLATE = "tokenizer.chat_template" + MODEL = "tokenizer.ggml.model" + LIST = "tokenizer.ggml.tokens" + TOKEN_TYPE = "tokenizer.ggml.token_type" + TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types + SCORES = "tokenizer.ggml.scores" + MERGES = "tokenizer.ggml.merges" + BOS_ID = "tokenizer.ggml.bos_token_id" + EOS_ID = "tokenizer.ggml.eos_token_id" + UNK_ID = "tokenizer.ggml.unknown_token_id" + SEP_ID = "tokenizer.ggml.seperator_token_id" + PAD_ID = "tokenizer.ggml.padding_token_id" + ADD_BOS = "tokenizer.ggml.add_bos_token" + ADD_EOS = "tokenizer.ggml.add_eos_token" + ADD_PREFIX = "tokenizer.ggml.add_space_prefix" + HF_JSON = "tokenizer.huggingface.json" + RWKV = "tokenizer.rwkv.world" + CHAT_TEMPLATE = "tokenizer.chat_template" # @@ -121,6 +122,7 @@ class MODEL_TENSOR(IntEnum): ATTN_OUT = auto() ATTN_NORM = auto() ATTN_NORM_2 = auto() + ATTN_OUT_NORM = auto() ATTN_ROT_EMBD = auto() FFN_GATE_INP = auto() FFN_NORM = auto() @@ -133,6 +135,7 @@ class MODEL_TENSOR(IntEnum): FFN_UP_EXP = auto() ATTN_Q_NORM = auto() ATTN_K_NORM = auto() + LAYER_OUT_NORM = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -176,6 +179,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", + MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm", MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", @@ -185,6 +189,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}", + MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -260,17 +265,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { ], MODEL_ARCH.BERT: [ MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, MODEL_TENSOR.TOKEN_TYPES, MODEL_TENSOR.POS_EMBD, MODEL_TENSOR.OUTPUT_NORM, - MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_OUT_NORM, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_OUT, - MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.LAYER_OUT_NORM, ], MODEL_ARCH.MPT: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 16808196e..20d2e5cff 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -387,6 +387,9 @@ class GGUFWriter: def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None: self.add_array(Keys.Tokenizer.TOKEN_TYPE, types) + def add_token_type_count(self, value: int) -> None: + self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value) + def add_token_scores(self, scores: Sequence[float]) -> None: self.add_array(Keys.Tokenizer.SCORES, scores) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 4f16d8504..c7ba1420e 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -30,6 +30,7 @@ class TensorNameMap: # Normalization of token embeddings MODEL_TENSOR.TOKEN_EMBD_NORM: ( "word_embeddings_layernorm", # bloom + "embeddings.LayerNorm", # bert ), # Position embeddings @@ -54,7 +55,6 @@ class TensorNameMap: "transformer.ln_f", # gpt2 gpt-j falcon "model.norm", # llama-hf baichuan internlm2 "norm", # llama-pth - "embeddings.LayerNorm", # bert "transformer.norm_f", # mpt "ln_f", # refact bloom qwen gpt2 "language_model.encoder.final_layernorm", # persimmon @@ -79,7 +79,6 @@ class TensorNameMap: "transformer.h.{bid}.ln_mlp", # falcon40b "model.layers.{bid}.input_layernorm", # llama-hf "layers.{bid}.attention_norm", # llama-pth - "encoder.layer.{bid}.attention.output.LayerNorm", # bert "language_model.encoder.layers.{bid}.input_layernorm", # persimmon "model.layers.{bid}.ln1", # yi "h.{bid}.ln_1", # gpt2 @@ -155,6 +154,11 @@ class TensorNameMap: "model.layers.{bid}.attention.wo", # internlm2 ), + # Attention output norm + MODEL_TENSOR.ATTN_OUT_NORM: ( + "encoder.layer.{bid}.attention.output.LayerNorm", # bert + ), + # Rotary embeddings MODEL_TENSOR.ATTN_ROT_EMBD: ( "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf @@ -171,7 +175,6 @@ class TensorNameMap: "transformer.blocks.{bid}.norm_2", # mpt "model.layers.{bid}.post_attention_layernorm", # llama-hf "layers.{bid}.ffn_norm", # llama-pth - "encoder.layer.{bid}.output.LayerNorm", # bert "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon "model.layers.{bid}.ln2", # yi "h.{bid}.ln_2", # gpt2 @@ -266,6 +269,10 @@ class TensorNameMap: MODEL_TENSOR.ROPE_FREQS: ( "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon ), + + MODEL_TENSOR.LAYER_OUT_NORM: ( + "encoder.layer.{bid}.output.LayerNorm", # bert + ) } mapping: dict[str, tuple[MODEL_TENSOR, str]] diff --git a/llama.cpp b/llama.cpp index 65e399adc..575b60f56 100644 --- a/llama.cpp +++ b/llama.cpp @@ -196,6 +196,7 @@ enum llm_arch { LLM_ARCH_STARCODER, LLM_ARCH_PERSIMMON, LLM_ARCH_REFACT, + LLM_ARCH_BERT, LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, @@ -219,6 +220,7 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_STARCODER, "starcoder" }, { LLM_ARCH_PERSIMMON, "persimmon" }, { LLM_ARCH_REFACT, "refact" }, + { LLM_ARCH_BERT, "bert" }, { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, @@ -271,6 +273,7 @@ enum llm_kv { LLM_KV_TOKENIZER_MODEL, LLM_KV_TOKENIZER_LIST, LLM_KV_TOKENIZER_TOKEN_TYPE, + LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, LLM_KV_TOKENIZER_SCORES, LLM_KV_TOKENIZER_MERGES, LLM_KV_TOKENIZER_BOS_ID, @@ -326,6 +329,7 @@ static std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, @@ -534,6 +538,24 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_BERT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_BLOOM, { @@ -1388,6 +1410,9 @@ static llama_state g_state; // available llama models enum e_model { MODEL_UNKNOWN, + MODEL_22M, + MODEL_33M, + MODEL_109M, MODEL_0_5B, MODEL_1B, MODEL_3B, @@ -1428,6 +1453,7 @@ struct llama_hparams { uint32_t n_ff; uint32_t n_expert = 0; uint32_t n_expert_used = 0; + uint32_t n_vocab_type = 0; // for BERT-style token types float f_norm_eps; float f_norm_rms_eps; @@ -1532,6 +1558,8 @@ struct llama_layer { struct ggml_tensor * bqkv; // normalization + struct ggml_tensor * attn_out_norm; + struct ggml_tensor * attn_out_norm_b; struct ggml_tensor * ffn_norm; struct ggml_tensor * ffn_norm_b; @@ -1550,6 +1578,10 @@ struct llama_layer { struct ggml_tensor * ffn_down_b; // b2 struct ggml_tensor * ffn_up_b; // b3 struct ggml_tensor * ffn_act; + + // normalization + struct ggml_tensor * layer_out_norm; + struct ggml_tensor * layer_out_norm_b; }; struct llama_kv_cell { @@ -1667,6 +1699,7 @@ struct llama_model { llama_vocab vocab; struct ggml_tensor * tok_embd; + struct ggml_tensor * type_embd; struct ggml_tensor * pos_embd; struct ggml_tensor * tok_norm; struct ggml_tensor * tok_norm_b; @@ -1791,8 +1824,10 @@ struct llama_context { struct ggml_tensor * inp_tokens; // I32 [n_batch] struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] struct ggml_tensor * inp_pos; // I32 [n_batch] + struct ggml_tensor * inp_type; // I32 [n_batch] struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch] struct ggml_tensor * inp_K_shift; // I32 [n_ctx] + struct ggml_tensor * inp_sum; // F32 [1, n_batch] #ifdef GGML_USE_MPI ggml_mpi_context * ctx_mpi = NULL; @@ -2933,6 +2968,21 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_BERT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_TOKEN_TYPE_COUNT, hparams.n_vocab_type); + + switch (hparams.n_embd) { + case 384: // MiniLM + switch (hparams.n_layer) { + case 6: model.type = e_model::MODEL_22M; break; + case 12: model.type = e_model::MODEL_33M; break; + } break; + case 768: // BERT-Base + model.type = e_model::MODEL_109M; break; + } + } break; case LLM_ARCH_BLOOM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -3719,6 +3769,45 @@ static bool llm_load_tensors( layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {64}); } } break; + case LLM_ARCH_BERT: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_vocab_type, n_embd}); + model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}); + model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); + model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); + layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}); + + layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); + layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + } + } break; case LLM_ARCH_BLOOM: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -5561,6 +5650,113 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_bert() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); + + inpL = ggml_add(ctx0, + ggml_get_rows(ctx0, model.type_embd, lctx.inp_type), + inpL); + inpL = ggml_add(ctx0, + ggml_get_rows(ctx0, model.pos_embd, lctx.inp_pos), + inpL); + + inpL = llm_build_norm(ctx0, inpL, hparams, + model.tok_norm, + model.tok_norm_b, + LLM_NORM, cb, -1); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * cur = inpL; + + // self-attention + { + // compute Q and K + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head, n_tokens); + struct ggml_tensor * K = ggml_permute(ctx0, Kcur, 0, 2, 1, 3); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head, n_tokens); + struct ggml_tensor * V = ggml_permute(ctx0, Vcur, 0, 2, 1, 3); + + struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q); + // KQ = soft_max(KQ / sqrt(head width)) + KQ = ggml_soft_max(ctx0, + ggml_scale(ctx0, KQ, 1.0f / sqrt((float)n_embd_head))); + + V = ggml_cont(ctx0, ggml_transpose(ctx0, V)); + struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV), n_embd, N); + + // attention output + cur = ggml_add(ctx0, + model.layers[il].bo, + ggml_mul_mat(ctx0, model.layers[il].wo, cur)); + } + + // re-add the layer input + cur = ggml_add(ctx0, cur, inpSA); + + // attention layer norm + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].attn_out_norm, + model.layers[il].attn_out_norm_b, + LLM_NORM, cb, il); + + struct ggml_tensor * ffn_inp = cur; + + // feed-forward network + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + + // attentions bypass the intermediate layer + cur = ggml_add(ctx0, cur, ffn_inp); + + // output layer norm + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].layer_out_norm, + model.layers[il].layer_out_norm_b, + LLM_NORM, cb, il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + // pooling (sum = [L, 1, B]) + cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), lctx.inp_sum); // [E, 1, B] + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_bloom() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -6835,6 +7031,12 @@ static struct ggml_cgraph * llama_build_graph( ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } + { + // for embedding models, token type is always zero ("sentence A") + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_type->buffer)); + memset(lctx.inp_type->data, 0, batch.n_tokens * ggml_element_size(lctx.inp_type)); + } + { const int64_t n_kv = llm.n_kv; const int64_t n_tokens = batch.n_tokens; @@ -6870,6 +7072,15 @@ static struct ggml_cgraph * llama_build_graph( data[i] = lctx.kv_self.cells[i].delta; } } + + { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_inp_sum->buffer)); + float * data = (float *) lctx.inp_sum->data; + + for (int i = 0; i < batch.n_tokens; ++i) { + data[i] = 1.0f/float(batch.n_tokens); + } + } } llm.init(); @@ -6899,6 +7110,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_refact(); } break; + case LLM_ARCH_BERT: + { + result = llm.build_bert(); + } break; case LLM_ARCH_BLOOM: { result = llm.build_bloom(); @@ -10574,7 +10789,7 @@ struct llama_context * llama_new_context_with_model( // graph inputs { ggml_init_params init_params = { - /* .mem_size */ ggml_tensor_overhead()*5, + /* .mem_size */ ggml_tensor_overhead()*7, /* .mem_buffer */ nullptr, /* .no_alloc */ true, }; @@ -10583,14 +10798,18 @@ struct llama_context * llama_new_context_with_model( ctx->inp_tokens = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch); ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); + ctx->inp_type = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch); ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx); + ctx->inp_sum = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, 1, cparams.n_batch); ggml_set_name(ctx->inp_tokens, "inp_tokens"); ggml_set_name(ctx->inp_embd, "inp_embd"); ggml_set_name(ctx->inp_pos, "inp_pos"); + ggml_set_name(ctx->inp_type, "inp_type"); ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask"); ggml_set_name(ctx->inp_K_shift, "inp_K_shift"); + ggml_set_name(ctx->inp_sum, "inp_sum"); ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));