Skip to content

Commit

Permalink
chore: gcs write file (#297)
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Gershman <[email protected]>
  • Loading branch information
romange committed Jul 13, 2024
1 parent 281d740 commit 2290be2
Show file tree
Hide file tree
Showing 8 changed files with 460 additions and 110 deletions.
35 changes: 29 additions & 6 deletions examples/gcs_demo.cc
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
// Copyright 2024, Roman Gershman. All rights reserved.
// See LICENSE for licensing terms.

#include <absl/strings/str_cat.h>
#include "base/flags.h"
#include "base/init.h"
#include "base/logging.h"
#include "io/file_util.h"
#include "util/cloud/gcp/gcs.h"
#include "util/cloud/gcp/gcs_file.h"
#include "util/fibers/pool.h"

using namespace std;
using namespace boost;
using namespace util;

using absl::GetFlag;

ABSL_FLAG(string, bucket, "", "");
ABSL_FLAG(string, prefix, "", "");

ABSL_FLAG(uint32_t, write, 0, "");
ABSL_FLAG(uint32_t, connect_ms, 2000, "");
ABSL_FLAG(bool, epoll, false, "Whether to use epoll instead of io_uring");

Expand All @@ -32,10 +34,31 @@ void Run(SSL_CTX* ctx) {

string prefix = GetFlag(FLAGS_prefix);
if (!prefix.empty()) {
auto cb = [](cloud::GCS::ObjectItem item) {
cout << "Object: " << item.key << ", size: " << item.size << endl;
};
ec = gcs.List(GetFlag(FLAGS_bucket), prefix, false, cb);
string bucket = GetFlag(FLAGS_bucket);
auto conn_pool = gcs.CreateConnectionPool();
CHECK(!bucket.empty());

if (GetFlag(FLAGS_write) > 0) {
auto src = io::ReadFileToString("/proc/self/exe");
CHECK(src);
for (unsigned i = 0; i < GetFlag(FLAGS_write); ++i) {
string dest_key = absl::StrCat(prefix, "_", i);
io::Result<io::WriteFile*> dest_res =
cloud::OpenWriteGcsFile(bucket, dest_key, &provider, conn_pool.get());
CHECK(dest_res) << "Could not open " << dest_key << " " << dest_res.error().message();
unique_ptr<io::WriteFile> dest(*dest_res);
error_code ec = dest->Write(*src);
CHECK(!ec);
ec = dest->Close();
CHECK(!ec);
CONSOLE_INFO << "Written " << dest_key;
}
} else {
auto cb = [](cloud::GCS::ObjectItem item) {
cout << "Object: " << item.key << ", size: " << item.size << endl;
};
ec = gcs.List(GetFlag(FLAGS_bucket), prefix, false, cb);
}
} else {
auto cb = [](std::string_view bname) { CONSOLE_INFO << bname; };

Expand Down
58 changes: 35 additions & 23 deletions util/cloud/gcp/gcp_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
#include "base/logging.h"
#include "util/cloud/gcp/gcp_creds_provider.h"

#define RETURN_UNEXPECTED(x) \
do { \
auto ec = (x); \
if (ec) \
return nonstd::make_unexpected(ec); \
} while (false)

namespace util::cloud {
using namespace std;
namespace h2 = boost::beast::http;
Expand All @@ -37,34 +30,54 @@ inline bool DoesServerPushback(h2::status st) {
h2::to_status_class(st) == h2::status_class::server_error;
}

constexpr auto kResumeIncomplete = h2::status::permanent_redirect;

bool IsResponseOK(h2::status st) {
// Partial content can appear because of the previous reconnect.
// For multipart uploads kResumeIncomplete can be returned.
return st == h2::status::ok || st == h2::status::partial_content || st == kResumeIncomplete;
}

} // namespace

const char GCP_API_DOMAIN[] = "www.googleapis.com";
const char GCS_API_DOMAIN[] = "storage.googleapis.com";

string AuthHeader(string_view access_token) {
return absl::StrCat("Bearer ", access_token);
}

EmptyRequest PrepareRequest(h2::verb req_verb, std::string_view url,
const string_view access_token) {
EmptyRequest req{req_verb, boost::beast::string_view{url.data(), url.size()}, 11};
req.set(h2::field::host, GCP_API_DOMAIN);
req.set(h2::field::authorization, AuthHeader(access_token));
req.keep_alive(true);
namespace detail {

EmptyRequestImpl::EmptyRequestImpl(h2::verb req_verb, std::string_view url,
const string_view access_token)
: req_{req_verb, boost::beast::string_view{url.data(), url.size()}, 11} {
req_.set(h2::field::host, GCS_API_DOMAIN);
req_.set(h2::field::authorization, AuthHeader(access_token));
// ? req_.keep_alive(true);
}

std::error_code EmptyRequestImpl::Send(http::Client* client) {
return client->Send(req_);
}

return req;
std::error_code DynamicBodyRequestImpl::Send(http::Client* client) {
return client->Send(req_);
}

} // namespace detail

RobustSender::RobustSender(unsigned num_iterations, GCPCredsProvider* provider)
: num_iterations_(num_iterations), provider_(provider) {
}

auto RobustSender::Send(http::Client* client, EmptyRequest* req) -> io::Result<HeaderParserPtr> {
auto RobustSender::Send(http::Client* client,
detail::HttpRequestBase* req) -> io::Result<HeaderParserPtr> {
error_code ec;
for (unsigned i = 0; i < num_iterations_; ++i) { // Iterate for possible token refresh.
VLOG(1) << "HttpReq" << i << ": " << *req << ", socket " << client->native_handle();
VLOG(1) << "HttpReq " << client->host() << ": " << req->GetHeaders() << ", ["
<< client->native_handle() << "]";

RETURN_UNEXPECTED(client->Send(*req));
RETURN_UNEXPECTED(req->Send(client));
HeaderParserPtr parser(new h2::response_parser<h2::empty_body>());
RETURN_UNEXPECTED(client->ReadHeader(parser.get()));
{
Expand All @@ -75,11 +88,11 @@ auto RobustSender::Send(http::Client* client, EmptyRequest* req) -> io::Result<H
LOG(FATAL) << "TBD: Schedule reconnect due to conn-close header";
}

// Partial content can appear because of the previous reconnect.
if (msg.result() == h2::status::ok || msg.result() == h2::status::partial_content) {
if (IsResponseOK(msg.result())) {
return parser;
}
}

// We have some kind of error, possibly with body that needs to be drained.
h2::response_parser<h2::string_body> drainer(std::move(*parser));
RETURN_UNEXPECTED(client->Recv(&drainer));
Expand All @@ -88,14 +101,13 @@ auto RobustSender::Send(http::Client* client, EmptyRequest* req) -> io::Result<H
if (DoesServerPushback(msg.result())) {
LOG(INFO) << "Retrying(" << client->native_handle() << ") with " << msg;

ThisFiber::SleepFor(1s);
i = 0; // Can potentially deadlock
ThisFiber::SleepFor(100ms);
continue;
}

if (IsUnauthorized(msg)) {
RETURN_UNEXPECTED(provider_->RefreshToken(client->proactor()));
req->set(h2::field::authorization, AuthHeader(provider_->access_token()));
req->SetHeader(h2::field::authorization, AuthHeader(provider_->access_token()));

continue;
}
Expand Down
96 changes: 90 additions & 6 deletions util/cloud/gcp/gcp_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,112 @@

namespace util::cloud {
class GCPCredsProvider;
extern const char GCS_API_DOMAIN[];

extern const char GCP_API_DOMAIN[];
namespace detail {
inline std::string_view FromBoostSV(boost::string_view sv) {
return std::string_view(sv.data(), sv.size());
}

using EmptyRequest = boost::beast::http::request<boost::beast::http::empty_body>;
class HttpRequestBase {
public:
HttpRequestBase(const HttpRequestBase&) = delete;
HttpRequestBase& operator=(const HttpRequestBase&) = delete;
HttpRequestBase() = default;

EmptyRequest PrepareRequest(boost::beast::http::verb req_verb, std::string_view url,
const std::string_view access_token);
virtual ~HttpRequestBase() = default;
virtual std::error_code Send(http::Client* client) = 0;

std::string AuthHeader(std::string_view access_token);
virtual const boost::beast::http::header<true>& GetHeaders() const = 0;

virtual void SetHeader(boost::beast::http::field f, std::string_view value) = 0;
};

class EmptyRequestImpl : public HttpRequestBase {
using EmptyRequest = boost::beast::http::request<boost::beast::http::empty_body>;
EmptyRequest req_;

public:
EmptyRequestImpl(boost::beast::http::verb req_verb, std::string_view url,
const std::string_view access_token);

void SetUrl(std::string_view url) {
req_.target(boost::string_view{url.data(), url.size()});
}

void Finalize() {
req_.prepare_payload();
}

void SetHeader(boost::beast::http::field f, std::string_view value) final {
req_.set(f, boost::string_view{value.data(), value.size()});
}

const boost::beast::http::header<true>& GetHeaders() const final {
return req_.base();
}

std::error_code Send(http::Client* client) final;
};

class DynamicBodyRequestImpl : public HttpRequestBase {
using DynamicBodyRequest = boost::beast::http::request<boost::beast::http::dynamic_body>;
DynamicBodyRequest req_;

public:
DynamicBodyRequestImpl(DynamicBodyRequestImpl&&) = default;

explicit DynamicBodyRequestImpl(std::string_view url)
: req_(boost::beast::http::verb::post, boost::string_view{url.data(), url.size()}, 11) {
}

template <typename BodyArgs> void SetBody(BodyArgs&& body_args) {
req_.body() = std::forward<BodyArgs>(body_args);
}

void SetHeader(boost::beast::http::field f, std::string_view value) final {
req_.set(f, boost::string_view{value.data(), value.size()});
}

void Finalize() {
req_.prepare_payload();
}

const boost::beast::http::header<true>& GetHeaders() const final {
return req_.base();
}

std::error_code Send(http::Client* client) final;
};

} // namespace detail

class RobustSender {
RobustSender(const RobustSender&) = delete;
RobustSender& operator=(const RobustSender&) = delete;

public:
using HeaderParserPtr =
std::unique_ptr<boost::beast::http::response_parser<boost::beast::http::empty_body>>;

RobustSender(unsigned num_iterations, GCPCredsProvider* provider);

io::Result<HeaderParserPtr> Send(http::Client* client, EmptyRequest* req);
io::Result<HeaderParserPtr> Send(http::Client* client, detail::HttpRequestBase* req);

private:
unsigned num_iterations_;
GCPCredsProvider* provider_;
};

std::string AuthHeader(std::string_view access_token);

#define RETURN_UNEXPECTED(x) \
do { \
auto ec = (x); \
if (ec) { \
VLOG(1) << "Failed " << #x << ": " << ec.message(); \
return nonstd::make_unexpected(ec); \
} \
} while (false)

} // namespace util::cloud
Loading

0 comments on commit 2290be2

Please sign in to comment.