From 9b82476ee9e73065a759f8bcc4cf27ec7ab2ed8c Mon Sep 17 00:00:00 2001 From: fairydreaming <166155368+fairydreaming@users.noreply.github.com> Date: Thu, 23 May 2024 11:49:53 +0200 Subject: [PATCH] Add missing inference support for GPTNeoXForCausalLM (Pythia and GPT-NeoX base models) (#7461) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * convert-hf : add conversion of bloom-style qkv tensor to gpt-style qkv (code borrowed from BloomModel) * llama : add inference support for LLM_ARCH_GPTNEOX * llama : add model types for every Pythia variant and GPT-NeoX Co-authored-by: Stanisław Szymczyk --- convert-hf-to-gguf.py | 38 +++++++ llama.cpp | 236 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 273 insertions(+), 1 deletion(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index daad1c4fc..5a00a5e89 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -673,6 +673,44 @@ class GPTNeoXModel(Model): self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True)) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) + n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) + + tensors: list[tuple[str, Tensor]] = [] + + if re.match(r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.weight", name): + # Map bloom-style qkv_linear to gpt-style qkv_linear + # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa + # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa + qkv_weights = data_torch.reshape((n_head, 3, n_embed // n_head, n_embed)) + data_torch = torch.cat( + ( + qkv_weights[:, 0, :, :].reshape((-1, n_embed)), + qkv_weights[:, 1, :, :].reshape((-1, n_embed)), + qkv_weights[:, 2, :, :].reshape((-1, n_embed)), + ), + dim=0, + ) + logger.info("re-format attention.linear_qkv.weight") + elif re.match(r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.bias", name): + qkv_bias = data_torch.reshape((n_head, 3, n_embed // n_head)) + data_torch = torch.cat( + ( + qkv_bias[:, 0, :].reshape((n_embed,)), + qkv_bias[:, 1, :].reshape((n_embed,)), + qkv_bias[:, 2, :].reshape((n_embed,)), + ), + dim=0, + ) + logger.info("re-format attention.linear_qkv.bias") + + tensors.append((self.map_tensor_name(name), data_torch)) + + return tensors + @Model.register("BloomForCausalLM") class BloomModel(Model): diff --git a/llama.cpp b/llama.cpp index 3e09a2390..5ff186a57 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1692,17 +1692,24 @@ static llama_state g_state; // available llama models enum e_model { MODEL_UNKNOWN, + MODEL_14M, MODEL_17M, MODEL_22M, MODEL_33M, + MODEL_70M, MODEL_109M, MODEL_137M, + MODEL_160M, MODEL_335M, + MODEL_410M, MODEL_0_5B, MODEL_1B, + MODEL_1_4B, MODEL_2B, + MODEL_2_8B, MODEL_3B, MODEL_4B, + MODEL_6_9B, MODEL_7B, MODEL_8B, MODEL_12B, @@ -1734,6 +1741,7 @@ static const size_t GiB = 1024*MiB; struct llama_hparams { bool vocab_only; bool rope_finetuned; + bool use_par_res; uint32_t n_vocab; uint32_t n_ctx_train; // context size the model was trained on @@ -3773,17 +3781,24 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { static const char * llama_model_type_name(e_model type) { switch (type) { + case MODEL_14M: return "14M"; case MODEL_17M: return "17M"; case MODEL_22M: return "22M"; case MODEL_33M: return "33M"; + case MODEL_70M: return "70M"; case MODEL_109M: return "109M"; case MODEL_137M: return "137M"; + case MODEL_160M: return "160M"; case MODEL_335M: return "335M"; + case MODEL_410M: return "410M"; case MODEL_0_5B: return "0.5B"; case MODEL_1B: return "1B"; + case MODEL_1_4B: return "1.4B"; case MODEL_2B: return "2B"; + case MODEL_2_8B: return "2.8B"; case MODEL_3B: return "3B"; case MODEL_4B: return "4B"; + case MODEL_6_9B: return "6.9B"; case MODEL_7B: return "7B"; case MODEL_8B: return "8B"; case MODEL_12B: return "12B"; @@ -4282,6 +4297,52 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_GPTNEOX: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); + switch (hparams.n_layer) { + case 6: + switch (hparams.n_ff) { + case 512: model.type = e_model::MODEL_14M; break; + case 2048: model.type = e_model::MODEL_70M; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 12: + switch (hparams.n_ff) { + case 3072: model.type = e_model::MODEL_160M; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 16: + switch (hparams.n_ff) { + case 8192: model.type = e_model::MODEL_1B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 24: + switch (hparams.n_ff) { + case 4096: model.type = e_model::MODEL_410M; break; + case 8192: model.type = e_model::MODEL_1_4B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 32: + switch (hparams.n_ff) { + case 10240: model.type = e_model::MODEL_2_8B; break; + case 16384: model.type = e_model::MODEL_6_9B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 36: + switch (hparams.n_ff) { + case 20480: model.type = e_model::MODEL_12B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 44: + switch (hparams.n_ff) { + case 24576: model.type = e_model::MODEL_20B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -6033,6 +6094,41 @@ static bool llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; + case LLM_ARCH_GPTNEOX: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); + + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -10560,6 +10656,140 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_gptneox() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // ffn + if (hparams.use_par_res) { + // attention and ffn are computed in parallel + // x = x + attn(ln1(x)) + ffn(ln2(x)) + + struct ggml_tensor * attn_out = cur; + + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, inpL); + cb(cur, "ffn_out", il); + + inpL = ggml_add(ctx0, cur, attn_out); + cb(inpL, "l_out", il); + } else { + // attention and ffn are computed sequentially + // x = x + attn(ln1(x)) + // x = x + ffn(ln2(x)) + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + } + + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -10770,6 +11000,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_olmo(); } break; + case LLM_ARCH_GPTNEOX: + { + result = llm.build_gptneox(); + } break; default: GGML_ASSERT(false); } @@ -15762,7 +15996,6 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { // these models do not use RoPE case LLM_ARCH_GPT2: case LLM_ARCH_GPTJ: - case LLM_ARCH_GPTNEOX: case LLM_ARCH_MPT: case LLM_ARCH_REFACT: case LLM_ARCH_BLOOM: @@ -15798,6 +16031,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_PHI3: case LLM_ARCH_GEMMA: case LLM_ARCH_STARCODER2: + case LLM_ARCH_GPTNEOX: return LLAMA_ROPE_TYPE_NEOX; // all model arches should be listed explicitly here