diff --git a/examples/batched-bench/README.md b/examples/batched-bench/README.md index fa98bf24e..28dbbdca9 100644 --- a/examples/batched-bench/README.md +++ b/examples/batched-bench/README.md @@ -10,13 +10,16 @@ There are 2 modes of operation: - `prompt is shared` - there is a common prompt of size `PP` used by all batches (i.e. `N_KV = PP + B*TG`) ```bash -./batched-bench MODEL_PATH [N_KV_MAX] [IS_PP_SHARED] [NGL] +./batched-bench MODEL_PATH [N_KV_MAX] [IS_PP_SHARED] [NGL] # LLaMA 7B, F16, N_KV_MAX = 16384 (8GB), prompt not shared ./batched-bench ./models/llama-7b/ggml-model-f16.gguf 16384 0 99 # LLaMA 7B, Q8_0, N_KV_MAX = 16384 (8GB), prompt is shared ./batched-bench ./models/llama-7b/ggml-model-q8_0.gguf 16384 1 99 + +# custom set of batches +./batched-bench ./models/llama-7b/ggml-model-q8_0.gguf 2048 0 999 128,256,512 128,256 1,2,4,8,16,32 ``` ## Sample results diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index b2ffdd987..675651316 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -7,11 +7,34 @@ #include #include +// mutates the input string +static std::vector parse_list(char * p) { + std::vector ret; + + char * q = p; + + while (*p) { + if (*p == ',') { + *p = '\0'; + ret.push_back(std::atoi(q)); + q = p + 1; + } + + ++p; + } + + ret.push_back(std::atoi(q)); + + return ret; +} + int main(int argc, char ** argv) { gpt_params params; if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH [N_KV_MAX] [IS_PP_SHARED] [NGL]\n" , argv[0]); + printf("usage: %s MODEL_PATH [N_KV_MAX] [IS_PP_SHARED] [NGL] \n" , argv[0]); + printf(" , and PL are comma-separated lists of numbers without spaces\n\n"); + printf(" example: %s ggml-model-f16.gguf 2048 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]); return 1 ; } @@ -40,6 +63,18 @@ int main(int argc, char ** argv) { n_gpu_layers = std::atoi(argv[4]); } + if (argc >= 6) { + n_pp = parse_list(argv[5]); + } + + if (argc >= 7) { + n_tg = parse_list(argv[6]); + } + + if (argc >= 8) { + n_pl = parse_list(argv[7]); + } + // init LLM llama_backend_init(params.numa);