diff --git a/examples/common.cpp b/examples/common.cpp index 1308f8410..478dbafcf 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -251,6 +251,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.model = argv[i]; + } else if (arg == "-a" || arg == "--alias") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.model_alias = argv[i]; } else if (arg == "--lora") { if (++i >= argc) { invalid_param = true; diff --git a/examples/common.h b/examples/common.h index 2b66382a6..fea9aa81a 100644 --- a/examples/common.h +++ b/examples/common.h @@ -45,6 +45,7 @@ struct gpt_params { float mirostat_eta = 0.10f; // learning rate std::string model = "models/7B/ggml-model.bin"; // model path + std::string model_alias = "unknown"; // model alias std::string prompt = ""; std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string input_prefix = ""; // string to prefix user inputs with diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3904412cb..9eacc929f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -400,8 +400,10 @@ void server_print_usage(int /*argc*/, char **argv, const gpt_params ¶ms) fprintf(stderr, " number of layers to store in VRAM\n"); fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); - fprintf(stderr, " -host ip address to listen (default 127.0.0.1)\n"); - fprintf(stderr, " -port PORT port to listen (default 8080)\n"); + fprintf(stderr, " -a ALIAS, --alias ALIAS\n"); + fprintf(stderr, " set an alias for the model, will be added as `model` field in completion response\n"); + fprintf(stderr, " --host ip address to listen (default 127.0.0.1)\n"); + fprintf(stderr, " --port PORT port to listen (default 8080)\n"); fprintf(stderr, "\n"); } @@ -453,6 +455,15 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para } params.model = argv[i]; } + else if (arg == "-a" || arg == "--alias") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + params.model_alias = argv[i]; + } else if (arg == "--embedding") { params.embedding = true; @@ -645,6 +656,7 @@ int main(int argc, char **argv) try { json data = { + {"model", llama.params.model_alias }, {"content", llama.generated_text }, {"tokens_predicted", llama.num_tokens_predicted}}; return res.set_content(data.dump(), "application/json");