batched : add len CLI argument

This commit is contained in:
Georgi Gerganov 2023-10-22 08:37:20 +03:00
parent 465219b914
commit 22c69a2794
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -11,12 +11,16 @@ int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
if (argc == 1 || argv[1][0] == '-') { if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL]\n" , argv[0]); printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN]\n" , argv[0]);
return 1 ; return 1 ;
} }
// number of parallel batches
int n_parallel = 1; int n_parallel = 1;
// total length of the sequences including the prompt
int n_len = 32;
if (argc >= 2) { if (argc >= 2) {
params.model = argv[1]; params.model = argv[1];
} }
@ -29,13 +33,14 @@ int main(int argc, char ** argv) {
n_parallel = std::atoi(argv[3]); n_parallel = std::atoi(argv[3]);
} }
if (argc >= 5) {
n_len = std::atoi(argv[4]);
}
if (params.prompt.empty()) { if (params.prompt.empty()) {
params.prompt = "Hello my name is"; params.prompt = "Hello my name is";
} }
// total length of the sequences including the prompt
const int n_len = 32;
// init LLM // init LLM
llama_backend_init(params.numa); llama_backend_init(params.numa);