mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-05 16:10:42 +01:00
Implement s3:// protocol (#11511)
For those that want to pull from s3 Signed-off-by: Eric Curtin <ecurtin@redhat.com>
This commit is contained in:
parent
5bbc7362cb
commit
ecef206ccb
@ -65,6 +65,13 @@ static int printe(const char * fmt, ...) {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::string strftime_fmt(const char * fmt, const std::tm & tm) {
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << std::put_time(&tm, fmt);
|
||||||
|
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
|
||||||
class Opt {
|
class Opt {
|
||||||
public:
|
public:
|
||||||
int init(int argc, const char ** argv) {
|
int init(int argc, const char ** argv) {
|
||||||
@ -698,6 +705,39 @@ class LlamaData {
|
|||||||
return download(url, bn, true);
|
return download(url, bn, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int s3_dl(const std::string & model, const std::string & bn) {
|
||||||
|
const size_t slash_pos = model.find('/');
|
||||||
|
if (slash_pos == std::string::npos) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string bucket = model.substr(0, slash_pos);
|
||||||
|
const std::string key = model.substr(slash_pos + 1);
|
||||||
|
const char * access_key = std::getenv("AWS_ACCESS_KEY_ID");
|
||||||
|
const char * secret_key = std::getenv("AWS_SECRET_ACCESS_KEY");
|
||||||
|
if (!access_key || !secret_key) {
|
||||||
|
printe("AWS credentials not found in environment\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate AWS Signature Version 4 headers
|
||||||
|
// (Implementation requires HMAC-SHA256 and date handling)
|
||||||
|
// Get current timestamp
|
||||||
|
const time_t now = time(nullptr);
|
||||||
|
const tm tm = *gmtime(&now);
|
||||||
|
const std::string date = strftime_fmt("%Y%m%d", tm);
|
||||||
|
const std::string datetime = strftime_fmt("%Y%m%dT%H%M%SZ", tm);
|
||||||
|
const std::vector<std::string> headers = {
|
||||||
|
"Authorization: AWS4-HMAC-SHA256 Credential=" + std::string(access_key) + "/" + date +
|
||||||
|
"/us-east-1/s3/aws4_request",
|
||||||
|
"x-amz-content-sha256: UNSIGNED-PAYLOAD", "x-amz-date: " + datetime
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::string url = "https://" + bucket + ".s3.amazonaws.com/" + key;
|
||||||
|
|
||||||
|
return download(url, bn, true, headers);
|
||||||
|
}
|
||||||
|
|
||||||
std::string basename(const std::string & path) {
|
std::string basename(const std::string & path) {
|
||||||
const size_t pos = path.find_last_of("/\\");
|
const size_t pos = path.find_last_of("/\\");
|
||||||
if (pos == std::string::npos) {
|
if (pos == std::string::npos) {
|
||||||
@ -738,6 +778,9 @@ class LlamaData {
|
|||||||
rm_until_substring(model_, "github:");
|
rm_until_substring(model_, "github:");
|
||||||
rm_until_substring(model_, "://");
|
rm_until_substring(model_, "://");
|
||||||
ret = github_dl(model_, bn);
|
ret = github_dl(model_, bn);
|
||||||
|
} else if (string_starts_with(model_, "s3://")) {
|
||||||
|
rm_until_substring(model_, "://");
|
||||||
|
ret = s3_dl(model_, bn);
|
||||||
} else { // ollama:// or nothing
|
} else { // ollama:// or nothing
|
||||||
rm_until_substring(model_, "ollama.com/library/");
|
rm_until_substring(model_, "ollama.com/library/");
|
||||||
rm_until_substring(model_, "://");
|
rm_until_substring(model_, "://");
|
||||||
|
Loading…
Reference in New Issue
Block a user