Skip to content

Commit

Permalink
add shuffling support
Browse files Browse the repository at this point in the history
  • Loading branch information
bkmartinjr authored and ryan-williams committed Oct 3, 2024
1 parent 3d8bf2a commit 1cc3670
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 20 deletions.
196 changes: 179 additions & 17 deletions src/tiledbsoma_ml/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,11 @@ def __init__(
X_name: str,
obs_column_names: Sequence[str] = ("soma_joinid",),
batch_size: int = 1,
shuffle: bool = True,
io_batch_size: int = 2**16,
shuffle_chunk_size: int = 64,
return_sparse_X: bool = False,
seed: int | None = None,
use_eager_fetch: bool = True,
):
"""
Expand All @@ -129,12 +132,25 @@ def __init__(
this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` batching, but higher
performance can be achieved by performing batching in this class, and setting the ``DataLoader``'s
``batch_size`` parameter to ``None``.
shuffle:
Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``.
io_batch_size:
The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts
maximum memory utilization, larger values provide better read performance, but require more memory.
The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts:
1. Maximum memory utilization, larger values provide better read performance, but require more memory.
2. The number of rows read prior to shuffling (see the ``shuffle`` parameter for details).
The default value of 65,536 provides high performance but may need to be reduced in memory-limited hosts
or when using a large number of :class:`DataLoader` workers.
shuffle_chunk_size:
The number of contiguous rows sampled prior to concatenation and shuffling.
Larger numbers correspond to less randomness, but greater read performance.
If ``shuffle == False``, this parameter is ignored.
return_sparse_X:
If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the
default), will return ``X`` data as a :class:`numpy.ndarray`.
seed:
The random seed used for shuffling. Defaults to ``None`` (no seed). This argument *MUST* be specified
when using :class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint
across worker processes.
use_eager_fetch:
Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is
made available for processing via the iterator. This allows network (or filesystem) requests to be made
Expand All @@ -147,6 +163,14 @@ def __init__(
Lifecycle:
experimental
.. warning::
When using this class in any distributed mode, calling the :meth:`set_epoch` method at
the beginning of each epoch **before** creating the :class:`DataLoader` iterator
is necessary to make shuffling work properly across multiple epochs. Otherwise,
the same ordering will be always used.
In addition, when using shuffling in a distributed configuration (e.g., ``DDP``), you
must provide a seed, ensuring that the same shuffle is used across all replicas.
"""

super().__init__()
Expand All @@ -160,19 +184,32 @@ def __init__(
self.obs_column_names = list(obs_column_names)
self.batch_size = batch_size
self.io_batch_size = io_batch_size
self.shuffle = shuffle
self.return_sparse_X = return_sparse_X
self.use_eager_fetch = use_eager_fetch
self._obs_joinids: npt.NDArray[np.int64] | None = None
self._var_joinids: npt.NDArray[np.int64] | None = None
self.seed = (
seed if seed is not None else np.random.default_rng().integers(0, 2**32 - 1)
)
self._user_specified_seed = seed is not None
self.shuffle_chunk_size = shuffle_chunk_size
self._initialized = False
self.epoch = 0

if self.shuffle:
# round io_batch_size up to a unit of shuffle_chunk_size to simplify code.
self.io_batch_size = (
ceil(io_batch_size / shuffle_chunk_size) * shuffle_chunk_size
)

if not self.obs_column_names:
raise ValueError("Must specify at least one value in `obs_column_names`")

def _create_obs_joinids_partition(self) -> Iterator[npt.NDArray[np.int64]]:
"""Create iterator over obs id chunks with split size of (roughly) io_batch_size.
As appropriate, will partition per worker.
As appropriate, will chunk, shuffle and apply partitioning per worker.
IMPORTANT: in any scenario using torch.distributed, where WORLD_SIZE > 1, this will
always partition such that each process has the same number of samples. Where
Expand All @@ -182,7 +219,8 @@ def _create_obs_joinids_partition(self) -> Iterator[npt.NDArray[np.int64]]:
Abstractly, the steps taken:
1. Split the joinids into WORLD_SIZE sections (aka number of GPUS in DDP)
2. Trim the splits to be of equal length
3. Partition by number of data loader workers (to not generate redundant batches
3. Chunk and optionally shuffle the chunks
4. Partition by number of data loader workers (to not generate redundant batches
in cases where the DataLoader is running with `n_workers>1`).
Private method.
Expand All @@ -201,11 +239,29 @@ def _create_obs_joinids_partition(self) -> Iterator[npt.NDArray[np.int64]]:
assert 0 <= (np.diff(_gpu_splits).min() - min_len) <= 1
_gpu_split = _gpu_split[:min_len]

obs_joinids_chunked = np.array_split(
_gpu_split, max(1, ceil(len(_gpu_split) / self.io_batch_size))
)
# 3. Chunk and optionally shuffle chunks
if self.shuffle:
assert self.io_batch_size % self.shuffle_chunk_size == 0
shuffle_split = np.array_split(
_gpu_split, max(1, ceil(len(_gpu_split) / self.shuffle_chunk_size))
)

# 3. Partition by DataLoader worker
# Deterministically create RNG - state must be same across all processes, ensuring
# that the joinid partitions are identical across all processes.
rng = np.random.default_rng(self.seed + self.epoch + 99)
rng.shuffle(shuffle_split)
obs_joinids_chunked = list(
np.concatenate(b)
for b in _batched(
shuffle_split, self.io_batch_size // self.shuffle_chunk_size
)
)
else:
obs_joinids_chunked = np.array_split(
_gpu_split, max(1, ceil(len(_gpu_split) / self.io_batch_size))
)

# 4. Partition by DataLoader worker
n_workers, worker_id = _get_worker_world_rank()
obs_splits = _splits(len(obs_joinids_chunked), n_workers)
obs_partition_joinids = obs_joinids_chunked[
Expand All @@ -215,7 +271,7 @@ def _create_obs_joinids_partition(self) -> Iterator[npt.NDArray[np.int64]]:
if logger.isEnabledFor(logging.DEBUG):
partition_size = sum([len(chunk) for chunk in obs_partition_joinids])
logger.debug(
f"Process {os.getpid()} {rank=}, {world_size=}, {worker_id=}, n_workers={n_workers}, {partition_size=}"
f"Process {os.getpid()} {rank=}, {world_size=}, {worker_id=}, n_workers={n_workers}, epoch={self.epoch}, {partition_size=}"
)

return iter(obs_partition_joinids)
Expand All @@ -230,7 +286,9 @@ def _init_once(self, exp: soma.Experiment | None = None) -> None:
if self._initialized:
return

logger.debug("Initializing ExperimentAxisQueryIterable")
logger.debug(
f"Initializing ExperimentAxisQueryIterable (shuffle={self.shuffle})"
)

if exp is None:
# If no user-provided Experiment, open/close it ourselves
Expand Down Expand Up @@ -275,8 +333,12 @@ def __iter__(self) -> Iterator[XObsDatum]:
world_size, rank = _get_distributed_world_rank()
n_workers, worker_id = _get_worker_world_rank()
logger.debug(
f"Iterator created {rank=}, {world_size=}, {worker_id=}, {n_workers=}"
f"Iterator created {rank=}, {world_size=}, {worker_id=}, {n_workers=}, seed={self.seed}, epoch={self.epoch}"
)
if world_size > 1 and self.shuffle and self._user_specified_seed is None:
raise ValueError(
"ExperimentAxisQueryIterable requires an explicit `seed` when shuffle is used in a multi-process configuration."
)

with self.experiment_locator.open_experiment() as exp:
self._init_once(exp)
Expand All @@ -295,6 +357,8 @@ def __iter__(self) -> Iterator[XObsDatum]:

yield from _mini_batch_iter

self.epoch += 1

def __len__(self) -> int:
"""Return the number of batches this iterable will produce. If run in the context of :class:`torch.distributed`
or as a multi-process loader (i.e., :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the
Expand Down Expand Up @@ -340,6 +404,18 @@ def shape(self) -> Tuple[int, int]:
# (num batches this worker will produce, num features)
return n_batches + bool(rem), len(self._var_joinids)

def set_epoch(self, epoch: int) -> None:
"""
Set the epoch for this Data iterator.
When :attr:`shuffle=True`, this will ensure that all replicas use a different
random ordering for each epoch. Failure to call this method before each epoch
will result in the same data ordering.
This call must be made before the per-epoch iterator is created.
"""
self.epoch = epoch

def __getitem__(self, index: int) -> XObsDatum:
raise NotImplementedError(
"``ExperimentAxisQueryIterable can only be iterated - does not support mapping"
Expand All @@ -353,12 +429,17 @@ def _io_batch_iter(
) -> Iterator[Tuple[sparse.csr_matrix, pd.DataFrame]]:
"""Iterate over IO batches, i.e., SOMA query reads, producing tuples of ``(X: csr_array, obs: DataFrame)``.
``obs`` joinids read are controlled by the ``obs_joinid_iter``. Iterator results will be reindexed.
``obs`` joinids read are controlled by the ``obs_joinid_iter``. Iterator results will be reindexed and shuffled
(if shuffling enabled).
Private method.
"""
assert self._var_joinids is not None

# Create RNG - does not need to be identical across processes, but use the seed anyway
# for reproducibility.
shuffle_rng = np.random.default_rng(self.seed + self.epoch)

obs_column_names = (
list(self.obs_column_names)
if "soma_joinid" in self.obs_column_names
Expand All @@ -368,7 +449,10 @@ def _io_batch_iter(

for obs_coords in obs_joinid_iter:
st_time = time.perf_counter()
obs_indexer = soma.IntIndexer(obs_coords, context=X.context)
obs_shuffled_coords = (
obs_coords if not self.shuffle else shuffle_rng.permuted(obs_coords)
)
obs_indexer = soma.IntIndexer(obs_shuffled_coords, context=X.context)
logger.debug(
f"Retrieving next SOMA IO batch of length {len(obs_coords)}..."
)
Expand All @@ -392,12 +476,12 @@ def _io_batch_iter(
.concat()
.to_pandas()
.set_index("soma_joinid")
.reindex(obs_coords, copy=False)
.reindex(obs_shuffled_coords, copy=False)
.reset_index() # demote "soma_joinid" to a column
[self.obs_column_names]
) # fmt: on

del obs_indexer, obs_coords, X_tbl
del obs_indexer, obs_coords, obs_shuffled_coords, X_tbl
gc.collect()

tm = time.perf_counter() - st_time
Expand All @@ -412,7 +496,7 @@ def _mini_batch_iter(
X: soma.SparseNDArray,
obs_joinid_iter: Iterator[npt.NDArray[np.int64]],
) -> Iterator[XObsDatum]:
"""Break IO batches into mini-batch-sized chunks.
"""Break IO batches into shuffled mini-batch-sized chunks.
Private method.
"""
Expand Down Expand Up @@ -500,7 +584,10 @@ def __init__(
X_name: str = "raw",
obs_column_names: Sequence[str] = ("soma_joinid",),
batch_size: int = 1,
shuffle: bool = True,
seed: int | None = None,
io_batch_size: int = 2**16,
shuffle_chunk_size: int = 64,
return_sparse_X: bool = False,
use_eager_fetch: bool = True,
):
Expand All @@ -516,9 +603,12 @@ def __init__(
X_name=X_name,
obs_column_names=obs_column_names,
batch_size=batch_size,
shuffle=shuffle,
seed=seed,
io_batch_size=io_batch_size,
return_sparse_X=return_sparse_X,
use_eager_fetch=use_eager_fetch,
shuffle_chunk_size=shuffle_chunk_size,
)

def __iter__(self) -> Iterator[XObsDatum]:
Expand Down Expand Up @@ -553,6 +643,25 @@ def shape(self) -> Tuple[int, int]:
"""
return self._exp_iter.shape

def set_epoch(self, epoch: int) -> None:
"""
Set the epoch for this Data iterator.
When :attr:`shuffle=True`, this will ensure that all replicas use a different
random ordering for each epoch. Failure to call this method before each epoch
will result in the same data ordering.
This call must be made before the per-epoch iterator is created.
Lifecycle:
experimental
"""
self._exp_iter.set_epoch(epoch)

@property
def epoch(self) -> int:
return self._exp_iter.epoch


class ExperimentAxisQueryIterableDataset(
torch.utils.data.IterableDataset[XObsDatum] # type:ignore[misc]
Expand Down Expand Up @@ -595,6 +704,19 @@ class ExperimentAxisQueryIterableDataset(
The ``io_batch_size`` parameter determines the number of rows read, from which mini-batches are yielded. A
larger value will increase total memory usage and may reduce average read time per row.
Shuffling support is enabled with the ``shuffle`` parameter, and will normally be more performant than using
:class:`DataLoader` shuffling. The shuffling algorithm works as follows:
1. Rows selected by the query are subdivided into groups of size ``shuffle_chunk_size``, aka a "shuffle chunk".
2. A random selection of shuffle chunks is drawn and read as a single I/O buffer (of size ``io_buffer_size``).
3. The entire I/O buffer is shuffled.
Put another way, we read randomly selected groups of observations from across all query results, concatenate
those into an I/O buffer, and shuffle the buffer before returning mini-batches. The randomness of the shuffle
is therefore determined by the ``io_buffer_size`` (number of rows read), and the ``shuffle_chunk_size``
(number of rows in each draw). Decreasing ``shuffle_chunk_size`` will increase shuffling randomness, and decrease I/O
performance.
This class will detect when run in a multiprocessing mode, including multi-worker :class:`torch.utils.data.DataLoader`
and multi-process training such as :class:`torch.nn.parallel.DistributedDataParallel`, and will automatically partition
data appropriately. In the case of distributed training, sample partitions across all processes must be equal. Any
Expand All @@ -610,7 +732,10 @@ def __init__(
X_name: str = "raw",
obs_column_names: Sequence[str] = ("soma_joinid",),
batch_size: int = 1,
shuffle: bool = True,
seed: int | None = None,
io_batch_size: int = 2**16,
shuffle_chunk_size: int = 64,
return_sparse_X: bool = False,
use_eager_fetch: bool = True,
):
Expand All @@ -636,11 +761,26 @@ def __init__(
Note that a ``batch_size`` of 1 allows this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader`
batching, but you will achieve higher performance by performing batching in this class, and setting the ``DataLoader``
batch_size parameter to ``None``.
shuffle:
Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``.
io_batch_size:
The number of ``obs``/``X`` rows to retrieve when reading data from SOMA.
The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts two aspects of
this class's behavior: 1) The maximum memory utilization, with larger values providing
better read performance, but also requiring more memory; 2) The number of rows read prior to shuffling
(see ``shuffle`` parameter for details). The default value of 131,072 provides high performance, but
may need to be reduced in memory limited hosts (or where a large number of :class:`DataLoader` workers
are employed).
shuffle_chunk_size:
The number of contiguous rows sampled, prior to concatenation and shuffling.
Larger numbers correspond to less randomness, but greater read performance.
If ``shuffle == False``, this parameter is ignored.
return_sparse_X:
If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the default), will
return ``X`` data as a :class:`numpy.ndarray`.
seed:
The random seed used for shuffling. Defaults to ``None`` (no seed). This argument *must* be specified when using
:class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker
processes.
use_eager_fetch:
Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made
available for processing via the iterator. This allows network (or filesystem) requests to be made in
Expand All @@ -660,9 +800,12 @@ def __init__(
X_name=X_name,
obs_column_names=obs_column_names,
batch_size=batch_size,
shuffle=shuffle,
seed=seed,
io_batch_size=io_batch_size,
return_sparse_X=return_sparse_X,
use_eager_fetch=use_eager_fetch,
shuffle_chunk_size=shuffle_chunk_size,
)

def __iter__(self) -> Iterator[XObsDatum]:
Expand Down Expand Up @@ -711,6 +854,25 @@ def shape(self) -> Tuple[int, int]:
"""
return self._exp_iter.shape

def set_epoch(self, epoch: int) -> None:
"""
Set the epoch for this Data iterator.
When :attr:`shuffle=True`, this will ensure that all replicas use a different
random ordering for each epoch. Failure to call this method before each epoch
will result in the same data ordering.
This call must be made before the per-epoch iterator is created.
Lifecycle:
experimental
"""
self._exp_iter.set_epoch(epoch)

@property
def epoch(self) -> int:
return self._exp_iter.epoch


def experiment_dataloader(
ds: torchdata.datapipes.iter.IterDataPipe | torch.utils.data.IterableDataset,
Expand Down
Loading

0 comments on commit 1cc3670

Please sign in to comment.