From 0a5d0cc576fc89c6b3badcc53e4e1f88b1f9dc4a Mon Sep 17 00:00:00 2001 From: Max Jones Date: Fri, 18 Nov 2022 15:37:38 -0500 Subject: [PATCH 01/25] Typing for generators and accessors --- xbatcher/accessors.py | 16 +++++------ xbatcher/generators.py | 65 ++++++++++++++++++++++++------------------ 2 files changed, 45 insertions(+), 36 deletions(-) diff --git a/xbatcher/accessors.py b/xbatcher/accessors.py index a9d19be..af7fed1 100644 --- a/xbatcher/accessors.py +++ b/xbatcher/accessors.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Any, Union import xarray as xr @@ -19,13 +19,13 @@ def _as_xarray_dataarray(xr_obj: Union[xr.Dataset, xr.DataArray]) -> xr.DataArra @xr.register_dataarray_accessor("batch") @xr.register_dataset_accessor("batch") class BatchAccessor: - def __init__(self, xarray_obj): + def __init__(self, xarray_obj: Union[xr.Dataset, xr.DataArray]): """ Batch accessor returning a BatchGenerator object via the `generator method` """ self._obj = xarray_obj - def generator(self, *args, **kwargs): + def generator(self, *args, **kwargs) -> BatchGenerator: """ Return a BatchGenerator via the batch accessor @@ -42,10 +42,10 @@ def generator(self, *args, **kwargs): @xr.register_dataarray_accessor("tf") @xr.register_dataset_accessor("tf") class TFAccessor: - def __init__(self, xarray_obj): + def __init__(self, xarray_obj: Union[xr.Dataset, xr.DataArray]): self._obj = xarray_obj - def to_tensor(self): + def to_tensor(self) -> Any: """Convert this DataArray to a tensorflow.Tensor""" import tensorflow as tf @@ -57,10 +57,10 @@ def to_tensor(self): @xr.register_dataarray_accessor("torch") @xr.register_dataset_accessor("torch") class TorchAccessor: - def __init__(self, xarray_obj): + def __init__(self, xarray_obj: Union[xr.Dataset, xr.DataArray]): self._obj = xarray_obj - def to_tensor(self): + def to_tensor(self) -> Any: """Convert this DataArray to a torch.Tensor""" import torch @@ -68,7 +68,7 @@ def to_tensor(self): return torch.tensor(data=dataarray.data) - def to_named_tensor(self): + def to_named_tensor(self) -> Any: """ Convert this DataArray to a torch.Tensor with named dimensions. diff --git a/xbatcher/generators.py b/xbatcher/generators.py index ec3b2b6..17365cb 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -1,13 +1,12 @@ """Classes for iterating through xarray datarrays / datasets in batches.""" import itertools -from collections import OrderedDict -from typing import Any, Dict, Hashable, Iterator +from typing import Any, Dict, Hashable, Iterator, List, OrderedDict, Sequence, Union import xarray as xr -def _slices(dimsize, size, overlap=0): +def _slices(dimsize: int, size: int, overlap: int = 0) -> Any: # return a list of slices to chop up a single dimension if overlap >= size: raise ValueError( @@ -23,7 +22,11 @@ def _slices(dimsize, size, overlap=0): return slices -def _iterate_through_dataset(ds, dims, overlap={}): +def _iterate_through_dataset( + ds: Union[xr.Dataset, xr.DataArray], + dims: OrderedDict[Hashable, int], + overlap: Dict[Hashable, int] = {}, +) -> Any: dim_slices = [] for dim in dims: dimsize = ds.sizes[dim] @@ -43,12 +46,16 @@ def _iterate_through_dataset(ds, dims, overlap={}): yield selector -def _drop_input_dims(ds, input_dims, suffix="_input"): +def _drop_input_dims( + ds: Union[xr.Dataset, xr.DataArray], + input_dims: OrderedDict[Hashable, int], + suffix: str = "_input", +) -> Union[xr.Dataset, xr.DataArray]: # remove input_dims coordinates from datasets, rename the dimensions # then put intput_dims back in as coordinates out = ds.copy() - for dim in input_dims: - newdim = dim + suffix + for dim in input_dims.keys(): + newdim = f"{dim}{suffix}" out = out.rename({dim: newdim}) # extra steps needed if there is a coordinate if newdim in out: @@ -57,13 +64,16 @@ def _drop_input_dims(ds, input_dims, suffix="_input"): return out -def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name="sample"): +def _maybe_stack_batch_dims( + ds: Union[xr.Dataset, xr.DataArray], + input_dims: Sequence[Hashable], +) -> Union[xr.Dataset, xr.DataArray]: batch_dims = [d for d in ds.sizes if d not in input_dims] if len(batch_dims) < 2: return ds - ds_stack = ds.stack(**{stacked_dim_name: batch_dims}) + ds_stack = ds.stack(sample=batch_dims) # ensure correct order - dim_order = (stacked_dim_name,) + tuple(input_dims) + dim_order = ("sample",) + tuple(input_dims) return ds_stack.transpose(*dim_order) @@ -105,7 +115,7 @@ class BatchGenerator: def __init__( self, - ds: xr.Dataset, + ds: Union[xr.Dataset, xr.DataArray], input_dims: Dict[Hashable, int], input_overlap: Dict[Hashable, int] = {}, batch_dims: Dict[Hashable, int] = {}, @@ -122,14 +132,14 @@ def __init__( self.preload_batch = preload_batch self._batches: Dict[int, Any] = self._gen_batches() # dict cache for batches - def __iter__(self) -> Iterator[xr.Dataset]: + def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]: for idx in self._batches: yield self[idx] def __len__(self) -> int: return len(self._batches) - def __getitem__(self, idx: int) -> xr.Dataset: + def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: if not isinstance(idx, int): raise NotImplementedError( @@ -143,14 +153,15 @@ def __getitem__(self, idx: int) -> xr.Dataset: if self.concat_input_dims: new_dim_suffix = "_input" - all_dsets = [ - _drop_input_dims( - self.ds.isel(**ds_input_select), - list(self.input_dims), - suffix=new_dim_suffix, + all_dsets: List = [] + for ds_input_select in self._batches[idx]: + all_dsets.append( + _drop_input_dims( + self.ds.isel(**ds_input_select), + self.input_dims, + suffix=new_dim_suffix, + ) ) - for ds_input_select in self._batches[idx] - ] dsc = xr.concat(all_dsets, dim="input_batch") new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims] return _maybe_stack_batch_dims(dsc, new_input_dims) @@ -167,13 +178,11 @@ def _gen_batches(self) -> dict: # going the eager route for now is allowing me to fill out the loader api # but it is likely to perform poorly. batches = [] - for ds_batch_selector in self._iterate_batch_dims(self.ds): + for ds_batch_selector in self._iterate_batch_dims(): ds_batch = self.ds.isel(**ds_batch_selector) if self.preload_batch: ds_batch.load() - - input_generator = self._iterate_input_dims(ds_batch) - + input_generator = self._iterate_input_dims() if self.concat_input_dims: batches.append(list(input_generator)) else: @@ -181,8 +190,8 @@ def _gen_batches(self) -> dict: return dict(zip(range(len(batches)), batches)) - def _iterate_batch_dims(self, ds): - return _iterate_through_dataset(ds, self.batch_dims) + def _iterate_batch_dims(self) -> Any: + return _iterate_through_dataset(self.ds, self.batch_dims) - def _iterate_input_dims(self, ds): - return _iterate_through_dataset(ds, self.input_dims, self.input_overlap) + def _iterate_input_dims(self) -> Any: + return _iterate_through_dataset(self.ds, self.input_dims, self.input_overlap) From 022ebf7c260b67c384c654cf5acde3d328513e23 Mon Sep 17 00:00:00 2001 From: Max Jones Date: Fri, 18 Nov 2022 18:09:44 -0500 Subject: [PATCH 02/25] Add recommendations from code review --- xbatcher/generators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 17365cb..c2c92d9 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -42,7 +42,7 @@ def _iterate_through_dataset( dim_slices.append(_slices(dimsize, size, olap)) for slices in itertools.product(*dim_slices): - selector = {key: slice for key, slice in zip(dims, slices)} + selector = dict(zip(dims, slices)) yield selector @@ -188,7 +188,7 @@ def _gen_batches(self) -> dict: else: batches += list(input_generator) - return dict(zip(range(len(batches)), batches)) + return dict(enumerate(batches)) def _iterate_batch_dims(self) -> Any: return _iterate_through_dataset(self.ds, self.batch_dims) From 42758282a25e5005acb012d3112060790bce6185 Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Thu, 1 Dec 2022 17:59:09 -0500 Subject: [PATCH 03/25] More informative function name --- xbatcher/generators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index c2c92d9..23dc933 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -6,7 +6,7 @@ import xarray as xr -def _slices(dimsize: int, size: int, overlap: int = 0) -> Any: +def _gen_slices(dimsize: int, size: int, overlap: int = 0) -> Any: # return a list of slices to chop up a single dimension if overlap >= size: raise ValueError( @@ -39,7 +39,7 @@ def _iterate_through_dataset( f"is greater than the dimension length of {dimsize} " f"for {dim}" ) - dim_slices.append(_slices(dimsize, size, olap)) + dim_slices.append(_gen_slices(dimsize, size, olap)) for slices in itertools.product(*dim_slices): selector = dict(zip(dims, slices)) From 8787b7b197129d1c3b7c65fc291d6e96b66958f2 Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Thu, 1 Dec 2022 18:06:00 -0500 Subject: [PATCH 04/25] Use keyword only args for internal functions --- xbatcher/generators.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 23dc933..c988a61 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -6,40 +6,43 @@ import xarray as xr -def _gen_slices(dimsize: int, size: int, overlap: int = 0) -> Any: +def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> Any: # return a list of slices to chop up a single dimension - if overlap >= size: + if overlap >= slice_size: raise ValueError( "input overlap must be less than the input sample length, but " - f"the input sample length is {size} and the overlap is {overlap}" + f"the input sample length is {slice_size} and the overlap is {overlap}" ) slices = [] - stride = size - overlap - for start in range(0, dimsize, stride): - end = start + size - if end <= dimsize: + stride = slice_size - overlap + for start in range(0, dim_size, stride): + end = start + slice_size + if end <= dim_size: slices.append(slice(start, end)) return slices def _iterate_through_dataset( ds: Union[xr.Dataset, xr.DataArray], + *, dims: OrderedDict[Hashable, int], overlap: Dict[Hashable, int] = {}, ) -> Any: dim_slices = [] for dim in dims: - dimsize = ds.sizes[dim] - size = dims[dim] + dim_size = ds.sizes[dim] + slice_size = dims[dim] olap = overlap.get(dim, 0) - if size > dimsize: + if slice_size > dim_size: raise ValueError( "input sample length must be less than or equal to the " - f"dimension length, but the sample length of {size} " - f"is greater than the dimension length of {dimsize} " + f"dimension length, but the sample length of {slice_size} " + f"is greater than the dimension length of {dim_size} " f"for {dim}" ) - dim_slices.append(_gen_slices(dimsize, size, olap)) + dim_slices.append( + _gen_slices(dim_size=dim_size, slice_size=slice_size, overlap=olap) + ) for slices in itertools.product(*dim_slices): selector = dict(zip(dims, slices)) @@ -191,7 +194,9 @@ def _gen_batches(self) -> dict: return dict(enumerate(batches)) def _iterate_batch_dims(self) -> Any: - return _iterate_through_dataset(self.ds, self.batch_dims) + return _iterate_through_dataset(self.ds, dims=self.batch_dims) def _iterate_input_dims(self) -> Any: - return _iterate_through_dataset(self.ds, self.input_dims, self.input_overlap) + return _iterate_through_dataset( + self.ds, dims=self.input_dims, overlap=self.input_overlap + ) From 7ea67b68d62a7c06ca6ce28cc0f7a1259843f840 Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Thu, 1 Dec 2022 18:07:42 -0500 Subject: [PATCH 05/25] Use dict over OrderedDict --- xbatcher/generators.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index c988a61..a5632bd 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -1,7 +1,7 @@ """Classes for iterating through xarray datarrays / datasets in batches.""" import itertools -from typing import Any, Dict, Hashable, Iterator, List, OrderedDict, Sequence, Union +from typing import Any, Dict, Hashable, Iterator, List, Sequence, Union import xarray as xr @@ -25,7 +25,7 @@ def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> Any: def _iterate_through_dataset( ds: Union[xr.Dataset, xr.DataArray], *, - dims: OrderedDict[Hashable, int], + dims: Dict[Hashable, int], overlap: Dict[Hashable, int] = {}, ) -> Any: dim_slices = [] @@ -51,7 +51,7 @@ def _iterate_through_dataset( def _drop_input_dims( ds: Union[xr.Dataset, xr.DataArray], - input_dims: OrderedDict[Hashable, int], + input_dims: Dict[Hashable, int], suffix: str = "_input", ) -> Union[xr.Dataset, xr.DataArray]: # remove input_dims coordinates from datasets, rename the dimensions @@ -127,10 +127,9 @@ def __init__( ): self.ds = ds - # should be a dict - self.input_dims = OrderedDict(input_dims) + self.input_dims = dict(input_dims) self.input_overlap = input_overlap - self.batch_dims = OrderedDict(batch_dims) + self.batch_dims = dict(batch_dims) self.concat_input_dims = concat_input_dims self.preload_batch = preload_batch self._batches: Dict[int, Any] = self._gen_batches() # dict cache for batches From d15d6b135919978f6fb1bf01af3acc00b5246bdf Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Thu, 1 Dec 2022 18:09:42 -0500 Subject: [PATCH 06/25] Add type hint for _gen_slices output --- xbatcher/generators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index a5632bd..8d6320d 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -6,7 +6,7 @@ import xarray as xr -def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> Any: +def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> List[slice]: # return a list of slices to chop up a single dimension if overlap >= slice_size: raise ValueError( From ec399bce024b51c649a4a45a3f9455cee9702b71 Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Thu, 1 Dec 2022 18:11:05 -0500 Subject: [PATCH 07/25] More informative function name --- xbatcher/generators.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 8d6320d..0601301 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -22,7 +22,7 @@ def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> List[sli return slices -def _iterate_through_dataset( +def _iterate_over_dimensions( ds: Union[xr.Dataset, xr.DataArray], *, dims: Dict[Hashable, int], @@ -193,9 +193,9 @@ def _gen_batches(self) -> dict: return dict(enumerate(batches)) def _iterate_batch_dims(self) -> Any: - return _iterate_through_dataset(self.ds, dims=self.batch_dims) + return _iterate_over_dimensions(self.ds, dims=self.batch_dims) def _iterate_input_dims(self) -> Any: - return _iterate_through_dataset( + return _iterate_over_dimensions( self.ds, dims=self.input_dims, overlap=self.input_overlap ) From 1770d361f2162657e5d23aee4e6aa3c3c4c4f2f4 Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Thu, 1 Dec 2022 18:11:45 -0500 Subject: [PATCH 08/25] More informative variable name --- xbatcher/generators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 0601301..f4ddb29 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -32,7 +32,7 @@ def _iterate_over_dimensions( for dim in dims: dim_size = ds.sizes[dim] slice_size = dims[dim] - olap = overlap.get(dim, 0) + slice_overlap = overlap.get(dim, 0) if slice_size > dim_size: raise ValueError( "input sample length must be less than or equal to the " @@ -41,7 +41,7 @@ def _iterate_over_dimensions( f"for {dim}" ) dim_slices.append( - _gen_slices(dim_size=dim_size, slice_size=slice_size, overlap=olap) + _gen_slices(dim_size=dim_size, slice_size=slice_size, overlap=slice_overlap) ) for slices in itertools.product(*dim_slices): From bcdcd727e655e2ff70c47130a396b8ef07cc79bc Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Thu, 1 Dec 2022 18:13:47 -0500 Subject: [PATCH 09/25] More informative name for batch selectors --- xbatcher/generators.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index f4ddb29..1f62d0a 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -132,14 +132,16 @@ def __init__( self.batch_dims = dict(batch_dims) self.concat_input_dims = concat_input_dims self.preload_batch = preload_batch - self._batches: Dict[int, Any] = self._gen_batches() # dict cache for batches + self._batch_selectors: Dict[ + int, Any + ] = self._gen_batch_selectors() # dict cache for batches def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]: - for idx in self._batches: + for idx in self._batch_selectors: yield self[idx] def __len__(self) -> int: - return len(self._batches) + return len(self._batch_selectors) def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: @@ -149,14 +151,14 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: ) if idx < 0: - idx = list(self._batches)[idx] + idx = list(self._batch_selectors)[idx] - if idx in self._batches: + if idx in self._batch_selectors: if self.concat_input_dims: new_dim_suffix = "_input" all_dsets: List = [] - for ds_input_select in self._batches[idx]: + for ds_input_select in self._batch_selectors[idx]: all_dsets.append( _drop_input_dims( self.ds.isel(**ds_input_select), @@ -170,12 +172,12 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: else: return _maybe_stack_batch_dims( - self.ds.isel(**self._batches[idx]), list(self.input_dims) + self.ds.isel(**self._batch_selectors[idx]), list(self.input_dims) ) else: raise IndexError("list index out of range") - def _gen_batches(self) -> dict: + def _gen_batch_selectors(self) -> dict: # in the future, we will want to do the batch generation lazily # going the eager route for now is allowing me to fill out the loader api # but it is likely to perform poorly. From cdd92e2df5ba756576192bd24b88f5d3b7edba7c Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Thu, 1 Dec 2022 18:15:46 -0500 Subject: [PATCH 10/25] Type hint for _iterate_over_dimensions output --- xbatcher/generators.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 1f62d0a..1fd4a39 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -27,7 +27,7 @@ def _iterate_over_dimensions( *, dims: Dict[Hashable, int], overlap: Dict[Hashable, int] = {}, -) -> Any: +) -> Iterator[Dict[Hashable, slice]]: dim_slices = [] for dim in dims: dim_size = ds.sizes[dim] @@ -43,7 +43,6 @@ def _iterate_over_dimensions( dim_slices.append( _gen_slices(dim_size=dim_size, slice_size=slice_size, overlap=slice_overlap) ) - for slices in itertools.product(*dim_slices): selector = dict(zip(dims, slices)) yield selector From cc9f04c2642d841096c098390853b301e3ac9d41 Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Fri, 2 Dec 2022 13:01:28 -0500 Subject: [PATCH 11/25] Fix batch_dims bug --- xbatcher/generators.py | 70 +++++++++++++++++++++++++----------------- 1 file changed, 41 insertions(+), 29 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 1fd4a39..6562565 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -1,10 +1,18 @@ """Classes for iterating through xarray datarrays / datasets in batches.""" import itertools -from typing import Any, Dict, Hashable, Iterator, List, Sequence, Union +from typing import Any, Dict, Hashable, Iterator, List, Sequence, Tuple, Union import xarray as xr +DimSelector = Dict[Hashable, slice] +ConcatBatchSelector = Tuple[DimSelector, List[DimSelector]] +BatchSelector = Union[ + List[ConcatBatchSelector], + Iterator[DimSelector], +] +BatchSelectors = Union[Dict[int, ConcatBatchSelector], Dict[int, DimSelector]] + def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> List[slice]: # return a list of slices to chop up a single dimension @@ -157,10 +165,11 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: if self.concat_input_dims: new_dim_suffix = "_input" all_dsets: List = [] - for ds_input_select in self._batch_selectors[idx]: + batch_dims_selector, input_dims_selectors = self._batch_selectors[idx] + for selector in input_dims_selectors: all_dsets.append( _drop_input_dims( - self.ds.isel(**ds_input_select), + self.ds.isel(dict(**batch_dims_selector, **selector)), self.input_dims, suffix=new_dim_suffix, ) @@ -169,34 +178,37 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims] return _maybe_stack_batch_dims(dsc, new_input_dims) else: - return _maybe_stack_batch_dims( - self.ds.isel(**self._batch_selectors[idx]), list(self.input_dims) + self.ds.isel(self._batch_selectors[idx]), + list(self.input_dims), ) else: raise IndexError("list index out of range") - def _gen_batch_selectors(self) -> dict: - # in the future, we will want to do the batch generation lazily - # going the eager route for now is allowing me to fill out the loader api - # but it is likely to perform poorly. - batches = [] - for ds_batch_selector in self._iterate_batch_dims(): - ds_batch = self.ds.isel(**ds_batch_selector) - if self.preload_batch: - ds_batch.load() - input_generator = self._iterate_input_dims() - if self.concat_input_dims: - batches.append(list(input_generator)) - else: - batches += list(input_generator) - - return dict(enumerate(batches)) - - def _iterate_batch_dims(self) -> Any: - return _iterate_over_dimensions(self.ds, dims=self.batch_dims) - - def _iterate_input_dims(self) -> Any: - return _iterate_over_dimensions( - self.ds, dims=self.input_dims, overlap=self.input_overlap - ) + def _gen_batch_selectors( + self, + ) -> BatchSelectors: + """ + Create batch selectors dict, which can be used to create a batch + from an xarray data object. + """ + if self.concat_input_dims: + batch_dim_selectors = _iterate_over_dimensions( + self.ds, dims=self.batch_dims + ) + # TODO: Consider iterator protocol rather than copying to list + input_dim_selectors = list( + _iterate_over_dimensions( + self.ds, dims=self.input_dims, overlap=self.input_overlap + ) + ) + batch_selectors: BatchSelector = [ + (selector, input_dim_selectors) for selector in batch_dim_selectors + ] + else: + batch_selectors = _iterate_over_dimensions( + self.ds, + dims=dict(**self.batch_dims, **self.input_dims), + overlap=self.input_overlap, + ) + return dict(enumerate(batch_selectors)) From 463546e7739e68b10f1ae456fb910a1628de1e5c Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Fri, 2 Dec 2022 16:41:44 -0500 Subject: [PATCH 12/25] Support preload_batch parameter --- xbatcher/generators.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 6562565..b1ed5f7 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -166,10 +166,13 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: new_dim_suffix = "_input" all_dsets: List = [] batch_dims_selector, input_dims_selectors = self._batch_selectors[idx] + batch_ds = self.ds.isel(batch_dims_selector) + if self.preload_batch: + batch_ds.load() for selector in input_dims_selectors: all_dsets.append( _drop_input_dims( - self.ds.isel(dict(**batch_dims_selector, **selector)), + batch_ds.isel(dict(**selector)), self.input_dims, suffix=new_dim_suffix, ) @@ -178,8 +181,11 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims] return _maybe_stack_batch_dims(dsc, new_input_dims) else: + batch_ds = self.ds.isel(self._batch_selectors[idx]) + if self.preload_batch: + batch_ds.load() return _maybe_stack_batch_dims( - self.ds.isel(self._batch_selectors[idx]), + batch_ds, list(self.input_dims), ) else: From 191a644ff863e5029a8e10263c3aa969355fb021 Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Tue, 13 Dec 2022 21:49:19 -0500 Subject: [PATCH 13/25] Try out dataclass for batch selectors --- xbatcher/__init__.py | 2 +- xbatcher/generators.py | 226 +++++++++++++++++++++++++++++++---------- 2 files changed, 172 insertions(+), 56 deletions(-) diff --git a/xbatcher/__init__.py b/xbatcher/__init__.py index 7282157..6fb8d75 100644 --- a/xbatcher/__init__.py +++ b/xbatcher/__init__.py @@ -3,7 +3,7 @@ from . import testing # noqa: F401 from .accessors import BatchAccessor # noqa: F401 -from .generators import BatchGenerator # noqa: F401 +from .generators import BatchGenerator, BatchSchema # noqa: F401 from .util.print_versions import show_versions # noqa: F401 try: diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 1a16647..0bcd4e4 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -1,17 +1,142 @@ """Classes for iterating through xarray datarrays / datasets in batches.""" import itertools -from typing import Any, Dict, Hashable, Iterator, List, Sequence, Tuple, Union +from dataclasses import dataclass +from operator import itemgetter +from typing import Dict, Hashable, Iterator, List, Sequence, Union import xarray as xr -DimSelector = Dict[Hashable, slice] -ConcatBatchSelector = Tuple[DimSelector, List[DimSelector]] -BatchSelector = Union[ - List[ConcatBatchSelector], - Iterator[DimSelector], -] -BatchSelectors = Union[Dict[int, ConcatBatchSelector], Dict[int, DimSelector]] +BatchSelector = List[Dict[Hashable, slice]] +BatchSelectorSet = Dict[int, BatchSelector] + + +@dataclass +class BatchSchema: + """ + A representation of the indices and stacking/transposing parameters needed + to generator batches from Xarray Datasets and DataArrays using + xbatcher.BatchGenerator. + + Parameters + ---------- + ds : ``xarray.Dataset`` or ``xarray.DataArray`` + The data to iterate over. Unlike for the BatchGenerator, the data is + not retained as a class attribute for the BatchSchema. + input_dims : dict + A dictionary specifying the size of the inputs in each dimension, + e.g. ``{'lat': 30, 'lon': 30}`` + These are the dimensions the ML library will see. All other dimensions + will be stacked into one dimension called ``sample``. + input_overlap : dict, optional + A dictionary specifying the overlap along each dimension + e.g. ``{'lat': 3, 'lon': 3}`` + batch_dims : dict, optional + A dictionary specifying the size of the batch along each dimension + e.g. ``{'time': 10}``. These will always be iterated over. + concat_input_dims : bool, optional + If ``True``, the dimension chunks specified in ``input_dims`` will be + concatenated and stacked into the ``sample`` dimension. The batch index + will be included as a new level ``input_batch`` in the ``sample`` + coordinate. + If ``False``, the dimension chunks specified in ``input_dims`` will be + iterated over. + preload_batch : bool, optional + If ``True``, each batch will be loaded into memory before reshaping / + processing, triggering any dask arrays to be computed. + + Notes + ----- + The BatchSchema is experimental and subject to change without notice. + """ + + def __init__( + self, + ds: Union[xr.Dataset, xr.DataArray], + input_dims: Dict[Hashable, int], + input_overlap: Dict[Hashable, int] = {}, + batch_dims: Dict[Hashable, int] = {}, + concat_input_bins: bool = True, + preload_batch: bool = True, + ): + self.input_dims = dict(input_dims) + self.input_overlap = input_overlap + self.batch_dims = dict(batch_dims) + self.concat_input_dims = concat_input_bins + self.preload_batch = preload_batch + self.selectors: BatchSelectorSet = self._gen_batch_selectors(ds) + + def _gen_batch_selectors(self, ds) -> BatchSelectorSet: + """ + Create batch selectors dict, which can be used to create a batch + from an xarray data object. + """ + # Separate batch_dims that are/are not also included in input_dims + self._duplicate_batch_dims = { + dim: length + for dim, length in self.batch_dims.items() + if self.input_dims.get(dim) is not None + } + self._unique_batch_dims = { + dim: length + for dim, length in self.batch_dims.items() + if self.input_dims.get(dim) is None + } + # Create an iterator that returns an object usable for .isel in xarray + patch_selectors = self._gen_patch_selectors(ds) + # Create the Dict containing batch selectors + if self.concat_input_dims: # Combine the patches into batches + batch_selectors = self._combine_patches_into_batch(ds, patch_selectors) + return dict(enumerate(batch_selectors)) + else: # Each patch gets its own batch + return {ind: [value] for ind, value in enumerate(patch_selectors)} + + def _gen_patch_selectors(self, ds) -> Iterator[Dict[Hashable, slice]]: + """ + Create an iterator that can be used to index an Xarray Dataset/DataArray. + """ + if self._duplicate_batch_dims and not self.concat_input_dims: + raise UserWarning( + f""" + The following dimensions were included in both ``input_dims`` + and ``batch_dims``. Since ``concat_input_dims`` is ``False``, + these dimensions will not impact batch generation: {self._duplicate_batch_dims} + """ + ) + # Generate the slices by iterating over batch_dims and input_dims + all_slices = _iterate_through_dimensions( + ds, + dims=dict(**self._unique_batch_dims, **self.input_dims), + overlap=self.input_overlap, + ) + return all_slices + + def _combine_patches_into_batch( + self, ds, patch_selectors + ) -> List[List[Dict[Hashable, slice]]]: + """ + Combine the patch selectors to form a batch + """ + # Check that patches are only combined with concat_input_dims + if not self.concat_input_dims: + raise AssertionError( + "Patches should only be combined into batches when ``concat_input_dims`` is ``True``" + ) + # If ``batch_dims`` isn't used, all patches will be included in a single batch + if not self.batch_dims: + batch_selectors = [list(patch_selectors)] + elif self._duplicate_batch_dims: + raise NotImplementedError("Not Implemented") + # Group patches based on the unique slices for dimensions in ``batch_dims`` + else: + batch_selectors = [ + list(value) + for _, value in itertools.groupby( + patch_selectors, key=itemgetter(*self.batch_dims) + ) + ] + # Group patches based on the unique dimensions in ``batch_dims`` + return batch_selectors def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> List[slice]: @@ -134,21 +259,41 @@ def __init__( ): self.ds = ds - self.input_dims = dict(input_dims) - self.input_overlap = input_overlap - self.batch_dims = dict(batch_dims) - self.concat_input_dims = concat_input_dims - self.preload_batch = preload_batch - self._batch_selectors: Dict[ - int, Any - ] = self._gen_batch_selectors() # dict cache for batches + self._batch_selectors: BatchSchema = BatchSchema( + ds, + input_dims=input_dims, + input_overlap=input_overlap, + batch_dims=batch_dims, + concat_input_bins=concat_input_dims, + preload_batch=preload_batch, + ) + + @property + def input_dims(self): + return self._batch_selectors.input_dims + + @property + def input_overlap(self): + return self._batch_selectors.input_overlap + + @property + def batch_dims(self): + return self._batch_selectors.batch_dims + + @property + def concat_input_dims(self): + return self._batch_selectors.concat_input_dims + + @property + def preload_batch(self): + return self._batch_selectors.preload_batch def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]: - for idx in self._batch_selectors: + for idx in self._batch_selectors.selectors: yield self[idx] def __len__(self) -> int: - return len(self._batch_selectors) + return len(self._batch_selectors.selectors) def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: @@ -158,21 +303,20 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: ) if idx < 0: - idx = list(self._batch_selectors)[idx] + idx = list(self._batch_selectors.selectors)[idx] - if idx in self._batch_selectors: + if idx in self._batch_selectors.selectors: if self.concat_input_dims: new_dim_suffix = "_input" all_dsets: List = [] - batch_dims_selector, input_dims_selectors = self._batch_selectors[idx] - batch_ds = self.ds.isel(batch_dims_selector) - if self.preload_batch: - batch_ds.load() - for selector in input_dims_selectors: + for selector in self._batch_selectors.selectors[idx]: + batch_ds = self.ds.isel(selector) + if self.preload_batch: + batch_ds.load() all_dsets.append( _drop_input_dims( - batch_ds.isel(dict(**selector)), + batch_ds, self.input_dims, suffix=new_dim_suffix, ) @@ -181,7 +325,7 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims] return _maybe_stack_batch_dims(dsc, new_input_dims) else: - batch_ds = self.ds.isel(self._batch_selectors[idx]) + batch_ds = self.ds.isel(self._batch_selectors.selectors[idx][0]) if self.preload_batch: batch_ds.load() return _maybe_stack_batch_dims( @@ -190,31 +334,3 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: ) else: raise IndexError("list index out of range") - - def _gen_batch_selectors( - self, - ) -> BatchSelectors: - """ - Create batch selectors dict, which can be used to create a batch - from an xarray data object. - """ - if self.concat_input_dims: - batch_dim_selectors = _iterate_through_dimensions( - self.ds, dims=self.batch_dims - ) - # TODO: Consider iterator protocol rather than copying to list - input_dim_selectors = list( - _iterate_through_dimensions( - self.ds, dims=self.input_dims, overlap=self.input_overlap - ) - ) - batch_selectors: BatchSelector = [ - (selector, input_dim_selectors) for selector in batch_dim_selectors - ] - else: - batch_selectors = _iterate_through_dimensions( - self.ds, - dims=dict(**self.batch_dims, **self.input_dims), - overlap=self.input_overlap, - ) - return dict(enumerate(batch_selectors)) From 6fdd8f07e3ff47ff6c717b59be415b03b1f4df0b Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Wed, 14 Dec 2022 21:53:21 -0500 Subject: [PATCH 14/25] Update comment --- xbatcher/generators.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 0bcd4e4..806a146 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -135,7 +135,6 @@ def _combine_patches_into_batch( patch_selectors, key=itemgetter(*self.batch_dims) ) ] - # Group patches based on the unique dimensions in ``batch_dims`` return batch_selectors From a0a5614bbda5ef5826e7bd5ea459ed71c6d7282a Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Sun, 18 Dec 2022 10:22:36 -0500 Subject: [PATCH 15/25] Remove xfail markers --- xbatcher/tests/test_generators.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index bba84dd..0c252e3 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -115,9 +115,6 @@ def test_batch_1d_concat(sample_ds_1d, input_size): assert "x" in ds_batch.coords -@pytest.mark.xfail( - reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131" -) def test_batch_1d_concat_duplicate_dim(sample_ds_1d): """ Test batch generation for a 1D dataset using ``concat_input_dims`` when @@ -219,9 +216,6 @@ def test_batch_3d_1d_input(sample_ds_3d, input_size): validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) -@pytest.mark.xfail( - reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131" -) @pytest.mark.parametrize("concat", [True, False]) def test_batch_3d_1d_input_batch_dims(sample_ds_3d, concat): """ @@ -239,9 +233,6 @@ def test_batch_3d_1d_input_batch_dims(sample_ds_3d, concat): validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) -@pytest.mark.xfail( - reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131" -) def test_batch_3d_1d_input_batch_concat_duplicate_dim(sample_ds_3d): """ Test batch generation for a 3D dataset using ``concat_input_dims`` when From 32f736f10a8ae608d971527cb76752a70f9d3eae Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Sun, 18 Dec 2022 15:08:39 -0500 Subject: [PATCH 16/25] Account for duplicate batch and input dims in testing utils --- xbatcher/testing.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/xbatcher/testing.py b/xbatcher/testing.py index 4cb224d..66546ef 100644 --- a/xbatcher/testing.py +++ b/xbatcher/testing.py @@ -101,9 +101,10 @@ def _get_sample_length( """ if generator.concat_input_dims: batch_concat_dims = [ - generator.ds.sizes.get(k) - // np.nanmax([v, generator.batch_dims.get(k, np.nan)]) - for k, v in generator.input_dims.items() + generator.batch_dims.get(dim) // length + if generator.batch_dims.get(dim) + else generator.ds.sizes.get(dim) // length + for dim, length in generator.input_dims.items() ] else: batch_concat_dims = [] From d7ea56426ef7db0ad78d957d755b8a71a628fad0 Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Sun, 18 Dec 2022 15:09:46 -0500 Subject: [PATCH 17/25] Support duplicate dims in batch and input --- xbatcher/generators.py | 123 ++++++++++++++++++++++++++++++++++------- 1 file changed, 103 insertions(+), 20 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 806a146..b33e1bd 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -3,10 +3,12 @@ import itertools from dataclasses import dataclass from operator import itemgetter -from typing import Dict, Hashable, Iterator, List, Sequence, Union +from typing import Any, Dict, Hashable, Iterator, List, Sequence, Union +import numpy as np import xarray as xr +PatchGenerator = Iterator[Dict[Hashable, slice]] BatchSelector = List[Dict[Hashable, slice]] BatchSelectorSet = Dict[int, BatchSelector] @@ -66,32 +68,42 @@ def __init__( self.preload_batch = preload_batch self.selectors: BatchSelectorSet = self._gen_batch_selectors(ds) - def _gen_batch_selectors(self, ds) -> BatchSelectorSet: + def _gen_batch_selectors( + self, ds: Union[xr.DataArray, xr.Dataset] + ) -> BatchSelectorSet: """ Create batch selectors dict, which can be used to create a batch from an xarray data object. """ # Separate batch_dims that are/are not also included in input_dims - self._duplicate_batch_dims = { + self._duplicate_batch_dims: Dict[Hashable, int] = { dim: length for dim, length in self.batch_dims.items() if self.input_dims.get(dim) is not None } - self._unique_batch_dims = { + self._unique_batch_dims: Dict[Hashable, int] = { dim: length for dim, length in self.batch_dims.items() if self.input_dims.get(dim) is None } + self._input_stride: Dict[Hashable, int] = { + dim: length - self.input_overlap.get(dim, 0) + for dim, length in self.input_dims.items() + } + self._all_sliced_dims: Dict[Hashable, int] = dict( + **self._unique_batch_dims, **self.input_dims + ) # Create an iterator that returns an object usable for .isel in xarray patch_selectors = self._gen_patch_selectors(ds) # Create the Dict containing batch selectors if self.concat_input_dims: # Combine the patches into batches - batch_selectors = self._combine_patches_into_batch(ds, patch_selectors) - return dict(enumerate(batch_selectors)) + return self._combine_patches_into_batch(ds, patch_selectors) else: # Each patch gets its own batch return {ind: [value] for ind, value in enumerate(patch_selectors)} - def _gen_patch_selectors(self, ds) -> Iterator[Dict[Hashable, slice]]: + def _gen_patch_selectors( + self, ds: Union[xr.DataArray, xr.Dataset] + ) -> PatchGenerator: """ Create an iterator that can be used to index an Xarray Dataset/DataArray. """ @@ -106,14 +118,14 @@ def _gen_patch_selectors(self, ds) -> Iterator[Dict[Hashable, slice]]: # Generate the slices by iterating over batch_dims and input_dims all_slices = _iterate_through_dimensions( ds, - dims=dict(**self._unique_batch_dims, **self.input_dims), + dims=self._all_sliced_dims, overlap=self.input_overlap, ) return all_slices def _combine_patches_into_batch( - self, ds, patch_selectors - ) -> List[List[Dict[Hashable, slice]]]: + self, ds: Union[xr.DataArray, xr.Dataset], patch_selectors: PatchGenerator + ) -> BatchSelectorSet: """ Combine the patch selectors to form a batch """ @@ -122,19 +134,90 @@ def _combine_patches_into_batch( raise AssertionError( "Patches should only be combined into batches when ``concat_input_dims`` is ``True``" ) - # If ``batch_dims`` isn't used, all patches will be included in a single batch if not self.batch_dims: - batch_selectors = [list(patch_selectors)] + return self._combine_patches_into_one_batch(patch_selectors) elif self._duplicate_batch_dims: - raise NotImplementedError("Not Implemented") - # Group patches based on the unique slices for dimensions in ``batch_dims`` + return self._combine_patches_grouped_by_input_and_batch_dims( + ds=ds, patch_selectors=patch_selectors + ) else: - batch_selectors = [ - list(value) - for _, value in itertools.groupby( - patch_selectors, key=itemgetter(*self.batch_dims) - ) - ] + return self._combine_patches_grouped_by_batch_dims(patch_selectors) + + def _combine_patches_into_one_batch( + self, patch_selectors: PatchGenerator + ) -> BatchSelectorSet: + """ + Group all patches into a single batch + """ + return dict(enumerate([list(patch_selectors)])) + + def _combine_patches_grouped_by_batch_dims( + self, patch_selectors: PatchGenerator + ) -> BatchSelectorSet: + """ + Group patches based on the unique slices for dimensions in ``batch_dims`` + """ + batch_selectors = [ + list(value) + for _, value in itertools.groupby( + patch_selectors, key=itemgetter(*self.batch_dims) + ) + ] + return dict(enumerate(batch_selectors)) + + def _combine_patches_grouped_by_input_and_batch_dims( + self, ds: Union[xr.DataArray, xr.Dataset], patch_selectors: PatchGenerator + ) -> BatchSelectorSet: + """ + Combine patches with multiple slices along ``batch_dims`` grouped into + each patch. Required when a dimension is duplicated between ``batch_dims`` + and ``input_dims``. + """ + if self._unique_batch_dims: + raise NotImplementedError("Not implemented") + n_patches_per_batch: Dict[Hashable, int] = { + dim: int(np.ceil(length / self._input_stride[dim])) + for dim, length in self.batch_dims.items() + } + n_patches_per_dim: Dict[Hashable, int] = { + dim: int((ds.sizes[dim] - self.input_overlap.get(dim, 0)) // length) + for dim, length in self._input_stride.items() + } + n_batches_per_dim: Dict[Hashable, int] = { + dim: int(ds.sizes[dim] // self.batch_dims.get(dim, ds.sizes[dim])) + for dim in self._all_sliced_dims.keys() + } + batch_id_per_dim: Dict[Hashable, Any] = { + dim: np.floor( + np.arange(0, n_patches) / n_patches_per_batch.get(dim, n_patches + 1) + ).astype(np.int64) + for dim, n_patches in n_patches_per_dim.items() + } + batch_id_per_patch = np.array( + list(itertools.product(*batch_id_per_dim.values())) + ).transpose() + batch_id_maximum = np.fromiter(n_batches_per_dim.values(), dtype=int) + batch_id_maximum = np.pad( + batch_id_maximum, + (0, (len(n_patches_per_dim) - len(n_batches_per_dim))), + constant_values=(1), + ) + batch_id_maximum = batch_id_maximum[:, np.newaxis] + batch_in_range_for_each_patch = np.all( + batch_id_per_patch < batch_id_maximum, axis=0 + ) + batch_id_per_patch = np.ravel_multi_index( + multi_index=batch_id_per_patch, + dims=tuple(n_batches_per_dim.values()), + mode="clip", + ) # type: ignore + n_batches = np.product(list(n_batches_per_dim.values())) + batch_selectors: Dict[int, List[Dict[Hashable, slice]]] = { + k: [] for k in range(n_batches) + } + for i, patch in enumerate(patch_selectors): + if batch_in_range_for_each_patch[i]: + batch_selectors[batch_id_per_patch[i]].append(patch) return batch_selectors From 1eca3959bc4bdc3dd834ce3d8c67e24bc74633d5 Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Mon, 19 Dec 2022 13:36:49 -0500 Subject: [PATCH 18/25] Split _combine_patches_grouped_by_input_and_batch_dims() --- xbatcher/generators.py | 89 ++++++++++++++++++++++++++++++------------ 1 file changed, 65 insertions(+), 24 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index b33e1bd..e59e009 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -174,51 +174,92 @@ def _combine_patches_grouped_by_input_and_batch_dims( and ``input_dims``. """ if self._unique_batch_dims: - raise NotImplementedError("Not implemented") - n_patches_per_batch: Dict[Hashable, int] = { + raise NotImplementedError( + "Currently either all or no batch_dims must be duplicated as input_dims." + ) + self._gen_patch_numbers(ds) + self._gen_batch_numbers(ds) + batch_id_per_patch = self._get_batch_multi_index_per_patch() + patch_in_range = self._get_batch_in_range_per_batch( + batch_multi_index=batch_id_per_patch + ) + batch_id_per_patch = self._ravel_batch_multi_index(batch_id_per_patch) + batch_selectors = self._gen_empty_batch_selectors() + for i, patch in enumerate(patch_selectors): + if patch_in_range[i]: + batch_selectors[batch_id_per_patch[i]].append(patch) + return batch_selectors + + def _gen_empty_batch_selectors(self) -> BatchSelectorSet: + """ + Create an empty batch selector set that can be populated by appending + patches to each batch. + """ + n_batches = np.product(list(self._n_batches_per_dim.values())) + return {k: [] for k in range(n_batches)} + + def _gen_patch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]): + """ + Calculate the number of patches per dimension and the number of patches + in each batch per dimension. + """ + self._n_patches_per_batch: Dict[Hashable, int] = { dim: int(np.ceil(length / self._input_stride[dim])) for dim, length in self.batch_dims.items() } - n_patches_per_dim: Dict[Hashable, int] = { + self._n_patches_per_dim: Dict[Hashable, int] = { dim: int((ds.sizes[dim] - self.input_overlap.get(dim, 0)) // length) for dim, length in self._input_stride.items() } - n_batches_per_dim: Dict[Hashable, int] = { + + def _gen_batch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]): + """ + Calculate the number of batches per dimension + """ + self._n_batches_per_dim: Dict[Hashable, int] = { dim: int(ds.sizes[dim] // self.batch_dims.get(dim, ds.sizes[dim])) for dim in self._all_sliced_dims.keys() } + + def _get_batch_multi_index_per_patch(self): + """ + Calculate the batch multi-index for each patch + """ batch_id_per_dim: Dict[Hashable, Any] = { dim: np.floor( - np.arange(0, n_patches) / n_patches_per_batch.get(dim, n_patches + 1) + np.arange(0, n_patches) + / self._n_patches_per_batch.get(dim, n_patches + 1) ).astype(np.int64) - for dim, n_patches in n_patches_per_dim.items() + for dim, n_patches in self._n_patches_per_dim.items() } batch_id_per_patch = np.array( list(itertools.product(*batch_id_per_dim.values())) ).transpose() - batch_id_maximum = np.fromiter(n_batches_per_dim.values(), dtype=int) + return batch_id_per_patch + + def _ravel_batch_multi_index(self, batch_multi_index): + """ + Convert the batch multi-index to a flat index for each patch + """ + return np.ravel_multi_index( + multi_index=batch_multi_index, + dims=tuple(self._n_batches_per_dim.values()), + mode="clip", + ) + + def _get_batch_in_range_per_batch(self, batch_multi_index): + """ + Determine whether each patch is contained within any of the batches. + """ + batch_id_maximum = np.fromiter(self._n_batches_per_dim.values(), dtype=int) batch_id_maximum = np.pad( batch_id_maximum, - (0, (len(n_patches_per_dim) - len(n_batches_per_dim))), + (0, (len(self._n_patches_per_dim) - len(self._n_batches_per_dim))), constant_values=(1), ) batch_id_maximum = batch_id_maximum[:, np.newaxis] - batch_in_range_for_each_patch = np.all( - batch_id_per_patch < batch_id_maximum, axis=0 - ) - batch_id_per_patch = np.ravel_multi_index( - multi_index=batch_id_per_patch, - dims=tuple(n_batches_per_dim.values()), - mode="clip", - ) # type: ignore - n_batches = np.product(list(n_batches_per_dim.values())) - batch_selectors: Dict[int, List[Dict[Hashable, slice]]] = { - k: [] for k in range(n_batches) - } - for i, patch in enumerate(patch_selectors): - if batch_in_range_for_each_patch[i]: - batch_selectors[batch_id_per_patch[i]].append(patch) - return batch_selectors + batch_in_range_per_patch = np.all(batch_multi_index < batch_id_maximum, axis=0) + return batch_in_range_per_patch def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> List[slice]: From 8a43ed6a08cacd9c93ae1a523eb669231ea5c64c Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Mon, 19 Dec 2022 13:41:44 -0500 Subject: [PATCH 19/25] Mark test with xfail --- xbatcher/tests/test_generators.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 0c252e3..3a9f98f 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -216,7 +216,18 @@ def test_batch_3d_1d_input(sample_ds_3d, input_size): validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) -@pytest.mark.parametrize("concat", [True, False]) +@pytest.mark.parametrize( + "concat", + [ + True, + pytest.param( + False, + marks=pytest.mark.xfail( + reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/126" + ), + ), + ], +) def test_batch_3d_1d_input_batch_dims(sample_ds_3d, concat): """ Test batch generation for a 3D dataset using ``input_dims`` and batch_dims``. From 98d5d19a9a835da68701d19c972f129464590f4a Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Mon, 19 Dec 2022 13:50:46 -0500 Subject: [PATCH 20/25] Remove dataclass decorator --- xbatcher/generators.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index e59e009..2b6f2e0 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -1,7 +1,6 @@ """Classes for iterating through xarray datarrays / datasets in batches.""" import itertools -from dataclasses import dataclass from operator import itemgetter from typing import Any, Dict, Hashable, Iterator, List, Sequence, Union @@ -13,7 +12,6 @@ BatchSelectorSet = Dict[int, BatchSelector] -@dataclass class BatchSchema: """ A representation of the indices and stacking/transposing parameters needed @@ -66,16 +64,7 @@ def __init__( self.batch_dims = dict(batch_dims) self.concat_input_dims = concat_input_bins self.preload_batch = preload_batch - self.selectors: BatchSelectorSet = self._gen_batch_selectors(ds) - - def _gen_batch_selectors( - self, ds: Union[xr.DataArray, xr.Dataset] - ) -> BatchSelectorSet: - """ - Create batch selectors dict, which can be used to create a batch - from an xarray data object. - """ - # Separate batch_dims that are/are not also included in input_dims + # Store helpful information based on arguments self._duplicate_batch_dims: Dict[Hashable, int] = { dim: length for dim, length in self.batch_dims.items() @@ -93,6 +82,15 @@ def _gen_batch_selectors( self._all_sliced_dims: Dict[Hashable, int] = dict( **self._unique_batch_dims, **self.input_dims ) + self.selectors: BatchSelectorSet = self._gen_batch_selectors(ds) + + def _gen_batch_selectors( + self, ds: Union[xr.DataArray, xr.Dataset] + ) -> BatchSelectorSet: + """ + Create batch selectors dict, which can be used to create a batch + from an xarray data object. + """ # Create an iterator that returns an object usable for .isel in xarray patch_selectors = self._gen_patch_selectors(ds) # Create the Dict containing batch selectors From b781280e7f1ce1ab52e048e8a2b2f7caf1a6c847 Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Mon, 19 Dec 2022 16:31:15 -0500 Subject: [PATCH 21/25] Fix warning --- xbatcher/generators.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 2b6f2e0..df0d1bf 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -1,6 +1,7 @@ """Classes for iterating through xarray datarrays / datasets in batches.""" import itertools +import warnings from operator import itemgetter from typing import Any, Dict, Hashable, Iterator, List, Sequence, Union @@ -106,12 +107,10 @@ def _gen_patch_selectors( Create an iterator that can be used to index an Xarray Dataset/DataArray. """ if self._duplicate_batch_dims and not self.concat_input_dims: - raise UserWarning( - f""" - The following dimensions were included in both ``input_dims`` - and ``batch_dims``. Since ``concat_input_dims`` is ``False``, - these dimensions will not impact batch generation: {self._duplicate_batch_dims} - """ + warnings.warn( + "The following dimensions were included in both ``input_dims`` " + "and ``batch_dims``. Since ``concat_input_dims`` is ``False``, " + f"these dimensions will not impact batch generation: {self._duplicate_batch_dims}" ) # Generate the slices by iterating over batch_dims and input_dims all_slices = _iterate_through_dimensions( From 29f871bcd1d5bf0fdb447683b97c179896d60781 Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Mon, 19 Dec 2022 21:19:24 -0500 Subject: [PATCH 22/25] Compute dask arrays before selecting on patches --- xbatcher/generators.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index df0d1bf..b30c088 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -430,13 +430,21 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: if self.concat_input_dims: new_dim_suffix = "_input" all_dsets: List = [] + batch_selector = {} + for dim in self._batch_selectors.batch_dims.keys(): + starts = [ + x[dim].start for x in self._batch_selectors.selectors[idx] + ] + stops = [x[dim].stop for x in self._batch_selectors.selectors[idx]] + batch_selector[dim] = slice(min(starts), max(stops)) + batch_ds = self.ds.isel(batch_selector) + if self.preload_batch: + batch_ds.load() for selector in self._batch_selectors.selectors[idx]: - batch_ds = self.ds.isel(selector) - if self.preload_batch: - batch_ds.load() + patch_ds = self.ds.isel(selector) all_dsets.append( _drop_input_dims( - batch_ds, + patch_ds, self.input_dims, suffix=new_dim_suffix, ) From 0b6c495b2391b25a456ddbc8a54470285f2d2669 Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Tue, 20 Dec 2022 10:37:13 -0500 Subject: [PATCH 23/25] Remove NotImplementedError --- xbatcher/generators.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index b30c088..20b51ea 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -170,10 +170,6 @@ def _combine_patches_grouped_by_input_and_batch_dims( each patch. Required when a dimension is duplicated between ``batch_dims`` and ``input_dims``. """ - if self._unique_batch_dims: - raise NotImplementedError( - "Currently either all or no batch_dims must be duplicated as input_dims." - ) self._gen_patch_numbers(ds) self._gen_batch_numbers(ds) batch_id_per_patch = self._get_batch_multi_index_per_patch() From 9d0fcb054f08d75c9c6cfc8086a1d6c8e7beaaaa Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Tue, 20 Dec 2022 10:53:30 -0500 Subject: [PATCH 24/25] Fix case with more batch_dims than input_dims --- xbatcher/generators.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 20b51ea..ad83b55 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -197,12 +197,15 @@ def _gen_patch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]): in each batch per dimension. """ self._n_patches_per_batch: Dict[Hashable, int] = { - dim: int(np.ceil(length / self._input_stride[dim])) + dim: int(np.ceil(length / self._input_stride.get(dim, length))) for dim, length in self.batch_dims.items() } self._n_patches_per_dim: Dict[Hashable, int] = { - dim: int((ds.sizes[dim] - self.input_overlap.get(dim, 0)) // length) - for dim, length in self._input_stride.items() + dim: int( + (ds.sizes[dim] - self.input_overlap.get(dim, 0)) + // (length - self.input_overlap.get(dim, 0)) + ) + for dim, length in self._all_sliced_dims.items() } def _gen_batch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]): From 288f0dc13196af384c62deae8ddad3494f5f0427 Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Tue, 3 Jan 2023 13:04:32 -0500 Subject: [PATCH 25/25] Update xbatcher/generators.py Co-authored-by: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> --- xbatcher/generators.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index ad83b55..3da074c 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -55,11 +55,15 @@ def __init__( self, ds: Union[xr.Dataset, xr.DataArray], input_dims: Dict[Hashable, int], - input_overlap: Dict[Hashable, int] = {}, - batch_dims: Dict[Hashable, int] = {}, + input_overlap: Dict[Hashable, int] = None, + batch_dims: Dict[Hashable, int] = None, concat_input_bins: bool = True, preload_batch: bool = True, ): + if input_overlap is None: + input_overlap = {} + if batch_dims is None: + batch_dims = {} self.input_dims = dict(input_dims) self.input_overlap = input_overlap self.batch_dims = dict(batch_dims)