Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let TensorFlow accessor and data loader handle either xarray.DataArray or xarray.Dataset inputs #107

Merged
merged 7 commits into from
Oct 13, 2022
46 changes: 30 additions & 16 deletions xbatcher/accessors.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
from typing import Union

import xarray as xr

from .generators import BatchGenerator


def _as_xarray_dataarray(xr_obj: Union[xr.Dataset, xr.DataArray]) -> xr.DataArray:
"""
Convert xarray.Dataset to xarray.DataArray if needed, so that it can
be converted into a Tensor object.
"""
if isinstance(xr_obj, xr.Dataset):
xr_obj = xr_obj.to_array().squeeze(dim="variable")

return xr_obj


@xr.register_dataarray_accessor("batch")
@xr.register_dataset_accessor("batch")
class BatchAccessor:
Expand All @@ -26,31 +39,32 @@ def generator(self, *args, **kwargs):
return BatchGenerator(self._obj, *args, **kwargs)


@xr.register_dataarray_accessor("tf")
@xr.register_dataset_accessor("tf")
class TFAccessor:
def __init__(self, xarray_obj):
self._obj = xarray_obj

def to_tensor(self):
"""Convert this DataArray to a tensorflow.Tensor"""
import tensorflow as tf

dataarray = _as_xarray_dataarray(xr_obj=self._obj)

return tf.convert_to_tensor(dataarray.data)


@xr.register_dataarray_accessor("torch")
@xr.register_dataset_accessor("torch")
class TorchAccessor:
def __init__(self, xarray_obj):
self._obj = xarray_obj

def _as_xarray_dataarray(self, xr_obj):
"""
Convert xarray.Dataset to xarray.DataArray if needed, so that it can
be converted into a torch.Tensor object.
"""
try:
# Convert xr.Dataset to xr.DataArray
dataarray = xr_obj.to_array().squeeze(dim="variable")
except AttributeError: # 'DataArray' object has no attribute 'to_array'
# If object is already an xr.DataArray
dataarray = xr_obj

return dataarray

def to_tensor(self):
"""Convert this DataArray to a torch.Tensor"""
import torch

dataarray = self._as_xarray_dataarray(xr_obj=self._obj)
dataarray = _as_xarray_dataarray(xr_obj=self._obj)

return torch.tensor(data=dataarray.data)

Expand All @@ -62,6 +76,6 @@ def to_named_tensor(self):
"""
import torch

dataarray = self._as_xarray_dataarray(xr_obj=self._obj)
dataarray = _as_xarray_dataarray(xr_obj=self._obj)

return torch.tensor(data=dataarray.data, names=tuple(dataarray.sizes))
43 changes: 43 additions & 0 deletions xbatcher/tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,30 @@ def sample_ds_3d():
return ds


@pytest.fixture(scope="module")
def sample_dataArray():
return xr.DataArray(np.zeros((2, 4), dtype="i4"), dims=("x", "y"), name="foo")


@pytest.fixture(scope="module")
def sample_Dataset():
return xr.Dataset(
{
"x": xr.DataArray(np.arange(10), dims="x"),
"foo": xr.DataArray(np.ones(10, dtype="float"), dims="x"),
}
)


def test_as_xarray_dataarray(sample_dataArray, sample_Dataset):
assert isinstance(
xbatcher.accessors._as_xarray_dataarray(sample_dataArray), xr.DataArray
)
assert isinstance(
xbatcher.accessors._as_xarray_dataarray(sample_Dataset), xr.DataArray
)


def test_batch_accessor_ds(sample_ds_3d):
bg_class = BatchGenerator(sample_ds_3d, input_dims={"x": 5})
bg_acc = sample_ds_3d.batch.generator(input_dims={"x": 5})
Expand All @@ -40,6 +64,25 @@ def test_batch_accessor_da(sample_ds_3d):
assert batch_class.equals(batch_acc)


@pytest.mark.parametrize(
"foo_var",
[
"foo", # xr.DataArray
["foo"], # xr.Dataset
],
)
def test_tf_to_tensor(sample_ds_3d, foo_var):
tf = pytest.importorskip("tensorflow")

foo = sample_ds_3d[foo_var]
t = foo.tf.to_tensor()
assert isinstance(t, tf.Tensor)
assert t.shape == tuple(foo.sizes.values())

foo_array = foo.to_array().squeeze() if hasattr(foo, "to_array") else foo
np.testing.assert_array_equal(t, foo_array.values)


@pytest.mark.parametrize(
"foo_var",
[
Expand Down