main : make reverse prompt option act as a stop token in non-interactive mode (#1032)

* Make reverse prompt option act as a stop token in non-interactive scenarios

* Making requested review changes

* Update gpt_params_parse and fix a merge error

* Revert "Update gpt_params_parse and fix a merge error"

This reverts commit 2bb2ff1748513591ad45b175a75ed1d8089d84c8.

* Update gpt_params_parse and fix a merge error take 2
This commit is contained in:
Jason McCartney 2023-05-19 10:24:59 -07:00 committed by GitHub
parent 79e3efb0e9
commit 7694b52b9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 11 deletions

View File

@ -351,7 +351,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} }
if (params.prompt_cache_all && if (params.prompt_cache_all &&
(params.interactive || params.interactive_first || (params.interactive || params.interactive_first ||
params.instruct || params.antiprompt.size())) { params.instruct)) {
fprintf(stderr, "error: --prompt-cache-all not supported in interactive mode yet\n"); fprintf(stderr, "error: --prompt-cache-all not supported in interactive mode yet\n");
gpt_print_usage(argc, argv, default_params); gpt_print_usage(argc, argv, default_params);
exit(1); exit(1);
@ -373,8 +373,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n"); fprintf(stderr, " halt generation at PROMPT, return control in interactive mode\n");
fprintf(stderr, " specified more than once for multiple prompts).\n"); fprintf(stderr, " (can be specified more than once for multiple prompts).\n");
fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n"); fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);

View File

@ -208,8 +208,8 @@ int main(int argc, char ** argv) {
params.antiprompt.push_back("### Instruction:\n\n"); params.antiprompt.push_back("### Instruction:\n\n");
} }
// enable interactive mode if reverse prompt or interactive start is specified // enable interactive mode if interactive start is specified
if (params.antiprompt.size() != 0 || params.interactive_first) { if (params.interactive_first) {
params.interactive = true; params.interactive = true;
} }
@ -305,7 +305,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd; std::vector<llama_token> embd;
while (n_remain != 0 || params.interactive) { while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict // predict
if (embd.size() > 0) { if (embd.size() > 0) {
// infinite text generation via context swapping // infinite text generation via context swapping
@ -503,9 +503,8 @@ int main(int argc, char ** argv) {
console_set_color(con_st, CONSOLE_COLOR_DEFAULT); console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
} }
// in interactive mode, and not currently processing queued inputs; // if not currently processing queued inputs;
// check if we should prompt the user for more if ((int) embd_inp.size() <= n_consumed) {
if (params.interactive && (int) embd_inp.size() <= n_consumed) {
// check for reverse prompt // check for reverse prompt
if (params.antiprompt.size()) { if (params.antiprompt.size()) {
@ -516,10 +515,21 @@ int main(int argc, char ** argv) {
is_antiprompt = false; is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output. // Check if each of the reverse prompts appears at the end of the output.
// If we're not running interactively, the reverse prompt might be tokenized with some following characters
// so we'll compensate for that by widening the search window a bit.
for (std::string & antiprompt : params.antiprompt) { for (std::string & antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { size_t extra_padding = params.interactive ? 0 : 2;
is_interacting = true; size_t search_start_pos = last_output.length() > static_cast<size_t>(antiprompt.length() + extra_padding)
? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding)
: 0;
if (last_output.find(antiprompt.c_str(), search_start_pos) != std::string::npos) {
if (params.interactive) {
is_interacting = true;
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
}
is_antiprompt = true; is_antiprompt = true;
fflush(stdout);
break; break;
} }
} }