Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: gcs write file #297

Merged
merged 1 commit into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions examples/gcs_demo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
#include "base/flags.h"
#include "base/init.h"
#include "base/logging.h"
#include "io/file_util.h"
#include "util/cloud/gcp/gcs.h"
#include "util/cloud/gcp/gcs_file.h"
#include "util/fibers/pool.h"

using namespace std;
using namespace boost;
using namespace util;

using absl::GetFlag;

ABSL_FLAG(string, bucket, "", "");
ABSL_FLAG(string, prefix, "", "");

ABSL_FLAG(bool, write, false, "");
ABSL_FLAG(uint32_t, connect_ms, 2000, "");
ABSL_FLAG(bool, epoll, false, "Whether to use epoll instead of io_uring");

Expand All @@ -32,10 +33,27 @@ void Run(SSL_CTX* ctx) {

string prefix = GetFlag(FLAGS_prefix);
if (!prefix.empty()) {
auto cb = [](cloud::GCS::ObjectItem item) {
cout << "Object: " << item.key << ", size: " << item.size << endl;
};
ec = gcs.List(GetFlag(FLAGS_bucket), prefix, false, cb);
string bucket = GetFlag(FLAGS_bucket);
auto conn_pool = gcs.CreateConnectionPool();
CHECK(!bucket.empty());

if (GetFlag(FLAGS_write)) {
auto src = io::ReadFileToString("/proc/self/exe");
CHECK(src);
io::Result<io::WriteFile*> dest_res =
cloud::OpenWriteGcsFile(bucket, prefix, &provider, conn_pool.get());
CHECK(dest_res);
unique_ptr<io::WriteFile> dest(*dest_res);
error_code ec = dest->Write(*src);
CHECK(!ec);
ec = dest->Close();
CHECK(!ec);
} else {
auto cb = [](cloud::GCS::ObjectItem item) {
cout << "Object: " << item.key << ", size: " << item.size << endl;
};
ec = gcs.List(GetFlag(FLAGS_bucket), prefix, false, cb);
}
} else {
auto cb = [](std::string_view bname) { CONSOLE_INFO << bname; };

Expand Down
58 changes: 35 additions & 23 deletions util/cloud/gcp/gcp_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
#include "base/logging.h"
#include "util/cloud/gcp/gcp_creds_provider.h"

#define RETURN_UNEXPECTED(x) \
do { \
auto ec = (x); \
if (ec) \
return nonstd::make_unexpected(ec); \
} while (false)

namespace util::cloud {
using namespace std;
namespace h2 = boost::beast::http;
Expand All @@ -37,34 +30,54 @@ inline bool DoesServerPushback(h2::status st) {
h2::to_status_class(st) == h2::status_class::server_error;
}

constexpr auto kResumeIncomplete = h2::status::permanent_redirect;

bool IsResponseOK(h2::status st) {
// Partial content can appear because of the previous reconnect.
// For multipart uploads kResumeIncomplete can be returned.
return st == h2::status::ok || st == h2::status::partial_content || st == kResumeIncomplete;
}

} // namespace

const char GCP_API_DOMAIN[] = "www.googleapis.com";
const char GCS_API_DOMAIN[] = "storage.googleapis.com";

string AuthHeader(string_view access_token) {
return absl::StrCat("Bearer ", access_token);
}

EmptyRequest PrepareRequest(h2::verb req_verb, std::string_view url,
const string_view access_token) {
EmptyRequest req{req_verb, boost::beast::string_view{url.data(), url.size()}, 11};
req.set(h2::field::host, GCP_API_DOMAIN);
req.set(h2::field::authorization, AuthHeader(access_token));
req.keep_alive(true);
namespace detail {

EmptyRequestImpl::EmptyRequestImpl(h2::verb req_verb, std::string_view url,
const string_view access_token)
: req_{req_verb, boost::beast::string_view{url.data(), url.size()}, 11} {
req_.set(h2::field::host, GCS_API_DOMAIN);
req_.set(h2::field::authorization, AuthHeader(access_token));
// ? req_.keep_alive(true);
}

std::error_code EmptyRequestImpl::Send(http::Client* client) {
return client->Send(req_);
}

return req;
std::error_code DynamicBodyRequestImpl::Send(http::Client* client) {
return client->Send(req_);
}

} // namespace detail

RobustSender::RobustSender(unsigned num_iterations, GCPCredsProvider* provider)
: num_iterations_(num_iterations), provider_(provider) {
}

auto RobustSender::Send(http::Client* client, EmptyRequest* req) -> io::Result<HeaderParserPtr> {
auto RobustSender::Send(http::Client* client,
detail::HttpRequestBase* req) -> io::Result<HeaderParserPtr> {
error_code ec;
for (unsigned i = 0; i < num_iterations_; ++i) { // Iterate for possible token refresh.
VLOG(1) << "HttpReq" << i << ": " << *req << ", socket " << client->native_handle();
VLOG(1) << "HttpReq " << client->host() << ": " << req->GetHeaders() << ", ["
<< client->native_handle() << "]";

RETURN_UNEXPECTED(client->Send(*req));
RETURN_UNEXPECTED(req->Send(client));
HeaderParserPtr parser(new h2::response_parser<h2::empty_body>());
RETURN_UNEXPECTED(client->ReadHeader(parser.get()));
{
Expand All @@ -75,11 +88,11 @@ auto RobustSender::Send(http::Client* client, EmptyRequest* req) -> io::Result<H
LOG(FATAL) << "TBD: Schedule reconnect due to conn-close header";
}

// Partial content can appear because of the previous reconnect.
if (msg.result() == h2::status::ok || msg.result() == h2::status::partial_content) {
if (IsResponseOK(msg.result())) {
return parser;
}
}

// We have some kind of error, possibly with body that needs to be drained.
h2::response_parser<h2::string_body> drainer(std::move(*parser));
RETURN_UNEXPECTED(client->Recv(&drainer));
Expand All @@ -88,14 +101,13 @@ auto RobustSender::Send(http::Client* client, EmptyRequest* req) -> io::Result<H
if (DoesServerPushback(msg.result())) {
LOG(INFO) << "Retrying(" << client->native_handle() << ") with " << msg;

ThisFiber::SleepFor(1s);
i = 0; // Can potentially deadlock
ThisFiber::SleepFor(100ms);
continue;
}

if (IsUnauthorized(msg)) {
RETURN_UNEXPECTED(provider_->RefreshToken(client->proactor()));
req->set(h2::field::authorization, AuthHeader(provider_->access_token()));
req->SetHeader(h2::field::authorization, AuthHeader(provider_->access_token()));

continue;
}
Expand Down
94 changes: 88 additions & 6 deletions util/cloud/gcp/gcp_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,110 @@

namespace util::cloud {
class GCPCredsProvider;
extern const char GCS_API_DOMAIN[];

extern const char GCP_API_DOMAIN[];
namespace detail {
inline std::string_view FromBoostSV(boost::string_view sv) {
return std::string_view(sv.data(), sv.size());
}

using EmptyRequest = boost::beast::http::request<boost::beast::http::empty_body>;
class HttpRequestBase {
public:
HttpRequestBase(const HttpRequestBase&) = delete;
HttpRequestBase& operator=(const HttpRequestBase&) = delete;
HttpRequestBase() = default;

EmptyRequest PrepareRequest(boost::beast::http::verb req_verb, std::string_view url,
const std::string_view access_token);
virtual ~HttpRequestBase() = default;
virtual std::error_code Send(http::Client* client) = 0;

std::string AuthHeader(std::string_view access_token);
virtual const boost::beast::http::header<true>& GetHeaders() const = 0;

virtual void SetHeader(boost::beast::http::field f, std::string_view value) = 0;
};

class EmptyRequestImpl : public HttpRequestBase {
using EmptyRequest = boost::beast::http::request<boost::beast::http::empty_body>;
EmptyRequest req_;

public:
EmptyRequestImpl(boost::beast::http::verb req_verb, std::string_view url,
const std::string_view access_token);

void SetUrl(std::string_view url) {
req_.target(boost::string_view{url.data(), url.size()});
}

void Finalize() {
req_.prepare_payload();
}

void SetHeader(boost::beast::http::field f, std::string_view value) final {
req_.set(f, boost::string_view{value.data(), value.size()});
}

const boost::beast::http::header<true>& GetHeaders() const final {
return req_.base();
}

std::error_code Send(http::Client* client) final;
};

class DynamicBodyRequestImpl : public HttpRequestBase {
using DynamicBodyRequest = boost::beast::http::request<boost::beast::http::dynamic_body>;
DynamicBodyRequest req_;

public:
DynamicBodyRequestImpl(DynamicBodyRequestImpl&&) = default;

explicit DynamicBodyRequestImpl(std::string_view url)
: req_(boost::beast::http::verb::post, boost::string_view{url.data(), url.size()}, 11) {
}

template<typename BodyArgs> void SetBody(BodyArgs&& body_args) {
req_.body() = std::forward<BodyArgs>(body_args);
}

void SetHeader(boost::beast::http::field f, std::string_view value) final {
req_.set(f, boost::string_view{value.data(), value.size()});
}

void Finalize() {
req_.prepare_payload();
}

const boost::beast::http::header<true>& GetHeaders() const final {
return req_.base();
}

std::error_code Send(http::Client* client) final;
};

} // namespace detail

class RobustSender {
RobustSender(const RobustSender&) = delete;
RobustSender& operator=(const RobustSender&) = delete;

public:
using HeaderParserPtr =
std::unique_ptr<boost::beast::http::response_parser<boost::beast::http::empty_body>>;

RobustSender(unsigned num_iterations, GCPCredsProvider* provider);

io::Result<HeaderParserPtr> Send(http::Client* client, EmptyRequest* req);
io::Result<HeaderParserPtr> Send(http::Client* client, detail::HttpRequestBase* req);

private:
unsigned num_iterations_;
GCPCredsProvider* provider_;
};

std::string AuthHeader(std::string_view access_token);

#define RETURN_UNEXPECTED(x) \
do { \
auto ec = (x); \
if (ec) \
return nonstd::make_unexpected(ec); \
} while (false)

} // namespace util::cloud
Loading
Loading