Skip to content

Commit

Permalink
Add tests for duplicate dims in batch_dims and input_dims (#139)
Browse files Browse the repository at this point in the history
* Add tests for duplicate dims in batch_dims and input_dims

* Update docstring

* Update xbatcher/tests/test_generators.py

Co-authored-by: Anderson Banihirwe <[email protected]>

* Remove more unnecessary enumerate()

* Update xbatcher/tests/test_generators.py

Co-authored-by: Raphael Hagen <[email protected]>

* Add more xfail markers

Co-authored-by: Anderson Banihirwe <[email protected]>
Co-authored-by: Raphael Hagen <[email protected]>
  • Loading branch information
3 people authored Dec 16, 2022
1 parent 663efa5 commit c0dd002
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 21 deletions.
70 changes: 53 additions & 17 deletions xbatcher/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ def _get_non_specified_dims(generator: BatchGenerator) -> Dict[Hashable, int]:
in the input_dims or batch_dims attributes of the batch generator.
"""
return {
k: v
for k, v in generator.ds.sizes.items()
if (generator.input_dims.get(k) is None and generator.batch_dims.get(k) is None)
dim: length
for dim, length in generator.ds.sizes.items()
if generator.input_dims.get(dim) is None
and generator.batch_dims.get(dim) is None
}


Expand All @@ -46,9 +47,30 @@ def _get_non_input_batch_dims(generator: BatchGenerator) -> Dict[Hashable, int]:
not also in input_dims
"""
return {
k: v
for k, v in generator.batch_dims.items()
if (generator.input_dims.get(k) is None)
dim: length
for dim, length in generator.batch_dims.items()
if generator.input_dims.get(dim) is None
}


def _get_duplicate_batch_dims(generator: BatchGenerator) -> Dict[Hashable, int]:
"""
Return all dimensions that are in both batch_dims and input_dims.
Parameters
----------
generator : xbatcher.BatchGenerator
The batch generator object.
Returns
-------
d : dict
Dict containing all dimensions duplicated between batch_dims and input_dims.
"""
return {
dim: length
for dim, length in generator.batch_dims.items()
if generator.input_dims.get(dim) is not None
}


Expand Down Expand Up @@ -188,17 +210,18 @@ def _get_nbatches_from_input_dims(generator: BatchGenerator) -> int:
"""
nbatches_from_input_dims = np.product(
[
generator.ds.sizes[k] // generator.input_dims[k]
for k in generator.input_dims.keys()
if generator.input_overlap.get(k) is None
generator.ds.sizes[dim] // length
for dim, length in generator.input_dims.items()
if generator.input_overlap.get(dim) is None
and generator.batch_dims.get(dim) is None
]
)
if generator.input_overlap:
nbatches_from_input_overlap = np.product(
[
(generator.ds.sizes[k] - generator.input_overlap[k])
// (generator.input_dims[k] - generator.input_overlap[k])
for k in generator.input_overlap
(generator.ds.sizes[dim] - overlap)
// (generator.input_dims[dim] - overlap)
for dim, overlap in generator.input_overlap.items()
]
)
return int(nbatches_from_input_overlap * nbatches_from_input_dims)
Expand All @@ -217,17 +240,30 @@ def validate_generator_length(generator: BatchGenerator) -> None:
The batch generator object.
"""
non_input_batch_dims = _get_non_input_batch_dims(generator)
nbatches_from_batch_dims = np.product(
duplicate_batch_dims = _get_duplicate_batch_dims(generator)
nbatches_from_unique_batch_dims = np.product(
[
generator.ds.sizes[k] // non_input_batch_dims[k]
for k in non_input_batch_dims.keys()
generator.ds.sizes[dim] // length
for dim, length in non_input_batch_dims.items()
]
)
nbatches_from_duplicate_batch_dims = np.product(
[
generator.ds.sizes[dim] // length
for dim, length in duplicate_batch_dims.items()
]
)
if generator.concat_input_dims:
expected_length = int(nbatches_from_batch_dims)
expected_length = int(
nbatches_from_unique_batch_dims * nbatches_from_duplicate_batch_dims
)
else:
nbatches_from_input_dims = _get_nbatches_from_input_dims(generator)
expected_length = int(nbatches_from_batch_dims * nbatches_from_input_dims)
expected_length = int(
nbatches_from_unique_batch_dims
* nbatches_from_duplicate_batch_dims
* nbatches_from_input_dims
)
TestCase().assertEqual(
expected_length,
len(generator),
Expand Down
49 changes: 45 additions & 4 deletions xbatcher/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,30 @@ def test_batch_1d_concat(sample_ds_1d, input_size):
)
validate_generator_length(bg)
expected_dims = get_batch_dimensions(bg)
for n, ds_batch in enumerate(bg):
for ds_batch in bg:
assert isinstance(ds_batch, xr.Dataset)
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)
assert "x" in ds_batch.coords


@pytest.mark.xfail(
reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131"
)
def test_batch_1d_concat_duplicate_dim(sample_ds_1d):
"""
Test batch generation for a 1D dataset using ``concat_input_dims`` when
the same dimension occurs in ``input_dims`` and `batch_dims``
"""
bg = BatchGenerator(
sample_ds_1d, input_dims={"x": 5}, batch_dims={"x": 10}, concat_input_dims=True
)
validate_generator_length(bg)
expected_dims = get_batch_dimensions(bg)
for ds_batch in bg:
assert isinstance(ds_batch, xr.Dataset)
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)


@pytest.mark.parametrize("input_size", [5, 10])
def test_batch_1d_no_coordinate(sample_ds_1d, input_size):
"""
Expand Down Expand Up @@ -148,7 +166,7 @@ def test_batch_1d_concat_no_coordinate(sample_ds_1d, input_size):
)
validate_generator_length(bg)
expected_dims = get_batch_dimensions(bg)
for n, ds_batch in enumerate(bg):
for ds_batch in bg:
assert isinstance(ds_batch, xr.Dataset)
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)
assert "x" not in ds_batch.coords
Expand Down Expand Up @@ -201,6 +219,9 @@ def test_batch_3d_1d_input(sample_ds_3d, input_size):
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)


@pytest.mark.xfail(
reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131"
)
@pytest.mark.parametrize("concat", [True, False])
def test_batch_3d_1d_input_batch_dims(sample_ds_3d, concat):
"""
Expand All @@ -218,6 +239,26 @@ def test_batch_3d_1d_input_batch_dims(sample_ds_3d, concat):
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)


@pytest.mark.xfail(
reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131"
)
def test_batch_3d_1d_input_batch_concat_duplicate_dim(sample_ds_3d):
"""
Test batch generation for a 3D dataset using ``concat_input_dims`` when
the same dimension occurs in ``input_dims`` and batch_dims``.
"""
bg = BatchGenerator(
sample_ds_3d,
input_dims={"x": 5, "y": 10},
batch_dims={"x": 10, "y": 20},
concat_input_dims=True,
)
validate_generator_length(bg)
expected_dims = get_batch_dimensions(bg)
for ds_batch in bg:
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)


@pytest.mark.parametrize("input_size", [5, 10])
def test_batch_3d_2d_input(sample_ds_3d, input_size):
"""
Expand Down Expand Up @@ -257,7 +298,7 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, input_size):
)
validate_generator_length(bg)
expected_dims = get_batch_dimensions(bg)
for n, ds_batch in enumerate(bg):
for ds_batch in bg:
assert isinstance(ds_batch, xr.Dataset)
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)

Expand All @@ -268,7 +309,7 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, input_size):
)
validate_generator_length(bg)
expected_dims = get_batch_dimensions(bg)
for n, ds_batch in enumerate(bg):
for ds_batch in bg:
assert isinstance(ds_batch, xr.Dataset)
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)

Expand Down

0 comments on commit c0dd002

Please sign in to comment.