Skip to content

Commit

Permalink
chore: iouring/recv-multishot (#325)
Browse files Browse the repository at this point in the history
Provide basis code for multishot support for socket receive operations together
with provided buffers.

Signed-off-by: Roman Gershman <[email protected]>
  • Loading branch information
romange committed Oct 23, 2024
1 parent 0829556 commit ad57475
Show file tree
Hide file tree
Showing 15 changed files with 531 additions and 165 deletions.
2 changes: 1 addition & 1 deletion base/cxx_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ TEST_F(CxxTest, UnderDebugger) {
vector<HasVector> table;

vector<int> ints(1024);
table.emplace_back(HasVector{.vals = move(ints)});
table.emplace_back(HasVector{.vals = std::move(ints)});
table.emplace_back(); // verified that HasVector was moved without copying the array.
}

Expand Down
60 changes: 42 additions & 18 deletions examples/echo_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#ifdef __linux__
#include "util/fibers/uring_socket.h"
#include "util/fibers/uring_proactor.h"
#endif

using namespace util;
Expand Down Expand Up @@ -56,6 +57,8 @@ ABSL_FLAG(bool, raw, true,
"If true, does not send/receive size parameter during "
"the connection handshake");
ABSL_FLAG(bool, tcp_nodelay, false, "if true - use tcp_nodelay option for server sockets");
ABSL_FLAG(bool, multishot, false, "If true, iouring sockets use multishot receives");
ABSL_FLAG(uint16_t, bufring_size, 256, "Size of the buffer ring for iouring sockets");

VarzQps ping_qps("ping-qps");
VarzCount connections("connections");
Expand All @@ -70,23 +73,26 @@ class EchoConnection : public Connection {
private:
void HandleRequests() final;

std::error_code ReadMsg(size_t* sz);
std::error_code ReadMsg();

std::queue<FiberSocketBase::ProvidedBuffer> prov_buffers_;
size_t pending_read_bytes_ = 0, first_buf_offset_ = 0;
size_t req_len_ = 0;
};

std::error_code EchoConnection::ReadMsg(size_t* sz) {
std::error_code EchoConnection::ReadMsg() {
FiberSocketBase::ProvidedBuffer pb[8];

while (pending_read_bytes_ < req_len_) {
auto res = socket_->RecvProvided(8, pb);
if (!res)
return res.error();
unsigned num_buf = *res;
unsigned num_bufs = socket_->RecvProvided(8, pb);

for (unsigned i = 0; i < num_buf; ++i) {
for (unsigned i = 0; i < num_bufs; ++i) {
if (pb[i].err_no > 0) {
DCHECK_EQ(i, 0u);
return error_code(pb[i].err_no, system_category());
}

DCHECK(!pb[i].buffer.empty());
prov_buffers_.push(pb[i]);
pending_read_bytes_ += pb[i].buffer.size();
}
Expand All @@ -104,7 +110,6 @@ void EchoConnection::HandleRequests() {
ThisFiber::SetName("HandleRequests");

std::error_code ec;
size_t sz;
vector<iovec> vec;
uint8_t buf[8];
vec.resize(2);
Expand All @@ -114,6 +119,13 @@ void EchoConnection::HandleRequests() {
CHECK_EQ(0, setsockopt(socket_->native_handle(), IPPROTO_TCP, TCP_NODELAY, &yes, sizeof(yes)));
}

#ifdef __linux__
bool is_multishot = GetFlag(FLAGS_multishot);
bool is_iouring = socket_->proactor()->GetKind() == ProactorBase::IOURING;
if (is_multishot && is_iouring) {
static_cast<fb2::UringSocket*>(socket_.get())->EnableRecvMultishot();
}
#endif
connections.IncBy(1);

vec[0].iov_base = buf;
Expand Down Expand Up @@ -149,18 +161,19 @@ void EchoConnection::HandleRequests() {
vector<FiberSocketBase::ProvidedBuffer> returned_buffers;

// after the handshake.
uint64_t replies = 0;
while (true) {
ec = ReadMsg(&sz);
ec = ReadMsg();
if (FiberSocketBase::IsConnClosed(ec)) {
VLOG(1) << "Closing " << ep;
VLOG(1) << "Closing " << ep << " after " << replies << " replies";
break;
}
CHECK(!ec) << ec;
ping_qps.Inc();

vec[0].iov_base = buf;
vec[0].iov_len = 4;
absl::little_endian::Store32(buf, sz);
absl::little_endian::Store32(buf, req_len_);
vec.resize(1);

size_t prepare_len = 0;
Expand All @@ -170,12 +183,14 @@ void EchoConnection::HandleRequests() {
DCHECK(!prov_buffers_.empty());
size_t needed = req_len_ - prepare_len;
const auto& pbuf = prov_buffers_.front();
size_t has_bytes = pbuf.buffer.size() - first_buf_offset_;
if (has_bytes <= needed) {
vec.push_back({const_cast<uint8_t*>(pbuf.buffer.data()) + first_buf_offset_, has_bytes});
prepare_len += has_bytes;
DCHECK_GE(pending_read_bytes_, has_bytes);
pending_read_bytes_ -= has_bytes;
size_t bytes_count = pbuf.buffer.size() - first_buf_offset_;
DCHECK(!pbuf.buffer.empty());

if (bytes_count <= needed) {
vec.push_back({const_cast<uint8_t*>(pbuf.buffer.data()) + first_buf_offset_, bytes_count});
prepare_len += bytes_count;
DCHECK_GE(pending_read_bytes_, bytes_count);
pending_read_bytes_ -= bytes_count;
returned_buffers.push_back(pbuf);
prov_buffers_.pop();
first_buf_offset_ = 0;
Expand All @@ -201,6 +216,7 @@ void EchoConnection::HandleRequests() {
socket_->ReturnProvided(pb);
}
returned_buffers.clear();
++replies;
if (ec)
break;
}
Expand Down Expand Up @@ -254,6 +270,8 @@ class Driver {

Driver(const Driver&) = delete;

uint64_t send_id_ = 0;

public:
Driver(ProactorBase* p);

Expand Down Expand Up @@ -335,6 +353,9 @@ void Driver::Connect(unsigned index, const tcp::endpoint& ep) {
void Driver::SendSingle() {
size_t req_size = absl::GetFlag(FLAGS_size);
std::unique_ptr<uint8_t[]> msg(new uint8_t[req_size]);
memcpy(msg.get(), &send_id_, std::min(sizeof(send_id_), req_size));
++send_id_;

error_code ec = socket_->Write(io::Bytes{msg.get(), req_size});
CHECK(!ec) << ec.message();
auto res = socket_->Read(io::MutableBytes(msg.get(), req_size));
Expand Down Expand Up @@ -369,6 +390,8 @@ size_t Driver::Run(base::Histogram* dest) {
auto start = absl::GetCurrentTimeNanos();

for (size_t j = 0; j < pipeline_cnt; ++j) {
memcpy(msg.get(), &send_id_, std::min(sizeof(send_id_), req_size));
++send_id_;
error_code ec = socket_->Write(io::Bytes{msg.get(), req_size});
if (ec && FiberSocketBase::IsConnClosed(ec)) {
conn_close = true;
Expand Down Expand Up @@ -498,7 +521,8 @@ int main(int argc, char* argv[]) {
#ifdef __linux__
if (!absl::GetFlag(FLAGS_epoll)) {
pp->AwaitBrief([](unsigned, auto* pb) {
fb2::UringSocket::InitProvidedBuffers(512, 64, static_cast<fb2::UringProactor*>(pb));
fb2::UringProactor* up = static_cast<fb2::UringProactor*>(pb);
up->RegisterBufferRing(0, absl::GetFlag(FLAGS_bufring_size), 64);
});
}
#endif
Expand Down
15 changes: 11 additions & 4 deletions util/fiber_socket_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,20 @@ class FiberSocketBase : public io::Sink, public io::AsyncSink, public io::Source
struct ProvidedBuffer {
io::Bytes buffer;
uint32_t allocated;
uint8_t cookie; // Used by the socket to identify the buffer.
uint16_t err_no; // Relevant only if buffer is empty.
uint8_t cookie; // Used by the socket to identify the buffer source.

void SetError(uint16_t err) {
err_no = err;
allocated = 0;
buffer = {};
}
};

// Unlike Recv/ReadSome, this method returns a buffer managed by the socket.
// Unlike Recv/ReadSome, this method returns buffers managed by the socket.
// They should be returned back to the socket after the data is read.
// small is an optional buffer that can be used for small messages.
virtual ::io::Result<unsigned> RecvProvided(unsigned buf_len, ProvidedBuffer* dest) = 0;
// Returns - number of buffers filled.
virtual unsigned RecvProvided(unsigned buf_len, ProvidedBuffer* dest) = 0;

virtual void ReturnProvided(const ProvidedBuffer& pbuf) = 0;

Expand Down
1 change: 1 addition & 0 deletions util/fibers/detail/fiber_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ void FiberInterface::ActivateOther(FiberInterface* other) {
// Check first if we the fiber belongs to the active thread.
if (other->scheduler_ == scheduler_) {
DVLOG(1) << "Activating " << other->name() << " from " << this->name();
DCHECK_EQ(other->flags_.load(std::memory_order_relaxed) & kTerminatedBit, 0);

// In case `other` times out on wait, it could be added to the ready queue already by
// ProcessSleep.
Expand Down
20 changes: 8 additions & 12 deletions util/fibers/epoll_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ auto EpollSocket::RecvMsg(const msghdr& msg, int flags) -> Result<size_t> {
return nonstd::make_unexpected(std::move(ec));
}

io::Result<unsigned> EpollSocket::RecvProvided(unsigned buf_len, ProvidedBuffer* dest) {
unsigned EpollSocket::RecvProvided(unsigned buf_len, ProvidedBuffer* dest) {
DCHECK_GT(buf_len, 0u);

int fd = native_handle();
Expand Down Expand Up @@ -435,7 +435,8 @@ io::Result<unsigned> EpollSocket::RecvProvided(unsigned buf_len, ProvidedBuffer*
}

if (SuspendMyself(read_context_, &ec) && ec) {
return nonstd::make_unexpected(std::move(ec));
res = ec.value();
break;
}
}

Expand All @@ -448,17 +449,12 @@ io::Result<unsigned> EpollSocket::RecvProvided(unsigned buf_len, ProvidedBuffer*

DVSOCK(1) << "Got " << res;

// ETIMEDOUT can happen if a socket does not have keepalive enabled or for some reason
// TCP connection did indeed stopped getting tcp keep alive packets.
if (!base::_in(res, {ECONNABORTED, EPIPE, ECONNRESET, ETIMEDOUT})) {
LOG(ERROR) << "sock[" << fd << "] Unexpected error " << res << "/" << strerror(res) << " "
<< RemoteEndpoint();
}

ec = std::error_code(res, std::system_category());
VSOCK(1) << "Error on " << RemoteEndpoint() << ": " << ec.message();
dest[0].buffer = {};
dest[0].allocated = 0;
dest[0].cookie = 1;
dest[0].err_no = res;

return nonstd::make_unexpected(std::move(ec));
return 1;
}

void EpollSocket::ReturnProvided(const ProvidedBuffer& pbuf) {
Expand Down
2 changes: 1 addition & 1 deletion util/fibers/epoll_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class EpollSocket : public LinuxSocketBase {

error_code Shutdown(int how) override;

io::Result<unsigned> RecvProvided(unsigned buf_len, ProvidedBuffer* dest) final;
unsigned RecvProvided(unsigned buf_len, ProvidedBuffer* dest) final;
void ReturnProvided(const ProvidedBuffer& pbuf) final;

void RegisterOnErrorCb(std::function<void(uint32_t)> cb) final;
Expand Down
Loading

0 comments on commit ad57475

Please sign in to comment.