diff --git a/common/common.cpp b/common/common.cpp index cc230c9ff..1efc2eda5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -647,14 +647,6 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, params.model = argv[i]; return true; } - if (arg == "-mu" || arg == "--model-url") { - if (++i >= argc) { - invalid_param = true; - return true; - } - params.model_url = argv[i]; - return true; - } if (arg == "-md" || arg == "--model-draft") { if (++i >= argc) { invalid_param = true; @@ -671,6 +663,30 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, params.model_alias = argv[i]; return true; } + if (arg == "-mu" || arg == "--model-url") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.model_url = argv[i]; + return true; + } + if (arg == "-hfr" || arg == "--hf-repo") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.hf_repo = argv[i]; + return true; + } + if (arg == "-hff" || arg == "--hf-file") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.hf_file = argv[i]; + return true; + } if (arg == "--lora") { if (++i >= argc) { invalid_param = true; @@ -1403,10 +1419,14 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" layer range to apply the control vector(s) to, start and end inclusive\n"); printf(" -m FNAME, --model FNAME\n"); printf(" model path (default: %s)\n", params.model.c_str()); - printf(" -mu MODEL_URL, --model-url MODEL_URL\n"); - printf(" model download url (default: %s)\n", params.model_url.c_str()); printf(" -md FNAME, --model-draft FNAME\n"); printf(" draft model for speculative decoding\n"); + printf(" -mu MODEL_URL, --model-url MODEL_URL\n"); + printf(" model download url (default: %s)\n", params.model_url.c_str()); + printf(" -hfr REPO, --hf-repo REPO\n"); + printf(" Hugging Face model repository (default: %s)\n", params.hf_repo.c_str()); + printf(" -hff FILE, --hf-file FILE\n"); + printf(" Hugging Face model file (default: %s)\n", params.hf_file.c_str()); printf(" -ld LOGDIR, --logdir LOGDIR\n"); printf(" path under which to save YAML logs (no logging if unset)\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); @@ -1655,8 +1675,10 @@ void llama_batch_add( #ifdef LLAMA_USE_CURL -struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, - struct llama_model_params params) { +struct llama_model * llama_load_model_from_url( + const char * model_url, + const char * path_model, + const struct llama_model_params & params) { // Basic validation of the model_url if (!model_url || strlen(model_url) == 0) { fprintf(stderr, "%s: invalid model_url\n", __func__); @@ -1850,25 +1872,62 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha return llama_load_model_from_file(path_model, params); } +struct llama_model * llama_load_model_from_hf( + const char * repo, + const char * model, + const char * path_model, + const struct llama_model_params & params) { + // construct hugging face model url: + // + // --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf + // https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf + // + // --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf + // https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf + // + + std::string model_url = "https://huggingface.co/"; + model_url += repo; + model_url += "/resolve/main/"; + model_url += model; + + return llama_load_model_from_url(model_url.c_str(), path_model, params); +} + #else -struct llama_model * llama_load_model_from_url(const char * /*model_url*/, const char * /*path_model*/, - struct llama_model_params /*params*/) { +struct llama_model * llama_load_model_from_url( + const char * /*model_url*/, + const char * /*path_model*/, + const struct llama_model_params & /*params*/) { fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); return nullptr; } +struct llama_model * llama_load_model_from_hf( + const char * /*repo*/, + const char * /*model*/, + const char * /*path_model*/, + const struct llama_model_params & /*params*/) { + fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); + return nullptr; +} + #endif // LLAMA_USE_CURL std::tuple llama_init_from_gpt_params(gpt_params & params) { auto mparams = llama_model_params_from_gpt_params(params); llama_model * model = nullptr; - if (!params.model_url.empty()) { + + if (!params.hf_repo.empty() && !params.hf_file.empty()) { + model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams); + } else if (!params.model_url.empty()) { model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams); } else { model = llama_load_model_from_file(params.model.c_str(), mparams); } + if (model == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); return std::make_tuple(nullptr, nullptr); @@ -1908,7 +1967,7 @@ std::tuple llama_init_from_gpt_par } for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { - const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]); + const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]); float lora_scale = std::get<1>(params.lora_adapter[i]); int err = llama_model_apply_lora_from_file(model, lora_adapter.c_str(), diff --git a/common/common.h b/common/common.h index 31fd401b6..d827d4df7 100644 --- a/common/common.h +++ b/common/common.h @@ -89,9 +89,11 @@ struct gpt_params { struct llama_sampling_params sparams; std::string model = "models/7B/ggml-model-f16.gguf"; // model path - std::string model_url = ""; // model url to download - std::string model_draft = ""; // draft model for speculative decoding + std::string model_draft = ""; // draft model for speculative decoding std::string model_alias = "unknown"; // model alias + std::string model_url = ""; // model url to download + std::string hf_repo = ""; // HF repo + std::string hf_file = ""; // HF file std::string prompt = ""; std::string prompt_file = ""; // store the external prompt file name std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state @@ -192,8 +194,8 @@ std::tuple llama_init_from_gpt_par struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); -struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, - struct llama_model_params params); +struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params); +struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params); // Batch utils