mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-27 20:43:07 +01:00
hf bitnet v1
This commit is contained in:
parent
3b38d48609
commit
076b4a197b
@ -1390,6 +1390,26 @@ class LlamaModel(Model):
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
@Model.register("BitnetForCausalLM")
|
||||
class BitnetModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.BITNET
|
||||
def set_vocab(self):
|
||||
self._set_vocab_sentencepiece()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
|
||||
|
||||
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||
if self.hparams["rope_scaling"].get("type") == "linear":
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
|
||||
|
||||
# def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
|
||||
# return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
@Model.register("GrokForCausalLM")
|
||||
class GrokModel(Model):
|
||||
|
111
ggml.c
111
ggml.c
@ -2621,6 +2621,22 @@ inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
|
||||
*s = idx;
|
||||
}
|
||||
|
||||
inline static void ggml_vec_absmaxclamp_f32(const int n, float * s, const float * x, float min) {
|
||||
float max = min;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
max = MAX(max, fabs(x[i]));
|
||||
}
|
||||
*s = max;
|
||||
}
|
||||
inline static void ggml_vec_scaleroundclamp_f32(const int n, float * s, const float * x, float scale, float min, float max) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
s[i] = round(x[i] * scale);
|
||||
if (s[i] > max) s[i] = max;
|
||||
if (s[i] < min) s[i] = min;
|
||||
s[i] /= scale;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// data types
|
||||
//
|
||||
@ -2709,9 +2725,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
|
||||
"CROSS_ENTROPY_LOSS",
|
||||
"CROSS_ENTROPY_LOSS_BACK",
|
||||
|
||||
"BITLINEAR_QUANT"
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
||||
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@ -2797,9 +2815,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
|
||||
"cross_entropy_loss(x,y)",
|
||||
"cross_entropy_loss_back(x,y)",
|
||||
|
||||
"bitlinear(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
||||
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@ -4830,6 +4850,28 @@ struct ggml_tensor * ggml_mean(
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_bitlinear_quant for bitnet
|
||||
|
||||
struct ggml_tensor * ggml_bitlinear_quant(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
GGML_ASSERT(false); // TODO: implement
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
int64_t ne[GGML_MAX_DIMS] = { a->ne[0], a->ne[1], a->ne[2], a->ne[3] };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, ggml_n_dims(a), ne);
|
||||
|
||||
result->op = GGML_OP_BITLINEAR_QUANT;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_argmax
|
||||
|
||||
struct ggml_tensor * ggml_argmax(
|
||||
@ -10740,6 +10782,62 @@ static void ggml_compute_forward_mean(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_bitlinear_quant_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
assert(params->ith == 0);
|
||||
|
||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
assert(src0->nb[0] == sizeof(float));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
assert(ne0 == ne00);
|
||||
assert(ne1 == ne01);
|
||||
assert(ne2 == ne02);
|
||||
assert(ne3 == ne03);
|
||||
|
||||
UNUSED(ne0);
|
||||
UNUSED(ne1);
|
||||
UNUSED(ne2);
|
||||
UNUSED(ne3);
|
||||
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||
float rowmax = 0.00001;
|
||||
ggml_vec_absmaxclamp_f32(ne00, &rowmax, (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), 0.00001);
|
||||
float s = 127 / rowmax;
|
||||
|
||||
ggml_vec_scaleroundclamp_f32(ne00,
|
||||
(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
||||
(float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
||||
s, -128, 127);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_bitlinear_quant(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_bitlinear_quant_f32(params, src0, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_argmax
|
||||
|
||||
static void ggml_compute_forward_argmax_f32(
|
||||
@ -17318,6 +17416,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_mean(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_BITLINEAR_QUANT:
|
||||
{
|
||||
ggml_compute_forward_bitlinear_quant(params, tensor->src[0], tensor);
|
||||
} break;
|
||||
case GGML_OP_ARGMAX:
|
||||
{
|
||||
ggml_compute_forward_argmax(params, tensor);
|
||||
@ -18484,6 +18586,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_BITLINEAR_QUANT:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
@ -19249,6 +19355,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
|
||||
case GGML_OP_GET_REL_POS:
|
||||
case GGML_OP_MAP_UNARY:
|
||||
case GGML_OP_MAP_BINARY:
|
||||
case GGML_OP_BITLINEAR_QUANT:
|
||||
case GGML_OP_MAP_CUSTOM1_F32:
|
||||
case GGML_OP_MAP_CUSTOM2_F32:
|
||||
case GGML_OP_MAP_CUSTOM3_F32:
|
||||
|
7
ggml.h
7
ggml.h
@ -506,6 +506,8 @@ extern "C" {
|
||||
GGML_OP_CROSS_ENTROPY_LOSS,
|
||||
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
|
||||
|
||||
GGML_OP_BITLINEAR_QUANT,
|
||||
|
||||
GGML_OP_COUNT,
|
||||
};
|
||||
|
||||
@ -993,6 +995,11 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// for bitnet
|
||||
GGML_API struct ggml_tensor * ggml_bitlinear_quant(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// argmax along rows
|
||||
GGML_API struct ggml_tensor * ggml_argmax(
|
||||
struct ggml_context * ctx,
|
||||
|
@ -148,6 +148,7 @@ class MODEL_ARCH(IntEnum):
|
||||
OLMO = auto()
|
||||
ARCTIC = auto()
|
||||
DEEPSEEK2 = auto()
|
||||
BITNET = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
@ -199,6 +200,8 @@ class MODEL_TENSOR(IntEnum):
|
||||
ATTN_KV_B = auto()
|
||||
ATTN_Q_A_NORM = auto()
|
||||
ATTN_KV_A_NORM = auto()
|
||||
FFN_SUB_NORM = auto()
|
||||
ATTN_SUB_NORM = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
@ -236,6 +239,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.OLMO: "olmo",
|
||||
MODEL_ARCH.ARCTIC: "arctic",
|
||||
MODEL_ARCH.DEEPSEEK2: "deepseek2",
|
||||
MODEL_ARCH.BITNET: "bitnet",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
@ -287,6 +291,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
|
||||
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
|
||||
MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
@ -806,6 +812,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.BITNET: [
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.ATTN_SUB_NORM,
|
||||
MODEL_TENSOR.FFN_SUB_NORM,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
@ -410,6 +410,14 @@ class TensorNameMap:
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: (
|
||||
"model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_SUB_NORM: (
|
||||
"model.layers.{bid}.self_attn.inner_attn_ln", # bitnet
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_SUB_NORM: (
|
||||
"model.layers.{bid}.mlp.ffn_layernorm", # bitnet
|
||||
),
|
||||
}
|
||||
|
||||
# architecture-specific block mappings
|
||||
|
248
llama.cpp
248
llama.cpp
@ -223,6 +223,7 @@ enum llm_arch {
|
||||
LLM_ARCH_OLMO,
|
||||
LLM_ARCH_ARCTIC,
|
||||
LLM_ARCH_DEEPSEEK2,
|
||||
LLM_ARCH_BITNET,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@ -261,6 +262,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_OLMO, "olmo" },
|
||||
{ LLM_ARCH_ARCTIC, "arctic" },
|
||||
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
||||
{ LLM_ARCH_BITNET, "bitnet" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@ -496,6 +498,8 @@ enum llm_tensor {
|
||||
LLM_TENSOR_ATTN_KV_B,
|
||||
LLM_TENSOR_ATTN_Q_A_NORM,
|
||||
LLM_TENSOR_ATTN_KV_A_NORM,
|
||||
LLM_TENSOR_ATTN_SUB_NORM,
|
||||
LLM_TENSOR_FFN_SUB_NORM,
|
||||
};
|
||||
|
||||
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
|
||||
@ -1108,6 +1112,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_BITNET,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
@ -1985,6 +2007,8 @@ struct llama_layer {
|
||||
struct ggml_tensor * attn_out_norm_b;
|
||||
struct ggml_tensor * attn_q_a_norm;
|
||||
struct ggml_tensor * attn_kv_a_norm;
|
||||
struct ggml_tensor * attn_sub_norm;
|
||||
struct ggml_tensor * ffn_sub_norm;
|
||||
|
||||
// attention
|
||||
struct ggml_tensor * wq;
|
||||
@ -4499,6 +4523,15 @@ static void llm_load_hparams(
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_BITNET:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 26: model.type = e_model::MODEL_3B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: (void)0;
|
||||
}
|
||||
|
||||
@ -6409,6 +6442,40 @@ static bool llm_load_tensors(
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_BITNET:
|
||||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
|
||||
// output
|
||||
{
|
||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
}
|
||||
|
||||
const uint32_t n_ff = hparams.n_ff;
|
||||
model.layers.resize(n_layer);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
|
||||
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
|
||||
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff});
|
||||
|
||||
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
@ -6761,6 +6828,15 @@ static struct ggml_tensor * llm_build_norm(
|
||||
return cur;
|
||||
}
|
||||
|
||||
static struct ggml_tensor * llm_build_qbitlinear(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * cur)
|
||||
{
|
||||
return ggml_bitlinear_quant(ctx, cur);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
static struct ggml_tensor * llm_build_ffn(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * cur,
|
||||
@ -6963,6 +7039,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
struct ggml_tensor * wo_b,
|
||||
struct ggml_tensor * q_cur,
|
||||
struct ggml_tensor * kq_mask,
|
||||
struct ggml_tensor * attn_sub_norm,
|
||||
int32_t n_tokens,
|
||||
int32_t n_kv,
|
||||
float kq_scale,
|
||||
@ -7057,6 +7134,17 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
cb(cur, "kqv_merged_cont", il);
|
||||
}
|
||||
|
||||
if (model.arch == LLM_ARCH_BITNET)
|
||||
{
|
||||
cur = llm_build_norm(ctx, cur, hparams,
|
||||
attn_sub_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_sub_norm", il);
|
||||
|
||||
// B2 for wo
|
||||
cur = llm_build_qbitlinear(ctx, cur);
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(graph, cur);
|
||||
|
||||
cur = ggml_mul_mat(ctx, wo, cur);
|
||||
@ -7102,7 +7190,7 @@ static struct ggml_tensor * llm_build_kv(
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b,
|
||||
q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
|
||||
q_cur, kq_mask, nullptr, n_tokens, n_kv, kq_scale, cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
return cur;
|
||||
@ -11448,6 +11536,159 @@ struct llm_build_context {
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_bitnet() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
// B1.Q
|
||||
cur = llm_build_qbitlinear(ctx0, cur);
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
}
|
||||
|
||||
// B1.K
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
|
||||
// B1.V
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_rope_custom(
|
||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
|
||||
n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_rope_custom(
|
||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
|
||||
n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
|
||||
cur = llm_build_kqv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, KQ_mask, model.layers[il].attn_sub_norm, n_tokens, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il
|
||||
);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward forward
|
||||
if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// cur = llm_build_ffn(ctx0, cur,
|
||||
// model.layers[il].ffn_up, NULL,
|
||||
// model.layers[il].ffn_gate, NULL,
|
||||
// model.layers[il].ffn_down, NULL,
|
||||
// NULL,
|
||||
// LLM_FFN_SILU, LLM_FFN_PAR, cb, il, hparams, model.layers[il].ffn_sub_norm, isbitnet);
|
||||
// cb(cur, "ffn_out", il);
|
||||
|
||||
|
||||
cur = llm_build_qbitlinear(ctx0, cur);
|
||||
|
||||
struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur);
|
||||
|
||||
cb(tmp, "ffn_up", il);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur);
|
||||
|
||||
cb(cur, "ffn_gate", il);
|
||||
|
||||
|
||||
cur = ggml_silu(ctx0, cur);
|
||||
cb(cur, "ffn_silu", il);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, tmp);
|
||||
cb(cur, "ffn_gate_par", il);
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.layers[il].ffn_sub_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_sub_norm", il);
|
||||
|
||||
// B4 for w2
|
||||
cur = llm_build_qbitlinear(ctx0, cur);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur);
|
||||
cb(cur, "ffn_down", il);
|
||||
|
||||
}
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head
|
||||
cur = ggml_mul_mat(ctx0, model.tok_embd, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
||||
@ -11670,6 +11911,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
{
|
||||
result = llm.build_deepseek2();
|
||||
} break;
|
||||
case LLM_ARCH_BITNET:
|
||||
{
|
||||
result = llm.build_bitnet();
|
||||
} break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
@ -16677,6 +16922,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_BERT:
|
||||
case LLM_ARCH_NOMIC_BERT:
|
||||
case LLM_ARCH_STABLELM:
|
||||
case LLM_ARCH_BITNET:
|
||||
case LLM_ARCH_QWEN:
|
||||
case LLM_ARCH_QWEN2:
|
||||
case LLM_ARCH_QWEN2MOE:
|
||||
|
482
tokenization_bitnet.py
Normal file
482
tokenization_bitnet.py
Normal file
@ -0,0 +1,482 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tokenization classes for LLaMA."""
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from transformers.convert_slow_tokenizer import import_protobuf
|
||||
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
|
||||
},
|
||||
"tokenizer_file": {
|
||||
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
|
||||
},
|
||||
}
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"hf-internal-testing/llama-tokenizer": 2048,
|
||||
}
|
||||
SPIECE_UNDERLINE = "▁"
|
||||
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
|
||||
# fmt: off
|
||||
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
|
||||
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
|
||||
that your responses are socially unbiased and positive in nature.
|
||||
|
||||
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
||||
correct. If you don't know the answer to a question, please don't share false information."""
|
||||
# fmt: on
|
||||
|
||||
|
||||
class BitnetTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Construct a Bitnet tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
|
||||
no padding token in the original model.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
Path to the vocabulary file.
|
||||
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
|
||||
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
||||
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
|
||||
The end of sequence token.
|
||||
pad_token (`str` or `tokenizers.AddedToken`, *optional*):
|
||||
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
|
||||
attention mechanisms or loss computation.
|
||||
sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
|
||||
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
||||
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
||||
to set:
|
||||
|
||||
- `enable_sampling`: Enable subword regularization.
|
||||
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
|
||||
|
||||
- `nbest_size = {0,1}`: No sampling is performed.
|
||||
- `nbest_size > 1`: samples from the nbest_size results.
|
||||
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
|
||||
using forward-filtering-and-backward-sampling algorithm.
|
||||
|
||||
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
|
||||
BPE-dropout.
|
||||
|
||||
add_bos_token (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to add an `bos_token` at the start of sequences.
|
||||
add_eos_token (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to add an `eos_token` at the end of sequences.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
|
||||
extra spaces.
|
||||
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the default system prompt for Bitnet should be used.
|
||||
spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to add spaces between special tokens.
|
||||
legacy (`bool`, *optional*):
|
||||
Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622
|
||||
and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple
|
||||
example:
|
||||
|
||||
- `legacy=True`:
|
||||
```python
|
||||
>>> from transformers import T5Tokenizer
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=True)
|
||||
>>> tokenizer.encode("Hello <extra_id_0>.")
|
||||
[8774, 32099, 3, 5, 1]
|
||||
```
|
||||
- `legacy=False`:
|
||||
```python
|
||||
>>> from transformers import T5Tokenizer
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False)
|
||||
>>> tokenizer.encode("Hello <extra_id_0>.") # the extra space `[3]` is no longer here
|
||||
[8774, 32099, 5, 1]
|
||||
```
|
||||
Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details.
|
||||
add_prefix_space (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
||||
other word.
|
||||
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
unk_token="<unk>",
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
pad_token=None,
|
||||
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
add_bos_token=True,
|
||||
add_eos_token=False,
|
||||
clean_up_tokenization_spaces=False,
|
||||
use_default_system_prompt=False,
|
||||
spaces_between_special_tokens=False,
|
||||
legacy=None,
|
||||
add_prefix_space=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
||||
bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
|
||||
eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
|
||||
unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
|
||||
pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
|
||||
|
||||
if legacy is None:
|
||||
logger.warning_once(
|
||||
f"You are using the default legacy behaviour of the {self.__class__}. This is"
|
||||
" expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
|
||||
" If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
|
||||
" means, and thoroughly read the reason why this was added as explained in"
|
||||
" https://github.com/huggingface/transformers/pull/24565"
|
||||
)
|
||||
legacy = True
|
||||
|
||||
self.legacy = legacy
|
||||
self.vocab_file = vocab_file
|
||||
self.add_bos_token = add_bos_token
|
||||
self.add_eos_token = add_eos_token
|
||||
self.use_default_system_prompt = use_default_system_prompt
|
||||
self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
|
||||
self.add_prefix_space = add_prefix_space
|
||||
|
||||
super().__init__(
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
add_bos_token=add_bos_token,
|
||||
add_eos_token=add_eos_token,
|
||||
sp_model_kwargs=self.sp_model_kwargs,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
use_default_system_prompt=use_default_system_prompt,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
legacy=legacy,
|
||||
add_prefix_space=add_prefix_space,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def unk_token_length(self):
|
||||
return len(self.sp_model.encode(str(self.unk_token)))
|
||||
|
||||
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
|
||||
def get_spm_processor(self, from_slow=False):
|
||||
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||
if self.legacy or from_slow: # no dependency on protobuf
|
||||
tokenizer.Load(self.vocab_file)
|
||||
return tokenizer
|
||||
|
||||
with open(self.vocab_file, "rb") as f:
|
||||
sp_model = f.read()
|
||||
model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
|
||||
model = model_pb2.ModelProto.FromString(sp_model)
|
||||
normalizer_spec = model_pb2.NormalizerSpec()
|
||||
normalizer_spec.add_dummy_prefix = False
|
||||
model.normalizer_spec.MergeFrom(normalizer_spec)
|
||||
sp_model = model.SerializeToString()
|
||||
tokenizer.LoadFromSerializedProto(sp_model)
|
||||
return tokenizer
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["sp_model"] = None
|
||||
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
|
||||
return state
|
||||
|
||||
def __setstate__(self, d):
|
||||
self.__dict__ = d
|
||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
"""Returns vocab size"""
|
||||
return self.sp_model.get_piece_size()
|
||||
|
||||
def get_vocab(self):
|
||||
"""Returns vocab as a dict"""
|
||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
|
||||
def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
|
||||
"""
|
||||
Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
|
||||
first token is special.
|
||||
"""
|
||||
if self.legacy or len(text) == 0:
|
||||
return super().tokenize(text, **kwargs)
|
||||
|
||||
text = text.replace(SPIECE_UNDERLINE, " ")
|
||||
if self.add_prefix_space:
|
||||
text = SPIECE_UNDERLINE + text
|
||||
|
||||
tokens = super().tokenize(text, **kwargs)
|
||||
|
||||
if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
|
||||
tokens = tokens[1:]
|
||||
return tokens
|
||||
|
||||
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
|
||||
def _tokenize(self, text, **kwargs):
|
||||
"""
|
||||
Returns a tokenized string.
|
||||
|
||||
We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
|
||||
SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
|
||||
`['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
|
||||
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
|
||||
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
|
||||
"""
|
||||
tokens = self.sp_model.encode(text, out_type=str)
|
||||
if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
|
||||
return tokens
|
||||
|
||||
# 1. Encode string + prefix ex: "<unk> Hey"
|
||||
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
|
||||
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
|
||||
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
return self.sp_model.piece_to_id(token)
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
token = self.sp_model.IdToPiece(index)
|
||||
return token
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
# since we manually add the prefix space, we have to remove it when decoding
|
||||
if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space:
|
||||
tokens[0] = tokens[0][1:]
|
||||
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
prev_is_special = False
|
||||
for i, token in enumerate(tokens):
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
if not prev_is_special and i != 0 and self.legacy:
|
||||
out_string += " "
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
prev_is_special = True
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
prev_is_special = False
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
"""
|
||||
Save the vocabulary and special tokens file to a directory.
|
||||
|
||||
Args:
|
||||
save_directory (`str`):
|
||||
The directory in which to save the vocabulary.
|
||||
|
||||
Returns:
|
||||
`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
out_vocab_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
||||
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
||||
|
||||
output = bos_token_id + token_ids_0 + eos_token_id
|
||||
|
||||
if token_ids_1 is not None:
|
||||
output = output + bos_token_id + token_ids_1 + eos_token_id
|
||||
|
||||
return output
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer `prepare_for_model` method.
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the token list is already formatted with special tokens for the model.
|
||||
|
||||
Returns:
|
||||
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
if already_has_special_tokens:
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
||||
)
|
||||
|
||||
bos_token_id = [1] if self.add_bos_token else []
|
||||
eos_token_id = [1] if self.add_eos_token else []
|
||||
|
||||
if token_ids_1 is None:
|
||||
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
|
||||
return (
|
||||
bos_token_id
|
||||
+ ([0] * len(token_ids_0))
|
||||
+ eos_token_id
|
||||
+ bos_token_id
|
||||
+ ([0] * len(token_ids_1))
|
||||
+ eos_token_id
|
||||
)
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
|
||||
sequence pair mask has the following format:
|
||||
|
||||
```
|
||||
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
||||
| first sequence | second sequence |
|
||||
```
|
||||
|
||||
if token_ids_1 is None, only returns the first portion of the mask (0s).
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of ids.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
||||
"""
|
||||
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
||||
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
||||
|
||||
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
|
||||
|
||||
if token_ids_1 is not None:
|
||||
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
|
||||
|
||||
return output
|
||||
|
||||
@property
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
|
||||
Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
|
||||
user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
|
||||
rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
|
||||
results in an unusual token ordering when it is present. This template should definitely be changed if you wish
|
||||
to fine-tune a model with more flexible role ordering!
|
||||
|
||||
The output should look something like:
|
||||
|
||||
<bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos><bos>[INST] Prompt [/INST] Answer <eos>
|
||||
<bos>[INST] Prompt [/INST]
|
||||
|
||||
The reference for this chat template is [this code
|
||||
snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362)
|
||||
in the original repository.
|
||||
"""
|
||||
logger.warning_once(
|
||||
"\nNo chat template is defined for this tokenizer - using the default template "
|
||||
f"for the {self.__class__.__name__} class. If the default is not appropriate for "
|
||||
"your model, please set `tokenizer.chat_template` to an appropriate template. "
|
||||
"See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
|
||||
)
|
||||
template = (
|
||||
"{% if messages[0]['role'] == 'system' %}"
|
||||
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
|
||||
"{% set system_message = messages[0]['content'] %}"
|
||||
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
|
||||
"{% set loop_messages = messages %}" # Or use the default system message if the flag is set
|
||||
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
|
||||
"{% else %}"
|
||||
"{% set loop_messages = messages %}"
|
||||
"{% set system_message = false %}"
|
||||
"{% endif %}"
|
||||
"{% for message in loop_messages %}" # Loop over all non-system messages
|
||||
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
|
||||
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
|
||||
"{% endif %}"
|
||||
"{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
|
||||
"{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}"
|
||||
"{% else %}"
|
||||
"{% set content = message['content'] %}"
|
||||
"{% endif %}"
|
||||
"{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
|
||||
"{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
|
||||
"{% elif message['role'] == 'system' %}"
|
||||
"{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}"
|
||||
"{% elif message['role'] == 'assistant' %}"
|
||||
"{{ ' ' + content.strip() + ' ' + eos_token }}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
|
||||
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
|
||||
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
|
||||
|
||||
return template
|
Loading…
Reference in New Issue
Block a user