Skip to content

Commit

Permalink
chore: revisit tls_socket code
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Gershman <[email protected]>
  • Loading branch information
romange committed Jul 12, 2024
1 parent 1fe964d commit ed4aeb2
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 103 deletions.
48 changes: 20 additions & 28 deletions util/tls/tls_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,36 +41,28 @@ static Engine::OpResult ToOpResult(const SSL* ssl, int result, const char* locat
return nonstd::make_unexpected(error);
}

int want = SSL_want(ssl);

if (want == SSL_NOTHING) {
int ssl_error = SSL_get_error(ssl, result);
int io_err = errno;

switch (ssl_error) {
case SSL_ERROR_ZERO_RETURN:
break;
case SSL_ERROR_SYSCALL:
LOG(WARNING) << "SSL syscall error " << io_err << ":" << result << " " << location;
break;
case SSL_ERROR_SSL:
LOG(WARNING) << "SSL protocol error " << io_err << ":" << result << " " << location;
break;
default:
LOG(WARNING) << "Unexpected SSL error " << io_err << ":" << result << " " << location;
break;
}

return Engine::EOF_STREAM;
int ssl_error = SSL_get_error(ssl, result);
int io_err = errno;

switch (ssl_error) {
case SSL_ERROR_ZERO_RETURN:
break;
case SSL_ERROR_WANT_READ:
return Engine::NEED_READ_AND_MAYBE_WRITE;
case SSL_ERROR_WANT_WRITE:
VLOG(1) << "SSL_ERROR_WANT_WRITE " << location;
return Engine::NEED_WRITE;
case SSL_ERROR_SYSCALL:
LOG(WARNING) << "SSL syscall error " << io_err << ":" << result << " " << location;
break;
case SSL_ERROR_SSL:
LOG(WARNING) << "SSL protocol error " << io_err << ":" << result << " " << location;
break;
default:
LOG(WARNING) << "Unexpected SSL error " << io_err << ":" << result << " " << location;
break;
}

if (SSL_WRITING == want)
return Engine::NEED_WRITE;
if (SSL_READING == want)
return Engine::NEED_READ_AND_MAYBE_WRITE;

LOG(ERROR) << "Unsupported want value " << want << ", ssl_error: " << SSL_get_error(ssl, result);

return Engine::EOF_STREAM;
}

Expand Down
9 changes: 8 additions & 1 deletion util/tls/tls_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ class Engine {
enum HandshakeType { CLIENT = 1, SERVER = 2 };
enum OpCode {
EOF_STREAM = -1,

// We use BIO buffers, therefore any SSL operation can end up writing to the internal BIO
// and result in success, even though the data has not been flushed to the underlying socket.
// See https://www.openssl.org/docs/man1.0.2/man3/BIO_new_bio_pair.html
// As a result, we must flush output buffer (if OutputPending() > 0)if before we do any
// Socket reads. We could flush after each SSL operation but that would result in fragmented
// Socket writes which we want to avoid.
NEED_READ_AND_MAYBE_WRITE = -2,
NEED_WRITE = -3,
};
Expand Down Expand Up @@ -89,7 +96,7 @@ class Engine {
void CommitInput(unsigned sz);

// Returns size of pending data that needs to be flushed out from SSL to I/O.
// See https://www.openssl.org/docs/man1.1.0/man3/BIO_new_bio_pair.html
// See https://www.openssl.org/docs/man1.0.2/man3/BIO_new_bio_pair.html
// Specifically, warning that says: "An application must not rely on the error value of
// SSL_operation() but must assure that the write buffer is always flushed first".
size_t OutputPending() const {
Expand Down
139 changes: 65 additions & 74 deletions util/tls/tls_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ auto TlsSocket::Shutdown(int how) -> error_code {
Engine::OpResult op_result = engine_->Shutdown();
if (op_result) {
// engine_ could send notification messages to the peer.
MaybeSendOutput();
std::ignore = MaybeSendOutput();
}

// In any case we should also shutdown the underlying TCP socket without relying on the
Expand Down Expand Up @@ -132,14 +132,10 @@ auto TlsSocket::Accept() -> AcceptResult {
if (op_val >= 0) { // Shutdown or empty read/write may return 0.
break;
}
if (op_val == Engine::EOF_STREAM) {
return make_unexpected(make_error_code(errc::connection_reset));
}
if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
ec = HandleSocketRead();
if (ec)
return make_unexpected(ec);
}

ec = HandleOp(op_val);
if (ec)
return make_unexpected(ec);
}

return nullptr;
Expand All @@ -162,36 +158,26 @@ error_code TlsSocket::Connect(const endpoint_type& endpoint,

// Flush the ssl data to the socket and run the loop that ensures handshaking converges.
int op_val = *op_result;
error_code ec;

// it should guide us to write and then read.
DCHECK_EQ(op_val, Engine::NEED_READ_AND_MAYBE_WRITE);
while (op_val < 0) {
if (op_val == Engine::EOF_STREAM) {
return make_error_code(errc::connection_reset);
}
error_code ec = HandleOp(op_val);
if (ec)
return ec;

if (op_val == Engine::NEED_WRITE) {
ec = HandleSocketWrite();
if (ec)
return ec;
} else if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
ec = HandleSocketWrite();
if (ec)
return ec;

ec = HandleSocketRead();
if (ec)
return ec;
}
op_result = engine_->Handshake(Engine::HandshakeType::CLIENT);
if (!op_result) {
return std::error_code(op_result.error(), std::system_category());
}
op_val = *op_result;
}

return ec;
const auto* cipher = SSL_get_current_cipher(engine_->native_handle());
VLOG(1) << "SSL handshake success, chosen " << SSL_CIPHER_get_name(cipher) << "/"
<< SSL_CIPHER_get_version(cipher);

return {};
}

auto TlsSocket::Close() -> error_code {
Expand Down Expand Up @@ -245,11 +231,6 @@ io::Result<size_t> TlsSocket::RecvMsg(const msghdr& msg, int flags) {
return make_unexpected(SSL2Error(op_result.error()));
}

error_code ec = MaybeSendOutput();
if (ec) {
return make_unexpected(ec);
}

int op_val = *op_result;
if (spin_count.Check(op_val <= 0)) {
// Once every 30 seconds.
Expand All @@ -267,26 +248,18 @@ io::Result<size_t> TlsSocket::RecvMsg(const msghdr& msg, int flags) {
++io;
--io_len;
if (io_len == 0)
break;
break; // Finished reading everything.
dest = Engine::MutableBuffer{reinterpret_cast<uint8_t*>(io->iov_base), io->iov_len};
}
continue; // We read everything we asked for - lets retry.
// We read everything we asked for but there are still buffers left to fill.
continue;
}
break;
}

if (read_total) // if we read something lets return it before we handle other states.
break;

if (op_val == Engine::EOF_STREAM) {
return make_unexpected(make_error_code(errc::connection_reset));
}

if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
ec = HandleSocketRead();
if (ec)
return make_unexpected(ec);
}
error_code ec = HandleOp(op_val);
if (ec)
return make_unexpected(ec);
}
return read_total;
}
Expand All @@ -307,12 +280,12 @@ io::Result<size_t> TlsSocket::WriteSome(const iovec* ptr, uint32_t len) {
// Chosen to be sufficiently smaller than the usual MTU (1500) and a multiple of 16.
// IP - max 24 bytes. TCP - max 60 bytes. TLS - max 21 bytes.
constexpr size_t kBufferSize = 1392;
io::Result<size_t> ec;
io::Result<size_t> res;
size_t total_sent = 0;

while (len) {
if (ptr->iov_len > kBufferSize || len == 1) {
ec = SendBuffer(Engine::Buffer{reinterpret_cast<uint8_t*>(ptr->iov_base), ptr->iov_len});
res = SendBuffer(Engine::Buffer{reinterpret_cast<uint8_t*>(ptr->iov_base), ptr->iov_len});
ptr++;
len--;
} else {
Expand All @@ -324,18 +297,18 @@ io::Result<size_t> TlsSocket::WriteSome(const iovec* ptr, uint32_t len) {
ptr++;
len--;
}
ec = SendBuffer({scratch, buffered_size});
res = SendBuffer({scratch, buffered_size});
}
if (!ec.has_value()) {
return ec;
} else {
total_sent += ec.value();
if (!res) {
return res;
}
total_sent += *res;
}
return total_sent;
}

io::Result<size_t> TlsSocket::SendBuffer(Engine::Buffer buf) {
// Sending buffer into ssl.
DCHECK(engine_);
DCHECK_GT(buf.size(), 0u);

Expand All @@ -348,17 +321,7 @@ io::Result<size_t> TlsSocket::SendBuffer(Engine::Buffer buf) {
return make_unexpected(SSL2Error(op_result.error()));
}

error_code ec = MaybeSendOutput();
if (ec) {
return make_unexpected(ec);
}

int op_val = *op_result;
if (spin_count.Check(op_val <= 0)) {
// Once every 30 seconds.
LOG_EVERY_T(WARNING, 30) << "IO loop spin limit reached. Limit: " << spin_count.Limit()
<< " Spins: " << spin_count.Spins();
}

if (op_val > 0) {
send_total += op_val;
Expand All @@ -370,15 +333,15 @@ io::Result<size_t> TlsSocket::SendBuffer(Engine::Buffer buf) {
}
}

if (op_val == Engine::EOF_STREAM) {
return make_unexpected(make_error_code(errc::connection_reset));
if (spin_count.Check(op_val <= 0)) {
// Once every 30 seconds.
LOG_EVERY_T(WARNING, 30) << "IO loop spin limit reached. Limit: " << spin_count.Limit()
<< " Spins: " << spin_count.Spins();
}

if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
ec = HandleSocketRead();
if (ec)
return make_unexpected(ec);
}
error_code ec = HandleOp(op_val);
if (ec)
return make_unexpected(ec);
}

return send_total;
Expand All @@ -395,6 +358,9 @@ SSL* TlsSocket::ssl_handle() {
}

auto TlsSocket::MaybeSendOutput() -> error_code {
if (engine_->OutputPending() == 0)
return {};

// This function is present in both read and write paths.
// meaning that both of them can be called concurrently from differrent fibers and then
// race over flushing the output buffer. We use state_ to prevent that.
Expand All @@ -419,6 +385,10 @@ auto TlsSocket::MaybeSendOutput() -> error_code {
}

auto TlsSocket::HandleSocketRead() -> error_code {
error_code ec = MaybeSendOutput();
if (ec)
return ec;

if (state_ & READ_IN_PROGRESS) {
// We need to Yield because otherwise we might end up in an infinite loop.
// See also comments in MaybeSendOutput.
Expand All @@ -434,33 +404,54 @@ auto TlsSocket::HandleSocketRead() -> error_code {
return esz.error();
}

DVLOG(1) << "TlsSocket:Read " << *esz << " bytes";

engine_->CommitInput(*esz);

return error_code{};
}

error_code TlsSocket::HandleSocketWrite() {
Engine::Buffer buffer = engine_->PeekOutputBuf();
DCHECK(!buffer.empty());

if (buffer.empty())
return {};

// we do not allow concurrent writes from multiple fibers.
state_ |= WRITE_IN_PROGRESS;
while (!buffer.empty()) {
// we do not allow concurrent writes from multiple fibers.
state_ |= WRITE_IN_PROGRESS;
io::Result<size_t> write_result = next_sock_->WriteSome(buffer);

// Safe to clear here since the code below is atomic fiber-wise.
state_ &= ~WRITE_IN_PROGRESS;
DCHECK(engine_);
if (!write_result) {
state_ &= ~WRITE_IN_PROGRESS;

return write_result.error();
}
CHECK_GT(*write_result, 0u);
engine_->ConsumeOutputBuf(*write_result);
buffer.remove_prefix(*write_result);
}
DCHECK_EQ(engine_->OutputPending(), 0u);

state_ &= ~WRITE_IN_PROGRESS;

return error_code{};
}

error_code TlsSocket::HandleOp(int op_val) {
switch (op_val) {
case Engine::EOF_STREAM:
return make_error_code(errc::connection_reset);
case Engine::NEED_READ_AND_MAYBE_WRITE:
return HandleSocketRead();
default:
LOG(DFATAL) << "Unsupported " << op_val;
}
return {};
}

TlsSocket::endpoint_type TlsSocket::LocalEndpoint() const {
return next_sock_->LocalEndpoint();
}
Expand Down
1 change: 1 addition & 0 deletions util/tls/tls_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class TlsSocket final : public FiberSocketBase {
error_code HandleSocketRead();

error_code HandleSocketWrite();
error_code HandleOp(int op);

std::unique_ptr<FiberSocketBase> next_sock_;
std::unique_ptr<Engine> engine_;
Expand Down

0 comments on commit ed4aeb2

Please sign in to comment.