Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backend] implement pthread backend #46

Merged
merged 13 commits into from
Sep 24, 2024
27 changes: 23 additions & 4 deletions csrc/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#ifndef DISABLE_AIO
#include "aio.h"
#endif
#ifndef DISABLE_PTHREAD
#include "pthread_backend.h"
#endif

std::unordered_set<std::string> get_backends()
{
Expand All @@ -20,6 +23,9 @@ std::unordered_set<std::string> get_backends()
#endif
#ifndef DISABLE_AIO
backends.insert("aio");
#endif
#ifndef DISABLE_PTHREAD
backends.insert("pthread");
#endif
return backends;
}
Expand All @@ -35,18 +41,27 @@ void probe_asyncio(const std::string &backend)
try
{
std::unique_ptr<AsyncIO> aio;
if (backend == "uring")
if (backend == "uring") {
#ifndef DISABLE_URING
aio.reset(new UringAsyncIO(2));
#else
throw std::runtime_error("backend is not installed\n");
throw std::runtime_error("backend uring is not installed\n");
#endif
else
} else if (backend == "aio") {
#ifndef DISABLE_AIO
aio.reset(new AIOAsyncIO(2));
#else
throw std::runtime_error("backend is not installed\n");
throw std::runtime_error("backend aio is not installed\n");
#endif
} else if (backend == "pthread") {
#ifndef DISABLE_PTHREAD
aio.reset(new PthreadAsyncIO(2));
#else
throw std::runtime_error("backend pthread is not installed\n");
#endif
} else {
throw std::runtime_error("unknown backend");
}

int fd = fileno(fp);
const int n_loop = 5, n_len = 18;
Expand Down Expand Up @@ -120,6 +135,10 @@ AsyncIO *create_asyncio(unsigned int n_entries, const std::string &backend)
#ifndef DISABLE_AIO
if (backend == "aio")
return new AIOAsyncIO(n_entries);
#endif
#ifndef DISABLE_PTHREAD
if (backend == "pthread")
return new PthreadAsyncIO(n_entries);
#endif
throw std::runtime_error("Unsupported backend: " + backend);
}
79 changes: 79 additions & 0 deletions csrc/pthread_backend.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include "pthread_backend.h"

void PthreadAsyncIO::write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback) {
auto fut = this->pool.submit_task(
[fd, buffer, n_bytes, offset] {
return pwrite(fd, buffer, n_bytes, offset);
}
);
this->write_fut.push_back(std::make_tuple(std::move(fut), callback));
}

void PthreadAsyncIO::writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback) {
auto fut = this->pool.submit_task(
[fd, iov, iovcnt, offset] {
return pwritev(fd, iov, iovcnt, offset);
}
);
this->write_fut.push_back(std::make_tuple(std::move(fut), callback));
}

void PthreadAsyncIO::read(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback) {
auto fut = this->pool.submit_task(
[fd, buffer, n_bytes, offset] {
return pread(fd, buffer, n_bytes, offset);
}
);
this->read_fut.push_back(std::make_tuple(std::move(fut), callback));
}

void PthreadAsyncIO::readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback) {
auto fut = this->pool.submit_task(
[fd, iov, iovcnt, offset] {
return preadv(fd, iov, iovcnt, offset);
}
);
this->read_fut.push_back(std::make_tuple(std::move(fut), callback));
}

void PthreadAsyncIO::get_event(WaitType wt) {
if (wt == NOWAIT) return;
this->sync_write_events();
this->sync_read_events();
}

void PthreadAsyncIO::sync_write_events() {
while (this->write_fut.size() > 0) {
auto front = std::move(this->write_fut.front());
this->write_fut.pop_front();

auto fut(std::move(std::get<0>(front)));
fut.wait();

auto callback = std::get<1>(front);
if (callback != nullptr) {
callback();
}
}
}

void PthreadAsyncIO::sync_read_events() {
while (this->read_fut.size() > 0) {
auto front = std::move(this->read_fut.front());
this->read_fut.pop_front();

auto fut(std::move(std::get<0>(front)));
fut.wait();

auto callback = std::get<1>(front);
if (callback != nullptr) {
callback();
}
}
}

void PthreadAsyncIO::synchronize() {
this->get_event(WAIT);
}

void PthreadAsyncIO::register_file(int fd) {}
2 changes: 2 additions & 0 deletions csrc/py_api.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include "offload.h"
#include "async_file_io.h"
#include "backend.h"
Expand Down
2 changes: 2 additions & 0 deletions include/offload.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once

#include "asyncio.h"
#include <ATen/ATen.h>

#include "space_mgr.h"
#ifndef DISABLE_URING
#include "uring.h"
Expand Down
41 changes: 41 additions & 0 deletions include/pthread_backend.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#pragma once

#include <stdexcept>
#include <sys/io.h>
#include <sys/uio.h>
#include <unistd.h>
#include <cstdlib>
#include <future>
#include <queue>
#include <tuple>
#include <functional>

#include "asyncio.h"
#include "threadpool.hpp"


class PthreadAsyncIO : public AsyncIO
{
private:
BS::thread_pool pool;
std::deque<std::tuple<std::future<ssize_t>, callback_t>> write_fut;
std::deque<std::tuple<std::future<ssize_t>, callback_t>> read_fut;

public:
PthreadAsyncIO(unsigned int n_entries)
: pool(n_entries) {}

~PthreadAsyncIO() {}

void write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback);
void read(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback);
void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);
void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);

void get_event(WaitType wt);
void sync_write_events();
void sync_read_events();
void synchronize();

void register_file(int fd);
};
Loading