From 0526560759f858c0d49721be29f1c8b7274210b8 Mon Sep 17 00:00:00 2001 From: vvhg1 Date: Sat, 7 Oct 2023 09:08:30 +0200 Subject: [PATCH] only rm when params.escape, rm space if possible which is added back or rm added space token --- examples/infill/infill.cpp | 8 ++++++-- examples/server/server.cpp | 9 +++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 4288ffed7..47ece81c6 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -233,12 +233,16 @@ int main(int argc, char ** argv) { const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM; LOG("add_bos: %d\n", add_bos); + bool suff_rm_leading_spc = params.escape; + if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { + params.input_suffix.erase(0, 1); + suff_rm_leading_spc = false; + } std::vector embd_inp; std::vector inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false); - // params.input_suffix.erase(0, params.input_suffix.find_first_not_of(" ")); std::vector inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false); const int space_token = 29871; - if (params.escape && inp_sfx.size() > 1 && inp_sfx[0] == space_token) { + if (suff_rm_leading_spc && inp_sfx[0] == space_token) { inp_sfx.erase(inp_sfx.begin()); } inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx)); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 53c0fb800..667f5db71 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -344,11 +344,16 @@ struct llama_server_context void loadInfill() { - params.input_suffix.erase(0, params.input_suffix.find_first_not_of(" ")); + bool suff_rm_leading_spc = params.escape; + if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { + params.input_suffix.erase(0, 1); + suff_rm_leading_spc = false; + } + auto prefix_tokens = tokenize(params.input_prefix, false); auto suffix_tokens = tokenize(params.input_suffix, false); const int space_token = 29871; - if (params.escape && suffix_tokens.size() > 1 && suffix_tokens[0] == space_token) { + if (suff_rm_leading_spc && suffix_tokens[0] == space_token) { suffix_tokens.erase(suffix_tokens.begin()); } prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx));