Skip to content

Commit

Permalink
add utility functions for datatree
Browse files Browse the repository at this point in the history
  • Loading branch information
melonora committed Nov 4, 2024
1 parent dcc6b68 commit 14b8aa8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
25 changes: 22 additions & 3 deletions src/napari_spatialdata/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import pandas as pd
from anndata import AnnData
from dask.dataframe import DataFrame as DaskDataFrame
from datatree import DataTree
from geopandas import GeoDataFrame
from loguru import logger
from matplotlib.colors import is_color_like, to_rgb
Expand All @@ -31,9 +30,10 @@
from scipy.sparse import issparse, spmatrix
from scipy.spatial import KDTree
from spatialdata import SpatialData, get_extent, join_spatialelement_table
from spatialdata._utils import skip_non_dimension_nodes
from spatialdata.models import SpatialElement, get_axes_names
from spatialdata.transformations import get_transformation
from xarray import DataArray
from xarray import DataArray, Dataset, DataTree

from napari_spatialdata.constants._pkg_constants import Key
from napari_spatialdata.utils._categoricals_utils import (
Expand Down Expand Up @@ -222,6 +222,21 @@ def _points_inside_triangles(points: ArrayLike, triangles: ArrayLike) -> ArrayLi
return out


@skip_non_dimension_nodes
def transpose(ds: Dataset, *args: Any, **kwargs: Any) -> Dataset:
return ds.transpose(*args, **kwargs)


@skip_non_dimension_nodes
def reindex(ds: Dataset, *args: Any, **kwargs: Any) -> Dataset:

# A copy is required as a dataset view as used in map_over_datasets is not mutable
# TODO: Check whether setting item on wrapping datatree node would be better than copy or this can be dropped.
ds_copy = ds.copy()
ds_copy["image"] = ds_copy["image"].reindex(*args, **kwargs)
return ds_copy


def _adjust_channels_order(element: DataArray | DataTree) -> tuple[DataArray | list[DataArray], bool]:
"""Swap the axes to y, x, c and check if an image supports rgb(a) visualization.
Expand Down Expand Up @@ -256,7 +271,11 @@ def _adjust_channels_order(element: DataArray | DataTree) -> tuple[DataArray | l

if len(c_coords) != 0 and set(c_coords) - {"r", "g", "b"} <= {"a"}:
rgb = True
new_raster = element.transpose("y", "x", "c").reindex(c=["r", "g", "b", "a"][: len(c_coords)])
if isinstance(element, DataArray):
new_raster = element.transpose("y", "x", "c").reindex(c=["r", "g", "b", "a"][: len(c_coords)])
else:
new_raster = element.map_over_datasets(transpose, "y", "x", "c")
new_raster = new_raster.map_over_datasets(reindex, {"c": ["r", "g", "b", "a"][: len(c_coords)]})
else:
rgb = False
new_raster = element
Expand Down
3 changes: 1 addition & 2 deletions tests/test_spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from dask.array.random import randint
from dask.dataframe import DataFrame as DaskDataFrame
from dask.dataframe import from_dask_array
from datatree import DataTree
from multiscale_spatial_image import to_multiscale
from napari.layers import Image, Labels, Points
from napari.utils.events import EventedList
Expand All @@ -19,7 +18,7 @@
from spatialdata.models import PointsModel, TableModel
from spatialdata.transformations import Identity
from spatialdata.transformations.operations import set_transformation
from xarray import DataArray
from xarray import DataArray, DataTree

from napari_spatialdata import QtAdataViewWidget
from napari_spatialdata._sdata_widgets import CoordinateSystemWidget, ElementWidget, SdataWidget
Expand Down

0 comments on commit 14b8aa8

Please sign in to comment.