only rm when params.escape, rm space if possible which is added back or rm added space token

This commit is contained in:
vvhg1 2023-10-07 09:08:30 +02:00
parent b4046aabbf
commit 0526560759
2 changed files with 13 additions and 4 deletions

View File

@ -233,12 +233,16 @@ int main(int argc, char ** argv) {
const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM; const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM;
LOG("add_bos: %d\n", add_bos); LOG("add_bos: %d\n", add_bos);
bool suff_rm_leading_spc = params.escape;
if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
std::vector<llama_token> embd_inp; std::vector<llama_token> embd_inp;
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false); std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false);
// params.input_suffix.erase(0, params.input_suffix.find_first_not_of(" "));
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false); std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false);
const int space_token = 29871; const int space_token = 29871;
if (params.escape && inp_sfx.size() > 1 && inp_sfx[0] == space_token) { if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
inp_sfx.erase(inp_sfx.begin()); inp_sfx.erase(inp_sfx.begin());
} }
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx)); inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));

View File

@ -344,11 +344,16 @@ struct llama_server_context
void loadInfill() void loadInfill()
{ {
params.input_suffix.erase(0, params.input_suffix.find_first_not_of(" ")); bool suff_rm_leading_spc = params.escape;
if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
auto prefix_tokens = tokenize(params.input_prefix, false); auto prefix_tokens = tokenize(params.input_prefix, false);
auto suffix_tokens = tokenize(params.input_suffix, false); auto suffix_tokens = tokenize(params.input_suffix, false);
const int space_token = 29871; const int space_token = 29871;
if (params.escape && suffix_tokens.size() > 1 && suffix_tokens[0] == space_token) { if (suff_rm_leading_spc && suffix_tokens[0] == space_token) {
suffix_tokens.erase(suffix_tokens.begin()); suffix_tokens.erase(suffix_tokens.begin());
} }
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx)); prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx));