diff --git a/xbatcher/testing.py b/xbatcher/testing.py index b034901..4cb224d 100644 --- a/xbatcher/testing.py +++ b/xbatcher/testing.py @@ -24,9 +24,10 @@ def _get_non_specified_dims(generator: BatchGenerator) -> Dict[Hashable, int]: in the input_dims or batch_dims attributes of the batch generator. """ return { - k: v - for k, v in generator.ds.sizes.items() - if (generator.input_dims.get(k) is None and generator.batch_dims.get(k) is None) + dim: length + for dim, length in generator.ds.sizes.items() + if generator.input_dims.get(dim) is None + and generator.batch_dims.get(dim) is None } @@ -46,9 +47,30 @@ def _get_non_input_batch_dims(generator: BatchGenerator) -> Dict[Hashable, int]: not also in input_dims """ return { - k: v - for k, v in generator.batch_dims.items() - if (generator.input_dims.get(k) is None) + dim: length + for dim, length in generator.batch_dims.items() + if generator.input_dims.get(dim) is None + } + + +def _get_duplicate_batch_dims(generator: BatchGenerator) -> Dict[Hashable, int]: + """ + Return all dimensions that are in both batch_dims and input_dims. + + Parameters + ---------- + generator : xbatcher.BatchGenerator + The batch generator object. + + Returns + ------- + d : dict + Dict containing all dimensions duplicated between batch_dims and input_dims. + """ + return { + dim: length + for dim, length in generator.batch_dims.items() + if generator.input_dims.get(dim) is not None } @@ -188,17 +210,18 @@ def _get_nbatches_from_input_dims(generator: BatchGenerator) -> int: """ nbatches_from_input_dims = np.product( [ - generator.ds.sizes[k] // generator.input_dims[k] - for k in generator.input_dims.keys() - if generator.input_overlap.get(k) is None + generator.ds.sizes[dim] // length + for dim, length in generator.input_dims.items() + if generator.input_overlap.get(dim) is None + and generator.batch_dims.get(dim) is None ] ) if generator.input_overlap: nbatches_from_input_overlap = np.product( [ - (generator.ds.sizes[k] - generator.input_overlap[k]) - // (generator.input_dims[k] - generator.input_overlap[k]) - for k in generator.input_overlap + (generator.ds.sizes[dim] - overlap) + // (generator.input_dims[dim] - overlap) + for dim, overlap in generator.input_overlap.items() ] ) return int(nbatches_from_input_overlap * nbatches_from_input_dims) @@ -217,17 +240,30 @@ def validate_generator_length(generator: BatchGenerator) -> None: The batch generator object. """ non_input_batch_dims = _get_non_input_batch_dims(generator) - nbatches_from_batch_dims = np.product( + duplicate_batch_dims = _get_duplicate_batch_dims(generator) + nbatches_from_unique_batch_dims = np.product( [ - generator.ds.sizes[k] // non_input_batch_dims[k] - for k in non_input_batch_dims.keys() + generator.ds.sizes[dim] // length + for dim, length in non_input_batch_dims.items() + ] + ) + nbatches_from_duplicate_batch_dims = np.product( + [ + generator.ds.sizes[dim] // length + for dim, length in duplicate_batch_dims.items() ] ) if generator.concat_input_dims: - expected_length = int(nbatches_from_batch_dims) + expected_length = int( + nbatches_from_unique_batch_dims * nbatches_from_duplicate_batch_dims + ) else: nbatches_from_input_dims = _get_nbatches_from_input_dims(generator) - expected_length = int(nbatches_from_batch_dims * nbatches_from_input_dims) + expected_length = int( + nbatches_from_unique_batch_dims + * nbatches_from_duplicate_batch_dims + * nbatches_from_input_dims + ) TestCase().assertEqual( expected_length, len(generator), diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index bf2b796..bba84dd 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -109,12 +109,30 @@ def test_batch_1d_concat(sample_ds_1d, input_size): ) validate_generator_length(bg) expected_dims = get_batch_dimensions(bg) - for n, ds_batch in enumerate(bg): + for ds_batch in bg: assert isinstance(ds_batch, xr.Dataset) validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) assert "x" in ds_batch.coords +@pytest.mark.xfail( + reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131" +) +def test_batch_1d_concat_duplicate_dim(sample_ds_1d): + """ + Test batch generation for a 1D dataset using ``concat_input_dims`` when + the same dimension occurs in ``input_dims`` and `batch_dims`` + """ + bg = BatchGenerator( + sample_ds_1d, input_dims={"x": 5}, batch_dims={"x": 10}, concat_input_dims=True + ) + validate_generator_length(bg) + expected_dims = get_batch_dimensions(bg) + for ds_batch in bg: + assert isinstance(ds_batch, xr.Dataset) + validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) + + @pytest.mark.parametrize("input_size", [5, 10]) def test_batch_1d_no_coordinate(sample_ds_1d, input_size): """ @@ -148,7 +166,7 @@ def test_batch_1d_concat_no_coordinate(sample_ds_1d, input_size): ) validate_generator_length(bg) expected_dims = get_batch_dimensions(bg) - for n, ds_batch in enumerate(bg): + for ds_batch in bg: assert isinstance(ds_batch, xr.Dataset) validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) assert "x" not in ds_batch.coords @@ -201,6 +219,9 @@ def test_batch_3d_1d_input(sample_ds_3d, input_size): validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) +@pytest.mark.xfail( + reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131" +) @pytest.mark.parametrize("concat", [True, False]) def test_batch_3d_1d_input_batch_dims(sample_ds_3d, concat): """ @@ -218,6 +239,26 @@ def test_batch_3d_1d_input_batch_dims(sample_ds_3d, concat): validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) +@pytest.mark.xfail( + reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131" +) +def test_batch_3d_1d_input_batch_concat_duplicate_dim(sample_ds_3d): + """ + Test batch generation for a 3D dataset using ``concat_input_dims`` when + the same dimension occurs in ``input_dims`` and batch_dims``. + """ + bg = BatchGenerator( + sample_ds_3d, + input_dims={"x": 5, "y": 10}, + batch_dims={"x": 10, "y": 20}, + concat_input_dims=True, + ) + validate_generator_length(bg) + expected_dims = get_batch_dimensions(bg) + for ds_batch in bg: + validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) + + @pytest.mark.parametrize("input_size", [5, 10]) def test_batch_3d_2d_input(sample_ds_3d, input_size): """ @@ -257,7 +298,7 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, input_size): ) validate_generator_length(bg) expected_dims = get_batch_dimensions(bg) - for n, ds_batch in enumerate(bg): + for ds_batch in bg: assert isinstance(ds_batch, xr.Dataset) validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) @@ -268,7 +309,7 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, input_size): ) validate_generator_length(bg) expected_dims = get_batch_dimensions(bg) - for n, ds_batch in enumerate(bg): + for ds_batch in bg: assert isinstance(ds_batch, xr.Dataset) validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch)