Skip to content

Commit

Permalink
chore: fixes in tls and client sockets
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Gershman <[email protected]>
  • Loading branch information
romange committed Jul 12, 2024
1 parent bd38683 commit eb166e1
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 65 deletions.
7 changes: 4 additions & 3 deletions util/fiber_socket_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class FiberSocketBase : public io::Sink, public io::AsyncSink, public io::Source

ABSL_MUST_USE_RESULT virtual AcceptResult Accept() = 0;

ABSL_MUST_USE_RESULT virtual error_code Connect(const endpoint_type& ep) = 0;
ABSL_MUST_USE_RESULT virtual error_code Connect(const endpoint_type& ep,
std::function<void(int)> on_pre_connect = {}) = 0;

ABSL_MUST_USE_RESULT virtual error_code Close() = 0;

Expand Down Expand Up @@ -200,8 +201,8 @@ class LinuxSocketBase : public FiberSocketBase {
// gives me 256M descriptors.
int32_t fd_;

private:
uint32_t timeout_ = UINT32_MAX;
private:
uint32_t timeout_ = UINT32_MAX;
};

void SetNonBlocking(int fd);
Expand Down
6 changes: 4 additions & 2 deletions util/fibers/epoll_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ auto EpollSocket::Accept() -> AcceptResult {
return nonstd::make_unexpected(ec);
}

auto EpollSocket::Connect(const endpoint_type& ep) -> error_code {
error_code EpollSocket::Connect(const endpoint_type& ep, std::function<void(int)> on_pre_connect) {
CHECK_EQ(fd_, -1);
CHECK(proactor() && proactor()->InMyThread());

Expand All @@ -208,7 +208,9 @@ auto EpollSocket::Connect(const endpoint_type& ep) -> error_code {
write_context_ = detail::FiberActive();
absl::Cleanup clean = [this]() { write_context_ = nullptr; };

// RegisterEvents(GetProactor()->ev_loop_fd(), fd, arm_index_ + 1024);
if (on_pre_connect) {
on_pre_connect(fd);
}

DVSOCK(2) << "Connecting";

Expand Down
3 changes: 2 additions & 1 deletion util/fibers/epoll_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class EpollSocket : public LinuxSocketBase {

ABSL_MUST_USE_RESULT AcceptResult Accept() final;

ABSL_MUST_USE_RESULT error_code Connect(const endpoint_type& ep) final;
ABSL_MUST_USE_RESULT error_code Connect(const endpoint_type& ep,
std::function<void(int)> on_pre_connect) final;
ABSL_MUST_USE_RESULT error_code Close() final;

// Really need here expected.
Expand Down
11 changes: 6 additions & 5 deletions util/fibers/uring_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ auto UringSocket::Accept() -> AcceptResult {
return fs;
}

auto UringSocket::Connect(const endpoint_type& ep) -> error_code {
auto UringSocket::Connect(const endpoint_type& ep, std::function<void(int)> on_pre_connect) -> error_code {
CHECK_EQ(fd_, -1);
CHECK(proactor() && proactor()->InMyThread());

Expand All @@ -163,12 +163,13 @@ auto UringSocket::Connect(const endpoint_type& ep) -> error_code {
// TODO: support direct descriptors. For now client sockets always use regular linux fds.
fd_ = fd << kFdShift;

IoResult io_res;
ep.data();
if (on_pre_connect) {
on_pre_connect(fd);
}

FiberCall fc(proactor, timeout());
fc->PrepConnect(fd, (const sockaddr*)ep.data(), ep.size());
io_res = fc.Get();
IoResult io_res = fc.Get();

if (io_res < 0) { // In that case connect returns -errno.
ec = error_code(-io_res, system_category());
Expand Down Expand Up @@ -333,7 +334,7 @@ io::Result<size_t> UringSocket::Recv(const io::MutableBytes& mb, int flags) {
Proactor* p = GetProactor();
DCHECK(ProactorBase::me() == p);

VSOCK(2) << "Recv [" << fd << "] " << flags;
VSOCK(2) << "Recv [" << fd << "], flags: " << flags;
ssize_t res;
while (true) {
FiberCall fc(p, timeout());
Expand Down
5 changes: 3 additions & 2 deletions util/fibers/uring_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class UringSocket : public LinuxSocketBase {

ABSL_MUST_USE_RESULT AcceptResult Accept() final;

ABSL_MUST_USE_RESULT error_code Connect(const endpoint_type& ep) final;
ABSL_MUST_USE_RESULT error_code Connect(const endpoint_type& ep,
std::function<void(int)> on_pre_connect) final;
ABSL_MUST_USE_RESULT error_code Close() final;

io::Result<size_t> WriteSome(const iovec* v, uint32_t len) override;
Expand Down Expand Up @@ -75,7 +76,7 @@ class UringSocket : public LinuxSocketBase {

struct ErrorCbRefWrapper {
uint32_t error_cb_id = 0;
uint32_t ref_count = 2; // one for the socket reference, one for the completion lambda.
uint32_t ref_count = 2; // one for the socket reference, one for the completion lambda.
std::function<void(uint32_t)> cb;

static ErrorCbRefWrapper* New(std::function<void(uint32_t)> cb) {
Expand Down
15 changes: 10 additions & 5 deletions util/http/http_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,15 @@ std::error_code Client::Reconnect() {
return berr;

FiberSocketBase* sock = proactor_->CreateSocket();
if (on_connect_cb_) {
on_connect_cb_(sock->native_handle());
}

socket_.reset(sock);
FiberSocketBase::endpoint_type ep{address, port_};
return socket_->Connect(ep);
auto on_connect = [this](int fd) {
if (on_connect_cb_) {
on_connect_cb_(fd);
}
};
return socket_->Connect(ep, std::move(on_connect));
}

#if 0
Expand Down Expand Up @@ -181,7 +184,9 @@ std::error_code TlsClient::Connect(string_view host, string_view service, SSL_CT
// verify server cert using server hostname
SSL_dane_enable(ssl_handle, host);
ec = tls_socket->Connect(FiberSocketBase::endpoint_type{});
if (!ec) {
if (ec) {
std::ignore = tls_socket->Close();
} else {
socket_.reset(tls_socket.release());
}
}
Expand Down
15 changes: 10 additions & 5 deletions util/tls/tls_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,13 @@ Engine::Engine(SSL_CTX* context) : ssl_(::SSL_new(context)) {
// SSL_set0_[rw]bio take ownership of the passed reference,
// so if we call both with the same BIO, we need the refcount to be 2.
BIO_up_ref(int_bio);

SSL_set0_rbio(ssl_, int_bio);
SSL_set0_wbio(ssl_, int_bio);

// Debugging traces.
// SSL_set_msg_callback(ssl_, SSL_trace);
// SSL_set_msg_callback_arg(ssl_, BIO_new_fp(stdout,0));
}

Engine::~Engine() {
Expand All @@ -111,21 +116,21 @@ Engine::~Engine() {
}


auto Engine::FetchOutputBuf() -> BufResult {
auto Engine::FetchOutputBuf() -> Buffer {
char* buf = nullptr;

int res = BIO_nread(external_bio_, &buf, INT_MAX);
if (res < 0) {
unsigned long error = ::ERR_get_error();
return nonstd::make_unexpected(error);
LOG(DFATAL) << "Unexpected result " << res << " " << error;

return Buffer{};
}

return Buffer(reinterpret_cast<const uint8_t*>(buf), res);
}

// TODO: to consider replacing BufResult with Buffer since
// it seems BIO_C_NREAD0 should not return negative values when used properly.
auto Engine::PeekOutputBuf() -> BufResult {
auto Engine::PeekOutputBuf() -> Buffer {
char* buf = nullptr;

long res = BIO_ctrl(external_bio_, BIO_C_NREAD0, 0, &buf);
Expand Down
5 changes: 2 additions & 3 deletions util/tls/tls_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class Engine {
// write. In any case for non-error OpResult a caller must check OutputPending and write the
// output buffer to the appropriate channel.
using OpResult = io::Result<int, unsigned long>;
using BufResult = io::Result<Buffer, unsigned long>;

// Construct a new engine for the specified context.
explicit Engine(SSL_CTX* context);
Expand Down Expand Up @@ -67,11 +66,11 @@ class Engine {
//! Returns output (read) buffer. This operation is destructive, i.e. after calling
//! this function the buffer is being consumed.
//! See OutputPending() for checking if there is a output buffer to consume.
BufResult FetchOutputBuf();
Buffer FetchOutputBuf();

//! Returns output buffer which is the read buffer of tls engine.
//! This operation is not destructive.
BufResult PeekOutputBuf();
Buffer PeekOutputBuf();

//! Tells the engine that sz bytes were consumed from the output buffer.
//! sz should be not greater than the buffer size from the last PeekOutputBuf() call.
Expand Down
13 changes: 6 additions & 7 deletions util/tls/tls_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,17 @@ static unsigned long RunPeer(SslStreamTest::Options opts, SslStreamTest::OpCb cb
if (opts.drain_output)
src->FetchOutputBuf();
else {
auto buf_result = src->PeekOutputBuf();
CHECK(buf_result);
VLOG(1) << opts.name << " wrote " << buf_result->size() << " bytes";
CHECK(!buf_result->empty());
auto buffer = src->PeekOutputBuf();
VLOG(1) << opts.name << " wrote " << buffer.size() << " bytes";
CHECK(!buffer.empty());

if (opts.mutate_indx) {
uint8_t* mem = const_cast<uint8_t*>(buf_result->data());
mem[opts.mutate_indx % buf_result->size()] = opts.mutate_val;
uint8_t* mem = const_cast<uint8_t*>(buffer.data());
mem[opts.mutate_indx % buffer.size()] = opts.mutate_val;
opts.mutate_indx = 0;
}

auto write_result = dest->WriteBuf(*buf_result);
auto write_result = dest->WriteBuf(buffer);
if (!write_result) {
return write_result.error();
}
Expand Down
98 changes: 68 additions & 30 deletions util/tls/tls_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ auto TlsSocket::Accept() -> AcceptResult {
return make_unexpected(make_error_code(errc::connection_reset));
}
if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
ec = HandleRead();
ec = HandleSocketRead();
if (ec)
return make_unexpected(ec);
}
Expand All @@ -145,19 +145,53 @@ auto TlsSocket::Accept() -> AcceptResult {
return nullptr;
}

auto TlsSocket::Connect(const endpoint_type& endpoint) -> error_code {
error_code TlsSocket::Connect(const endpoint_type& endpoint,
std::function<void(int)> on_pre_connect) {
DCHECK(engine_);
auto io_result = engine_->Handshake(Engine::HandshakeType::CLIENT);
if (!io_result.has_value()) {
return std::error_code(io_result.error(), std::system_category());
Engine::OpResult op_result = engine_->Handshake(Engine::HandshakeType::CLIENT);
if (!op_result) {
return std::error_code(op_result.error(), std::system_category());
}

// If the socket is already open, we should not call connect on it
if (IsOpen()) {
return {};
if (!IsOpen()) {
error_code ec = next_sock_->Connect(endpoint, std::move(on_pre_connect));
if (ec)
return ec;
}

// Flush the ssl data to the socket and run the loop that ensures handshaking converges.
int op_val = *op_result;
error_code ec;

// it should guide us to write and then read.
DCHECK_EQ(op_val, Engine::NEED_READ_AND_MAYBE_WRITE);
while (op_val < 0) {
if (op_val == Engine::EOF_STREAM) {
return make_error_code(errc::connection_reset);
}

if (op_val == Engine::NEED_WRITE) {
ec = HandleSocketWrite();
if (ec)
return ec;
} else if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
ec = HandleSocketWrite();
if (ec)
return ec;

ec = HandleSocketRead();
if (ec)
return ec;
}
op_result = engine_->Handshake(Engine::HandshakeType::CLIENT);
if (!op_result) {
return std::error_code(op_result.error(), std::system_category());
}
op_val = *op_result;
}

return next_sock_->Connect(endpoint);
return ec;
}

auto TlsSocket::Close() -> error_code {
Expand Down Expand Up @@ -249,7 +283,7 @@ io::Result<size_t> TlsSocket::RecvMsg(const msghdr& msg, int flags) {
}

if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
ec = HandleRead();
ec = HandleSocketRead();
if (ec)
return make_unexpected(ec);
}
Expand Down Expand Up @@ -341,7 +375,7 @@ io::Result<size_t> TlsSocket::SendBuffer(Engine::Buffer buf) {
}

if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
ec = HandleRead();
ec = HandleSocketRead();
if (ec)
return make_unexpected(ec);
}
Expand Down Expand Up @@ -381,28 +415,10 @@ auto TlsSocket::MaybeSendOutput() -> error_code {
return error_code{};
}

auto buf_result = engine_->PeekOutputBuf();
CHECK(buf_result);

if (!buf_result->empty()) {
// we do not allow concurrent writes from multiple fibers.
state_ |= WRITE_IN_PROGRESS;
io::Result<size_t> write_result = next_sock_->WriteSome(*buf_result);

// Safe to clear here since the code below is atomic fiber-wise.
state_ &= ~WRITE_IN_PROGRESS;
DCHECK(engine_);
if (!write_result) {
return write_result.error();
}
CHECK_GT(*write_result, 0u);
engine_->ConsumeOutputBuf(*write_result);
}

return error_code{};
return HandleSocketWrite();
}

auto TlsSocket::HandleRead() -> error_code {
auto TlsSocket::HandleSocketRead() -> error_code {
if (state_ & READ_IN_PROGRESS) {
// We need to Yield because otherwise we might end up in an infinite loop.
// See also comments in MaybeSendOutput.
Expand All @@ -423,6 +439,28 @@ auto TlsSocket::HandleRead() -> error_code {
return error_code{};
}

error_code TlsSocket::HandleSocketWrite() {
Engine::Buffer buffer = engine_->PeekOutputBuf();

while (!buffer.empty()) {
// we do not allow concurrent writes from multiple fibers.
state_ |= WRITE_IN_PROGRESS;
io::Result<size_t> write_result = next_sock_->WriteSome(buffer);

// Safe to clear here since the code below is atomic fiber-wise.
state_ &= ~WRITE_IN_PROGRESS;
DCHECK(engine_);
if (!write_result) {
return write_result.error();
}
CHECK_GT(*write_result, 0u);
engine_->ConsumeOutputBuf(*write_result);
buffer.remove_prefix(*write_result);
}

return error_code{};
}

TlsSocket::endpoint_type TlsSocket::LocalEndpoint() const {
return next_sock_->LocalEndpoint();
}
Expand Down
6 changes: 4 additions & 2 deletions util/tls/tls_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class TlsSocket final : public FiberSocketBase {

// The endpoint should not really pass here, it is to keep
// the interface with FiberSocketBase.
error_code Connect(const endpoint_type&) final;
error_code Connect(const endpoint_type& ep, std::function<void(int)> on_pre_connect = {}) final;

error_code Close() final;

Expand Down Expand Up @@ -92,7 +92,9 @@ class TlsSocket final : public FiberSocketBase {
error_code MaybeSendOutput();

/// Read encrypted data from the network socket and feed it into the TLS engine.
error_code HandleRead();
error_code HandleSocketRead();

error_code HandleSocketWrite();

std::unique_ptr<FiberSocketBase> next_sock_;
std::unique_ptr<Engine> engine_;
Expand Down

0 comments on commit eb166e1

Please sign in to comment.