From 4ad0676927330ccc84c66b8ab7c27ddf18aea43d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 28 Sep 2023 15:48:38 +0300 Subject: [PATCH] parallel : fix crash when `-n -1` --- examples/parallel/parallel.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index c7fb6d81a..790189af9 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -114,7 +114,7 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < clients.size(); ++i) { auto & client = clients[i]; client.id = i; - client.tokens_prev.resize(params.n_predict); + client.tokens_prev.resize(std::max(256, params.n_predict)); std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0); } @@ -321,7 +321,8 @@ int main(int argc, char ** argv) { // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str()); if (client.n_decoded > 2 && - (id == llama_token_eos(ctx) || client.n_decoded + client.n_prompt >= params.n_predict || + (id == llama_token_eos(ctx) || + (params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) || client.response.find("User:") != std::string::npos || client.response.find('\n') != std::string::npos)) { // basic reverse prompt