mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
added support for Authorization Bearer tokens when downloading model (#8307)
* added support for Authorization Bearer tokens * removed auth_token, removed set_ function, other small fixes * Update common/common.cpp --------- Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
This commit is contained in:
parent
60d83a0149
commit
86e7299ef5
@ -190,6 +190,12 @@ int32_t cpu_get_num_math() {
|
|||||||
// CLI argument parsing
|
// CLI argument parsing
|
||||||
//
|
//
|
||||||
|
|
||||||
|
void gpt_params_handle_hf_token(gpt_params & params) {
|
||||||
|
if (params.hf_token.empty() && std::getenv("HF_TOKEN")) {
|
||||||
|
params.hf_token = std::getenv("HF_TOKEN");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void gpt_params_handle_model_default(gpt_params & params) {
|
void gpt_params_handle_model_default(gpt_params & params) {
|
||||||
if (!params.hf_repo.empty()) {
|
if (!params.hf_repo.empty()) {
|
||||||
// short-hand to avoid specifying --hf-file -> default it to --model
|
// short-hand to avoid specifying --hf-file -> default it to --model
|
||||||
@ -237,6 +243,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||||||
|
|
||||||
gpt_params_handle_model_default(params);
|
gpt_params_handle_model_default(params);
|
||||||
|
|
||||||
|
gpt_params_handle_hf_token(params);
|
||||||
|
|
||||||
if (params.escape) {
|
if (params.escape) {
|
||||||
string_process_escapes(params.prompt);
|
string_process_escapes(params.prompt);
|
||||||
string_process_escapes(params.input_prefix);
|
string_process_escapes(params.input_prefix);
|
||||||
@ -652,6 +660,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||||||
params.model_url = argv[i];
|
params.model_url = argv[i];
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (arg == "-hft" || arg == "--hf-token") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params.hf_token = argv[i];
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (arg == "-hfr" || arg == "--hf-repo") {
|
if (arg == "-hfr" || arg == "--hf-repo") {
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
params.hf_repo = argv[i];
|
params.hf_repo = argv[i];
|
||||||
@ -1576,6 +1592,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
|||||||
options.push_back({ "*", "-mu, --model-url MODEL_URL", "model download url (default: unused)" });
|
options.push_back({ "*", "-mu, --model-url MODEL_URL", "model download url (default: unused)" });
|
||||||
options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" });
|
options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" });
|
||||||
options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" });
|
options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" });
|
||||||
|
options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" });
|
||||||
|
|
||||||
options.push_back({ "retrieval" });
|
options.push_back({ "retrieval" });
|
||||||
options.push_back({ "retrieval", " --context-file FNAME", "file to load context from (repeat to specify multiple files)" });
|
options.push_back({ "retrieval", " --context-file FNAME", "file to load context from (repeat to specify multiple files)" });
|
||||||
@ -2015,9 +2032,9 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|||||||
llama_model * model = nullptr;
|
llama_model * model = nullptr;
|
||||||
|
|
||||||
if (!params.hf_repo.empty() && !params.hf_file.empty()) {
|
if (!params.hf_repo.empty() && !params.hf_file.empty()) {
|
||||||
model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams);
|
model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
|
||||||
} else if (!params.model_url.empty()) {
|
} else if (!params.model_url.empty()) {
|
||||||
model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams);
|
model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
|
||||||
} else {
|
} else {
|
||||||
model = llama_load_model_from_file(params.model.c_str(), mparams);
|
model = llama_load_model_from_file(params.model.c_str(), mparams);
|
||||||
}
|
}
|
||||||
@ -2205,7 +2222,7 @@ static bool starts_with(const std::string & str, const std::string & prefix) {
|
|||||||
return str.rfind(prefix, 0) == 0;
|
return str.rfind(prefix, 0) == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool llama_download_file(const std::string & url, const std::string & path) {
|
static bool llama_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
|
||||||
|
|
||||||
// Initialize libcurl
|
// Initialize libcurl
|
||||||
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
|
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
|
||||||
@ -2220,6 +2237,15 @@ static bool llama_download_file(const std::string & url, const std::string & pat
|
|||||||
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
|
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
|
||||||
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
|
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
|
||||||
|
|
||||||
|
// Check if hf-token or bearer-token was specified
|
||||||
|
if (!hf_token.empty()) {
|
||||||
|
std::string auth_header = "Authorization: Bearer ";
|
||||||
|
auth_header += hf_token.c_str();
|
||||||
|
struct curl_slist *http_headers = NULL;
|
||||||
|
http_headers = curl_slist_append(http_headers, auth_header.c_str());
|
||||||
|
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers);
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
|
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
|
||||||
// operating system. Currently implemented under MS-Windows.
|
// operating system. Currently implemented under MS-Windows.
|
||||||
@ -2415,6 +2441,7 @@ static bool llama_download_file(const std::string & url, const std::string & pat
|
|||||||
struct llama_model * llama_load_model_from_url(
|
struct llama_model * llama_load_model_from_url(
|
||||||
const char * model_url,
|
const char * model_url,
|
||||||
const char * path_model,
|
const char * path_model,
|
||||||
|
const char * hf_token,
|
||||||
const struct llama_model_params & params) {
|
const struct llama_model_params & params) {
|
||||||
// Basic validation of the model_url
|
// Basic validation of the model_url
|
||||||
if (!model_url || strlen(model_url) == 0) {
|
if (!model_url || strlen(model_url) == 0) {
|
||||||
@ -2422,7 +2449,7 @@ struct llama_model * llama_load_model_from_url(
|
|||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!llama_download_file(model_url, path_model)) {
|
if (!llama_download_file(model_url, path_model, hf_token)) {
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2470,14 +2497,14 @@ struct llama_model * llama_load_model_from_url(
|
|||||||
// Prepare download in parallel
|
// Prepare download in parallel
|
||||||
std::vector<std::future<bool>> futures_download;
|
std::vector<std::future<bool>> futures_download;
|
||||||
for (int idx = 1; idx < n_split; idx++) {
|
for (int idx = 1; idx < n_split; idx++) {
|
||||||
futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split](int download_idx) -> bool {
|
futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool {
|
||||||
char split_path[PATH_MAX] = {0};
|
char split_path[PATH_MAX] = {0};
|
||||||
llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);
|
llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);
|
||||||
|
|
||||||
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
|
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
|
||||||
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
|
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
|
||||||
|
|
||||||
return llama_download_file(split_url, split_path);
|
return llama_download_file(split_url, split_path, hf_token);
|
||||||
}, idx));
|
}, idx));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2496,6 +2523,7 @@ struct llama_model * llama_load_model_from_hf(
|
|||||||
const char * repo,
|
const char * repo,
|
||||||
const char * model,
|
const char * model,
|
||||||
const char * path_model,
|
const char * path_model,
|
||||||
|
const char * hf_token,
|
||||||
const struct llama_model_params & params) {
|
const struct llama_model_params & params) {
|
||||||
// construct hugging face model url:
|
// construct hugging face model url:
|
||||||
//
|
//
|
||||||
@ -2511,7 +2539,7 @@ struct llama_model * llama_load_model_from_hf(
|
|||||||
model_url += "/resolve/main/";
|
model_url += "/resolve/main/";
|
||||||
model_url += model;
|
model_url += model;
|
||||||
|
|
||||||
return llama_load_model_from_url(model_url.c_str(), path_model, params);
|
return llama_load_model_from_url(model_url.c_str(), path_model, hf_token, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
@ -2519,6 +2547,7 @@ struct llama_model * llama_load_model_from_hf(
|
|||||||
struct llama_model * llama_load_model_from_url(
|
struct llama_model * llama_load_model_from_url(
|
||||||
const char * /*model_url*/,
|
const char * /*model_url*/,
|
||||||
const char * /*path_model*/,
|
const char * /*path_model*/,
|
||||||
|
const char * /*hf_token*/,
|
||||||
const struct llama_model_params & /*params*/) {
|
const struct llama_model_params & /*params*/) {
|
||||||
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
|
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -2528,6 +2557,7 @@ struct llama_model * llama_load_model_from_hf(
|
|||||||
const char * /*repo*/,
|
const char * /*repo*/,
|
||||||
const char * /*model*/,
|
const char * /*model*/,
|
||||||
const char * /*path_model*/,
|
const char * /*path_model*/,
|
||||||
|
const char * /*hf_token*/,
|
||||||
const struct llama_model_params & /*params*/) {
|
const struct llama_model_params & /*params*/) {
|
||||||
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
|
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -108,6 +108,7 @@ struct gpt_params {
|
|||||||
std::string model_draft = ""; // draft model for speculative decoding
|
std::string model_draft = ""; // draft model for speculative decoding
|
||||||
std::string model_alias = "unknown"; // model alias
|
std::string model_alias = "unknown"; // model alias
|
||||||
std::string model_url = ""; // model url to download
|
std::string model_url = ""; // model url to download
|
||||||
|
std::string hf_token = ""; // HF token
|
||||||
std::string hf_repo = ""; // HF repo
|
std::string hf_repo = ""; // HF repo
|
||||||
std::string hf_file = ""; // HF file
|
std::string hf_file = ""; // HF file
|
||||||
std::string prompt = "";
|
std::string prompt = "";
|
||||||
@ -256,6 +257,7 @@ struct gpt_params {
|
|||||||
bool spm_infill = false; // suffix/prefix/middle pattern for infill
|
bool spm_infill = false; // suffix/prefix/middle pattern for infill
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void gpt_params_handle_hf_token(gpt_params & params);
|
||||||
void gpt_params_handle_model_default(gpt_params & params);
|
void gpt_params_handle_model_default(gpt_params & params);
|
||||||
|
|
||||||
bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params);
|
bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params);
|
||||||
@ -311,8 +313,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|||||||
struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
|
struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
|
||||||
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
|
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
|
||||||
|
|
||||||
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params);
|
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params);
|
||||||
struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params);
|
struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params);
|
||||||
|
|
||||||
// Batch utils
|
// Batch utils
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user