Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add experiment_dataloader helper API (PR 5 of N) #8

Merged
merged 1 commit into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/tiledbsoma_ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from .pytorch import (
ExperimentAxisQueryIterableDataset,
ExperimentAxisQueryIterDataPipe,
experiment_dataloader,
)

__version__ = "0.1.0-dev"

__all__ = [
"ExperimentAxisQueryIterDataPipe",
"ExperimentAxisQueryIterableDataset",
"experiment_dataloader",
]
87 changes: 87 additions & 0 deletions src/tiledbsoma_ml/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,73 @@ def shape(self) -> Tuple[int, int]:
return self._exp_iter.shape


def experiment_dataloader(
ds: torchdata.datapipes.iter.IterDataPipe | torch.utils.data.IterableDataset,
**dataloader_kwargs: Any,
) -> torch.utils.data.DataLoader:
"""Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely instantiate a
:class:`torch.utils.data.DataLoader` that works with :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset`
or :class:`tiledbsoma_ml.ExperimentAxisQueryIterDataPipe`.

Several :class:`torch.utils.data.DataLoader` constructor parameters are not applicable, or are non-performant,
when using loaders from this module, including ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``.
Specifying any of these parameters will result in an error.

Refer to ``https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader`` for more information on
:class:`torch.utils.data.DataLoader` parameters.

Args:
ds:
A :class:`torch.utils.data.IterableDataset` or a :class:`torchdata.datapipes.iter.IterDataPipe`. May
include chained data pipes.
**dataloader_kwargs:
Additional keyword arguments to pass to the :class:`torch.utils.data.DataLoader` constructor,
except for ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``, which are not
supported when using data loaders in this module.

Returns:
A :class:`torch.utils.data.DataLoader`.

Raises:
ValueError: if any of the ``shuffle``, ``batch_size``, ``sampler``, or ``batch_sampler`` params
are passed as keyword arguments.

Lifecycle:
experimental
"""
unsupported_dataloader_args = [
"shuffle",
"batch_size",
"sampler",
"batch_sampler",
]
if set(unsupported_dataloader_args).intersection(dataloader_kwargs.keys()):
raise ValueError(
f"The {','.join(unsupported_dataloader_args)} DataLoader parameters are not supported"
)

if dataloader_kwargs.get("num_workers", 0) > 0:
_init_multiprocessing()

if "collate_fn" not in dataloader_kwargs:
dataloader_kwargs["collate_fn"] = _collate_noop

return torch.utils.data.DataLoader(
ds,
batch_size=None, # batching is handled by upstream iterator
shuffle=False, # shuffling is handled by upstream iterator
**dataloader_kwargs,
)


def _collate_noop(datum: _T) -> _T:
"""Noop collation for use with a dataloader instance.

Private.
"""
return datum


def _splits(total_length: int, sections: int) -> npt.NDArray[np.intp]:
"""For ``total_length`` points, compute start/stop offsets that split the length into roughly equal sizes.

Expand Down Expand Up @@ -784,3 +851,23 @@ def _get_worker_world_rank() -> Tuple[int, int]:
num_workers = worker_info.num_workers
worker = worker_info.id
return num_workers, worker


def _init_multiprocessing() -> None:
"""Ensures use of "spawn" for starting child processes with multiprocessing.

Forked processes are known to be problematic:
https://pytorch.org/docs/stable/notes/multiprocessing.html#avoiding-and-fighting-deadlocks
Also, CUDA does not support forked child processes:
https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing

Private.
"""
orig_start_method = torch.multiprocessing.get_start_method()
if orig_start_method != "spawn":
if orig_start_method:
logger.warning(
"switching torch multiprocessing start method from "
f'"{torch.multiprocessing.get_start_method()}" to "spawn"'
)
torch.multiprocessing.set_start_method("spawn", force=True)
195 changes: 194 additions & 1 deletion tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

from functools import partial
from pathlib import Path
from typing import Callable, Optional, Sequence, Union
from typing import Any, Callable, Optional, Sequence, Tuple, Union
from unittest.mock import patch

import numpy as np
import numpy.typing as npt
import pandas as pd
import pyarrow as pa
import pytest
Expand All @@ -26,6 +27,7 @@
ExperimentAxisQueryIterable,
ExperimentAxisQueryIterableDataset,
ExperimentAxisQueryIterDataPipe,
experiment_dataloader,
)

assert_array_equal = partial(np.testing.assert_array_equal, strict=True)
Expand Down Expand Up @@ -436,6 +438,37 @@ def test_batching__partial_soma_batches_are_concatenated(
assert [len(batch[0]) for batch in batches] == [3, 3, 3, 1]


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)]
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
def test_multiprocessing__returns_full_result(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
soma_experiment: Experiment,
) -> None:
"""Tests the ExperimentAxisQueryIterDataPipe provides all data, as collected from multiple processes that are managed by a
PyTorch DataLoader with multiple workers configured."""
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
obs_column_names=["soma_joinid", "label"],
io_batch_size=3, # two chunks, one per worker
)
# Note we're testing the ExperimentAxisQueryIterDataPipe via a DataLoader, since this is what sets up the multiprocessing
dl = experiment_dataloader(dp, num_workers=2)

full_result = list(iter(dl))

soma_joinids = np.concatenate(
[t[1]["soma_joinid"].to_numpy() for t in full_result]
)
assert sorted(soma_joinids) == list(range(6))


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen",
[(6, 3, pytorch_x_value_gen), (7, 3, pytorch_x_value_gen)],
Expand Down Expand Up @@ -545,6 +578,166 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank(
assert soma_joinids == expected_joinids


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
def test_experiment_dataloader__non_batched(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
soma_experiment: Experiment,
use_eager_fetch: bool,
) -> None:
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
obs_column_names=["label"],
use_eager_fetch=use_eager_fetch,
)
dl = experiment_dataloader(dp)
data = [row for row in dl]
assert all(d[0].shape == (3,) for d in data)
assert all(d[1].shape == (1, 1) for d in data)

row = data[0]
assert row[0].tolist() == [0, 1, 0]
assert row[1]["label"].tolist() == ["0"]


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
def test_experiment_dataloader__batched(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
soma_experiment: Experiment,
use_eager_fetch: bool,
) -> None:
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
batch_size=3,
use_eager_fetch=use_eager_fetch,
)
dl = experiment_dataloader(dp)
data = [row for row in dl]

batch = data[0]
assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]]
assert batch[1].to_numpy().tolist() == [[0], [1], [2]]


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen,use_eager_fetch",
[
(10, 3, pytorch_x_value_gen, use_eager_fetch)
for use_eager_fetch in (True, False)
],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
def test_experiment_dataloader__batched_length(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
soma_experiment: Experiment,
use_eager_fetch: bool,
) -> None:
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
obs_column_names=["label"],
batch_size=3,
use_eager_fetch=use_eager_fetch,
)
dl = experiment_dataloader(dp)
assert len(dl) == len(list(dl))


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen,batch_size",
[(10, 3, pytorch_x_value_gen, batch_size) for batch_size in (1, 3, 10)],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
def test_experiment_dataloader__collate_fn(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
soma_experiment: Experiment,
batch_size: int,
) -> None:
def collate_fn(
batch_size: int, data: Tuple[npt.NDArray[np.number[Any]], pd.DataFrame]
) -> Tuple[npt.NDArray[np.number[Any]], pd.DataFrame]:
assert isinstance(data, tuple)
assert len(data) == 2
assert isinstance(data[0], np.ndarray) and isinstance(data[1], pd.DataFrame)
if batch_size > 1:
assert data[0].shape[0] == data[1].shape[0]
assert data[0].shape[0] <= batch_size
else:
assert data[0].ndim == 1
assert data[1].shape[1] <= batch_size
return data

with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
obs_column_names=["label"],
batch_size=batch_size,
)
dl = experiment_dataloader(dp, collate_fn=partial(collate_fn, batch_size))
assert len(list(dl)) > 0


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen", [(10, 1, pytorch_x_value_gen)]
)
def test__pytorch_splitting(
soma_experiment: Experiment,
) -> None:
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = ExperimentAxisQueryIterDataPipe(
query,
X_name="raw",
obs_column_names=["label"],
)
# function not available for IterableDataset, yet....
dp_train, dp_test = dp.random_split(
weights={"train": 0.7, "test": 0.3}, seed=1234
)
dl = experiment_dataloader(dp_train)

all_rows = list(iter(dl))
assert len(all_rows) == 7


def test_experiment_dataloader__unsupported_params__fails() -> None:
with patch(
"tiledbsoma_ml.pytorch.ExperimentAxisQueryIterDataPipe"
) as dummy_exp_data_pipe:
with pytest.raises(ValueError):
experiment_dataloader(dummy_exp_data_pipe, shuffle=True)
with pytest.raises(ValueError):
experiment_dataloader(dummy_exp_data_pipe, batch_size=3)
with pytest.raises(ValueError):
experiment_dataloader(dummy_exp_data_pipe, batch_sampler=[])
with pytest.raises(ValueError):
experiment_dataloader(dummy_exp_data_pipe, sampler=[])


def test_batched() -> None:
from tiledbsoma_ml.pytorch import _batched

Expand Down