Skip to content

Commit

Permalink
Add MergeIterator
Browse files Browse the repository at this point in the history
MergeIterator can merge multiple source iterators (of same types),
optionally adding stochastic sampling over the iterators.
  • Loading branch information
mthrok committed Nov 19, 2024
1 parent 204dd7d commit 5b4ae5d
Show file tree
Hide file tree
Showing 2 changed files with 329 additions and 3 deletions.
163 changes: 161 additions & 2 deletions src/spdl/dataloader/_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

__all__ = ["DataLoader", "MapIterator"]
__all__ = ["DataLoader", "MapIterator", "MergeIterator"]

import random
import sys
from collections.abc import (
AsyncIterable,
Awaitable,
Callable,
Iterable,
Iterator,
Mapping,
Sequence,
)
from typing import Generic, TypeAlias, TypeVar

Expand Down Expand Up @@ -250,7 +253,7 @@ class MapIterator(Iterable[V]):
Args:
mapping: Object implements :py:class:`~collections.abc.Mapping` interface.
sampler: **Optional** Generator that yields key for the mapping.
Used to specify the iteratoin order over the mapping and/or to sample
Used to specify the iteration order over the mapping and/or to sample
from a subset of the mapping.
Example:
Expand Down Expand Up @@ -283,3 +286,159 @@ def __init__(
def __iter__(self) -> Iterator[V]:
for key in self.sampler or self.mapping:
yield self.mapping[key]


_FIRST_EXHAUSTION = 0


def _ordered_iter(iterators: list[Iterator[T]], stop_after: float) -> Iterable[T]:
num_items = 0
while iterators:
remove = []

for i, iterator in enumerate(iterators):
try:
yield next(iterator)
except StopIteration:
if stop_after == _FIRST_EXHAUSTION:
return
# Insert in reversed order beacause we use this for popping from list
remove.insert(0, i)
continue

num_items += 1
if stop_after > 0 and num_items >= stop_after:
return

if remove:
for i in remove:
iterators.pop(i)


def _stocastic_iter(
iterators: list[Iterator[T]],
weights: Sequence[float],
stop_after: float,
seed: int,
) -> Iterable[T]:
# These are all checked in MergeIterator constructor
assert len(iterators) == len(weights)
assert all(w >= sys.float_info.epsilon for w in weights)

population = list(range(len(iterators)))
rng = random.Random(seed)
num_items = 0

not_exhausted = [True for _ in range(len(iterators))]
while any(not_exhausted):
for i in rng.choices(population, weights, k=100):
try:
yield next(iterators[i])
except StopIteration:
not_exhausted[i] = False
if stop_after == _FIRST_EXHAUSTION:
return
continue

num_items += 1
if stop_after > 0 and num_items >= stop_after:
return


class MergeIterator(Iterable[T]):
"""Iterate over given iterables and yield one item from each iterator.
Args:
iterables: The source iterables
probabilities: The probability to choose the next iterable.
If not provided, the given iterables are visited in the given order
repeatedly.
stop_after: Determines the stop criteria or the behavior when one of
the input iterables gets exhausted,
Available values are;
- ``0``: The iteration stops when one of the iterator is exhausted.
- ``n > 0``: The iteration stops when the specified number of items
are yielded or all the input iterables are exhausted.
- ``-1``: The iteration continues until all the input iterables are
exhausted.
seed: Used to seed the random generator when probabilities is provided.
Example:
>>> iterables = [
... [0, 1, 2],
... [10, 11, 12],
... [20, 21, 22],
... ]
>>>
>>> print(list(MergeIterator(iterables)))
[0, 10, 20, 1, 11, 21, 2, 12, 22]
>>>
>>> # By default, it stops after one iterable gets exhausted.
>>> iterables = [
... [0, 1, 2],
... [10, 11],
... [20, 21, 22],
... ]
>>>
>>> print(list(MergeIterator(iterables)))
[0, 10, 20, 1, 21, 2] # 22 is not included
>>>
>>> # Stop after yielding the given number of items
>>> print(list(MergeIterator(iterables, stop_after=5)))
[0, 10, 20, 1, 11]
>>>
>>> # stop_after>1 ignores the exhaustion.
>>> print(list(MergeIterator(iterables, stop_after=9)))
[0, 10, 20, 1, 11, 21, 2, 22]
>>>
>>> # Providing weights will pick up the iterable stocastically.
>>> print(list(MergeIterator(iterables, stop_after=9, weights=[1, 1, 1])))
[0, 1, 10, 11, 20, 2, 21, 22]
"""

def __init__(
self,
iterables: Sequence[Iterable[T]],
*,
weights: Sequence[float] | None = None,
stop_after: int = _FIRST_EXHAUSTION,
seed: int = 0,
) -> None:
if not iterables:
raise ValueError("iterables cannot be empty.")

if weights is not None:
if len(weights) != len(iterables):
raise ValueError(
f"The number of probabilities ({len(weights)}) and "
f"iterables ({len(iterables)}) must match."
)

# If any of them is 0 or negative, then there is something wrong with
# user logic, so we raise an exception.
if any(w < sys.float_info.epsilon for w in weights):
raise ValueError("Weights must be non-zero and positive.")

if not stop_after >= -1:
msg = (
f"`stop_after` must be greater than or equal to -1. Found: {stop_after}"
)
raise ValueError(msg)

self.iterables = iterables
self.weights = weights
self.stop_after = stop_after
self.seed = seed

def __iter__(self) -> Iterator[T]:
iterators = [iter(ite) for ite in self.iterables]

if self.weights is None:
yield from _ordered_iter(iterators, self.stop_after)
else:
yield from _stocastic_iter(
iterators, self.weights, self.stop_after, self.seed
)
169 changes: 168 additions & 1 deletion tests/spdl_unittest/dataloader/dataloader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import time

from spdl.dataloader import DataLoader, MapIterator
import pytest
from spdl.dataloader import DataLoader, MapIterator, MergeIterator


def get_dl(*args, timeout=3, num_threads=2, **kwargs):
Expand Down Expand Up @@ -181,3 +182,169 @@ def test_mapiterator_sampler():

result = list(MapIterator(mapping, sampler))
assert result == ["e", "c", "a"]


def test_mergeiterator_ordered():
"""MergeIterator iterates multiple iterators"""

iterables = [
[0, 1, 2],
[10, 11, 12],
[20, 21, 22],
]

result = list(MergeIterator(iterables))
assert result == [0, 10, 20, 1, 11, 21, 2, 12, 22]


def test_mergeiterator_ordered_stop_after_first_exhaustion():
"""MergeIterator stops after the first exhaustion"""

iterables = [
[0],
[10, 11, 12],
[20, 21, 22],
]

result = list(MergeIterator(iterables, stop_after=0))
assert result == [0, 10, 20]

iterables = [
[0, 1, 2],
[10],
[20, 21, 22],
]

result = list(MergeIterator(iterables, stop_after=0))
assert result == [0, 10, 20, 1]

iterables = [
[0, 1, 2],
[10, 11],
[20],
]

result = list(MergeIterator(iterables, stop_after=0))
assert result == [0, 10, 20, 1, 11]


def test_mergeiterator_ordered_stop_after_N():
"""MergeIterator stops after N items are yielded"""

iterables = [
[0, 1, 2],
[10, 11, 12],
[20, 21, 22],
]

result = list(MergeIterator(iterables, stop_after=1))
assert result == [0]

result = list(MergeIterator(iterables, stop_after=5))
assert result == [0, 10, 20, 1, 11]

result = list(MergeIterator(iterables, stop_after=7))
assert result == [0, 10, 20, 1, 11, 21, 2]


def test_mergeiterator_ordered_stop_after_minus1():
"""MergeIterator stops after all the iterables are exhausted"""

iterables = [
[0, 1, 2],
[10, 11, 12],
[20, 21, 22],
]

result = list(MergeIterator(iterables, stop_after=-1))
assert result == [0, 10, 20, 1, 11, 21, 2, 12, 22]

iterables = [
[0, 1, 2],
[10],
[20, 21, 22],
]

result = list(MergeIterator(iterables, stop_after=-1))
assert result == [0, 10, 20, 1, 21, 2, 22]

iterables = [
[0, 1, 2],
[10, 11, 12],
[20],
]

result = list(MergeIterator(iterables, stop_after=-1))
assert result == [0, 10, 20, 1, 11, 2, 12]


def test_mergeiterator_ordered_n():
"""with stop_after=N, MergeIterator continues iterating after encountering an exhaustion."""
iterables = [
[0, 1, 2],
[10],
[20, 21, 22],
]

result = list(MergeIterator(iterables, stop_after=5))
assert result == [0, 10, 20, 1, 21]

result = list(MergeIterator(iterables, stop_after=7))
assert result == [0, 10, 20, 1, 21, 2, 22]

result = list(MergeIterator(iterables, stop_after=8))
assert result == [0, 10, 20, 1, 21, 2, 22]


def test_mergeiterator_stochastic_smoke_test():
"""MergeIterator with probabilitiies do not get stuck."""

iterables = [
[0, 1, 2],
[10, 11, 12],
[20, 21, 22],
]

weights = [1, 1, 1]

result = list(MergeIterator(iterables, weights=weights, stop_after=-1))
assert set(result) == {0, 1, 2, 10, 11, 12, 20, 21, 22}


def test_mergeiterator_stochastic_rejects_zero():
"""weight=0 is rejected."""
weights = [1, 0]

with pytest.raises(ValueError):
MergeIterator([[1]], weights=weights)

weights = [1, 0.0]

with pytest.raises(ValueError):
MergeIterator([[1]], weights=weights)


def test_mergeiterator_stochastic_stop_after_N():
"""Values are taken from iterables with higher weights"""
weights = [1000000, 1]

iterables = [
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
]

result = list(MergeIterator(iterables, weights=weights, stop_after=3))
assert result == [0, 1, 2]


def test_mergeiterator_stochastic_stop_after_first_exhaustion():
"""Values are taken from iterables with higher weights"""
weights = [1000000, 1]

iterables = [
[0, 1, 2, 3],
[10, 11, 12, 13],
]

result = list(MergeIterator(iterables, weights=weights, stop_after=0))
assert result == [0, 1, 2, 3]

0 comments on commit 5b4ae5d

Please sign in to comment.