llama : fix detokenization of non-special added-tokens (#4916)

Co-authored-by: goerch <jhr.walter@t-online.de>
This commit is contained in:
Georgi Gerganov 2024-01-13 18:47:38 +02:00 committed by GitHub
parent 2d57de5255
commit f172de03f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10305,6 +10305,8 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
if (0 <= token && token < llama_n_vocab(model)) { if (0 <= token && token < llama_n_vocab(model)) {
switch (llama_vocab_get_type(model->vocab)) { switch (llama_vocab_get_type(model->vocab)) {
case LLAMA_VOCAB_TYPE_SPM: { case LLAMA_VOCAB_TYPE_SPM: {
// NOTE: we accept all unsupported token types,
// suppressing them like CONTROL tokens.
if (llama_is_normal_token(model->vocab, token)) { if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text; std::string result = model->vocab.id_to_token[token].text;
llama_unescape_whitespace(result); llama_unescape_whitespace(result);
@ -10313,6 +10315,13 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
} }
memcpy(buf, result.c_str(), result.length()); memcpy(buf, result.c_str(), result.length());
return result.length(); return result.length();
} else if (llama_is_user_defined_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text;
if (length < (int) result.length()) {
return -result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
if (length < 3) { if (length < 3) {
return -3; return -3;
@ -10327,14 +10336,12 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
} }
buf[0] = llama_token_to_byte(model->vocab, token); buf[0] = llama_token_to_byte(model->vocab, token);
return 1; return 1;
} else {
// TODO: for now we accept all unsupported token types,
// suppressing them like CONTROL tokens.
// GGML_ASSERT(false);
} }
break; break;
} }
case LLAMA_VOCAB_TYPE_BPE: { case LLAMA_VOCAB_TYPE_BPE: {
// NOTE: we accept all unsupported token types,
// suppressing them like CONTROL tokens.
if (llama_is_normal_token(model->vocab, token)) { if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text; std::string result = model->vocab.id_to_token[token].text;
result = llama_decode_text(result); result = llama_decode_text(result);
@ -10343,12 +10350,15 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
} }
memcpy(buf, result.c_str(), result.length()); memcpy(buf, result.c_str(), result.length());
return result.length(); return result.length();
} else if (llama_is_user_defined_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text;
if (length < (int) result.length()) {
return -result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (llama_is_control_token(model->vocab, token)) { } else if (llama_is_control_token(model->vocab, token)) {
; ;
} else {
// TODO: for now we accept all unsupported token types,
// suppressing them like CONTROL tokens.
// GGML_ASSERT(false);
} }
break; break;
} }