Skip to content

Commit

Permalink
[backend] add environ to overwrite passed backend
Browse files Browse the repository at this point in the history
  • Loading branch information
botbw committed Oct 1, 2024
1 parent 51ed242 commit e28a80a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
24 changes: 22 additions & 2 deletions csrc/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,29 @@ iovec *tensors_to_iovec(const std::vector<at::Tensor> &tensors)
return iovs;
}

std::string Offloader::get_default_backend() {
const char* env = getenv("TENSORNVME_BACKEND");
if (env == nullptr) {
return std::string("");
}
return std::string(env);
}

Offloader::Offloader(const std::string &filename, unsigned int n_entries, const std::string &backend) : filename(filename), space_mgr(SpaceManager(0))
{
this->aio = create_asyncio(n_entries, backend);
{
std::string default_backend = get_default_backend();
if (default_backend.size() > 0) {
if (get_backends().count(default_backend) == 0) {
throw std::runtime_error("Cannot find backend: " + default_backend + ", please check if TENSORNVME_BACKEND is set correctly");
}
this->aio = create_asyncio(n_entries, default_backend);
} else {
if (get_backends().count(backend) == 0) {
throw std::runtime_error("Cannot find backend: " + backend + ", please check the passed backend is set correctly");
}
this->aio = create_asyncio(n_entries, backend);
}

this->fd = open(filename.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
this->aio->register_file(fd);
}
Expand Down
2 changes: 2 additions & 0 deletions include/offload.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "aio.h"
#endif

#include <cstdlib>
class Offloader
{
public:
Expand All @@ -31,6 +32,7 @@ class Offloader
void async_readv(const std::vector<at::Tensor> &tensors, const std::string &key, callback_t callback = nullptr);
void sync_writev(const std::vector<at::Tensor> &tensors, const std::string &key);
void sync_readv(const std::vector<at::Tensor> &tensors, const std::string &key);
static std::string get_default_backend();
private:
const std::string filename;
int fd;
Expand Down
2 changes: 0 additions & 2 deletions tensornvme/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

class DiskOffloader(Offloader):
def __init__(self, dir_name: str, n_entries: int = 16, backend: str = 'uring') -> None:
assert backend in get_backends(
), f'Unsupported backend: {backend}, please install tensornvme with this backend'
if not os.path.exists(dir_name):
os.mkdir(dir_name)
assert os.path.isdir(dir_name)
Expand Down

0 comments on commit e28a80a

Please sign in to comment.