gptneox-main.cpp : tensor name map changes

This commit is contained in:
klosax 2023-08-14 10:59:18 +02:00 committed by GitHub
parent 806a15749d
commit d753dfbcc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -370,17 +370,19 @@ bool gpt_neox_model_load(const std::string & fname, gpt_neox_model & model, gpt2
int keyidx; int keyidx;
keyidx = gguf_find_key(ggufctx, "general.name"); keyidx = gguf_find_key(ggufctx, "general.name");
if (keyidx != -1) { fprintf(stdout, "%s: model name = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } if (keyidx != -1) { fprintf(stdout, "%s: model name = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); }
keyidx = gguf_find_key(ggufctx, "general.description"); keyidx = gguf_find_key(ggufctx, "general.description");
if (keyidx != -1) { fprintf(stdout, "%s: model description = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } if (keyidx != -1) { fprintf(stdout, "%s: model description = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); }
keyidx = gguf_find_key(ggufctx, "general.author"); keyidx = gguf_find_key(ggufctx, "general.author");
if (keyidx != -1) { fprintf(stdout, "%s: model author = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } if (keyidx != -1) { fprintf(stdout, "%s: model author = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); }
keyidx = gguf_find_key(ggufctx, "general.license"); keyidx = gguf_find_key(ggufctx, "general.license");
if (keyidx != -1) { fprintf(stdout, "%s: model license = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } if (keyidx != -1) { fprintf(stdout, "%s: model license = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); }
keyidx = gguf_find_key(ggufctx, "general.architecture"); keyidx = gguf_find_key(ggufctx, "general.architecture");
if (keyidx != -1) { fprintf(stdout, "%s: model architecture = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } if (keyidx != -1) { fprintf(stdout, "%s: model architecture = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); }
keyidx = gguf_find_key(ggufctx, "general.file_type"); keyidx = gguf_find_key(ggufctx, "general.file_type");
if (keyidx != -1) { fprintf(stdout, "%s: model file type = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } if (keyidx != -1) { fprintf(stdout, "%s: model file type = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); }
keyidx = gguf_find_key(ggufctx, "general.source.hugginface.repository");
if (keyidx != -1) { fprintf(stdout, "%s: model source HF repo = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); }
} }
// check required metadata // check required metadata
@ -551,21 +553,21 @@ bool gpt_neox_model_load(const std::string & fname, gpt_neox_model & model, gpt2
model.blocks.resize(n_block); model.blocks.resize(n_block);
model.wte = ggml_get_tensor(ctx, "transformer.token_embd.weight"); model.wte = ggml_get_tensor(ctx, "token_embd.weight");
model.ln_f_g = ggml_get_tensor(ctx, "transformer.output_norm.weight"); model.ln_f_g = ggml_get_tensor(ctx, "output_norm.weight");
model.ln_f_b = ggml_get_tensor(ctx, "transformer.output_norm.bias"); model.ln_f_b = ggml_get_tensor(ctx, "output_norm.bias");
model.lmh_g = ggml_get_tensor(ctx, "transformer.output.weight"); model.lmh_g = ggml_get_tensor(ctx, "output.weight");
// map by name // map by name
model.tensors["transformer.token_embd.weight"] = model.wte; model.tensors["token_embd.weight"] = model.wte;
model.tensors["transformer.output_norm.weight"] = model.ln_f_g; model.tensors["output_norm.weight"] = model.ln_f_g;
model.tensors["transformer.output_norm.bias"] = model.ln_f_b; model.tensors["output_norm.bias"] = model.ln_f_b;
model.tensors["transformer.output.weight"] = model.lmh_g; model.tensors["output.weight"] = model.lmh_g;
for (int i = 0; i < n_block; ++i) { for (int i = 0; i < n_block; ++i) {
auto & block = model.blocks[i]; auto & block = model.blocks[i];
std::string blocknamestart = "transformer.blocks." + std::to_string(i) + "."; std::string blocknamestart = "blk." + std::to_string(i) + ".";
block.ln_1_g = get_tensor_ex(ctx, blocknamestart + "attn_norm.weight" ); block.ln_1_g = get_tensor_ex(ctx, blocknamestart + "attn_norm.weight" );
block.ln_1_b = get_tensor_ex(ctx, blocknamestart + "attn_norm.bias" ); block.ln_1_b = get_tensor_ex(ctx, blocknamestart + "attn_norm.bias" );