mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
llama-run : include temperature option (#10899)
This commit updates the `examples/run/README.md` file to include a new option for setting the temperature and updates the `run.cpp` file to parse this option. Signed-off-by: Eric Curtin <ecurtin@redhat.com>
This commit is contained in:
parent
7024d59e6a
commit
dab76c92cc
@ -19,6 +19,8 @@ Options:
|
|||||||
Context size (default: 2048)
|
Context size (default: 2048)
|
||||||
-n, --ngl <value>
|
-n, --ngl <value>
|
||||||
Number of GPU layers (default: 0)
|
Number of GPU layers (default: 0)
|
||||||
|
--temp <value>
|
||||||
|
Temperature (default: 0.8)
|
||||||
-v, --verbose, --log-verbose
|
-v, --verbose, --log-verbose
|
||||||
Set verbosity level to infinity (i.e. log all messages, useful for debugging)
|
Set verbosity level to infinity (i.e. log all messages, useful for debugging)
|
||||||
-h, --help
|
-h, --help
|
||||||
|
@ -55,29 +55,51 @@ static int printe(const char * fmt, ...) {
|
|||||||
class Opt {
|
class Opt {
|
||||||
public:
|
public:
|
||||||
int init(int argc, const char ** argv) {
|
int init(int argc, const char ** argv) {
|
||||||
|
ctx_params = llama_context_default_params();
|
||||||
|
model_params = llama_model_default_params();
|
||||||
|
context_size_default = ctx_params.n_batch;
|
||||||
|
ngl_default = model_params.n_gpu_layers;
|
||||||
|
common_params_sampling sampling;
|
||||||
|
temperature_default = sampling.temp;
|
||||||
|
|
||||||
|
if (argc < 2) {
|
||||||
|
printe("Error: No arguments provided.\n");
|
||||||
|
print_help();
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Parse arguments
|
// Parse arguments
|
||||||
if (parse(argc, argv)) {
|
if (parse(argc, argv)) {
|
||||||
printe("Error: Failed to parse arguments.\n");
|
printe("Error: Failed to parse arguments.\n");
|
||||||
help();
|
print_help();
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If help is requested, show help and exit
|
// If help is requested, show help and exit
|
||||||
if (help_) {
|
if (help) {
|
||||||
help();
|
print_help();
|
||||||
return 2;
|
return 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx_params.n_batch = context_size >= 0 ? context_size : context_size_default;
|
||||||
|
model_params.n_gpu_layers = ngl >= 0 ? ngl : ngl_default;
|
||||||
|
temperature = temperature >= 0 ? temperature : temperature_default;
|
||||||
|
|
||||||
return 0; // Success
|
return 0; // Success
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_context_params ctx_params;
|
||||||
|
llama_model_params model_params;
|
||||||
std::string model_;
|
std::string model_;
|
||||||
std::string user_;
|
std::string user;
|
||||||
int context_size_ = -1, ngl_ = -1;
|
int context_size = -1, ngl = -1;
|
||||||
bool verbose_ = false;
|
float temperature = -1;
|
||||||
|
bool verbose = false;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool help_ = false;
|
int context_size_default = -1, ngl_default = -1;
|
||||||
|
float temperature_default = -1;
|
||||||
|
bool help = false;
|
||||||
|
|
||||||
bool parse_flag(const char ** argv, int i, const char * short_opt, const char * long_opt) {
|
bool parse_flag(const char ** argv, int i, const char * short_opt, const char * long_opt) {
|
||||||
return strcmp(argv[i], short_opt) == 0 || strcmp(argv[i], long_opt) == 0;
|
return strcmp(argv[i], short_opt) == 0 || strcmp(argv[i], long_opt) == 0;
|
||||||
@ -89,6 +111,17 @@ class Opt {
|
|||||||
}
|
}
|
||||||
|
|
||||||
option_value = std::atoi(argv[++i]);
|
option_value = std::atoi(argv[++i]);
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int handle_option_with_value(int argc, const char ** argv, int & i, float & option_value) {
|
||||||
|
if (i + 1 >= argc) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
option_value = std::atof(argv[++i]);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,18 +129,22 @@ class Opt {
|
|||||||
bool options_parsing = true;
|
bool options_parsing = true;
|
||||||
for (int i = 1, positional_args_i = 0; i < argc; ++i) {
|
for (int i = 1, positional_args_i = 0; i < argc; ++i) {
|
||||||
if (options_parsing && (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0)) {
|
if (options_parsing && (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0)) {
|
||||||
if (handle_option_with_value(argc, argv, i, context_size_) == 1) {
|
if (handle_option_with_value(argc, argv, i, context_size) == 1) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
} else if (options_parsing && (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0)) {
|
} else if (options_parsing && (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0)) {
|
||||||
if (handle_option_with_value(argc, argv, i, ngl_) == 1) {
|
if (handle_option_with_value(argc, argv, i, ngl) == 1) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
} else if (options_parsing && strcmp(argv[i], "--temp") == 0) {
|
||||||
|
if (handle_option_with_value(argc, argv, i, temperature) == 1) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
} else if (options_parsing &&
|
} else if (options_parsing &&
|
||||||
(parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
|
(parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
|
||||||
verbose_ = true;
|
verbose = true;
|
||||||
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
|
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
|
||||||
help_ = true;
|
help = true;
|
||||||
return 0;
|
return 0;
|
||||||
} else if (options_parsing && strcmp(argv[i], "--") == 0) {
|
} else if (options_parsing && strcmp(argv[i], "--") == 0) {
|
||||||
options_parsing = false;
|
options_parsing = false;
|
||||||
@ -120,16 +157,16 @@ class Opt {
|
|||||||
model_ = argv[i];
|
model_ = argv[i];
|
||||||
} else if (positional_args_i == 1) {
|
} else if (positional_args_i == 1) {
|
||||||
++positional_args_i;
|
++positional_args_i;
|
||||||
user_ = argv[i];
|
user = argv[i];
|
||||||
} else {
|
} else {
|
||||||
user_ += " " + std::string(argv[i]);
|
user += " " + std::string(argv[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void help() const {
|
void print_help() const {
|
||||||
printf(
|
printf(
|
||||||
"Description:\n"
|
"Description:\n"
|
||||||
" Runs a llm\n"
|
" Runs a llm\n"
|
||||||
@ -142,6 +179,8 @@ class Opt {
|
|||||||
" Context size (default: %d)\n"
|
" Context size (default: %d)\n"
|
||||||
" -n, --ngl <value>\n"
|
" -n, --ngl <value>\n"
|
||||||
" Number of GPU layers (default: %d)\n"
|
" Number of GPU layers (default: %d)\n"
|
||||||
|
" --temp <value>\n"
|
||||||
|
" Temperature (default: %.1f)\n"
|
||||||
" -v, --verbose, --log-verbose\n"
|
" -v, --verbose, --log-verbose\n"
|
||||||
" Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n"
|
" Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n"
|
||||||
" -h, --help\n"
|
" -h, --help\n"
|
||||||
@ -170,7 +209,7 @@ class Opt {
|
|||||||
" llama-run file://some-file3.gguf\n"
|
" llama-run file://some-file3.gguf\n"
|
||||||
" llama-run --ngl 999 some-file4.gguf\n"
|
" llama-run --ngl 999 some-file4.gguf\n"
|
||||||
" llama-run --ngl 999 some-file5.gguf Hello World\n",
|
" llama-run --ngl 999 some-file5.gguf Hello World\n",
|
||||||
llama_context_default_params().n_batch, llama_model_default_params().n_gpu_layers);
|
context_size_default, ngl_default, temperature_default);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -495,12 +534,12 @@ class LlamaData {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
context = initialize_context(model, opt.context_size_);
|
context = initialize_context(model, opt);
|
||||||
if (!context) {
|
if (!context) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
sampler = initialize_sampler();
|
sampler = initialize_sampler(opt);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -619,14 +658,12 @@ class LlamaData {
|
|||||||
// Initializes the model and returns a unique pointer to it
|
// Initializes the model and returns a unique pointer to it
|
||||||
llama_model_ptr initialize_model(Opt & opt) {
|
llama_model_ptr initialize_model(Opt & opt) {
|
||||||
ggml_backend_load_all();
|
ggml_backend_load_all();
|
||||||
llama_model_params model_params = llama_model_default_params();
|
|
||||||
model_params.n_gpu_layers = opt.ngl_ >= 0 ? opt.ngl_ : model_params.n_gpu_layers;
|
|
||||||
resolve_model(opt.model_);
|
resolve_model(opt.model_);
|
||||||
printe(
|
printe(
|
||||||
"\r%*s"
|
"\r%*s"
|
||||||
"\rLoading model",
|
"\rLoading model",
|
||||||
get_terminal_width(), " ");
|
get_terminal_width(), " ");
|
||||||
llama_model_ptr model(llama_load_model_from_file(opt.model_.c_str(), model_params));
|
llama_model_ptr model(llama_load_model_from_file(opt.model_.c_str(), opt.model_params));
|
||||||
if (!model) {
|
if (!model) {
|
||||||
printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str());
|
printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str());
|
||||||
}
|
}
|
||||||
@ -636,10 +673,8 @@ class LlamaData {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initializes the context with the specified parameters
|
// Initializes the context with the specified parameters
|
||||||
llama_context_ptr initialize_context(const llama_model_ptr & model, const int n_ctx) {
|
llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) {
|
||||||
llama_context_params ctx_params = llama_context_default_params();
|
llama_context_ptr context(llama_new_context_with_model(model.get(), opt.ctx_params));
|
||||||
ctx_params.n_ctx = ctx_params.n_batch = n_ctx >= 0 ? n_ctx : ctx_params.n_batch;
|
|
||||||
llama_context_ptr context(llama_new_context_with_model(model.get(), ctx_params));
|
|
||||||
if (!context) {
|
if (!context) {
|
||||||
printe("%s: error: failed to create the llama_context\n", __func__);
|
printe("%s: error: failed to create the llama_context\n", __func__);
|
||||||
}
|
}
|
||||||
@ -648,10 +683,10 @@ class LlamaData {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initializes and configures the sampler
|
// Initializes and configures the sampler
|
||||||
llama_sampler_ptr initialize_sampler() {
|
llama_sampler_ptr initialize_sampler(const Opt & opt) {
|
||||||
llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params()));
|
llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params()));
|
||||||
llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1));
|
llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1));
|
||||||
llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(0.8f));
|
llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(opt.temperature));
|
||||||
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
|
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
|
||||||
|
|
||||||
return sampler;
|
return sampler;
|
||||||
@ -798,9 +833,9 @@ static int apply_chat_template_with_error_handling(LlamaData & llama_data, const
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to handle user input
|
// Helper function to handle user input
|
||||||
static int handle_user_input(std::string & user_input, const std::string & user_) {
|
static int handle_user_input(std::string & user_input, const std::string & user) {
|
||||||
if (!user_.empty()) {
|
if (!user.empty()) {
|
||||||
user_input = user_;
|
user_input = user;
|
||||||
return 0; // No need for interactive input
|
return 0; // No need for interactive input
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -832,17 +867,17 @@ static bool is_stdout_a_terminal() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Function to tokenize the prompt
|
// Function to tokenize the prompt
|
||||||
static int chat_loop(LlamaData & llama_data, const std::string & user_) {
|
static int chat_loop(LlamaData & llama_data, const std::string & user) {
|
||||||
int prev_len = 0;
|
int prev_len = 0;
|
||||||
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
||||||
static const bool stdout_a_terminal = is_stdout_a_terminal();
|
static const bool stdout_a_terminal = is_stdout_a_terminal();
|
||||||
while (true) {
|
while (true) {
|
||||||
// Get user input
|
// Get user input
|
||||||
std::string user_input;
|
std::string user_input;
|
||||||
while (handle_user_input(user_input, user_)) {
|
while (handle_user_input(user_input, user)) {
|
||||||
}
|
}
|
||||||
|
|
||||||
add_message("user", user_.empty() ? user_input : user_, llama_data);
|
add_message("user", user.empty() ? user_input : user, llama_data);
|
||||||
int new_len;
|
int new_len;
|
||||||
if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) {
|
if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) {
|
||||||
return 1;
|
return 1;
|
||||||
@ -854,7 +889,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user_) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!user_.empty()) {
|
if (!user.empty()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -869,7 +904,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user_) {
|
|||||||
|
|
||||||
static void log_callback(const enum ggml_log_level level, const char * text, void * p) {
|
static void log_callback(const enum ggml_log_level level, const char * text, void * p) {
|
||||||
const Opt * opt = static_cast<Opt *>(p);
|
const Opt * opt = static_cast<Opt *>(p);
|
||||||
if (opt->verbose_ || level == GGML_LOG_LEVEL_ERROR) {
|
if (opt->verbose || level == GGML_LOG_LEVEL_ERROR) {
|
||||||
printe("%s", text);
|
printe("%s", text);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -890,11 +925,11 @@ int main(int argc, const char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!is_stdin_a_terminal()) {
|
if (!is_stdin_a_terminal()) {
|
||||||
if (!opt.user_.empty()) {
|
if (!opt.user.empty()) {
|
||||||
opt.user_ += "\n\n";
|
opt.user += "\n\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
opt.user_ += read_pipe_data();
|
opt.user += read_pipe_data();
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_log_set(log_callback, &opt);
|
llama_log_set(log_callback, &opt);
|
||||||
@ -903,7 +938,7 @@ int main(int argc, const char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (chat_loop(llama_data, opt.user_)) {
|
if (chat_loop(llama_data, opt.user)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user