diff --git a/src/spdl/dataloader/_dataloader.py b/src/spdl/dataloader/_dataloader.py index 2c758a0b..9ed58428 100644 --- a/src/spdl/dataloader/_dataloader.py +++ b/src/spdl/dataloader/_dataloader.py @@ -4,8 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -__all__ = ["DataLoader", "MapIterator"] +__all__ = ["DataLoader", "MapIterator", "MergeIterator"] +import random +import sys from collections.abc import ( AsyncIterable, Awaitable, @@ -13,6 +15,7 @@ Iterable, Iterator, Mapping, + Sequence, ) from typing import Generic, TypeAlias, TypeVar @@ -250,7 +253,7 @@ class MapIterator(Iterable[V]): Args: mapping: Object implements :py:class:`~collections.abc.Mapping` interface. sampler: **Optional** Generator that yields key for the mapping. - Used to specify the iteratoin order over the mapping and/or to sample + Used to specify the iteration order over the mapping and/or to sample from a subset of the mapping. Example: @@ -283,3 +286,159 @@ def __init__( def __iter__(self) -> Iterator[V]: for key in self.sampler or self.mapping: yield self.mapping[key] + + +_FIRST_EXHAUSTION = 0 + + +def _ordered_iter(iterators: list[Iterator[T]], stop_after: float) -> Iterable[T]: + num_items = 0 + while iterators: + remove = [] + + for i, iterator in enumerate(iterators): + try: + yield next(iterator) + except StopIteration: + if stop_after == _FIRST_EXHAUSTION: + return + # Insert in reversed order beacause we use this for popping from list + remove.insert(0, i) + continue + + num_items += 1 + if stop_after > 0 and num_items >= stop_after: + return + + if remove: + for i in remove: + iterators.pop(i) + + +def _stocastic_iter( + iterators: list[Iterator[T]], + weights: Sequence[float], + stop_after: float, + seed: int, +) -> Iterable[T]: + # These are all checked in MergeIterator constructor + assert len(iterators) == len(weights) + assert all(w >= sys.float_info.epsilon for w in weights) + + population = list(range(len(iterators))) + rng = random.Random(seed) + num_items = 0 + + not_exhausted = [True for _ in range(len(iterators))] + while any(not_exhausted): + for i in rng.choices(population, weights, k=100): + try: + yield next(iterators[i]) + except StopIteration: + not_exhausted[i] = False + if stop_after == _FIRST_EXHAUSTION: + return + continue + + num_items += 1 + if stop_after > 0 and num_items >= stop_after: + return + + +class MergeIterator(Iterable[T]): + """Iterate over given iterables and yield one item from each iterator. + + + Args: + iterables: The source iterables + probabilities: The probability to choose the next iterable. + If not provided, the given iterables are visited in the given order + repeatedly. + stop_after: Determines the stop criteria or the behavior when one of + the input iterables gets exhausted, + Available values are; + + - ``0``: The iteration stops when one of the iterator is exhausted. + - ``n > 0``: The iteration stops when the specified number of items + are yielded or all the input iterables are exhausted. + - ``-1``: The iteration continues until all the input iterables are + exhausted. + seed: Used to seed the random generator when probabilities is provided. + + Example: + + >>> iterables = [ + ... [0, 1, 2], + ... [10, 11, 12], + ... [20, 21, 22], + ... ] + >>> + >>> print(list(MergeIterator(iterables))) + [0, 10, 20, 1, 11, 21, 2, 12, 22] + >>> + >>> # By default, it stops after one iterable gets exhausted. + >>> iterables = [ + ... [0, 1, 2], + ... [10, 11], + ... [20, 21, 22], + ... ] + >>> + >>> print(list(MergeIterator(iterables))) + [0, 10, 20, 1, 21, 2] # 22 is not included + >>> + >>> # Stop after yielding the given number of items + >>> print(list(MergeIterator(iterables, stop_after=5))) + [0, 10, 20, 1, 11] + >>> + >>> # stop_after>1 ignores the exhaustion. + >>> print(list(MergeIterator(iterables, stop_after=9))) + [0, 10, 20, 1, 11, 21, 2, 22] + >>> + >>> # Providing weights will pick up the iterable stocastically. + >>> print(list(MergeIterator(iterables, stop_after=9, weights=[1, 1, 1]))) + [0, 1, 10, 11, 20, 2, 21, 22] + """ + + def __init__( + self, + iterables: Sequence[Iterable[T]], + *, + weights: Sequence[float] | None = None, + stop_after: int = _FIRST_EXHAUSTION, + seed: int = 0, + ) -> None: + if not iterables: + raise ValueError("iterables cannot be empty.") + + if weights is not None: + if len(weights) != len(iterables): + raise ValueError( + f"The number of probabilities ({len(weights)}) and " + f"iterables ({len(iterables)}) must match." + ) + + # If any of them is 0 or negative, then there is something wrong with + # user logic, so we raise an exception. + if any(w < sys.float_info.epsilon for w in weights): + raise ValueError("Weights must be non-zero and positive.") + + if not stop_after >= -1: + msg = ( + f"`stop_after` must be greater than or equal to -1. Found: {stop_after}" + ) + raise ValueError(msg) + + self.iterables = iterables + self.weights = weights + self.stop_after = stop_after + self.seed = seed + + def __iter__(self) -> Iterator[T]: + iterators = [iter(ite) for ite in self.iterables] + + if self.weights is None: + yield from _ordered_iter(iterators, self.stop_after) + else: + yield from _stocastic_iter( + iterators, self.weights, self.stop_after, self.seed + ) diff --git a/tests/spdl_unittest/dataloader/dataloader_test.py b/tests/spdl_unittest/dataloader/dataloader_test.py index 8670029a..16530917 100644 --- a/tests/spdl_unittest/dataloader/dataloader_test.py +++ b/tests/spdl_unittest/dataloader/dataloader_test.py @@ -6,7 +6,8 @@ import time -from spdl.dataloader import DataLoader, MapIterator +import pytest +from spdl.dataloader import DataLoader, MapIterator, MergeIterator def get_dl(*args, timeout=3, num_threads=2, **kwargs): @@ -181,3 +182,169 @@ def test_mapiterator_sampler(): result = list(MapIterator(mapping, sampler)) assert result == ["e", "c", "a"] + + +def test_mergeiterator_ordered(): + """MergeIterator iterates multiple iterators""" + + iterables = [ + [0, 1, 2], + [10, 11, 12], + [20, 21, 22], + ] + + result = list(MergeIterator(iterables)) + assert result == [0, 10, 20, 1, 11, 21, 2, 12, 22] + + +def test_mergeiterator_ordered_stop_after_first_exhaustion(): + """MergeIterator stops after the first exhaustion""" + + iterables = [ + [0], + [10, 11, 12], + [20, 21, 22], + ] + + result = list(MergeIterator(iterables, stop_after=0)) + assert result == [0, 10, 20] + + iterables = [ + [0, 1, 2], + [10], + [20, 21, 22], + ] + + result = list(MergeIterator(iterables, stop_after=0)) + assert result == [0, 10, 20, 1] + + iterables = [ + [0, 1, 2], + [10, 11], + [20], + ] + + result = list(MergeIterator(iterables, stop_after=0)) + assert result == [0, 10, 20, 1, 11] + + +def test_mergeiterator_ordered_stop_after_N(): + """MergeIterator stops after N items are yielded""" + + iterables = [ + [0, 1, 2], + [10, 11, 12], + [20, 21, 22], + ] + + result = list(MergeIterator(iterables, stop_after=1)) + assert result == [0] + + result = list(MergeIterator(iterables, stop_after=5)) + assert result == [0, 10, 20, 1, 11] + + result = list(MergeIterator(iterables, stop_after=7)) + assert result == [0, 10, 20, 1, 11, 21, 2] + + +def test_mergeiterator_ordered_stop_after_minus1(): + """MergeIterator stops after all the iterables are exhausted""" + + iterables = [ + [0, 1, 2], + [10, 11, 12], + [20, 21, 22], + ] + + result = list(MergeIterator(iterables, stop_after=-1)) + assert result == [0, 10, 20, 1, 11, 21, 2, 12, 22] + + iterables = [ + [0, 1, 2], + [10], + [20, 21, 22], + ] + + result = list(MergeIterator(iterables, stop_after=-1)) + assert result == [0, 10, 20, 1, 21, 2, 22] + + iterables = [ + [0, 1, 2], + [10, 11, 12], + [20], + ] + + result = list(MergeIterator(iterables, stop_after=-1)) + assert result == [0, 10, 20, 1, 11, 2, 12] + + +def test_mergeiterator_ordered_n(): + """with stop_after=N, MergeIterator continues iterating after encountering an exhaustion.""" + iterables = [ + [0, 1, 2], + [10], + [20, 21, 22], + ] + + result = list(MergeIterator(iterables, stop_after=5)) + assert result == [0, 10, 20, 1, 21] + + result = list(MergeIterator(iterables, stop_after=7)) + assert result == [0, 10, 20, 1, 21, 2, 22] + + result = list(MergeIterator(iterables, stop_after=8)) + assert result == [0, 10, 20, 1, 21, 2, 22] + + +def test_mergeiterator_stochastic_smoke_test(): + """MergeIterator with probabilitiies do not get stuck.""" + + iterables = [ + [0, 1, 2], + [10, 11, 12], + [20, 21, 22], + ] + + weights = [1, 1, 1] + + result = list(MergeIterator(iterables, weights=weights, stop_after=-1)) + assert set(result) == {0, 1, 2, 10, 11, 12, 20, 21, 22} + + +def test_mergeiterator_stochastic_rejects_zero(): + """weight=0 is rejected.""" + weights = [1, 0] + + with pytest.raises(ValueError): + MergeIterator([[1]], weights=weights) + + weights = [1, 0.0] + + with pytest.raises(ValueError): + MergeIterator([[1]], weights=weights) + + +def test_mergeiterator_stochastic_stop_after_N(): + """Values are taken from iterables with higher weights""" + weights = [1000000, 1] + + iterables = [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], + ] + + result = list(MergeIterator(iterables, weights=weights, stop_after=3)) + assert result == [0, 1, 2] + + +def test_mergeiterator_stochastic_stop_after_first_exhaustion(): + """Values are taken from iterables with higher weights""" + weights = [1000000, 1] + + iterables = [ + [0, 1, 2, 3], + [10, 11, 12, 13], + ] + + result = list(MergeIterator(iterables, weights=weights, stop_after=0)) + assert result == [0, 1, 2, 3]