common : refactor nested if causing error C1061 on MSVC (#6101)

* Refactor nested if causing error C1061 on MSVC.

* Revert back and remove else's.

* Add flag to track found arguments.
This commit is contained in:
DAN™ 2024-03-16 11:39:15 -04:00 committed by GitHub
parent a56d09a440
commit 15961ec04d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -151,13 +151,17 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
std::replace(arg.begin(), arg.end(), '_', '-'); std::replace(arg.begin(), arg.end(), '_', '-');
} }
bool arg_found = false;
if (arg == "-s" || arg == "--seed") { if (arg == "-s" || arg == "--seed") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.seed = std::stoul(argv[i]); params.seed = std::stoul(argv[i]);
} else if (arg == "-t" || arg == "--threads") { }
if (arg == "-t" || arg == "--threads") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -166,7 +170,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
if (params.n_threads <= 0) { if (params.n_threads <= 0) {
params.n_threads = std::thread::hardware_concurrency(); params.n_threads = std::thread::hardware_concurrency();
} }
} else if (arg == "-tb" || arg == "--threads-batch") { }
if (arg == "-tb" || arg == "--threads-batch") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -175,7 +181,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
if (params.n_threads_batch <= 0) { if (params.n_threads_batch <= 0) {
params.n_threads_batch = std::thread::hardware_concurrency(); params.n_threads_batch = std::thread::hardware_concurrency();
} }
} else if (arg == "-td" || arg == "--threads-draft") { }
if (arg == "-td" || arg == "--threads-draft") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -184,7 +192,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
if (params.n_threads_draft <= 0) { if (params.n_threads_draft <= 0) {
params.n_threads_draft = std::thread::hardware_concurrency(); params.n_threads_draft = std::thread::hardware_concurrency();
} }
} else if (arg == "-tbd" || arg == "--threads-batch-draft") { }
if (arg == "-tbd" || arg == "--threads-batch-draft") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -193,25 +203,37 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
if (params.n_threads_batch_draft <= 0) { if (params.n_threads_batch_draft <= 0) {
params.n_threads_batch_draft = std::thread::hardware_concurrency(); params.n_threads_batch_draft = std::thread::hardware_concurrency();
} }
} else if (arg == "-p" || arg == "--prompt") { }
if (arg == "-p" || arg == "--prompt") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.prompt = argv[i]; params.prompt = argv[i];
} else if (arg == "-e" || arg == "--escape") { }
if (arg == "-e" || arg == "--escape") {
arg_found = true;
params.escape = true; params.escape = true;
} else if (arg == "--prompt-cache") { }
if (arg == "--prompt-cache") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.path_prompt_cache = argv[i]; params.path_prompt_cache = argv[i];
} else if (arg == "--prompt-cache-all") { }
if (arg == "--prompt-cache-all") {
arg_found = true;
params.prompt_cache_all = true; params.prompt_cache_all = true;
} else if (arg == "--prompt-cache-ro") { }
if (arg == "--prompt-cache-ro") {
arg_found = true;
params.prompt_cache_ro = true; params.prompt_cache_ro = true;
} else if (arg == "-bf" || arg == "--binary-file") { }
if (arg == "-bf" || arg == "--binary-file") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -228,7 +250,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
ss << file.rdbuf(); ss << file.rdbuf();
params.prompt = ss.str(); params.prompt = ss.str();
fprintf(stderr, "Read %zu bytes from binary file %s\n", params.prompt.size(), argv[i]); fprintf(stderr, "Read %zu bytes from binary file %s\n", params.prompt.size(), argv[i]);
} else if (arg == "-f" || arg == "--file") { }
if (arg == "-f" || arg == "--file") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -245,51 +269,67 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
if (!params.prompt.empty() && params.prompt.back() == '\n') { if (!params.prompt.empty() && params.prompt.back() == '\n') {
params.prompt.pop_back(); params.prompt.pop_back();
} }
} else if (arg == "-n" || arg == "--n-predict") { }
if (arg == "-n" || arg == "--n-predict") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_predict = std::stoi(argv[i]); params.n_predict = std::stoi(argv[i]);
} else if (arg == "--top-k") { }
if (arg == "--top-k") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.top_k = std::stoi(argv[i]); sparams.top_k = std::stoi(argv[i]);
} else if (arg == "-c" || arg == "--ctx-size") { }
if (arg == "-c" || arg == "--ctx-size") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_ctx = std::stoi(argv[i]); params.n_ctx = std::stoi(argv[i]);
} else if (arg == "--grp-attn-n" || arg == "-gan") { }
if (arg == "--grp-attn-n" || arg == "-gan") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.grp_attn_n = std::stoi(argv[i]); params.grp_attn_n = std::stoi(argv[i]);
} else if (arg == "--grp-attn-w" || arg == "-gaw") { }
if (arg == "--grp-attn-w" || arg == "-gaw") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.grp_attn_w = std::stoi(argv[i]); params.grp_attn_w = std::stoi(argv[i]);
} else if (arg == "--rope-freq-base") { }
if (arg == "--rope-freq-base") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.rope_freq_base = std::stof(argv[i]); params.rope_freq_base = std::stof(argv[i]);
} else if (arg == "--rope-freq-scale") { }
if (arg == "--rope-freq-scale") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.rope_freq_scale = std::stof(argv[i]); params.rope_freq_scale = std::stof(argv[i]);
} else if (arg == "--rope-scaling") { }
if (arg == "--rope-scaling") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -299,43 +339,57 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; } else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
else { invalid_param = true; break; } else { invalid_param = true; break; }
} else if (arg == "--rope-scale") { }
if (arg == "--rope-scale") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.rope_freq_scale = 1.0f/std::stof(argv[i]); params.rope_freq_scale = 1.0f/std::stof(argv[i]);
} else if (arg == "--yarn-orig-ctx") { }
if (arg == "--yarn-orig-ctx") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.yarn_orig_ctx = std::stoi(argv[i]); params.yarn_orig_ctx = std::stoi(argv[i]);
} else if (arg == "--yarn-ext-factor") { }
if (arg == "--yarn-ext-factor") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.yarn_ext_factor = std::stof(argv[i]); params.yarn_ext_factor = std::stof(argv[i]);
} else if (arg == "--yarn-attn-factor") { }
if (arg == "--yarn-attn-factor") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.yarn_attn_factor = std::stof(argv[i]); params.yarn_attn_factor = std::stof(argv[i]);
} else if (arg == "--yarn-beta-fast") { }
if (arg == "--yarn-beta-fast") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.yarn_beta_fast = std::stof(argv[i]); params.yarn_beta_fast = std::stof(argv[i]);
} else if (arg == "--yarn-beta-slow") { }
if (arg == "--yarn-beta-slow") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.yarn_beta_slow = std::stof(argv[i]); params.yarn_beta_slow = std::stof(argv[i]);
} else if (arg == "--pooling") { }
if (arg == "--pooling") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -345,118 +399,156 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else { invalid_param = true; break; } else { invalid_param = true; break; }
} else if (arg == "--defrag-thold" || arg == "-dt") { }
if (arg == "--defrag-thold" || arg == "-dt") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.defrag_thold = std::stof(argv[i]); params.defrag_thold = std::stof(argv[i]);
} else if (arg == "--samplers") { }
if (arg == "--samplers") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
const auto sampler_names = string_split(argv[i], ';'); const auto sampler_names = string_split(argv[i], ';');
sparams.samplers_sequence = sampler_types_from_names(sampler_names, true); sparams.samplers_sequence = sampler_types_from_names(sampler_names, true);
} else if (arg == "--sampling-seq") { }
if (arg == "--sampling-seq") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.samplers_sequence = sampler_types_from_chars(argv[i]); sparams.samplers_sequence = sampler_types_from_chars(argv[i]);
} else if (arg == "--top-p") { }
if (arg == "--top-p") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.top_p = std::stof(argv[i]); sparams.top_p = std::stof(argv[i]);
} else if (arg == "--min-p") { }
if (arg == "--min-p") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.min_p = std::stof(argv[i]); sparams.min_p = std::stof(argv[i]);
} else if (arg == "--temp") { }
if (arg == "--temp") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.temp = std::stof(argv[i]); sparams.temp = std::stof(argv[i]);
sparams.temp = std::max(sparams.temp, 0.0f); sparams.temp = std::max(sparams.temp, 0.0f);
} else if (arg == "--tfs") { }
if (arg == "--tfs") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.tfs_z = std::stof(argv[i]); sparams.tfs_z = std::stof(argv[i]);
} else if (arg == "--typical") { }
if (arg == "--typical") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.typical_p = std::stof(argv[i]); sparams.typical_p = std::stof(argv[i]);
} else if (arg == "--repeat-last-n") { }
if (arg == "--repeat-last-n") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.penalty_last_n = std::stoi(argv[i]); sparams.penalty_last_n = std::stoi(argv[i]);
sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n); sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n);
} else if (arg == "--repeat-penalty") { }
if (arg == "--repeat-penalty") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.penalty_repeat = std::stof(argv[i]); sparams.penalty_repeat = std::stof(argv[i]);
} else if (arg == "--frequency-penalty") { }
if (arg == "--frequency-penalty") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.penalty_freq = std::stof(argv[i]); sparams.penalty_freq = std::stof(argv[i]);
} else if (arg == "--presence-penalty") { }
if (arg == "--presence-penalty") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.penalty_present = std::stof(argv[i]); sparams.penalty_present = std::stof(argv[i]);
} else if (arg == "--dynatemp-range") { }
if (arg == "--dynatemp-range") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.dynatemp_range = std::stof(argv[i]); sparams.dynatemp_range = std::stof(argv[i]);
} else if (arg == "--dynatemp-exp") { }
if (arg == "--dynatemp-exp") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.dynatemp_exponent = std::stof(argv[i]); sparams.dynatemp_exponent = std::stof(argv[i]);
} else if (arg == "--mirostat") { }
if (arg == "--mirostat") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.mirostat = std::stoi(argv[i]); sparams.mirostat = std::stoi(argv[i]);
} else if (arg == "--mirostat-lr") { }
if (arg == "--mirostat-lr") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.mirostat_eta = std::stof(argv[i]); sparams.mirostat_eta = std::stof(argv[i]);
} else if (arg == "--mirostat-ent") { }
if (arg == "--mirostat-ent") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.mirostat_tau = std::stof(argv[i]); sparams.mirostat_tau = std::stof(argv[i]);
} else if (arg == "--cfg-negative-prompt") { }
if (arg == "--cfg-negative-prompt") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.cfg_negative_prompt = argv[i]; sparams.cfg_negative_prompt = argv[i];
} else if (arg == "--cfg-negative-prompt-file") { }
if (arg == "--cfg-negative-prompt-file") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -471,86 +563,114 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') { if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') {
sparams.cfg_negative_prompt.pop_back(); sparams.cfg_negative_prompt.pop_back();
} }
} else if (arg == "--cfg-scale") { }
if (arg == "--cfg-scale") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.cfg_scale = std::stof(argv[i]); sparams.cfg_scale = std::stof(argv[i]);
} else if (arg == "-b" || arg == "--batch-size") { }
if (arg == "-b" || arg == "--batch-size") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_batch = std::stoi(argv[i]); params.n_batch = std::stoi(argv[i]);
} else if (arg == "-ub" || arg == "--ubatch-size") { }
if (arg == "-ub" || arg == "--ubatch-size") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_ubatch = std::stoi(argv[i]); params.n_ubatch = std::stoi(argv[i]);
} else if (arg == "--keep") { }
if (arg == "--keep") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_keep = std::stoi(argv[i]); params.n_keep = std::stoi(argv[i]);
} else if (arg == "--draft") { }
if (arg == "--draft") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_draft = std::stoi(argv[i]); params.n_draft = std::stoi(argv[i]);
} else if (arg == "--chunks") { }
if (arg == "--chunks") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_chunks = std::stoi(argv[i]); params.n_chunks = std::stoi(argv[i]);
} else if (arg == "-np" || arg == "--parallel") { }
if (arg == "-np" || arg == "--parallel") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_parallel = std::stoi(argv[i]); params.n_parallel = std::stoi(argv[i]);
} else if (arg == "-ns" || arg == "--sequences") { }
if (arg == "-ns" || arg == "--sequences") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_sequences = std::stoi(argv[i]); params.n_sequences = std::stoi(argv[i]);
} else if (arg == "--p-split" || arg == "-ps") { }
if (arg == "--p-split" || arg == "-ps") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.p_split = std::stof(argv[i]); params.p_split = std::stof(argv[i]);
} else if (arg == "-m" || arg == "--model") { }
if (arg == "-m" || arg == "--model") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.model = argv[i]; params.model = argv[i];
} else if (arg == "-md" || arg == "--model-draft") { }
if (arg == "-md" || arg == "--model-draft") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.model_draft = argv[i]; params.model_draft = argv[i];
} else if (arg == "-a" || arg == "--alias") { }
if (arg == "-a" || arg == "--alias") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.model_alias = argv[i]; params.model_alias = argv[i];
} else if (arg == "--lora") { }
if (arg == "--lora") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.lora_adapter.emplace_back(argv[i], 1.0f); params.lora_adapter.emplace_back(argv[i], 1.0f);
params.use_mmap = false; params.use_mmap = false;
} else if (arg == "--lora-scaled") { }
if (arg == "--lora-scaled") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -562,19 +682,25 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
} }
params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i])); params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
params.use_mmap = false; params.use_mmap = false;
} else if (arg == "--lora-base") { }
if (arg == "--lora-base") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.lora_base = argv[i]; params.lora_base = argv[i];
} else if (arg == "--control-vector") { }
if (arg == "--control-vector") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.control_vectors.push_back({ 1.0f, argv[i], }); params.control_vectors.push_back({ 1.0f, argv[i], });
} else if (arg == "--control-vector-scaled") { }
if (arg == "--control-vector-scaled") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -585,7 +711,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.control_vectors.push_back({ std::stof(argv[i]), fname, }); params.control_vectors.push_back({ std::stof(argv[i]), fname, });
} else if (arg == "--control-vector-layer-range") { }
if (arg == "--control-vector-layer-range") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -596,49 +724,85 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.control_vector_layer_end = std::stoi(argv[i]); params.control_vector_layer_end = std::stoi(argv[i]);
} else if (arg == "--mmproj") { }
if (arg == "--mmproj") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.mmproj = argv[i]; params.mmproj = argv[i];
} else if (arg == "--image") { }
if (arg == "--image") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.image = argv[i]; params.image = argv[i];
} else if (arg == "-i" || arg == "--interactive") { }
if (arg == "-i" || arg == "--interactive") {
arg_found = true;
params.interactive = true; params.interactive = true;
} else if (arg == "--embedding") { }
if (arg == "--embedding") {
arg_found = true;
params.embedding = true; params.embedding = true;
} else if (arg == "--interactive-first") { }
if (arg == "--interactive-first") {
arg_found = true;
params.interactive_first = true; params.interactive_first = true;
} else if (arg == "-ins" || arg == "--instruct") { }
if (arg == "-ins" || arg == "--instruct") {
arg_found = true;
params.instruct = true; params.instruct = true;
} else if (arg == "-cml" || arg == "--chatml") { }
if (arg == "-cml" || arg == "--chatml") {
arg_found = true;
params.chatml = true; params.chatml = true;
} else if (arg == "--infill") { }
if (arg == "--infill") {
arg_found = true;
params.infill = true; params.infill = true;
} else if (arg == "-dkvc" || arg == "--dump-kv-cache") { }
if (arg == "-dkvc" || arg == "--dump-kv-cache") {
arg_found = true;
params.dump_kv_cache = true; params.dump_kv_cache = true;
} else if (arg == "-nkvo" || arg == "--no-kv-offload") { }
if (arg == "-nkvo" || arg == "--no-kv-offload") {
arg_found = true;
params.no_kv_offload = true; params.no_kv_offload = true;
} else if (arg == "-ctk" || arg == "--cache-type-k") { }
if (arg == "-ctk" || arg == "--cache-type-k") {
arg_found = true;
params.cache_type_k = argv[++i]; params.cache_type_k = argv[++i];
} else if (arg == "-ctv" || arg == "--cache-type-v") { }
if (arg == "-ctv" || arg == "--cache-type-v") {
arg_found = true;
params.cache_type_v = argv[++i]; params.cache_type_v = argv[++i];
} else if (arg == "--multiline-input") { }
if (arg == "--multiline-input") {
arg_found = true;
params.multiline_input = true; params.multiline_input = true;
} else if (arg == "--simple-io") { }
if (arg == "--simple-io") {
arg_found = true;
params.simple_io = true; params.simple_io = true;
} else if (arg == "-cb" || arg == "--cont-batching") { }
if (arg == "-cb" || arg == "--cont-batching") {
arg_found = true;
params.cont_batching = true; params.cont_batching = true;
} else if (arg == "--color") { }
if (arg == "--color") {
arg_found = true;
params.use_color = true; params.use_color = true;
} else if (arg == "--mlock") { }
if (arg == "--mlock") {
arg_found = true;
params.use_mlock = true; params.use_mlock = true;
} else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { }
if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -648,7 +812,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n"); fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
} }
} else if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") { }
if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -658,7 +824,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n"); fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
} }
} else if (arg == "--main-gpu" || arg == "-mg") { }
if (arg == "--main-gpu" || arg == "-mg") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -667,7 +835,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
#ifndef GGML_USE_CUBLAS_SYCL #ifndef GGML_USE_CUBLAS_SYCL
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL. Setting the main GPU has no effect.\n"); fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL. Setting the main GPU has no effect.\n");
#endif // GGML_USE_CUBLAS_SYCL #endif // GGML_USE_CUBLAS_SYCL
} else if (arg == "--split-mode" || arg == "-sm") { }
if (arg == "--split-mode" || arg == "-sm") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -691,7 +861,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL. Setting the split mode has no effect.\n"); fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL. Setting the split mode has no effect.\n");
#endif // GGML_USE_CUBLAS_SYCL #endif // GGML_USE_CUBLAS_SYCL
} else if (arg == "--tensor-split" || arg == "-ts") { }
if (arg == "--tensor-split" || arg == "-ts") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -716,9 +888,13 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
#ifndef GGML_USE_CUBLAS_SYCL_VULKAN #ifndef GGML_USE_CUBLAS_SYCL_VULKAN
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL/Vulkan. Setting a tensor split has no effect.\n"); fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL/Vulkan. Setting a tensor split has no effect.\n");
#endif // GGML_USE_CUBLAS_SYCL #endif // GGML_USE_CUBLAS_SYCL
} else if (arg == "--no-mmap") { }
if (arg == "--no-mmap") {
arg_found = true;
params.use_mmap = false; params.use_mmap = false;
} else if (arg == "--numa") { }
if (arg == "--numa") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -728,17 +904,25 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; } else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; }
else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; } else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
else { invalid_param = true; break; } else { invalid_param = true; break; }
} else if (arg == "--verbose-prompt") { }
if (arg == "--verbose-prompt") {
arg_found = true;
params.verbose_prompt = true; params.verbose_prompt = true;
} else if (arg == "--no-display-prompt") { }
if (arg == "--no-display-prompt") {
arg_found = true;
params.display_prompt = false; params.display_prompt = false;
} else if (arg == "-r" || arg == "--reverse-prompt") { }
if (arg == "-r" || arg == "--reverse-prompt") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.antiprompt.emplace_back(argv[i]); params.antiprompt.emplace_back(argv[i]);
} else if (arg == "-ld" || arg == "--logdir") { }
if (arg == "-ld" || arg == "--logdir") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -748,63 +932,93 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
if (params.logdir.back() != DIRECTORY_SEPARATOR) { if (params.logdir.back() != DIRECTORY_SEPARATOR) {
params.logdir += DIRECTORY_SEPARATOR; params.logdir += DIRECTORY_SEPARATOR;
} }
} else if (arg == "--save-all-logits" || arg == "--kl-divergence-base") { }
if (arg == "--save-all-logits" || arg == "--kl-divergence-base") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.logits_file = argv[i]; params.logits_file = argv[i];
} else if (arg == "--perplexity" || arg == "--all-logits") { }
if (arg == "--perplexity" || arg == "--all-logits") {
arg_found = true;
params.logits_all = true; params.logits_all = true;
} else if (arg == "--ppl-stride") { }
if (arg == "--ppl-stride") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.ppl_stride = std::stoi(argv[i]); params.ppl_stride = std::stoi(argv[i]);
} else if (arg == "-ptc" || arg == "--print-token-count") { }
if (arg == "-ptc" || arg == "--print-token-count") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_print = std::stoi(argv[i]); params.n_print = std::stoi(argv[i]);
} else if (arg == "--ppl-output-type") { }
if (arg == "--ppl-output-type") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.ppl_output_type = std::stoi(argv[i]); params.ppl_output_type = std::stoi(argv[i]);
} else if (arg == "--hellaswag") { }
if (arg == "--hellaswag") {
arg_found = true;
params.hellaswag = true; params.hellaswag = true;
} else if (arg == "--hellaswag-tasks") { }
if (arg == "--hellaswag-tasks") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.hellaswag_tasks = std::stoi(argv[i]); params.hellaswag_tasks = std::stoi(argv[i]);
} else if (arg == "--winogrande") { }
if (arg == "--winogrande") {
arg_found = true;
params.winogrande = true; params.winogrande = true;
} else if (arg == "--winogrande-tasks") { }
if (arg == "--winogrande-tasks") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.winogrande_tasks = std::stoi(argv[i]); params.winogrande_tasks = std::stoi(argv[i]);
} else if (arg == "--multiple-choice") { }
if (arg == "--multiple-choice") {
arg_found = true;
params.multiple_choice = true; params.multiple_choice = true;
} else if (arg == "--multiple-choice-tasks") { }
if (arg == "--multiple-choice-tasks") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.multiple_choice_tasks = std::stoi(argv[i]); params.multiple_choice_tasks = std::stoi(argv[i]);
} else if (arg == "--kl-divergence") { }
if (arg == "--kl-divergence") {
arg_found = true;
params.kl_divergence = true; params.kl_divergence = true;
} else if (arg == "--ignore-eos") { }
if (arg == "--ignore-eos") {
arg_found = true;
params.ignore_eos = true; params.ignore_eos = true;
} else if (arg == "--no-penalize-nl") { }
if (arg == "--no-penalize-nl") {
arg_found = true;
sparams.penalize_nl = false; sparams.penalize_nl = false;
} else if (arg == "-l" || arg == "--logit-bias") { }
if (arg == "-l" || arg == "--logit-bias") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -823,36 +1037,51 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
invalid_param = true; invalid_param = true;
break; break;
} }
} else if (arg == "-h" || arg == "--help") { }
if (arg == "-h" || arg == "--help") {
arg_found = true;
return false; return false;
}
} else if (arg == "--version") { if (arg == "--version") {
arg_found = true;
fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT); fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET); fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET);
exit(0); exit(0);
} else if (arg == "--random-prompt") { }
if (arg == "--random-prompt") {
arg_found = true;
params.random_prompt = true; params.random_prompt = true;
} else if (arg == "--in-prefix-bos") { }
if (arg == "--in-prefix-bos") {
arg_found = true;
params.input_prefix_bos = true; params.input_prefix_bos = true;
} else if (arg == "--in-prefix") { }
if (arg == "--in-prefix") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.input_prefix = argv[i]; params.input_prefix = argv[i];
} else if (arg == "--in-suffix") { }
if (arg == "--in-suffix") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.input_suffix = argv[i]; params.input_suffix = argv[i];
} else if (arg == "--grammar") { }
if (arg == "--grammar") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.grammar = argv[i]; sparams.grammar = argv[i];
} else if (arg == "--grammar-file") { }
if (arg == "--grammar-file") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -868,7 +1097,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
std::istreambuf_iterator<char>(), std::istreambuf_iterator<char>(),
std::back_inserter(sparams.grammar) std::back_inserter(sparams.grammar)
); );
} else if (arg == "--override-kv") { }
if (arg == "--override-kv") {
arg_found = true;
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -911,10 +1142,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.kv_overrides.push_back(kvo); params.kv_overrides.push_back(kvo);
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
// Parse args for logging parameters // Parse args for logging parameters
} else if ( log_param_single_parse( argv[i] ) ) { }
if ( log_param_single_parse( argv[i] ) ) {
arg_found = true;
// Do nothing, log_param_single_parse automatically does it's thing // Do nothing, log_param_single_parse automatically does it's thing
// and returns if a match was found and parsed. // and returns if a match was found and parsed.
} else if ( log_param_pair_parse( /*check_but_dont_parse*/ true, argv[i] ) ) { }
if ( log_param_pair_parse( /*check_but_dont_parse*/ true, argv[i] ) ) {
arg_found = true;
// We have a matching known parameter requiring an argument, // We have a matching known parameter requiring an argument,
// now we need to check if there is anything after this argv // now we need to check if there is anything after this argv
// and flag invalid_param or parse it. // and flag invalid_param or parse it.
@ -928,7 +1163,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
} }
// End of Parse args for logging parameters // End of Parse args for logging parameters
#endif // LOG_DISABLE_LOGS #endif // LOG_DISABLE_LOGS
} else { }
if (!arg_found) {
throw std::invalid_argument("error: unknown argument: " + arg); throw std::invalid_argument("error: unknown argument: " + arg);
} }
} }