diff --git a/src/napari_spatialdata/utils/_utils.py b/src/napari_spatialdata/utils/_utils.py index f8b664c..61c3c97 100644 --- a/src/napari_spatialdata/utils/_utils.py +++ b/src/napari_spatialdata/utils/_utils.py @@ -12,7 +12,6 @@ import pandas as pd from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame -from datatree import DataTree from geopandas import GeoDataFrame from loguru import logger from matplotlib.colors import is_color_like, to_rgb @@ -31,9 +30,10 @@ from scipy.sparse import issparse, spmatrix from scipy.spatial import KDTree from spatialdata import SpatialData, get_extent, join_spatialelement_table +from spatialdata._utils import skip_non_dimension_nodes from spatialdata.models import SpatialElement, get_axes_names from spatialdata.transformations import get_transformation -from xarray import DataArray +from xarray import DataArray, Dataset, DataTree from napari_spatialdata.constants._pkg_constants import Key from napari_spatialdata.utils._categoricals_utils import ( @@ -222,6 +222,21 @@ def _points_inside_triangles(points: ArrayLike, triangles: ArrayLike) -> ArrayLi return out +@skip_non_dimension_nodes +def transpose(ds: Dataset, *args: Any, **kwargs: Any) -> Dataset: + return ds.transpose(*args, **kwargs) + + +@skip_non_dimension_nodes +def reindex(ds: Dataset, *args: Any, **kwargs: Any) -> Dataset: + + # A copy is required as a dataset view as used in map_over_datasets is not mutable + # TODO: Check whether setting item on wrapping datatree node would be better than copy or this can be dropped. + ds_copy = ds.copy() + ds_copy["image"] = ds_copy["image"].reindex(*args, **kwargs) + return ds_copy + + def _adjust_channels_order(element: DataArray | DataTree) -> tuple[DataArray | list[DataArray], bool]: """Swap the axes to y, x, c and check if an image supports rgb(a) visualization. @@ -256,7 +271,11 @@ def _adjust_channels_order(element: DataArray | DataTree) -> tuple[DataArray | l if len(c_coords) != 0 and set(c_coords) - {"r", "g", "b"} <= {"a"}: rgb = True - new_raster = element.transpose("y", "x", "c").reindex(c=["r", "g", "b", "a"][: len(c_coords)]) + if isinstance(element, DataArray): + new_raster = element.transpose("y", "x", "c").reindex(c=["r", "g", "b", "a"][: len(c_coords)]) + else: + new_raster = element.map_over_datasets(transpose, "y", "x", "c") + new_raster = new_raster.map_over_datasets(reindex, {"c": ["r", "g", "b", "a"][: len(c_coords)]}) else: rgb = False new_raster = element diff --git a/tests/test_spatialdata.py b/tests/test_spatialdata.py index 18d68ce..4ea0b79 100644 --- a/tests/test_spatialdata.py +++ b/tests/test_spatialdata.py @@ -8,7 +8,6 @@ from dask.array.random import randint from dask.dataframe import DataFrame as DaskDataFrame from dask.dataframe import from_dask_array -from datatree import DataTree from multiscale_spatial_image import to_multiscale from napari.layers import Image, Labels, Points from napari.utils.events import EventedList @@ -19,7 +18,7 @@ from spatialdata.models import PointsModel, TableModel from spatialdata.transformations import Identity from spatialdata.transformations.operations import set_transformation -from xarray import DataArray +from xarray import DataArray, DataTree from napari_spatialdata import QtAdataViewWidget from napari_spatialdata._sdata_widgets import CoordinateSystemWidget, ElementWidget, SdataWidget