mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 21:37:19 +01:00
bitnet : pad tensors to 256
This commit is contained in:
parent
569a03ed97
commit
e9f2abfc8c
@ -1423,6 +1423,20 @@ class BitnetModel(Model):
|
||||
"o_proj.weight")):
|
||||
data_torch = self.weight_quant(data_torch)
|
||||
|
||||
# pad 1D tensors
|
||||
# TODO: is padding with 0s an invariant, or do we also need some scaling factor?
|
||||
if name.endswith(("input_layernorm.weight", "post_attention_layernorm.weight", "model.norm.weight")):
|
||||
data_torch = torch.nn.functional.pad(data_torch, (0, 256 - data_torch.size(0) % 256), mode='constant', value=0)
|
||||
logger.info(f"pad {name} to {data_torch.size()}")
|
||||
|
||||
# pad 2D tensors
|
||||
# TODO: double-check that this is the correct way to pad the rows
|
||||
if name.endswith(("embed_tokens.weight", "q_proj.weight", "k_proj.weight", "v_proj.weight",
|
||||
"down_proj.weight", "up_proj.weight", "gate_proj.weight",
|
||||
"o_proj.weight")):
|
||||
data_torch = torch.nn.functional.pad(data_torch, (0, 256 - data_torch.size(1) % 256), mode='constant', value=0)
|
||||
logger.info(f"pad {name} to {data_torch.size()}")
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
|
46
llama.cpp
46
llama.cpp
@ -2021,9 +2021,9 @@ struct llama_layer {
|
||||
struct ggml_tensor * wkv_b;
|
||||
|
||||
// attention bias
|
||||
struct ggml_tensor * bq;
|
||||
struct ggml_tensor * bk;
|
||||
struct ggml_tensor * bv;
|
||||
struct ggml_tensor * bq = nullptr;
|
||||
struct ggml_tensor * bk = nullptr;
|
||||
struct ggml_tensor * bv = nullptr;
|
||||
struct ggml_tensor * bo;
|
||||
struct ggml_tensor * bqkv;
|
||||
|
||||
@ -6440,14 +6440,18 @@ static bool llm_load_tensors(
|
||||
} break;
|
||||
case LLM_ARCH_BITNET:
|
||||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
const uint32_t n_ff = hparams.n_ff;
|
||||
const uint32_t n_ff_pad = GGML_PAD(n_ff, 256);
|
||||
|
||||
const int64_t n_embd_pad = GGML_PAD(n_embd, 256);
|
||||
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd_pad, n_vocab});
|
||||
|
||||
// output
|
||||
{
|
||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd_pad});
|
||||
}
|
||||
|
||||
const uint32_t n_ff = hparams.n_ff;
|
||||
model.layers.resize(n_layer);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
@ -6456,20 +6460,20 @@ static bool llm_load_tensors(
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd_pad});
|
||||
layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
|
||||
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
|
||||
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd_pad, n_embd});
|
||||
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd_pad, n_embd_gqa});
|
||||
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd_pad, n_embd_gqa});
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_pad, n_embd});
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd_pad});
|
||||
layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff});
|
||||
|
||||
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd_pad, n_ff});
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_pad, n_embd});
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd_pad, n_ff});
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
@ -11658,6 +11662,7 @@ struct llm_build_context {
|
||||
|
||||
ggml_build_forward_expand(graph, cur_attn);
|
||||
|
||||
cur_attn = ggml_pad(ctx0, cur_attn, (256 - cur_attn->ne[0] % 256) % 256, 0, 0, 0);
|
||||
cur = ggml_mul_mat(ctx0, wo, cur_attn);
|
||||
|
||||
cb(cur, "kqv_out", il);
|
||||
@ -11670,6 +11675,8 @@ struct llm_build_context {
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
cur = ggml_pad(ctx0, cur, (256 - cur->ne[0] % 256) % 256, 0, 0, 0);
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
@ -11680,8 +11687,8 @@ struct llm_build_context {
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur);
|
||||
|
||||
struct ggml_tensor * tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur);
|
||||
|
||||
cb(tmp, "ffn_up", il);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur);
|
||||
@ -11700,9 +11707,14 @@ struct llm_build_context {
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_sub_norm", il);
|
||||
|
||||
cur = ggml_pad(ctx0, cur, (256 - cur->ne[0] % 256) % 256, 0, 0, 0);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur);
|
||||
cb(cur, "ffn_down", il);
|
||||
}
|
||||
|
||||
cur = ggml_pad(ctx0, cur, (256 - cur->ne[0] % 256) % 256, 0, 0, 0);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user