From ecef206ccb186a1cde8dd2523b1da3e12f593f9e Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Sat, 1 Feb 2025 11:30:54 +0100 Subject: [PATCH] Implement s3:// protocol (#11511) For those that want to pull from s3 Signed-off-by: Eric Curtin --- examples/run/run.cpp | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 9cecae48c..cf61f4add 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -65,6 +65,13 @@ static int printe(const char * fmt, ...) { return ret; } +static std::string strftime_fmt(const char * fmt, const std::tm & tm) { + std::ostringstream oss; + oss << std::put_time(&tm, fmt); + + return oss.str(); +} + class Opt { public: int init(int argc, const char ** argv) { @@ -698,6 +705,39 @@ class LlamaData { return download(url, bn, true); } + int s3_dl(const std::string & model, const std::string & bn) { + const size_t slash_pos = model.find('/'); + if (slash_pos == std::string::npos) { + return 1; + } + + const std::string bucket = model.substr(0, slash_pos); + const std::string key = model.substr(slash_pos + 1); + const char * access_key = std::getenv("AWS_ACCESS_KEY_ID"); + const char * secret_key = std::getenv("AWS_SECRET_ACCESS_KEY"); + if (!access_key || !secret_key) { + printe("AWS credentials not found in environment\n"); + return 1; + } + + // Generate AWS Signature Version 4 headers + // (Implementation requires HMAC-SHA256 and date handling) + // Get current timestamp + const time_t now = time(nullptr); + const tm tm = *gmtime(&now); + const std::string date = strftime_fmt("%Y%m%d", tm); + const std::string datetime = strftime_fmt("%Y%m%dT%H%M%SZ", tm); + const std::vector headers = { + "Authorization: AWS4-HMAC-SHA256 Credential=" + std::string(access_key) + "/" + date + + "/us-east-1/s3/aws4_request", + "x-amz-content-sha256: UNSIGNED-PAYLOAD", "x-amz-date: " + datetime + }; + + const std::string url = "https://" + bucket + ".s3.amazonaws.com/" + key; + + return download(url, bn, true, headers); + } + std::string basename(const std::string & path) { const size_t pos = path.find_last_of("/\\"); if (pos == std::string::npos) { @@ -738,6 +778,9 @@ class LlamaData { rm_until_substring(model_, "github:"); rm_until_substring(model_, "://"); ret = github_dl(model_, bn); + } else if (string_starts_with(model_, "s3://")) { + rm_until_substring(model_, "://"); + ret = s3_dl(model_, bn); } else { // ollama:// or nothing rm_until_substring(model_, "ollama.com/library/"); rm_until_substring(model_, "://");