Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for duplicate dims in batch_dims and input_dims #139

Merged
merged 6 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 53 additions & 17 deletions xbatcher/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ def _get_non_specified_dims(generator: BatchGenerator) -> Dict[Hashable, int]:
in the input_dims or batch_dims attributes of the batch generator.
"""
return {
k: v
for k, v in generator.ds.sizes.items()
if (generator.input_dims.get(k) is None and generator.batch_dims.get(k) is None)
dim: length
for dim, length in generator.ds.sizes.items()
if generator.input_dims.get(dim) is None
and generator.batch_dims.get(dim) is None
}


Expand All @@ -46,9 +47,30 @@ def _get_non_input_batch_dims(generator: BatchGenerator) -> Dict[Hashable, int]:
not also in input_dims
"""
return {
k: v
for k, v in generator.batch_dims.items()
if (generator.input_dims.get(k) is None)
dim: length
for dim, length in generator.batch_dims.items()
if generator.input_dims.get(dim) is None
}


def _get_duplicate_batch_dims(generator: BatchGenerator) -> Dict[Hashable, int]:
"""
Return all dimensions that are in both batch_dims and input_dims.

Parameters
----------
generator : xbatcher.BatchGenerator
The batch generator object.

Returns
-------
d : dict
Dict containing all dimensions duplicated between batch_dims and input_dims.
"""
return {
dim: length
for dim, length in generator.batch_dims.items()
if generator.input_dims.get(dim) is not None
}


Expand Down Expand Up @@ -188,17 +210,18 @@ def _get_nbatches_from_input_dims(generator: BatchGenerator) -> int:
"""
nbatches_from_input_dims = np.product(
[
generator.ds.sizes[k] // generator.input_dims[k]
for k in generator.input_dims.keys()
if generator.input_overlap.get(k) is None
generator.ds.sizes[dim] // length
for dim, length in generator.input_dims.items()
if generator.input_overlap.get(dim) is None
and generator.batch_dims.get(dim) is None
]
)
if generator.input_overlap:
nbatches_from_input_overlap = np.product(
[
(generator.ds.sizes[k] - generator.input_overlap[k])
// (generator.input_dims[k] - generator.input_overlap[k])
for k in generator.input_overlap
(generator.ds.sizes[dim] - overlap)
// (generator.input_dims[dim] - overlap)
for dim, overlap in generator.input_overlap.items()
]
)
return int(nbatches_from_input_overlap * nbatches_from_input_dims)
Expand All @@ -217,17 +240,30 @@ def validate_generator_length(generator: BatchGenerator) -> None:
The batch generator object.
"""
non_input_batch_dims = _get_non_input_batch_dims(generator)
nbatches_from_batch_dims = np.product(
duplicate_batch_dims = _get_duplicate_batch_dims(generator)
nbatches_from_unique_batch_dims = np.product(
[
generator.ds.sizes[k] // non_input_batch_dims[k]
for k in non_input_batch_dims.keys()
generator.ds.sizes[dim] // length
for dim, length in non_input_batch_dims.items()
]
)
nbatches_from_duplicate_batch_dims = np.product(
[
generator.ds.sizes[dim] // length
for dim, length in duplicate_batch_dims.items()
]
)
if generator.concat_input_dims:
expected_length = int(nbatches_from_batch_dims)
expected_length = int(
nbatches_from_unique_batch_dims * nbatches_from_duplicate_batch_dims
)
else:
nbatches_from_input_dims = _get_nbatches_from_input_dims(generator)
expected_length = int(nbatches_from_batch_dims * nbatches_from_input_dims)
expected_length = int(
nbatches_from_unique_batch_dims
* nbatches_from_duplicate_batch_dims
* nbatches_from_input_dims
)
TestCase().assertEqual(
expected_length,
len(generator),
Expand Down
49 changes: 45 additions & 4 deletions xbatcher/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,30 @@ def test_batch_1d_concat(sample_ds_1d, input_size):
)
validate_generator_length(bg)
expected_dims = get_batch_dimensions(bg)
for n, ds_batch in enumerate(bg):
for ds_batch in bg:
assert isinstance(ds_batch, xr.Dataset)
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)
assert "x" in ds_batch.coords


maxrjones marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.xfail(
reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131"
)
def test_batch_1d_concat_duplicate_dim(sample_ds_1d):
"""
Test batch generation for a 1D dataset using ``concat_input_dims`` when
the same dimension occurs in ``input_dims`` and `batch_dims``
"""
bg = BatchGenerator(
sample_ds_1d, input_dims={"x": 5}, batch_dims={"x": 10}, concat_input_dims=True
)
validate_generator_length(bg)
expected_dims = get_batch_dimensions(bg)
for ds_batch in bg:
assert isinstance(ds_batch, xr.Dataset)
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)


@pytest.mark.parametrize("input_size", [5, 10])
def test_batch_1d_no_coordinate(sample_ds_1d, input_size):
"""
Expand Down Expand Up @@ -148,7 +166,7 @@ def test_batch_1d_concat_no_coordinate(sample_ds_1d, input_size):
)
validate_generator_length(bg)
expected_dims = get_batch_dimensions(bg)
for n, ds_batch in enumerate(bg):
for ds_batch in bg:
assert isinstance(ds_batch, xr.Dataset)
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)
assert "x" not in ds_batch.coords
Expand Down Expand Up @@ -201,6 +219,9 @@ def test_batch_3d_1d_input(sample_ds_3d, input_size):
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)


@pytest.mark.xfail(
reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131"
)
@pytest.mark.parametrize("concat", [True, False])
def test_batch_3d_1d_input_batch_dims(sample_ds_3d, concat):
"""
Expand All @@ -218,6 +239,26 @@ def test_batch_3d_1d_input_batch_dims(sample_ds_3d, concat):
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)


maxrjones marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.xfail(
reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131"
)
def test_batch_3d_1d_input_batch_concat_duplicate_dim(sample_ds_3d):
"""
Test batch generation for a 3D dataset using ``concat_input_dims`` when
the same dimension occurs in ``input_dims`` and batch_dims``.
"""
bg = BatchGenerator(
sample_ds_3d,
input_dims={"x": 5, "y": 10},
batch_dims={"x": 10, "y": 20},
concat_input_dims=True,
)
validate_generator_length(bg)
expected_dims = get_batch_dimensions(bg)
for ds_batch in bg:
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)


@pytest.mark.parametrize("input_size", [5, 10])
def test_batch_3d_2d_input(sample_ds_3d, input_size):
"""
Expand Down Expand Up @@ -257,7 +298,7 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, input_size):
)
validate_generator_length(bg)
expected_dims = get_batch_dimensions(bg)
for n, ds_batch in enumerate(bg):
for ds_batch in bg:
assert isinstance(ds_batch, xr.Dataset)
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)

Expand All @@ -268,7 +309,7 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, input_size):
)
validate_generator_length(bg)
expected_dims = get_batch_dimensions(bg)
for n, ds_batch in enumerate(bg):
for ds_batch in bg:
assert isinstance(ds_batch, xr.Dataset)
validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)

Expand Down