From f06324199734c4a60a214dc516bb9ca53f216145 Mon Sep 17 00:00:00 2001 From: Raphael Hagen Date: Tue, 11 Oct 2022 22:32:34 -0700 Subject: [PATCH 1/7] moved __as_xarray_dataarray outside of TorchAccessor. Added tests for util --- xbatcher/accessors.py | 33 ++++++++++++++++---------------- xbatcher/tests/test_accessors.py | 24 +++++++++++++++++++++++ 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/xbatcher/accessors.py b/xbatcher/accessors.py index 43a7d9c..9a85c58 100644 --- a/xbatcher/accessors.py +++ b/xbatcher/accessors.py @@ -1,8 +1,23 @@ +from typing import Union + import xarray as xr from .generators import BatchGenerator +def _as_xarray_dataarray( + xr_obj: Union[xr.Dataset, xr.DataArray] +) -> Union[xr.Dataset, xr.DataArray]: + """ + Convert xarray.Dataset to xarray.DataArray if needed, so that it can + be converted into a torch.Tensor object. + """ + if isinstance(xr_obj, xr.Dataset): + xr_obj = xr_obj.to_array().squeeze(dim="variable") + + return xr_obj + + @xr.register_dataarray_accessor("batch") @xr.register_dataset_accessor("batch") class BatchAccessor: @@ -32,25 +47,11 @@ class TorchAccessor: def __init__(self, xarray_obj): self._obj = xarray_obj - def _as_xarray_dataarray(self, xr_obj): - """ - Convert xarray.Dataset to xarray.DataArray if needed, so that it can - be converted into a torch.Tensor object. - """ - try: - # Convert xr.Dataset to xr.DataArray - dataarray = xr_obj.to_array().squeeze(dim="variable") - except AttributeError: # 'DataArray' object has no attribute 'to_array' - # If object is already an xr.DataArray - dataarray = xr_obj - - return dataarray - def to_tensor(self): """Convert this DataArray to a torch.Tensor""" import torch - dataarray = self._as_xarray_dataarray(xr_obj=self._obj) + dataarray = _as_xarray_dataarray(xr_obj=self._obj) return torch.tensor(data=dataarray.data) @@ -62,6 +63,6 @@ def to_named_tensor(self): """ import torch - dataarray = self._as_xarray_dataarray(xr_obj=self._obj) + dataarray = _as_xarray_dataarray(xr_obj=self._obj) return torch.tensor(data=dataarray.data, names=tuple(dataarray.sizes)) diff --git a/xbatcher/tests/test_accessors.py b/xbatcher/tests/test_accessors.py index 18d24e0..6a719da 100644 --- a/xbatcher/tests/test_accessors.py +++ b/xbatcher/tests/test_accessors.py @@ -22,6 +22,30 @@ def sample_ds_3d(): return ds +@pytest.fixture(scope="module") +def sample_dataArray(): + return xr.DataArray(np.zeros((2, 4), dtype="i4"), dims=("x", "y"), name="foo") + + +@pytest.fixture(scope="module") +def sample_Dataset(): + return xr.Dataset( + { + "x": xr.DataArray(np.arange(10), dims="x"), + "foo": xr.DataArray(np.ones(10, dtype="float"), dims="x"), + } + ) + + +def test_as_xarray_dataarray(sample_dataArray, sample_Dataset): + assert isinstance( + xbatcher.accessors._as_xarray_dataarray(sample_dataArray), xr.DataArray + ) + assert isinstance( + xbatcher.accessors._as_xarray_dataarray(sample_Dataset), xr.DataArray + ) + + def test_batch_accessor_ds(sample_ds_3d): bg_class = BatchGenerator(sample_ds_3d, input_dims={"x": 5}) bg_acc = sample_ds_3d.batch.generator(input_dims={"x": 5}) From 261b91152517a810a55777f74fc6ace9426b04c6 Mon Sep 17 00:00:00 2001 From: Raphael Hagen Date: Wed, 12 Oct 2022 09:02:59 -0700 Subject: [PATCH 2/7] attempt at keras accessors, tests still needed --- xbatcher/accessors.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/xbatcher/accessors.py b/xbatcher/accessors.py index 9a85c58..c1ca021 100644 --- a/xbatcher/accessors.py +++ b/xbatcher/accessors.py @@ -41,6 +41,33 @@ def generator(self, *args, **kwargs): return BatchGenerator(self._obj, *args, **kwargs) +@xr.register_dataarray_accessor("keras") +@xr.register_dataset_accessor("keras") +class KerasAccessor: + def __init__(self, xarray_obj): + self._obj = xarray_obj + + def to_tensor(self): + """Convert this DataArray to a torch.Tensor""" + import tensorflow as tf + + dataarray = _as_xarray_dataarray(xr_obj=self._obj) + + return tf.convert_to_tensor(dataarray.data) + + def to_named_tensor(self): + """ + Convert this DataArray to a torch.Tensor with named dimensions. + + See https://pytorch.org/docs/stable/named_tensor.html + """ + import tensorflow as tf + + dataarray = _as_xarray_dataarray(xr_obj=self._obj) + + return tf.convert_to_tensor(dataarray.data, name=tuple(dataarray.sizes)) + + @xr.register_dataarray_accessor("torch") @xr.register_dataset_accessor("torch") class TorchAccessor: From d146227959b71591ea5d5248e1b0502bee612a09 Mon Sep 17 00:00:00 2001 From: Raphael Hagen Date: Wed, 12 Oct 2022 15:30:52 -0700 Subject: [PATCH 3/7] test_keras_to_tensor unit test added --- xbatcher/accessors.py | 20 +++++++++++--------- xbatcher/tests/test_accessors.py | 19 +++++++++++++++++++ 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/xbatcher/accessors.py b/xbatcher/accessors.py index c1ca021..cfd1113 100644 --- a/xbatcher/accessors.py +++ b/xbatcher/accessors.py @@ -48,24 +48,26 @@ def __init__(self, xarray_obj): self._obj = xarray_obj def to_tensor(self): - """Convert this DataArray to a torch.Tensor""" + """Convert this DataArray to a tensorflow.Tensor""" import tensorflow as tf dataarray = _as_xarray_dataarray(xr_obj=self._obj) return tf.convert_to_tensor(dataarray.data) - def to_named_tensor(self): - """ - Convert this DataArray to a torch.Tensor with named dimensions. + def to_named_tensor( + self, + ): # There does not seem to be a named tensor for tensorflow? + pass - See https://pytorch.org/docs/stable/named_tensor.html - """ - import tensorflow as tf + # """ + # Convert this DataArray to a .Tensor with named dimensions. + # """ + # import tensorflow as tf - dataarray = _as_xarray_dataarray(xr_obj=self._obj) + # dataarray = _as_xarray_dataarray(xr_obj=self._obj) - return tf.convert_to_tensor(dataarray.data, name=tuple(dataarray.sizes)) + # return tf.convert_to_tensor(dataarray.data, name=tuple(dataarray.sizes)) @xr.register_dataarray_accessor("torch") diff --git a/xbatcher/tests/test_accessors.py b/xbatcher/tests/test_accessors.py index 6a719da..774411e 100644 --- a/xbatcher/tests/test_accessors.py +++ b/xbatcher/tests/test_accessors.py @@ -64,6 +64,25 @@ def test_batch_accessor_da(sample_ds_3d): assert batch_class.equals(batch_acc) +@pytest.mark.parametrize( + "foo_var", + [ + "foo", # xr.DataArray + ["foo"], # xr.Dataset + ], +) +def test_keras_to_tensor(sample_ds_3d, foo_var): + tensorflow = pytest.importorskip("tensorflow") + + foo = sample_ds_3d[foo_var] + t = foo.keras.to_tensor() + assert isinstance(t, tensorflow.Tensor) + assert t.shape == tuple(foo.sizes.values()) + + foo_array = foo.to_array().squeeze() if hasattr(foo, "to_array") else foo + np.testing.assert_array_equal(t, foo_array.values) + + @pytest.mark.parametrize( "foo_var", [ From 2ab5cf3a00c07590f8ffde8a15eb3691d65e7b5e Mon Sep 17 00:00:00 2001 From: Raphael Hagen Date: Wed, 12 Oct 2022 15:32:02 -0700 Subject: [PATCH 4/7] removed notes --- xbatcher/accessors.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/xbatcher/accessors.py b/xbatcher/accessors.py index cfd1113..6c67d8e 100644 --- a/xbatcher/accessors.py +++ b/xbatcher/accessors.py @@ -55,20 +55,6 @@ def to_tensor(self): return tf.convert_to_tensor(dataarray.data) - def to_named_tensor( - self, - ): # There does not seem to be a named tensor for tensorflow? - pass - - # """ - # Convert this DataArray to a .Tensor with named dimensions. - # """ - # import tensorflow as tf - - # dataarray = _as_xarray_dataarray(xr_obj=self._obj) - - # return tf.convert_to_tensor(dataarray.data, name=tuple(dataarray.sizes)) - @xr.register_dataarray_accessor("torch") @xr.register_dataset_accessor("torch") From 55bc5348861dbddd698e9aa3b0b8890e7ab31739 Mon Sep 17 00:00:00 2001 From: Raphael Hagen Date: Thu, 13 Oct 2022 10:14:17 -0700 Subject: [PATCH 5/7] updated accessor and class names from Keras to tensorflow --- xbatcher/accessors.py | 10 ++++------ xbatcher/tests/test_accessors.py | 4 ++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/xbatcher/accessors.py b/xbatcher/accessors.py index 6c67d8e..dc4044d 100644 --- a/xbatcher/accessors.py +++ b/xbatcher/accessors.py @@ -5,9 +5,7 @@ from .generators import BatchGenerator -def _as_xarray_dataarray( - xr_obj: Union[xr.Dataset, xr.DataArray] -) -> Union[xr.Dataset, xr.DataArray]: +def _as_xarray_dataarray(xr_obj: Union[xr.Dataset, xr.DataArray]) -> xr.DataArray: """ Convert xarray.Dataset to xarray.DataArray if needed, so that it can be converted into a torch.Tensor object. @@ -41,9 +39,9 @@ def generator(self, *args, **kwargs): return BatchGenerator(self._obj, *args, **kwargs) -@xr.register_dataarray_accessor("keras") -@xr.register_dataset_accessor("keras") -class KerasAccessor: +@xr.register_dataarray_accessor("tf") +@xr.register_dataset_accessor("tf") +class TFAccessor: def __init__(self, xarray_obj): self._obj = xarray_obj diff --git a/xbatcher/tests/test_accessors.py b/xbatcher/tests/test_accessors.py index 774411e..225b659 100644 --- a/xbatcher/tests/test_accessors.py +++ b/xbatcher/tests/test_accessors.py @@ -71,11 +71,11 @@ def test_batch_accessor_da(sample_ds_3d): ["foo"], # xr.Dataset ], ) -def test_keras_to_tensor(sample_ds_3d, foo_var): +def test_tf_to_tensor(sample_ds_3d, foo_var): tensorflow = pytest.importorskip("tensorflow") foo = sample_ds_3d[foo_var] - t = foo.keras.to_tensor() + t = foo.tf.to_tensor() assert isinstance(t, tensorflow.Tensor) assert t.shape == tuple(foo.sizes.values()) From f8b8bcb38f802cc2b224f6bffdb7f5c6f290ed0a Mon Sep 17 00:00:00 2001 From: Raphael Hagen Date: Thu, 13 Oct 2022 10:28:23 -0700 Subject: [PATCH 6/7] Update xbatcher/tests/test_accessors.py Co-authored-by: Wei Ji <23487320+weiji14@users.noreply.github.com> --- xbatcher/tests/test_accessors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xbatcher/tests/test_accessors.py b/xbatcher/tests/test_accessors.py index 225b659..cb3b37e 100644 --- a/xbatcher/tests/test_accessors.py +++ b/xbatcher/tests/test_accessors.py @@ -72,11 +72,11 @@ def test_batch_accessor_da(sample_ds_3d): ], ) def test_tf_to_tensor(sample_ds_3d, foo_var): - tensorflow = pytest.importorskip("tensorflow") + tf = pytest.importorskip("tensorflow") foo = sample_ds_3d[foo_var] t = foo.tf.to_tensor() - assert isinstance(t, tensorflow.Tensor) + assert isinstance(t, tf.Tensor) assert t.shape == tuple(foo.sizes.values()) foo_array = foo.to_array().squeeze() if hasattr(foo, "to_array") else foo From e4f6ef1da1626603ed55f83604dd8a001eafae6a Mon Sep 17 00:00:00 2001 From: Raphael Hagen Date: Thu, 13 Oct 2022 10:28:39 -0700 Subject: [PATCH 7/7] Update xbatcher/accessors.py Co-authored-by: Wei Ji <23487320+weiji14@users.noreply.github.com> --- xbatcher/accessors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xbatcher/accessors.py b/xbatcher/accessors.py index dc4044d..a9d19be 100644 --- a/xbatcher/accessors.py +++ b/xbatcher/accessors.py @@ -8,7 +8,7 @@ def _as_xarray_dataarray(xr_obj: Union[xr.Dataset, xr.DataArray]) -> xr.DataArray: """ Convert xarray.Dataset to xarray.DataArray if needed, so that it can - be converted into a torch.Tensor object. + be converted into a Tensor object. """ if isinstance(xr_obj, xr.Dataset): xr_obj = xr_obj.to_array().squeeze(dim="variable")