diff --git a/llama.cpp b/llama.cpp index a833d4c15..43deadd85 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3909,9 +3909,57 @@ static bool llm_load_tensors( return true; } +// check if the URL is a HuggingFace model, and if so, try to download it +static void hf_try_download_model(std::string & url) { + bool is_url = false; + + if (url.size() > 22) { + is_url = (url.compare(0, 22, "https://huggingface.co") == 0); + } + + if (!is_url) { + return; + } + + // Examples: + // + // https://huggingface.co/TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF/resolve/main/mixtral-8x7b-instruct-v0.1.Q2_K.gguf + + std::string basename; + basename = url.substr(url.find_last_of("/\\") + 1); + + LLAMA_LOG_INFO("%s: detected URL, attempting to download %s\n", __func__, basename.c_str()); + + { + const std::string cmd = "wget -q --show-progress -c -O " + basename + " " + url; + LLAMA_LOG_INFO("%s: %s\n", __func__, cmd.c_str()); + + const int ret = system(cmd.c_str()); + if (ret == 0) { + url = basename; + return; + } + } + + { + const std::string cmd = "curl -C - -f -o " + basename + " -L " + url; + LLAMA_LOG_INFO("%s: %s\n", __func__, cmd.c_str()); + + const int ret = system(cmd.c_str()); + if (ret == 0) { + url = basename; + return; + } + } + + LLAMA_LOG_WARN("%s: failed to download\n", __func__); +} + // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback -static int llama_model_load(const std::string & fname, llama_model & model, const llama_model_params & params) { +static int llama_model_load(std::string fname, llama_model & model, const llama_model_params & params) { try { + hf_try_download_model(fname); + llama_model_loader ml(fname, params.use_mmap, params.kv_overrides); model.hparams.vocab_only = params.vocab_only;