From 2290be269eaed9d5ddf2d3ae5abda23a1a513537 Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Sat, 13 Jul 2024 17:35:23 +0300 Subject: [PATCH] chore: gcs write file (#297) Signed-off-by: Roman Gershman --- examples/gcs_demo.cc | 35 +++++- util/cloud/gcp/gcp_utils.cc | 58 +++++---- util/cloud/gcp/gcp_utils.h | 96 +++++++++++++- util/cloud/gcp/gcs.cc | 91 +++++++++----- util/cloud/gcp/gcs.h | 4 + util/cloud/gcp/gcs_file.cc | 244 ++++++++++++++++++++++++++++++++---- util/cloud/gcp/gcs_file.h | 32 +---- util/tls/tls_socket.cc | 10 ++ 8 files changed, 460 insertions(+), 110 deletions(-) diff --git a/examples/gcs_demo.cc b/examples/gcs_demo.cc index 5aeedd4b..faf492c8 100644 --- a/examples/gcs_demo.cc +++ b/examples/gcs_demo.cc @@ -1,21 +1,23 @@ // Copyright 2024, Roman Gershman. All rights reserved. // See LICENSE for licensing terms. +#include #include "base/flags.h" #include "base/init.h" #include "base/logging.h" +#include "io/file_util.h" #include "util/cloud/gcp/gcs.h" +#include "util/cloud/gcp/gcs_file.h" #include "util/fibers/pool.h" using namespace std; -using namespace boost; using namespace util; using absl::GetFlag; ABSL_FLAG(string, bucket, "", ""); ABSL_FLAG(string, prefix, "", ""); - +ABSL_FLAG(uint32_t, write, 0, ""); ABSL_FLAG(uint32_t, connect_ms, 2000, ""); ABSL_FLAG(bool, epoll, false, "Whether to use epoll instead of io_uring"); @@ -32,10 +34,31 @@ void Run(SSL_CTX* ctx) { string prefix = GetFlag(FLAGS_prefix); if (!prefix.empty()) { - auto cb = [](cloud::GCS::ObjectItem item) { - cout << "Object: " << item.key << ", size: " << item.size << endl; - }; - ec = gcs.List(GetFlag(FLAGS_bucket), prefix, false, cb); + string bucket = GetFlag(FLAGS_bucket); + auto conn_pool = gcs.CreateConnectionPool(); + CHECK(!bucket.empty()); + + if (GetFlag(FLAGS_write) > 0) { + auto src = io::ReadFileToString("/proc/self/exe"); + CHECK(src); + for (unsigned i = 0; i < GetFlag(FLAGS_write); ++i) { + string dest_key = absl::StrCat(prefix, "_", i); + io::Result dest_res = + cloud::OpenWriteGcsFile(bucket, dest_key, &provider, conn_pool.get()); + CHECK(dest_res) << "Could not open " << dest_key << " " << dest_res.error().message(); + unique_ptr dest(*dest_res); + error_code ec = dest->Write(*src); + CHECK(!ec); + ec = dest->Close(); + CHECK(!ec); + CONSOLE_INFO << "Written " << dest_key; + } + } else { + auto cb = [](cloud::GCS::ObjectItem item) { + cout << "Object: " << item.key << ", size: " << item.size << endl; + }; + ec = gcs.List(GetFlag(FLAGS_bucket), prefix, false, cb); + } } else { auto cb = [](std::string_view bname) { CONSOLE_INFO << bname; }; diff --git a/util/cloud/gcp/gcp_utils.cc b/util/cloud/gcp/gcp_utils.cc index 7358048b..ac7aa8dd 100644 --- a/util/cloud/gcp/gcp_utils.cc +++ b/util/cloud/gcp/gcp_utils.cc @@ -10,13 +10,6 @@ #include "base/logging.h" #include "util/cloud/gcp/gcp_creds_provider.h" -#define RETURN_UNEXPECTED(x) \ - do { \ - auto ec = (x); \ - if (ec) \ - return nonstd::make_unexpected(ec); \ - } while (false) - namespace util::cloud { using namespace std; namespace h2 = boost::beast::http; @@ -37,34 +30,54 @@ inline bool DoesServerPushback(h2::status st) { h2::to_status_class(st) == h2::status_class::server_error; } +constexpr auto kResumeIncomplete = h2::status::permanent_redirect; + +bool IsResponseOK(h2::status st) { + // Partial content can appear because of the previous reconnect. + // For multipart uploads kResumeIncomplete can be returned. + return st == h2::status::ok || st == h2::status::partial_content || st == kResumeIncomplete; +} + } // namespace -const char GCP_API_DOMAIN[] = "www.googleapis.com"; +const char GCS_API_DOMAIN[] = "storage.googleapis.com"; string AuthHeader(string_view access_token) { return absl::StrCat("Bearer ", access_token); } -EmptyRequest PrepareRequest(h2::verb req_verb, std::string_view url, - const string_view access_token) { - EmptyRequest req{req_verb, boost::beast::string_view{url.data(), url.size()}, 11}; - req.set(h2::field::host, GCP_API_DOMAIN); - req.set(h2::field::authorization, AuthHeader(access_token)); - req.keep_alive(true); +namespace detail { + +EmptyRequestImpl::EmptyRequestImpl(h2::verb req_verb, std::string_view url, + const string_view access_token) + : req_{req_verb, boost::beast::string_view{url.data(), url.size()}, 11} { + req_.set(h2::field::host, GCS_API_DOMAIN); + req_.set(h2::field::authorization, AuthHeader(access_token)); + // ? req_.keep_alive(true); +} + +std::error_code EmptyRequestImpl::Send(http::Client* client) { + return client->Send(req_); +} - return req; +std::error_code DynamicBodyRequestImpl::Send(http::Client* client) { + return client->Send(req_); } +} // namespace detail + RobustSender::RobustSender(unsigned num_iterations, GCPCredsProvider* provider) : num_iterations_(num_iterations), provider_(provider) { } -auto RobustSender::Send(http::Client* client, EmptyRequest* req) -> io::Result { +auto RobustSender::Send(http::Client* client, + detail::HttpRequestBase* req) -> io::Result { error_code ec; for (unsigned i = 0; i < num_iterations_; ++i) { // Iterate for possible token refresh. - VLOG(1) << "HttpReq" << i << ": " << *req << ", socket " << client->native_handle(); + VLOG(1) << "HttpReq " << client->host() << ": " << req->GetHeaders() << ", [" + << client->native_handle() << "]"; - RETURN_UNEXPECTED(client->Send(*req)); + RETURN_UNEXPECTED(req->Send(client)); HeaderParserPtr parser(new h2::response_parser()); RETURN_UNEXPECTED(client->ReadHeader(parser.get())); { @@ -75,11 +88,11 @@ auto RobustSender::Send(http::Client* client, EmptyRequest* req) -> io::Result drainer(std::move(*parser)); RETURN_UNEXPECTED(client->Recv(&drainer)); @@ -88,14 +101,13 @@ auto RobustSender::Send(http::Client* client, EmptyRequest* req) -> io::Resultnative_handle() << ") with " << msg; - ThisFiber::SleepFor(1s); - i = 0; // Can potentially deadlock + ThisFiber::SleepFor(100ms); continue; } if (IsUnauthorized(msg)) { RETURN_UNEXPECTED(provider_->RefreshToken(client->proactor())); - req->set(h2::field::authorization, AuthHeader(provider_->access_token())); + req->SetHeader(h2::field::authorization, AuthHeader(provider_->access_token())); continue; } diff --git a/util/cloud/gcp/gcp_utils.h b/util/cloud/gcp/gcp_utils.h index 0c5bd8a3..794b0d4f 100644 --- a/util/cloud/gcp/gcp_utils.h +++ b/util/cloud/gcp/gcp_utils.h @@ -11,28 +11,112 @@ namespace util::cloud { class GCPCredsProvider; +extern const char GCS_API_DOMAIN[]; -extern const char GCP_API_DOMAIN[]; +namespace detail { +inline std::string_view FromBoostSV(boost::string_view sv) { + return std::string_view(sv.data(), sv.size()); +} -using EmptyRequest = boost::beast::http::request; +class HttpRequestBase { + public: + HttpRequestBase(const HttpRequestBase&) = delete; + HttpRequestBase& operator=(const HttpRequestBase&) = delete; + HttpRequestBase() = default; -EmptyRequest PrepareRequest(boost::beast::http::verb req_verb, std::string_view url, - const std::string_view access_token); + virtual ~HttpRequestBase() = default; + virtual std::error_code Send(http::Client* client) = 0; -std::string AuthHeader(std::string_view access_token); + virtual const boost::beast::http::header& GetHeaders() const = 0; + + virtual void SetHeader(boost::beast::http::field f, std::string_view value) = 0; +}; + +class EmptyRequestImpl : public HttpRequestBase { + using EmptyRequest = boost::beast::http::request; + EmptyRequest req_; + + public: + EmptyRequestImpl(boost::beast::http::verb req_verb, std::string_view url, + const std::string_view access_token); + + void SetUrl(std::string_view url) { + req_.target(boost::string_view{url.data(), url.size()}); + } + + void Finalize() { + req_.prepare_payload(); + } + + void SetHeader(boost::beast::http::field f, std::string_view value) final { + req_.set(f, boost::string_view{value.data(), value.size()}); + } + + const boost::beast::http::header& GetHeaders() const final { + return req_.base(); + } + + std::error_code Send(http::Client* client) final; +}; + +class DynamicBodyRequestImpl : public HttpRequestBase { + using DynamicBodyRequest = boost::beast::http::request; + DynamicBodyRequest req_; + + public: + DynamicBodyRequestImpl(DynamicBodyRequestImpl&&) = default; + + explicit DynamicBodyRequestImpl(std::string_view url) + : req_(boost::beast::http::verb::post, boost::string_view{url.data(), url.size()}, 11) { + } + + template void SetBody(BodyArgs&& body_args) { + req_.body() = std::forward(body_args); + } + + void SetHeader(boost::beast::http::field f, std::string_view value) final { + req_.set(f, boost::string_view{value.data(), value.size()}); + } + + void Finalize() { + req_.prepare_payload(); + } + + const boost::beast::http::header& GetHeaders() const final { + return req_.base(); + } + + std::error_code Send(http::Client* client) final; +}; + +} // namespace detail class RobustSender { + RobustSender(const RobustSender&) = delete; + RobustSender& operator=(const RobustSender&) = delete; + public: using HeaderParserPtr = std::unique_ptr>; RobustSender(unsigned num_iterations, GCPCredsProvider* provider); - io::Result Send(http::Client* client, EmptyRequest* req); + io::Result Send(http::Client* client, detail::HttpRequestBase* req); private: unsigned num_iterations_; GCPCredsProvider* provider_; }; +std::string AuthHeader(std::string_view access_token); + +#define RETURN_UNEXPECTED(x) \ + do { \ + auto ec = (x); \ + if (ec) { \ + VLOG(1) << "Failed " << #x << ": " << ec.message(); \ + return nonstd::make_unexpected(ec); \ + } \ + } while (false) + } // namespace util::cloud \ No newline at end of file diff --git a/util/cloud/gcp/gcs.cc b/util/cloud/gcp/gcs.cc index 946c5422..03a70e61 100644 --- a/util/cloud/gcp/gcs.cc +++ b/util/cloud/gcp/gcs.cc @@ -30,13 +30,6 @@ auto Unexpected(std::errc code) { return nonstd::make_unexpected(make_error_code(code)); } -#define RETURN_UNEXPECTED(x) \ - do { \ - auto ec = (x); \ - if (ec) \ - return nonstd::make_unexpected(ec); \ - } while (false) - #define RETURN_ERROR(x) \ do { \ auto ec = (x); \ @@ -44,8 +37,7 @@ auto Unexpected(std::errc code) { return ec; \ } while (false) - -io::Result ExpandFile(string_view path) { +io::Result ExpandFilePath(string_view path) { io::Result res = io::StatFiles(path); if (!res) { @@ -60,7 +52,7 @@ io::Result ExpandFile(string_view path) { } std::error_code LoadGCPConfig(string* account_id, string* project_id) { - io::Result path = ExpandFile("~/.config/gcloud/configurations/config_default"); + io::Result path = ExpandFilePath("~/.config/gcloud/configurations/config_default"); if (!path) { return path.error(); } @@ -153,17 +145,49 @@ io::Result ParseTokenResponse(std::string&& response) { return result; } -#define FETCH_ARRAY_MEMBER(val) \ - if (!(val).IsArray()) \ - return make_error_code(errc::bad_message); \ - auto array = val.GetArray() +constexpr unsigned kTcpKeepAliveInterval = 30; + +error_code EnableKeepAlive(int fd) { + int val = 1; + if (setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &val, sizeof(val)) < 0) { + return std::error_code(errno, std::system_category()); + } + + val = kTcpKeepAliveInterval; + if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPINTVL, &val, sizeof(val)) < 0) { + return std::error_code(errno, std::system_category()); + } + + val = kTcpKeepAliveInterval; +#ifdef __APPLE__ + if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPALIVE, &val, sizeof(val)) < 0) { + return std::error_code(errno, std::system_category()); + } +#else + if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPIDLE, &val, sizeof(val)) < 0) { + return std::error_code(errno, std::system_category()); + } +#endif + + val = 3; + if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPCNT, &val, sizeof(val)) < 0) { + return std::error_code(errno, std::system_category()); + } + + return std::error_code{}; +} + +#define FETCH_ARRAY_MEMBER(val) \ + if (!(val).IsArray()) \ + return make_error_code(errc::bad_message); \ + auto array = val.GetArray() } // namespace error_code GCPCredsProvider::Init(unsigned connect_ms, fb2::ProactorBase* pb) { CHECK_GT(connect_ms, 0u); - io::Result root_path = ExpandFile("~/.config/gcloud"); + io::Result root_path = ExpandFilePath("~/.config/gcloud"); if (!root_path) { return root_path.error(); } @@ -213,8 +237,10 @@ error_code GCPCredsProvider::RefreshToken(fb2::ProactorBase* pb) { error_code ec = https_client.Connect(kDomain, "443", context); http::TlsClient::FreeContext(context); - if (ec) + if (ec) { + VLOG(1) << "Could not connect to " << kDomain; return ec; + } h2::request req{h2::verb::post, "/token", 11}; req.set(h2::field::host, kDomain); req.set(h2::field::content_type, "application/x-www-form-urlencoded"); @@ -255,22 +281,25 @@ GCS::~GCS() { std::error_code GCS::Connect(unsigned msec) { client_->set_connect_timeout_ms(msec); - - return client_->Connect(GCP_API_DOMAIN, "443", ssl_ctx_); + client_->AssignOnConnect([](int fd) { + auto ec = EnableKeepAlive(fd); + LOG_IF(WARNING, ec) << "Error setting keep alive " << ec.message() << " " << fd; + }); + return client_->Connect(GCS_API_DOMAIN, "443", ssl_ctx_); } error_code GCS::ListBuckets(ListBucketCb cb) { string url = absl::StrCat("/storage/v1/b?project=", creds_provider_.project_id()); - absl::StrAppend(&url, "&maxResults=50&fields=items,nextPageToken"); + absl::StrAppend(&url, "&maxResults=50&fields=items/id,nextPageToken"); - auto http_req = PrepareRequest(h2::verb::get, url, creds_provider_.access_token()); + detail::EmptyRequestImpl empty_req(h2::verb::get, url, creds_provider_.access_token()); rj::Document doc; RobustSender sender(2, &creds_provider_); while (true) { - io::Result parse_res = sender.Send(client_.get(), &http_req); + io::Result parse_res = sender.Send(client_.get(), &empty_req); if (!parse_res) return parse_res.error(); RobustSender::HeaderParserPtr empty_parser = std::move(*parse_res); @@ -294,7 +323,7 @@ error_code GCS::ListBuckets(ListBucketCb cb) { for (size_t i = 0; i < array.Size(); ++i) { const auto& item = array[i]; - auto it = item.FindMember("name"); + auto it = item.FindMember("id"); if (it != item.MemberEnd()) { cb(string_view{it->value.GetString(), it->value.GetStringLength()}); } @@ -305,13 +334,12 @@ error_code GCS::ListBuckets(ListBucketCb cb) { break; } absl::string_view page_token{it->value.GetString(), it->value.GetStringLength()}; - http_req.target(absl::StrCat(url, "&pageToken=", page_token)); + empty_req.SetUrl(absl::StrCat(url, "&pageToken=", page_token)); } return {}; } -error_code GCS::List(string_view bucket, string_view prefix, bool recursive, - ListObjectCb cb) { +error_code GCS::List(string_view bucket, string_view prefix, bool recursive, ListObjectCb cb) { CHECK(!bucket.empty()); string url = "/storage/v1/b/"; @@ -320,12 +348,13 @@ error_code GCS::List(string_view bucket, string_view prefix, bool recursive, if (!recursive) { absl::StrAppend(&url, "&delimiter=%2f"); } - auto http_req = PrepareRequest(h2::verb::get, url, creds_provider_.access_token()); + + detail::EmptyRequestImpl empty_req(h2::verb::get, url, creds_provider_.access_token()); rj::Document doc; RobustSender sender(2, &creds_provider_); while (true) { - io::Result parse_res = sender.Send(client_.get(), &http_req); + io::Result parse_res = sender.Send(client_.get(), &empty_req); if (!parse_res) return parse_res.error(); RobustSender::HeaderParserPtr empty_parser = std::move(*parse_res); @@ -372,10 +401,16 @@ error_code GCS::List(string_view bucket, string_view prefix, bool recursive, break; } absl::string_view page_token{it->value.GetString(), it->value.GetStringLength()}; - http_req.target(absl::StrCat(url, "&pageToken=", page_token)); + empty_req.SetUrl(absl::StrCat(url, "&pageToken=", page_token)); } return {}; } +unique_ptr GCS::CreateConnectionPool() const { + unique_ptr res( + new http::ClientPool(GCS_API_DOMAIN, ssl_ctx_, client_->proactor())); + return res; +} + } // namespace cloud } // namespace util \ No newline at end of file diff --git a/util/cloud/gcp/gcs.h b/util/cloud/gcp/gcs.h index ee1ca918..50c63692 100644 --- a/util/cloud/gcp/gcs.h +++ b/util/cloud/gcp/gcs.h @@ -9,6 +9,7 @@ #include "util/cloud/gcp/gcp_creds_provider.h" #include "util/http/http_client.h" +#include "util/http/https_client_pool.h" typedef struct ssl_ctx_st SSL_CTX; @@ -35,6 +36,9 @@ class GCS { std::error_code ListBuckets(ListBucketCb cb); std::error_code List(std::string_view bucket, std::string_view prefix, bool recursive, ListObjectCb cb); + + std::unique_ptr CreateConnectionPool() const; + private: GCPCredsProvider& creds_provider_; SSL_CTX* ssl_ctx_; diff --git a/util/cloud/gcp/gcs_file.cc b/util/cloud/gcp/gcs_file.cc index fb2959d8..3dcc5536 100644 --- a/util/cloud/gcp/gcs_file.cc +++ b/util/cloud/gcp/gcs_file.cc @@ -4,52 +4,254 @@ #include "util/cloud/gcp/gcs_file.h" #include +#include #include +#include +#include "base/flags.h" +#include "base/logging.h" #include "strings/escaping.h" #include "util/cloud/gcp/gcp_utils.h" +#include "util/http/http_common.h" + +ABSL_FLAG(bool, gcs_dry_upload, false, ""); namespace util { namespace cloud { using namespace std; namespace h2 = boost::beast::http; +using boost::beast::multi_buffer; +using HeaderParserPtr = RobustSender::HeaderParserPtr; namespace { +//! [from, to) limited range out of total. If total is < 0 then it's unknown. +string ContentRangeHeader(size_t from, size_t to, ssize_t total) { + DCHECK_LE(from, to); + string tmp{"bytes "}; + + if (from < to) { // common case. + absl::StrAppend(&tmp, from, "-", to - 1, "/"); // content-range is inclusive. + if (total >= 0) { + absl::StrAppend(&tmp, total); + } else { + tmp.push_back('*'); + } + } else { + // We can write empty ranges only when we finalize the file and total is known. + DCHECK_GE(total, 0); + absl::StrAppend(&tmp, "*/", total); + } + + return tmp; +} + +// File handle that writes to GCS. +// +// This uses multipart uploads, where it will buffer upto the configured part +// size before uploading. +class GcsWriteFile : public io::WriteFile { + public: + // Writes bytes to the GCS object. This will either buffer internally or + // write a part to GCS. + io::Result WriteSome(const iovec* v, uint32_t len) override; + + // Closes the object and completes the multipart upload. Therefore the object + // will not be uploaded unless Close is called. + error_code Close() override; + + GcsWriteFile(const string_view key, string_view upload_id, size_t part_size, + http::ClientPool* pool, GCPCredsProvider* creds_provider); + + private: + error_code FillBuf(const uint8* buffer, size_t length); + error_code Upload(); + + using UploadRequest = detail::DynamicBodyRequestImpl; + unique_ptr PrepareRequest(size_t to, ssize_t total); + + string upload_id_; + multi_buffer body_mb_; + size_t uploaded_ = 0; + http::ClientPool* pool_; + GCPCredsProvider* creds_provider_; +}; + +GcsWriteFile::GcsWriteFile(string_view key, string_view upload_id, size_t part_size, + http::ClientPool* pool, GCPCredsProvider* creds_provider) + : io::WriteFile(key), upload_id_(upload_id), body_mb_(part_size), pool_(pool), + creds_provider_(creds_provider) { +} + +io::Result GcsWriteFile::WriteSome(const iovec* v, uint32_t len) { + size_t total = 0; + for (uint32_t i = 0; i < len; ++i) { + RETURN_UNEXPECTED(FillBuf(reinterpret_cast(v->iov_base), v->iov_len)); + total += v->iov_len; + } + return total; +} + +error_code GcsWriteFile::Close() { + size_t to = uploaded_ + body_mb_.size(); + auto req = PrepareRequest(to, to); + + string body; + if (!absl::GetFlag(FLAGS_gcs_dry_upload)) { + RobustSender sender(3, creds_provider_); + auto client_handle = pool_->GetHandle(); + io::Result res = sender.Send(client_handle.get(), req.get()); + if (!res) { + LOG(ERROR) << "Error closing GCS file " << create_file_name() << " for request: \n" + << req->GetHeaders() << ", status " << res.error().message(); + return res.error(); + } + HeaderParserPtr head_parser = std::move(*res); + h2::response_parser resp(std::move(*head_parser)); + auto ec = client_handle->Recv(&resp); + if (ec) + return ec; + body = std::move(resp.get().body()); + + /* + body is in a json reponse with all the metadata of the object + { + "kind": "storage#object", + "id": "mybucket/roman/bar_0/1720889888465538", + "selfLink": "https://www.googleapis.com/storage/v1/b/mybucket/o/roman%2Fbar_0", + "mediaLink": + "https://storage.googleapis.com/download/storage/v1/b/mybucket/o/roman%2Fbar_0?generation=1720889888465538&alt=media", + "name": "roman/bar_0", + "bucket": "mybucket", + "generation": "1720889888465538", + "metageneration": "1", + "storageClass": "STANDARD", + "size": "27270144", + "md5Hash": "O/P7e3k8qxRQaomHhn0H9Q==", + "crc32c": "8s2ltw==", + "etag": "CILF/7O+pIcDEAE=", + "timeCreated": "2024-07-13T16:58:08.476Z", + "updated": "2024-07-13T16:58:08.476Z", + "timeStorageClassUpdated": "2024-07-13T16:58:08.476Z" + } + + */ + } + + VLOG(1) << "Closed file " << req->GetHeaders() << "\n" << body; + + return {}; +} + +error_code GcsWriteFile::FillBuf(const uint8* buffer, size_t length) { + while (length >= body_mb_.max_size() - body_mb_.size()) { + size_t prepare_size = body_mb_.max_size() - body_mb_.size(); + auto mbs = body_mb_.prepare(prepare_size); + size_t offs = 0; + for (auto mb : mbs) { + memcpy(mb.data(), buffer + offs, mb.size()); + offs += mb.size(); + } + DCHECK_EQ(offs, prepare_size); + body_mb_.commit(prepare_size); + + auto ec = Upload(); + if (ec) + return ec; + + length -= prepare_size; + buffer += prepare_size; + } + + if (length) { + auto mbs = body_mb_.prepare(length); + for (auto mb : mbs) { + memcpy(mb.data(), buffer, mb.size()); + buffer += mb.size(); + } + body_mb_.commit(length); + } + return {}; +} + +error_code GcsWriteFile::Upload() { + size_t body_size = body_mb_.size(); + CHECK_GT(body_size, 0u); + CHECK_EQ(0u, body_size % (1U << 18)) << body_size; // Must be multiple of 256KB. + + size_t to = uploaded_ + body_size; + + auto req = PrepareRequest(to, -1); + + error_code res; + if (!absl::GetFlag(FLAGS_gcs_dry_upload)) { + // TODO: RobustSender must access the entire pool, not just a single client. + RobustSender sender(3, creds_provider_); + auto client_handle = pool_->GetHandle(); + io::Result res = sender.Send(client_handle.get(), req.get()); + if (!res) + return res.error(); + + VLOG(1) << "Uploaded range " << uploaded_ << "/" << to << " for " << upload_id_; + HeaderParserPtr parser_ptr = std::move(*res); + const auto& resp_msg = parser_ptr->get(); + auto it = resp_msg.find(h2::field::range); + CHECK(it != resp_msg.end()) << resp_msg; + + string_view range = detail::FromBoostSV(it->value()); + CHECK(absl::ConsumePrefix(&range, "bytes=")); + size_t pos = range.find('-'); + CHECK_LT(pos, range.size()); + size_t uploaded_pos = 0; + CHECK(absl::SimpleAtoi(range.substr(pos + 1), &uploaded_pos)); + CHECK_EQ(uploaded_pos + 1, to); + } + + uploaded_ = to; + return {}; +} + +auto GcsWriteFile::PrepareRequest(size_t to, ssize_t total) -> unique_ptr { + unique_ptr upload_req(new UploadRequest(upload_id_)); + + upload_req->SetBody(std::move(body_mb_)); + upload_req->SetHeader(h2::field::content_range, ContentRangeHeader(uploaded_, to, total)); + upload_req->SetHeader(h2::field::content_type, http::kBinMime); + upload_req->Finalize(); + + return upload_req; +} } // namespace -io::Result GcsWriteFile::Open(const string& bucket, const string& key, - GCPCredsProvider* creds_provider, - http::ClientPool* pool, size_t part_size) { +io::Result OpenWriteGcsFile(const string& bucket, const string& key, + GCPCredsProvider* creds_provider, + http::ClientPool* pool, size_t part_size) { string url = "/upload/storage/v1/b/"; absl::StrAppend(&url, bucket, "/o?uploadType=resumable&name="); strings::AppendUrlEncoded(key, &url); string token = creds_provider->access_token(); - auto req = PrepareRequest(h2::verb::post, url, token); - string upload_id; -#if 0 - ApiSenderDynamicBody sender("start_write", gce, pool); - auto res = sender.SendGeneric(3, std::move(req)); - if (!res.ok()) - return res.status; - - const auto& resp = sender.parser()->get(); + detail::EmptyRequestImpl empty_req(h2::verb::post, url, token); + empty_req.Finalize(); // it's post request so it's required. - // HttpsClientPool::ClientHandle handle = std::move(res.obj); - - auto it = resp.find(h2::field::location); - if (it == resp.end()) { - return Status(StatusCode::PARSE_ERROR, "Can not find location header"); + RobustSender sender(3, creds_provider); + auto client_handle = pool->GetHandle(); + io::Result res = sender.Send(client_handle.get(), &empty_req); + if (!res) { + return nonstd::make_unexpected(res.error()); } - string upload_id = string(it->value()); - -#endif + HeaderParserPtr parser_ptr = std::move(*res); + const auto& headers = parser_ptr->get(); + auto it = headers.find(h2::field::location); + if (it == headers.end()) { + LOG(ERROR) << "Could not find location in " << headers; + return nonstd::make_unexpected(make_error_code(errc::connection_refused)); + } - return new GcsWriteFile(key, upload_id, part_size, pool); + return new GcsWriteFile(key, detail::FromBoostSV(it->value()), part_size, pool, creds_provider); } } // namespace cloud diff --git a/util/cloud/gcp/gcs_file.h b/util/cloud/gcp/gcs_file.h index 8b0ab968..c6bffe93 100644 --- a/util/cloud/gcp/gcs_file.h +++ b/util/cloud/gcp/gcs_file.h @@ -4,39 +4,19 @@ #pragma once #include "io/file.h" -#include "util/http/https_client_pool.h" #include "util/cloud/gcp/gcp_creds_provider.h" +#include "util/http/https_client_pool.h" namespace util { namespace cloud { -// File handle that writes to GCS. -// -// This uses multipart uploads, where it will buffer upto the configured part -// size before uploading. -class GcsWriteFile : public io::WriteFile { - public: - static constexpr size_t kDefaultPartSize = 1ULL << 23; // 8MB. - - // Writes bytes to the GCS object. This will either buffer internally or - // write a part to GCS. - io::Result WriteSome(const iovec* v, uint32_t len) override; - - // Closes the object and completes the multipart upload. Therefore the object - // will not be uploaded unless Close is called. - std::error_code Close() override; - - static io::Result Open(const std::string& bucket, const std::string& key, - GCPCredsProvider* creds_provider, - http::ClientPool* pool, size_t part_size = kDefaultPartSize); - - private: - GcsWriteFile(const std::string& key, const std::string& upload_id, - size_t part_size, http::ClientPool* pool); +static constexpr size_t kDefaultGCPPartSize = 1ULL << 23; // 8MB. - std::string upload_id_; -}; +io::Result OpenWriteGcsFile(const std::string& bucket, const std::string& key, + GCPCredsProvider* creds_provider, + http::ClientPool* pool, + size_t part_size = kDefaultGCPPartSize); } // namespace cloud } // namespace util \ No newline at end of file diff --git a/util/tls/tls_socket.cc b/util/tls/tls_socket.cc index a54375f4..53c40447 100644 --- a/util/tls/tls_socket.cc +++ b/util/tls/tls_socket.cc @@ -327,6 +327,8 @@ io::Result TlsSocket::WriteSome(const iovec* ptr, uint32_t len) { } io::Result TlsSocket::SendBuffer(Engine::Buffer buf) { + DVLOG(2) << "TlsSocket::SendBuffer " << buf.size() << " bytes"; + // Sending buffer into ssl. DCHECK(engine_); DCHECK_GT(buf.size(), 0u); @@ -364,6 +366,13 @@ io::Result TlsSocket::SendBuffer(Engine::Buffer buf) { return make_unexpected(ec); } + // Usually we want to batch writes as much as possible, but here we can not now if more writes + // will follow. We must flush the output buffer, so that data will be sent down the socket. + error_code ec = MaybeSendOutput(); + if (ec) { + return make_unexpected(ec); + } + return send_total; } @@ -436,6 +445,7 @@ error_code TlsSocket::HandleSocketWrite() { if (buffer.empty()) return {}; + DVLOG(2) << "HandleSocketWrite " << buffer.size(); // we do not allow concurrent writes from multiple fibers. state_ |= WRITE_IN_PROGRESS; while (!buffer.empty()) {