Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let TensorFlow accessor and data loader handle either xarray.DataArray or xarray.Dataset inputs #107

Merged
merged 7 commits into from
Oct 13, 2022
48 changes: 32 additions & 16 deletions xbatcher/accessors.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
from typing import Union

import xarray as xr

from .generators import BatchGenerator


def _as_xarray_dataarray(
xr_obj: Union[xr.Dataset, xr.DataArray]
) -> Union[xr.Dataset, xr.DataArray]:
weiji14 marked this conversation as resolved.
Show resolved Hide resolved
"""
Convert xarray.Dataset to xarray.DataArray if needed, so that it can
be converted into a torch.Tensor object.
norlandrhagen marked this conversation as resolved.
Show resolved Hide resolved
"""
if isinstance(xr_obj, xr.Dataset):
xr_obj = xr_obj.to_array().squeeze(dim="variable")

return xr_obj


@xr.register_dataarray_accessor("batch")
@xr.register_dataset_accessor("batch")
class BatchAccessor:
Expand All @@ -26,31 +41,32 @@ def generator(self, *args, **kwargs):
return BatchGenerator(self._obj, *args, **kwargs)


@xr.register_dataarray_accessor("keras")
@xr.register_dataset_accessor("keras")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure that we want to use keras instead of tensorflow as the accessor name (considering that the returned tensors are tf.Tensor objects)?

Copy link
Member

@weiji14 weiji14 Oct 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait, I guess we're using xbatcher.loaders.keras.CustomTFDataset in the dataloader at https://github.com/xarray-contrib/xbatcher/blob/ed45a99da54503de2e94cc90f12510f590ea9be6/doc/api.rst#dataloaders, so might as well stick with keras to be consistent (unless anyone is keen to change everything to tensorflow).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with your point. What do you and @norlandrhagen think of tf for the accessor name and TFAccessor as the class name to keep it shorter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedbaack @weiji14 @maxrjones

Good point on the naming. I can update the accessor name and class name. Should up update the data loader naming(xbatcher.loaders.keras.CustomTFDataset) in this PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename the tensorflow dataloader in a separate PR (so that it shows up in the changelog as a backward incompatible change).

class KerasAccessor:
def __init__(self, xarray_obj):
self._obj = xarray_obj

def to_tensor(self):
"""Convert this DataArray to a tensorflow.Tensor"""
import tensorflow as tf

dataarray = _as_xarray_dataarray(xr_obj=self._obj)

return tf.convert_to_tensor(dataarray.data)


@xr.register_dataarray_accessor("torch")
@xr.register_dataset_accessor("torch")
class TorchAccessor:
def __init__(self, xarray_obj):
self._obj = xarray_obj

def _as_xarray_dataarray(self, xr_obj):
"""
Convert xarray.Dataset to xarray.DataArray if needed, so that it can
be converted into a torch.Tensor object.
"""
try:
# Convert xr.Dataset to xr.DataArray
dataarray = xr_obj.to_array().squeeze(dim="variable")
except AttributeError: # 'DataArray' object has no attribute 'to_array'
# If object is already an xr.DataArray
dataarray = xr_obj

return dataarray

def to_tensor(self):
"""Convert this DataArray to a torch.Tensor"""
import torch

dataarray = self._as_xarray_dataarray(xr_obj=self._obj)
dataarray = _as_xarray_dataarray(xr_obj=self._obj)

return torch.tensor(data=dataarray.data)

Expand All @@ -62,6 +78,6 @@ def to_named_tensor(self):
"""
import torch

dataarray = self._as_xarray_dataarray(xr_obj=self._obj)
dataarray = _as_xarray_dataarray(xr_obj=self._obj)

return torch.tensor(data=dataarray.data, names=tuple(dataarray.sizes))
43 changes: 43 additions & 0 deletions xbatcher/tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,30 @@ def sample_ds_3d():
return ds


@pytest.fixture(scope="module")
def sample_dataArray():
return xr.DataArray(np.zeros((2, 4), dtype="i4"), dims=("x", "y"), name="foo")


@pytest.fixture(scope="module")
def sample_Dataset():
return xr.Dataset(
{
"x": xr.DataArray(np.arange(10), dims="x"),
"foo": xr.DataArray(np.ones(10, dtype="float"), dims="x"),
}
)


def test_as_xarray_dataarray(sample_dataArray, sample_Dataset):
assert isinstance(
xbatcher.accessors._as_xarray_dataarray(sample_dataArray), xr.DataArray
)
assert isinstance(
xbatcher.accessors._as_xarray_dataarray(sample_Dataset), xr.DataArray
)


def test_batch_accessor_ds(sample_ds_3d):
bg_class = BatchGenerator(sample_ds_3d, input_dims={"x": 5})
bg_acc = sample_ds_3d.batch.generator(input_dims={"x": 5})
Expand All @@ -40,6 +64,25 @@ def test_batch_accessor_da(sample_ds_3d):
assert batch_class.equals(batch_acc)


@pytest.mark.parametrize(
"foo_var",
[
"foo", # xr.DataArray
["foo"], # xr.Dataset
],
)
def test_keras_to_tensor(sample_ds_3d, foo_var):
tensorflow = pytest.importorskip("tensorflow")

foo = sample_ds_3d[foo_var]
t = foo.keras.to_tensor()
assert isinstance(t, tensorflow.Tensor)
assert t.shape == tuple(foo.sizes.values())

foo_array = foo.to_array().squeeze() if hasattr(foo, "to_array") else foo
np.testing.assert_array_equal(t, foo_array.values)


@pytest.mark.parametrize(
"foo_var",
[
Expand Down