Fix n^2 loop in tokenization (#254)

This causes long prompts to parse very slowly.
This commit is contained in:
Gary Linscott 2023-03-18 04:17:19 -07:00 committed by GitHub
parent b2de7f18df
commit a81d0c2a17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -302,7 +302,7 @@ std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::st
// Forward pass // Forward pass
for (int i = 0; i < len; i++) { for (int i = 0; i < len; i++) {
int max_len = std::min(len - i, MAX_TOKEN_LEN); int max_len = std::min(len - i, MAX_TOKEN_LEN);
for (int sub_len = 1; sub_len <= len - i; sub_len++) { for (int sub_len = 1; sub_len <= max_len; sub_len++) {
auto sub = text.substr(i, sub_len); auto sub = text.substr(i, sub_len);
auto token = vocab.token_to_id.find(sub); auto token = vocab.token_to_id.find(sub);
if (token != vocab.token_to_id.end()) { if (token != vocab.token_to_id.end()) {