Skip to content

Commit

Permalink
Add MergeIterator (#283)
Browse files Browse the repository at this point in the history
MergeIterator can merge multiple source iterators (of same types),
optionally adding stochastic sampling over the iterators.
  • Loading branch information
mthrok authored Nov 19, 2024
1 parent 204dd7d commit ce85f8e
Show file tree
Hide file tree
Showing 2 changed files with 329 additions and 3 deletions.
163 changes: 161 additions & 2 deletions src/spdl/dataloader/_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

__all__ = ["DataLoader", "MapIterator"]
__all__ = ["DataLoader", "MapIterator", "MergeIterator"]

import random
import sys
from collections.abc import (
AsyncIterable,
Awaitable,
Callable,
Iterable,
Iterator,
Mapping,
Sequence,
)
from typing import Generic, TypeAlias, TypeVar

Expand Down Expand Up @@ -250,7 +253,7 @@ class MapIterator(Iterable[V]):
Args:
mapping: Object implements :py:class:`~collections.abc.Mapping` interface.
sampler: **Optional** Generator that yields key for the mapping.
Used to specify the iteratoin order over the mapping and/or to sample
Used to specify the iteration order over the mapping and/or to sample
from a subset of the mapping.
Example:
Expand Down Expand Up @@ -283,3 +286,159 @@ def __init__(
def __iter__(self) -> Iterator[V]:
for key in self.sampler or self.mapping:
yield self.mapping[key]


_FIRST_EXHAUSTION = 0


def _ordered_iter(iterators: list[Iterator[T]], stop_after: float) -> Iterable[T]:
num_items = 0
while iterators:
remove = []

for i, iterator in enumerate(iterators):
try:
yield next(iterator)
except StopIteration:
if stop_after == _FIRST_EXHAUSTION:
return
# Insert in reversed order beacause we use this for popping from list
remove.insert(0, i)
continue

num_items += 1
if stop_after > 0 and num_items >= stop_after:
return

if remove:
for i in remove:
iterators.pop(i)


def _stocastic_iter(
iterators: list[Iterator[T]],
weights: Sequence[float],
stop_after: float,
seed: int,
) -> Iterable[T]:
# These are all checked in MergeIterator constructor
assert len(iterators) == len(weights)
assert all(w >= sys.float_info.epsilon for w in weights)

population = list(range(len(iterators)))
rng = random.Random(seed)
num_items = 0

not_exhausted = [True for _ in range(len(iterators))]
while any(not_exhausted):
for i in rng.choices(population, weights, k=100):
try:
yield next(iterators[i])
except StopIteration:
not_exhausted[i] = False
if stop_after == _FIRST_EXHAUSTION:
return
continue

num_items += 1
if stop_after > 0 and num_items >= stop_after:
return


class MergeIterator(Iterable[T]):
"""Iterate over given iterables and yield one item from each iterator.
Args:
iterables: The source iterables
probabilities: The probability to choose the next iterable.
If not provided, the given iterables are visited in the given order
repeatedly.
stop_after: Determines the stop criteria or the behavior when one of
the input iterables gets exhausted,
Available values are;
- ``0``: The iteration stops when one of the iterator is exhausted.
- ``n > 0``: The iteration stops when the specified number of items
are yielded or all the input iterables are exhausted.
- ``-1``: The iteration continues until all the input iterables are
exhausted.
seed: Used to seed the random generator when probabilities is provided.
Example:
>>> iterables = [
... [0, 1, 2],
... [10, 11, 12],
... [20, 21, 22],
... ]
>>>
>>> print(list(MergeIterator(iterables)))
[0, 10, 20, 1, 11, 21, 2, 12, 22]
>>>
>>> # By default, it stops after one iterable gets exhausted.
>>> iterables = [
... [0, 1, 2],
... [10, 11],
... [20, 21, 22],
... ]
>>>
>>> print(list(MergeIterator(iterables)))
[0, 10, 20, 1, 21, 2] # 22 is not included
>>>
>>> # Stop after yielding the given number of items
>>> print(list(MergeIterator(iterables, stop_after=5)))
[0, 10, 20, 1, 11]
>>>
>>> # stop_after>1 ignores the exhaustion.
>>> print(list(MergeIterator(iterables, stop_after=9)))
[0, 10, 20, 1, 11, 21, 2, 22]
>>>
>>> # Providing weights will pick up the iterable stocastically.
>>> print(list(MergeIterator(iterables, stop_after=9, weights=[1, 1, 1])))
[0, 1, 10, 11, 20, 2, 21, 22]
"""

def __init__(
self,
iterables: Sequence[Iterable[T]],
*,
weights: Sequence[float] | None = None,
stop_after: int = _FIRST_EXHAUSTION,
seed: int = 0,
) -> None:
if not iterables:
raise ValueError("iterables cannot be empty.")

if weights is not None:
if len(weights) != len(iterables):
raise ValueError(
f"The number of probabilities ({len(weights)}) and "
f"iterables ({len(iterables)}) must match."
)

# If any of them is 0 or negative, then there is something wrong with
# user logic, so we raise an exception.
if any(w < sys.float_info.epsilon for w in weights):
raise ValueError("Weights must be non-zero and positive.")

if not stop_after >= -1:
msg = (
f"`stop_after` must be greater than or equal to -1. Found: {stop_after}"
)
raise ValueError(msg)

self.iterables = iterables
self.weights = weights
self.stop_after = stop_after
self.seed = seed

def __iter__(self) -> Iterator[T]:
iterators = [iter(ite) for ite in self.iterables]

if self.weights is None:
yield from _ordered_iter(iterators, self.stop_after)
else:
yield from _stocastic_iter(
iterators, self.weights, self.stop_after, self.seed
)
169 changes: 168 additions & 1 deletion tests/spdl_unittest/dataloader/dataloader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import time

from spdl.dataloader import DataLoader, MapIterator
import pytest
from spdl.dataloader import DataLoader, MapIterator, MergeIterator


def get_dl(*args, timeout=3, num_threads=2, **kwargs):
Expand Down Expand Up @@ -181,3 +182,169 @@ def test_mapiterator_sampler():

result = list(MapIterator(mapping, sampler))
assert result == ["e", "c", "a"]


def test_mergeiterator_ordered():
"""MergeIterator iterates multiple iterators"""

iterables = [
[0, 1, 2],
[10, 11, 12],
[20, 21, 22],
]

result = list(MergeIterator(iterables))
assert result == [0, 10, 20, 1, 11, 21, 2, 12, 22]


def test_mergeiterator_ordered_stop_after_first_exhaustion():
"""MergeIterator stops after the first exhaustion"""

iterables = [
[0],
[10, 11, 12],
[20, 21, 22],
]

result = list(MergeIterator(iterables, stop_after=0))
assert result == [0, 10, 20]

iterables = [
[0, 1, 2],
[10],
[20, 21, 22],
]

result = list(MergeIterator(iterables, stop_after=0))
assert result == [0, 10, 20, 1]

iterables = [
[0, 1, 2],
[10, 11],
[20],
]

result = list(MergeIterator(iterables, stop_after=0))
assert result == [0, 10, 20, 1, 11]


def test_mergeiterator_ordered_stop_after_N():
"""MergeIterator stops after N items are yielded"""

iterables = [
[0, 1, 2],
[10, 11, 12],
[20, 21, 22],
]

result = list(MergeIterator(iterables, stop_after=1))
assert result == [0]

result = list(MergeIterator(iterables, stop_after=5))
assert result == [0, 10, 20, 1, 11]

result = list(MergeIterator(iterables, stop_after=7))
assert result == [0, 10, 20, 1, 11, 21, 2]


def test_mergeiterator_ordered_stop_after_minus1():
"""MergeIterator stops after all the iterables are exhausted"""

iterables = [
[0, 1, 2],
[10, 11, 12],
[20, 21, 22],
]

result = list(MergeIterator(iterables, stop_after=-1))
assert result == [0, 10, 20, 1, 11, 21, 2, 12, 22]

iterables = [
[0, 1, 2],
[10],
[20, 21, 22],
]

result = list(MergeIterator(iterables, stop_after=-1))
assert result == [0, 10, 20, 1, 21, 2, 22]

iterables = [
[0, 1, 2],
[10, 11, 12],
[20],
]

result = list(MergeIterator(iterables, stop_after=-1))
assert result == [0, 10, 20, 1, 11, 2, 12]


def test_mergeiterator_ordered_n():
"""with stop_after=N, MergeIterator continues iterating after encountering an exhaustion."""
iterables = [
[0, 1, 2],
[10],
[20, 21, 22],
]

result = list(MergeIterator(iterables, stop_after=5))
assert result == [0, 10, 20, 1, 21]

result = list(MergeIterator(iterables, stop_after=7))
assert result == [0, 10, 20, 1, 21, 2, 22]

result = list(MergeIterator(iterables, stop_after=8))
assert result == [0, 10, 20, 1, 21, 2, 22]


def test_mergeiterator_stochastic_smoke_test():
"""MergeIterator with probabilitiies do not get stuck."""

iterables = [
[0, 1, 2],
[10, 11, 12],
[20, 21, 22],
]

weights = [1, 1, 1]

result = list(MergeIterator(iterables, weights=weights, stop_after=-1))
assert set(result) == {0, 1, 2, 10, 11, 12, 20, 21, 22}


def test_mergeiterator_stochastic_rejects_zero():
"""weight=0 is rejected."""
weights = [1, 0]

with pytest.raises(ValueError):
MergeIterator([[1]], weights=weights)

weights = [1, 0.0]

with pytest.raises(ValueError):
MergeIterator([[1]], weights=weights)


def test_mergeiterator_stochastic_stop_after_N():
"""Values are taken from iterables with higher weights"""
weights = [1000000, 1]

iterables = [
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
]

result = list(MergeIterator(iterables, weights=weights, stop_after=3))
assert result == [0, 1, 2]


def test_mergeiterator_stochastic_stop_after_first_exhaustion():
"""Values are taken from iterables with higher weights"""
weights = [1000000, 1]

iterables = [
[0, 1, 2, 3],
[10, 11, 12, 13],
]

result = list(MergeIterator(iterables, weights=weights, stop_after=0))
assert result == [0, 1, 2, 3]

0 comments on commit ce85f8e

Please sign in to comment.