llama : support NUL bytes in tokens

This commit is contained in:
Francis Couture-Harpin 2024-08-11 21:00:03 -04:00
parent 4134999e01
commit faaac59d16
7 changed files with 28 additions and 18 deletions

View File

@ -2224,9 +2224,8 @@ class InternLM2Model(Model):
def set_vocab(self): def set_vocab(self):
# (TODO): Is there a better way? # (TODO): Is there a better way?
# Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character # Copy from _set_vocab_sentencepiece, The only difference is that we find mislabeled UNUSED tokens,
# \x00 specially and convert it into an emoji character to prevent it from being mistakenly # and that we set '<|im_end|>' as the eos token for chat models.
# recognized as an empty string in C++.
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
from sentencepiece import sentencepiece_model_pb2 as model from sentencepiece import sentencepiece_model_pb2 as model
@ -2253,11 +2252,6 @@ class InternLM2Model(Model):
piece = tokenizer.IdToPiece(token_id) piece = tokenizer.IdToPiece(token_id)
text = piece.encode("utf-8") text = piece.encode("utf-8")
score = tokenizer.GetScore(token_id) score = tokenizer.GetScore(token_id)
if text == b"\x00":
# (TODO): fixme
# Hack here and replace the \x00 characters.
logger.warning(f"InternLM2 convert token '{text}' to '🐉'!")
text = "🐉".encode("utf-8")
toktype = SentencePieceTokenTypes.NORMAL toktype = SentencePieceTokenTypes.NORMAL
if tokenizer.IsUnknown(token_id): if tokenizer.IsUnknown(token_id):

View File

@ -561,7 +561,7 @@ static void load_vocab(const char * filename, const Config * config, struct llam
vocab->id_to_token.resize(n_vocab); vocab->id_to_token.resize(n_vocab);
for (uint32_t i = 0; i < n_vocab; i++) { for (uint32_t i = 0; i < n_vocab; i++) {
std::string word = gguf_get_arr_str(ctx, token_idx, i); std::string word(gguf_get_arr_str(ctx, token_idx, i), gguf_get_arr_str_n(ctx, token_idx, i));
vocab->token_to_id[word] = i; vocab->token_to_id[word] = i;

View File

@ -12,7 +12,7 @@ static bool g_verbose = false;
static std::string get_kv_str(struct gguf_context * ctx_gguf, const std::string & key){ static std::string get_kv_str(struct gguf_context * ctx_gguf, const std::string & key){
int id = gguf_find_key(ctx_gguf, key.c_str()); int id = gguf_find_key(ctx_gguf, key.c_str());
return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf, id)); return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf, id), gguf_get_val_str_n(ctx_gguf, id));
} }
static float get_kv_f32(struct gguf_context * ctx_gguf, const std::string & key) { static float get_kv_f32(struct gguf_context * ctx_gguf, const std::string & key) {

View File

@ -225,7 +225,7 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
switch (type) { switch (type) {
case GGUF_TYPE_STRING: case GGUF_TYPE_STRING:
return gguf_get_val_str(ctx_gguf, i); return std::string(gguf_get_val_str(ctx_gguf, i), gguf_get_val_str_n(ctx_gguf, i));
case GGUF_TYPE_ARRAY: case GGUF_TYPE_ARRAY:
{ {
const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
@ -235,7 +235,7 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
ss << "["; ss << "[";
for (int j = 0; j < arr_n; j++) { for (int j = 0; j < arr_n; j++) {
if (arr_type == GGUF_TYPE_STRING) { if (arr_type == GGUF_TYPE_STRING) {
std::string val = gguf_get_arr_str(ctx_gguf, i, j); std::string val(gguf_get_arr_str(ctx_gguf, i, j), gguf_get_arr_str_n(ctx_gguf, i, j));
// escape quotes // escape quotes
replace_all(val, "\\", "\\\\"); replace_all(val, "\\", "\\\\");
replace_all(val, "\"", "\\\""); replace_all(val, "\"", "\\\"");

View File

@ -2313,10 +2313,12 @@ extern "C" {
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id); GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id); GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id); GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
GGML_API int gguf_get_val_str_n(const struct gguf_context * ctx, int key_id);
GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id); GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id); GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id); GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i); GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
GGML_API int gguf_get_arr_str_n(const struct gguf_context * ctx, int key_id, int i);
GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx); GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx);
GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name); GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name);

View File

@ -21335,6 +21335,14 @@ const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i
return str->data; return str->data;
} }
int gguf_get_arr_str_n(const struct gguf_context * ctx, int key_id, int i) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
struct gguf_kv * kv = &ctx->kv[key_id];
struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
return str->n;
}
int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) { int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
@ -21413,6 +21421,12 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
return ctx->kv[key_id].value.str.data; return ctx->kv[key_id].value.str.data;
} }
int gguf_get_val_str_n(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
return ctx->kv[key_id].value.str.n;
}
const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) { const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY); GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);

View File

@ -1406,7 +1406,7 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
switch (type) { switch (type) {
case GGUF_TYPE_STRING: case GGUF_TYPE_STRING:
return gguf_get_val_str(ctx_gguf, i); return std::string(gguf_get_val_str(ctx_gguf, i), gguf_get_val_str_n(ctx_gguf, i));
case GGUF_TYPE_ARRAY: case GGUF_TYPE_ARRAY:
{ {
const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
@ -1416,7 +1416,7 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
ss << "["; ss << "[";
for (int j = 0; j < arr_n; j++) { for (int j = 0; j < arr_n; j++) {
if (arr_type == GGUF_TYPE_STRING) { if (arr_type == GGUF_TYPE_STRING) {
std::string val = gguf_get_arr_str(ctx_gguf, i, j); std::string val(gguf_get_arr_str(ctx_gguf, i, j), gguf_get_arr_str_n(ctx_gguf, i, j));
// escape quotes // escape quotes
replace_all(val, "\\", "\\\\"); replace_all(val, "\\", "\\\\");
replace_all(val, "\"", "\\\""); replace_all(val, "\"", "\\\"");
@ -3436,7 +3436,7 @@ namespace GGUFMeta {
static constexpr gguf_type gt = GGUF_TYPE_STRING; static constexpr gguf_type gt = GGUF_TYPE_STRING;
static std::string getter(const gguf_context * ctx, const int kid) { static std::string getter(const gguf_context * ctx, const int kid) {
return gguf_get_val_str(ctx, kid); return std::string(gguf_get_val_str(ctx, kid), gguf_get_val_str_n(ctx, kid));
} }
}; };
@ -5316,7 +5316,7 @@ static void llm_load_vocab(
const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
for (int i = 0; i < n_merges; i++) { for (int i = 0; i < n_merges; i++) {
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); const std::string word(gguf_get_arr_str(ctx, merges_keyidx, i), gguf_get_arr_str_n(ctx, merges_keyidx, i));
GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
std::string first; std::string first;
@ -5521,7 +5521,7 @@ static void llm_load_vocab(
vocab.id_to_token.resize(n_vocab); vocab.id_to_token.resize(n_vocab);
for (uint32_t i = 0; i < n_vocab; i++) { for (uint32_t i = 0; i < n_vocab; i++) {
std::string word = gguf_get_arr_str(ctx, token_idx, i); std::string word(gguf_get_arr_str(ctx, token_idx, i), gguf_get_arr_str_n(ctx, token_idx, i));
GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
vocab.token_to_id[word] = i; vocab.token_to_id[word] = i;
@ -16207,7 +16207,7 @@ static void llama_lora_adapter_init_internal(struct llama_model * model, const c
{ {
auto get_kv_str = [&](const std::string & key) -> std::string { auto get_kv_str = [&](const std::string & key) -> std::string {
int id = gguf_find_key(ctx_gguf, key.c_str()); int id = gguf_find_key(ctx_gguf, key.c_str());
return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf, id)); return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf, id), gguf_get_val_str_n(ctx_gguf, id));
}; };
auto get_kv_f32 = [&](const std::string & key) -> float { auto get_kv_f32 = [&](const std::string & key) -> float {
int id = gguf_find_key(ctx_gguf, key.c_str()); int id = gguf_find_key(ctx_gguf, key.c_str());