mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-27 04:23:06 +01:00
Enhance user input handling for llama-run (#11138)
The main motivation for this change is it was not handing ctrl-c/ctrl-d correctly. Modify `read_user_input` to handle EOF, "/bye" command, and empty input cases. Introduce `get_user_input` function to manage user input loop and handle different return cases. Signed-off-by: Eric Curtin <ecurtin@redhat.com>
This commit is contained in:
parent
f7cd13301c
commit
1bf839b1e8
@ -11,6 +11,8 @@
|
|||||||
# include <curl/curl.h>
|
# include <curl/curl.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include <signal.h>
|
||||||
|
|
||||||
#include <climits>
|
#include <climits>
|
||||||
#include <cstdarg>
|
#include <cstdarg>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
@ -25,6 +27,13 @@
|
|||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
#include "llama-cpp.h"
|
#include "llama-cpp.h"
|
||||||
|
|
||||||
|
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32)
|
||||||
|
[[noreturn]] static void sigint_handler(int) {
|
||||||
|
printf("\n");
|
||||||
|
exit(0); // not ideal, but it's the only way to guarantee exit in all cases
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
GGML_ATTRIBUTE_FORMAT(1, 2)
|
GGML_ATTRIBUTE_FORMAT(1, 2)
|
||||||
static std::string fmt(const char * fmt, ...) {
|
static std::string fmt(const char * fmt, ...) {
|
||||||
va_list ap;
|
va_list ap;
|
||||||
@ -801,7 +810,20 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
|
|||||||
|
|
||||||
static int read_user_input(std::string & user) {
|
static int read_user_input(std::string & user) {
|
||||||
std::getline(std::cin, user);
|
std::getline(std::cin, user);
|
||||||
return user.empty(); // Should have data in happy path
|
if (std::cin.eof()) {
|
||||||
|
printf("\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (user == "/bye") {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (user.empty()) {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0; // Should have data in happy path
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to generate a response based on the prompt
|
// Function to generate a response based on the prompt
|
||||||
@ -868,7 +890,25 @@ static bool is_stdout_a_terminal() {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to tokenize the prompt
|
// Function to handle user input
|
||||||
|
static int get_user_input(std::string & user_input, const std::string & user) {
|
||||||
|
while (true) {
|
||||||
|
const int ret = handle_user_input(user_input, user);
|
||||||
|
if (ret == 1) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ret == 2) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main chat loop function
|
||||||
static int chat_loop(LlamaData & llama_data, const std::string & user) {
|
static int chat_loop(LlamaData & llama_data, const std::string & user) {
|
||||||
int prev_len = 0;
|
int prev_len = 0;
|
||||||
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
||||||
@ -876,7 +916,8 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
|
|||||||
while (true) {
|
while (true) {
|
||||||
// Get user input
|
// Get user input
|
||||||
std::string user_input;
|
std::string user_input;
|
||||||
while (handle_user_input(user_input, user)) {
|
if (get_user_input(user_input, user) == 1) {
|
||||||
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
add_message("user", user.empty() ? user_input : user, llama_data);
|
add_message("user", user.empty() ? user_input : user, llama_data);
|
||||||
@ -917,7 +958,23 @@ static std::string read_pipe_data() {
|
|||||||
return result.str();
|
return result.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ctrl_c_handling() {
|
||||||
|
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
|
||||||
|
struct sigaction sigint_action;
|
||||||
|
sigint_action.sa_handler = sigint_handler;
|
||||||
|
sigemptyset(&sigint_action.sa_mask);
|
||||||
|
sigint_action.sa_flags = 0;
|
||||||
|
sigaction(SIGINT, &sigint_action, NULL);
|
||||||
|
#elif defined(_WIN32)
|
||||||
|
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
||||||
|
return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
|
||||||
|
};
|
||||||
|
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, const char ** argv) {
|
int main(int argc, const char ** argv) {
|
||||||
|
ctrl_c_handling();
|
||||||
Opt opt;
|
Opt opt;
|
||||||
const int ret = opt.init(argc, argv);
|
const int ret = opt.init(argc, argv);
|
||||||
if (ret == 2) {
|
if (ret == 2) {
|
||||||
|
Loading…
Reference in New Issue
Block a user