Skip to content

Commit

Permalink
added torch post-hook, enum cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jinensetpal committed Aug 4, 2023
1 parent b135def commit dc90d61
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 8 deletions.
5 changes: 4 additions & 1 deletion dagshub/data_engine/client/loaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def _get_file_columns(self):
[self.entries[0].metadata[col] for col in self.metadata_columns],
):
try:
if self.source == "repo":
if (
self.datasource.source.source_type
== self.datasource.source.source_type.REPOSITORY
):
self.repo.list_path((self.datasource_root / str(value)).as_posix())
else:
self.repo.list_storage_path(
Expand Down
12 changes: 7 additions & 5 deletions dagshub/data_engine/client/loaders/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@ def __init__(
batch_size=1,
shuffle=True,
seed=None,
pre_hook=lambda x: x,
post_hook=lambda x: x,
):
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.pre_hook = pre_hook
self.post_hook = post_hook

if seed:
Expand All @@ -56,7 +54,6 @@ def __len__(self) -> int:
return self.dataset.__len__() // self.batch_size

def __getitem__(self, index: int) -> tf.Tensor:
index = self.pre_hook(index)
samples = [
self.dataset.__getitem__(index)
for index in self.indices[
Expand All @@ -70,7 +67,7 @@ def __getitem__(self, index: int) -> tf.Tensor:
for idx, tensor in enumerate(sample):
batch[idx].append(tensor)

return self.post_hook([tf.stack(column) for column in batch])
return tuple(self.post_hook([tf.stack(column) for column in batch]))

def on_epoch_end(self) -> None:
self.indices = np.arange(self.dataset.__len__())
Expand All @@ -85,7 +82,12 @@ def image(filepath: str) -> tf.Tensor:

@staticmethod
def audio(filepath: str) -> tf.Tensor:
raise NotImplementedError("Coming Soon!")
return tf.audio.decode_wav(
tf.io.read_file(str(filepath)),
desired_channels=-1,
desired_samples=-1,
name=None,
)

@staticmethod
def video(filepath: str) -> tf.Tensor:
Expand Down
33 changes: 31 additions & 2 deletions dagshub/data_engine/client/loaders/torch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from multiprocessing import Process
from typing import TYPE_CHECKING, Union
from dagshub.common.util import lazy_load
from typing import TYPE_CHECKING, Union, Any
from dagshub.data_engine.client.loaders.base import DagsHubDataset

torch = lazy_load("torch")
Expand All @@ -18,13 +18,42 @@ def __init__(self, *args, **kwargs):
self.type = "torch"


class _BaseDataLoaderIter:
def __next__(self) -> Any:
return self.post_hook(super().__next__())


class _SingleProcessDataLoaderIter(
_BaseDataLoaderIter, torch.utils.data.dataloader._SingleProcessDataLoaderIter
):
def __init__(self, *args, post_hook, **kwargs):
self.post_hook = post_hook
super().__init__(*args, **kwargs)


class _MultiProcessingDataLoaderIter(
_BaseDataLoaderIter, torch.utils.data.dataloader._MultiProcessingDataLoaderIter
):
def __init__(self, *args, post_hook, **kwargs):
self.post_hook = post_hook
super().__init__(*args, **kwargs)


class PyTorchDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, **kwargs):
def __init__(self, *args, post_hook=lambda x: x, **kwargs):
super().__init__(*args, **kwargs)
self.post_hook = post_hook
self.dataset.order = list(self.sampler)
if self.dataset.strategy == "background":
Process(target=self.dataset.pull).start()

def _get_iterator(self) -> "_BaseDataLoaderIter":
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self, post_hook=self.post_hook)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self, post_hook=self.post_hook)


class Tensorizers:
@staticmethod
Expand Down

0 comments on commit dc90d61

Please sign in to comment.