From 3dc4e5b354288de248a59bb6801659ad4a5d4061 Mon Sep 17 00:00:00 2001 From: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Date: Wed, 25 Sep 2024 11:39:49 +0300 Subject: [PATCH 1/7] Fix bug adata in model not being reset (#317) * fix bug adata in model not being reset * changelog --- CHANGELOG.md | 4 +++ src/napari_spatialdata/_view.py | 37 +++++++++++++++++--------- src/napari_spatialdata/_viewer.py | 5 ++-- src/napari_spatialdata/_widgets.py | 2 +- src/napari_spatialdata/utils/_utils.py | 4 ++- 5 files changed, 34 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8969d730..afbc7199 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning][]. ## [0.x.x] - 2024-xx-xx +### Fixed + +- Bug table was not reset after an element without table was added #317 + ## [0.5.2] - 2024-08-16 ### Minor diff --git a/src/napari_spatialdata/_view.py b/src/napari_spatialdata/_view.py index aa6ac626..31e0fb83 100644 --- a/src/napari_spatialdata/_view.py +++ b/src/napari_spatialdata/_view.py @@ -137,9 +137,13 @@ def _update_adata(self) -> None: ) layer.metadata["adata"] = table - if layer is not None and "adata" in layer.metadata: - with self.model.events.adata.blocker(): - self.model.adata = layer.metadata["adata"] + if layer is not None: + if "adata" in layer.metadata: + with self.model.events.adata.blocker(): + self.model.adata = layer.metadata["adata"] + else: + with self.model.events.adata.blocker(): + self.model.adata = None if self.model.adata.shape == (0, 0): return @@ -186,8 +190,8 @@ def _select_layer(self) -> None: self.color_widget.clear() return - if layer is not None and "adata" in layer.metadata: - self.model.adata = layer.metadata["adata"] + if layer is not None: + self.model.adata = layer.metadata.get("adata", None) def screenshot(self) -> Any: return QImg2array(self.grab().toImage()) @@ -384,10 +388,11 @@ def _select_layer(self) -> None: if isinstance(layer, (Points, Shapes)) and (cols_df := layer.metadata.get("_columns_df")) is not None: self.dataframe_columns_widget.addItems(map(str, cols_df.columns)) self.model.system_name = layer.metadata.get("name", None) + self.model.adata = None return - if layer is not None and "adata" in layer.metadata: - self.model.adata = layer.metadata["adata"] + if layer is not None: + self.model.adata = layer.metadata.get("adata", None) if self.model.adata.shape == (0, 0): return @@ -418,9 +423,13 @@ def _update_adata(self) -> None: ) layer.metadata["adata"] = table - if layer is not None and "adata" in layer.metadata: - with self.model.events.adata.blocker(): - self.model.adata = layer.metadata["adata"] + if layer is not None: + if "adata" in layer.metadata: + with self.model.events.adata.blocker(): + self.model.adata = layer.metadata["adata"] + else: + with self.model.events.adata.blocker(): + self.model.adata = None if self.model.adata.shape == (0, 0): return @@ -440,10 +449,12 @@ def _update_adata(self) -> None: return def _get_adata_layer(self) -> Sequence[str | None]: + if self.model.adata is None: + return [None] adata_layers = list(self.model.adata.layers.keys()) - if len(adata_layers): - return adata_layers - return [None] + if len(adata_layers) == 0: + return [None] + return adata_layers def _change_color_by(self) -> None: self.color_by.setText(f"Color by: {self.model.color_by}") diff --git a/src/napari_spatialdata/_viewer.py b/src/napari_spatialdata/_viewer.py index 67ba4201..98dbe123 100644 --- a/src/napari_spatialdata/_viewer.py +++ b/src/napari_spatialdata/_viewer.py @@ -17,6 +17,7 @@ from shapely import Polygon from spatialdata import get_element_annotators, get_element_instances from spatialdata._core.query.relational_query import _left_join_spatialelement_table +from spatialdata._types import ArrayLike from spatialdata.models import PointsModel, ShapesModel, TableModel, force_2d, get_channels from spatialdata.transformations import Affine, Identity from spatialdata.transformations._utils import scale_radii @@ -254,15 +255,13 @@ def _save_shapes_to_sdata( for shape in layer_to_save._data_view.shapes ] - def _fix_coords(coords: np.ndarray) -> np.ndarray: + def _fix_coords(coords: ArrayLike) -> ArrayLike: remove_z = coords.shape[1] == 3 first_index = 1 if remove_z else 0 coords = coords[:, first_index::] return np.fliplr(coords) polygons: list[Polygon] = [Polygon(_fix_coords(p)) for p in coords] - # polygons: list[Polygon] = [Polygon(i) for i in _transform_coordinates(coords, f=lambda x: x[::-1])] - # polygons: list[Polygon] = [Polygon(i) for i in _transform_coordinates(layer_to_save.data, f=lambda x: x[::-1])] gdf = GeoDataFrame({"geometry": polygons}) force_2d(gdf) diff --git a/src/napari_spatialdata/_widgets.py b/src/napari_spatialdata/_widgets.py index eb18e266..b21d079a 100644 --- a/src/napari_spatialdata/_widgets.py +++ b/src/napari_spatialdata/_widgets.py @@ -227,7 +227,7 @@ def _(self, vec: pd.Series, **kwargs: Any) -> dict[str, Any]: f"The {vec_color_name} column must have unique values for the each {vec.name} level. Found:\n" f"{unique_colors}" ) - color_dict = unique_colors.to_dict()["genes_color"] + color_dict = unique_colors.to_dict()[f"{vec.name}_color"] if self.model.instance_key is not None and self.model.instance_key == vec.index.name: merge_df = pd.merge( diff --git a/src/napari_spatialdata/utils/_utils.py b/src/napari_spatialdata/utils/_utils.py index dff137e7..56c503a5 100644 --- a/src/napari_spatialdata/utils/_utils.py +++ b/src/napari_spatialdata/utils/_utils.py @@ -419,7 +419,7 @@ def get_itemindex_by_text( return widget_item -def _get_init_table_list(layer: Layer) -> Sequence[str | None] | None: +def _get_init_table_list(layer: Layer | None) -> Sequence[str | None] | None: """ Get the table names annotating the SpatialElement upon creating the napari layer. @@ -432,6 +432,8 @@ def _get_init_table_list(layer: Layer) -> Sequence[str | None] | None: ------ The list of table names annotating the SpatialElement if any. """ + if layer is None: + return None table_names: Sequence[str | None] | None if table_names := layer.metadata.get("table_names"): return table_names # type: ignore[no-any-return] From 803c35e632eb17732c17de90d02794f5c2e6065b Mon Sep 17 00:00:00 2001 From: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Date: Wed, 25 Sep 2024 11:04:00 +0200 Subject: [PATCH 2/7] Update CHANGELOG.md --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index afbc7199..8dd7d896 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,11 +8,13 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.0.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html -## [0.x.x] - 2024-xx-xx +## [0.5.3] - 2024-09-25 ### Fixed - Bug table was not reset after an element without table was added #317 +- Bug when changing channel for a multichannel image #301 #302 +- Bug when plotting catgorical annotations on points #304 ## [0.5.2] - 2024-08-16 From 8df1eb632b4f95042ef9b31d07a4546a97e17680 Mon Sep 17 00:00:00 2001 From: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Date: Mon, 30 Sep 2024 14:27:26 +0200 Subject: [PATCH 3/7] fix radii of transformed circles (#318) * fix radii of transformed circles * fix pre-commit * fix mypy * remove last NDArray, replaced with ArrayLike * removed another NDArrayA --- CHANGELOG.md | 6 ++++ src/napari_spatialdata/_interactive.py | 4 +-- src/napari_spatialdata/_model.py | 11 ++++--- src/napari_spatialdata/_scatterwidgets.py | 21 ++++++------ src/napari_spatialdata/_view.py | 2 +- src/napari_spatialdata/_viewer.py | 5 +-- src/napari_spatialdata/_widgets.py | 20 +++++++----- src/napari_spatialdata/utils/_test_utils.py | 9 +++--- src/napari_spatialdata/utils/_utils.py | 36 ++++++++------------- tests/conftest.py | 4 +-- tests/test_widgets.py | 12 +++---- 11 files changed, 65 insertions(+), 65 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8dd7d896..3a5e6320 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.0.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html +## [0.5.4] - 2024-xx-xx + +### Fixed + +- Bug wrong radii transformed circles (e.g. with Visium lowres) + ## [0.5.3] - 2024-09-25 ### Fixed diff --git a/src/napari_spatialdata/_interactive.py b/src/napari_spatialdata/_interactive.py index 2f8b15f2..e515bd61 100644 --- a/src/napari_spatialdata/_interactive.py +++ b/src/napari_spatialdata/_interactive.py @@ -4,10 +4,10 @@ import napari from napari.utils.events import EventedList +from spatialdata._types import ArrayLike from napari_spatialdata._sdata_widgets import SdataWidget from napari_spatialdata.utils._utils import ( - NDArrayA, get_duplicate_element_names, get_elements_meta_mapping, get_itemindex_by_text, @@ -97,6 +97,6 @@ def run(self) -> None: """Run the napari application.""" napari.run() - def screenshot(self) -> NDArrayA | Any: + def screenshot(self) -> ArrayLike | Any: """Take a screenshot of the viewer in its current state.""" return self._viewer.screenshot(canvas_only=False) diff --git a/src/napari_spatialdata/_model.py b/src/napari_spatialdata/_model.py index 6144b8ab..c7a0ca83 100644 --- a/src/napari_spatialdata/_model.py +++ b/src/napari_spatialdata/_model.py @@ -9,10 +9,11 @@ from anndata import AnnData from napari.layers import Layer from napari.utils.events import EmitterGroup, Event +from spatialdata._types import ArrayLike from spatialdata.models import get_table_keys from napari_spatialdata.constants._constants import Symbol -from napari_spatialdata.utils._utils import NDArrayA, _ensure_dense_vector +from napari_spatialdata.utils._utils import _ensure_dense_vector __all__ = ["DataModel"] @@ -71,7 +72,7 @@ def get_items(self, attr: str) -> tuple[str, ...] | None: return None @_ensure_dense_vector - def get_obs(self, name: str, **_: Any) -> tuple[pd.Series | NDArrayA | None, str]: # TODO(giovp): fix docstring + def get_obs(self, name: str, **_: Any) -> tuple[pd.Series | ArrayLike | None, str]: # TODO(giovp): fix docstring """ Return an observation. @@ -95,7 +96,7 @@ def get_obs(self, name: str, **_: Any) -> tuple[pd.Series | NDArrayA | None, str return obs_column, self._format_key(name) @_ensure_dense_vector - def get_columns_df(self, name: str | int, **_: Any) -> tuple[NDArrayA | None, str]: + def get_columns_df(self, name: str | int, **_: Any) -> tuple[ArrayLike | None, str]: """ Return a column of the dataframe of the SpatialElement. @@ -113,7 +114,7 @@ def get_columns_df(self, name: str | int, **_: Any) -> tuple[NDArrayA | None, st return self.layer.metadata["_columns_df"][name], self._format_key(name) @_ensure_dense_vector - def get_var(self, name: str | int, **_: Any) -> tuple[NDArrayA | None, str]: # TODO(giovp): fix docstring + def get_var(self, name: str | int, **_: Any) -> tuple[ArrayLike | None, str]: # TODO(giovp): fix docstring """ Return a column in anndata.var_names. @@ -135,7 +136,7 @@ def get_var(self, name: str | int, **_: Any) -> tuple[NDArrayA | None, str]: # return self.adata._get_X(layer=self.adata_layer)[ix], self._format_key(name, adata_layer=True) @_ensure_dense_vector - def get_obsm(self, name: str, index: int | str = 0) -> tuple[NDArrayA | None, str]: + def get_obsm(self, name: str, index: int | str = 0) -> tuple[ArrayLike | None, str]: """ Return a vector from :attr:`anndata.AnnData.obsm`. diff --git a/src/napari_spatialdata/_scatterwidgets.py b/src/napari_spatialdata/_scatterwidgets.py index 05885a76..4c188e79 100644 --- a/src/napari_spatialdata/_scatterwidgets.py +++ b/src/napari_spatialdata/_scatterwidgets.py @@ -19,12 +19,13 @@ from pandas.api.types import CategoricalDtype from qtpy import QtWidgets from qtpy.QtCore import Signal +from spatialdata._types import ArrayLike from napari_spatialdata._model import DataModel from napari_spatialdata._widgets import AListWidget, ComponentWidget from napari_spatialdata.constants.config import POINT_SIZE_SCATTERPLOT_WIDGET from napari_spatialdata.utils._categoricals_utils import _add_categorical_legend -from napari_spatialdata.utils._utils import NDArrayA, _get_categorical, _set_palette +from napari_spatialdata.utils._utils import _get_categorical, _set_palette __all__ = [ "MatplotlibWidget", @@ -63,7 +64,7 @@ def __init__( model: DataModel, ax: Axes, collection: Collection, - data: list[NDArrayA], + data: list[ArrayLike], alpha_other: float = 0.3, viewer: Viewer | None = None, ): @@ -89,7 +90,7 @@ def __init__( self.selector = LassoSelector(ax, onselect=self.onselect) - self.ind: NDArrayA | None = None + self.ind: ArrayLike | None = None def export(self, adata: AnnData) -> None: model_layer: Layer = self.model.layer @@ -116,7 +117,7 @@ def export(self, adata: AnnData) -> None: sdata[table_name].obs[obs_name] = self.exported_data show_info(f"Exported selected coordinates to obs in AnnData as: {obs_name}") - def onselect(self, verts: list[NDArrayA]) -> None: + def onselect(self, verts: list[ArrayLike]) -> None: self.path = Path(verts) self.ind = np.nonzero(self.path.contains_points(self.xys))[0] @@ -140,7 +141,7 @@ def __init__(self, model: DataModel, attr: str, color: bool, **kwargs: Any): AListWidget.__init__(self, None, model, attr, **kwargs) self.attrChanged.connect(self._onChange) self._color = color - self._data: NDArrayA | dict[str, Any] | None = None + self._data: ArrayLike | dict[str, Any] | None = None self.itemClicked.connect(lambda item: self._onOneClick((item.text(),))) def _onChange(self) -> None: @@ -223,11 +224,11 @@ def chosen(self, chosen: str | None) -> None: self._chosen = chosen if chosen is not None else None @property - def data(self) -> NDArrayA | dict[str, Any] | None: + def data(self) -> ArrayLike | dict[str, Any] | None: return self._data @data.setter - def data(self, data: NDArrayA | dict[str, Any]) -> None: + def data(self, data: ArrayLike | dict[str, Any]) -> None: self._data = data @@ -251,9 +252,9 @@ def __init__(self, viewer: Viewer | None, model: DataModel): def _onClick( self, - x_data: NDArrayA | pd.Series, - y_data: NDArrayA | pd.Series, - color_data: NDArrayA | dict[str, NDArrayA | pd.Series | dict[str, str]], + x_data: ArrayLike | pd.Series, + y_data: ArrayLike | pd.Series, + color_data: ArrayLike | dict[str, ArrayLike | pd.Series | dict[str, str]], x_label: str | None, y_label: str | None, color_label: str | None, diff --git a/src/napari_spatialdata/_view.py b/src/napari_spatialdata/_view.py index 31e0fb83..94c2850b 100644 --- a/src/napari_spatialdata/_view.py +++ b/src/napari_spatialdata/_view.py @@ -97,7 +97,7 @@ def __init__(self, napari_viewer: Viewer, model: DataModel | None = None): lambda: self.matplotlib_widget._onClick( self.x_widget.widget.data, self.y_widget.widget.data, - self.color_widget.widget.data, # type:ignore[arg-type] + self.color_widget.widget.data, self.x_widget.getFormattedLabel(), self.y_widget.getFormattedLabel(), self.color_widget.getFormattedLabel(), diff --git a/src/napari_spatialdata/_viewer.py b/src/napari_spatialdata/_viewer.py index 98dbe123..1cf6c62e 100644 --- a/src/napari_spatialdata/_viewer.py +++ b/src/napari_spatialdata/_viewer.py @@ -20,7 +20,6 @@ from spatialdata._types import ArrayLike from spatialdata.models import PointsModel, ShapesModel, TableModel, force_2d, get_channels from spatialdata.transformations import Affine, Identity -from spatialdata.transformations._utils import scale_radii from napari_spatialdata._model import DataModel from napari_spatialdata.constants import config @@ -741,8 +740,6 @@ def _adjust_radii_of_points_layer(self, layer: Layer, affine: npt.ArrayLike) -> raise ValueError(f"Invalid affine shape: {affine.shape}") affine_transformation = Affine(affine, input_axes=axes, output_axes=axes) - new_radii = scale_radii(radii=radii, affine=affine_transformation, axes=axes) - # the points size is the diameter, in "data pixels" of the current coordinate system, so we need to scale by # scale factor of the affine transformation. This scale factor is an approximation when the affine # transformation is anisotropic. @@ -751,7 +748,7 @@ def _adjust_radii_of_points_layer(self, layer: Layer, affine: npt.ArrayLike) -> modules = np.absolute(eigenvalues) scale_factor = np.mean(modules) - layer.size = 2 * new_radii / scale_factor + layer.size = 2 * radii * scale_factor def _affine_transform_layers(self, coordinate_system: str) -> None: for layer in self.viewer.layers: diff --git a/src/napari_spatialdata/_widgets.py b/src/napari_spatialdata/_widgets.py index b21d079a..87daf177 100644 --- a/src/napari_spatialdata/_widgets.py +++ b/src/napari_spatialdata/_widgets.py @@ -20,13 +20,14 @@ from qtpy.QtCore import Qt, Signal from scanpy.plotting._utils import _set_colors_for_categorical_obs from sklearn.preprocessing import MinMaxScaler +from spatialdata._types import ArrayLike from superqt import QRangeSlider from vispy import scene from vispy.color.colormap import Colormap, MatplotlibColormap from vispy.scene.widgets import ColorBarWidget from napari_spatialdata._model import DataModel -from napari_spatialdata.utils._utils import NDArrayA, _min_max_norm, get_napari_version +from napari_spatialdata.utils._utils import _min_max_norm, get_napari_version __all__ = [ "AListWidget", @@ -188,7 +189,7 @@ def _handle_already_present(self, layer_name: str) -> None: self.viewer.layers.selection.select_only(self.viewer.layers[layer_name]) @singledispatchmethod - def _get_points_properties(self, vec: NDArrayA | pd.Series, **kwargs: Any) -> dict[str, Any]: + def _get_points_properties(self, vec: ArrayLike | pd.Series, **kwargs: Any) -> dict[str, Any]: raise NotImplementedError(type(vec)) @_get_points_properties.register(pd.Series) @@ -252,7 +253,7 @@ def _(self, vec: pd.Series, **kwargs: Any) -> dict[str, Any]: } @_get_points_properties.register(np.ndarray) - def _(self, vec: NDArrayA, **kwargs: Any) -> dict[str, Any]: + def _(self, vec: ArrayLike, **kwargs: Any) -> dict[str, Any]: layer = kwargs.pop("layer", None) # Here kwargs['key'] is actually the column name. @@ -271,9 +272,9 @@ def _(self, vec: NDArrayA, **kwargs: Any) -> dict[str, Any]: layer_meta = self.model.layer.metadata if self.model.layer is not None else None element_indices = pd.Series(layer_meta["indices"], name="element_indices") if isinstance(layer, Labels): - vec = vec.drop(index=0) if 0 in vec.index else vec # type:ignore[attr-defined] + vec = vec.drop(index=0) if 0 in vec.index else vec # element_indices = element_indices[element_indices != 0] - diff_element_table = set(element_indices).difference(set(vec.index)) # type:ignore[attr-defined] + diff_element_table = set(element_indices).difference(set(vec.index)) merge_vec = pd.merge(element_indices, vec, left_on="element_indices", right_index=True, how="left")[ "vec" ].fillna(0, axis=0) @@ -541,7 +542,10 @@ def _onValueChange(self, percentile: tuple[float, float]) -> None: if "data" not in layer.metadata: return None # noqa: RET501 v = layer.metadata["data"] - clipped = np.clip(v, *np.percentile(v, percentile)) + # this code is currently not used since the slider is not enabled; so I silenced the mypy error; 2. there is a + # mismatch for this error with the mypy in the CI, so I silenced the unused-ignore from the local mypy. + # when this code is re-enabled, let's fix mypy + clipped = np.clip(v, *np.percentile(v, percentile)) # type: ignore[misc,unused-ignore] if isinstance(layer, Points): layer.metadata = {**layer.metadata, "perc": percentile} @@ -559,7 +563,7 @@ def _onValueChange(self, percentile: tuple[float, float]) -> None: self._colorbar.setClim((np.min(layer.properties["value"]), np.max(layer.properties["value"]))) self._colorbar.update_color() - def _scale_vec(self, vec: NDArrayA) -> NDArrayA: + def _scale_vec(self, vec: ArrayLike) -> ArrayLike: ominn, omaxx = self._colorbar.getOclim() delta = omaxx - ominn + 1e-12 @@ -567,7 +571,7 @@ def _scale_vec(self, vec: NDArrayA) -> NDArrayA: minn = (minn - ominn) / delta maxx = (maxx - ominn) / delta scaler = MinMaxScaler(feature_range=(minn, maxx)) - return scaler.fit_transform(vec.reshape(-1, 1)) # type: ignore[no-any-return] + return scaler.fit_transform(vec.reshape(-1, 1)) @property def viewer(self) -> napari.Viewer: diff --git a/src/napari_spatialdata/utils/_test_utils.py b/src/napari_spatialdata/utils/_test_utils.py index cf9681c7..64242c8b 100644 --- a/src/napari_spatialdata/utils/_test_utils.py +++ b/src/napari_spatialdata/utils/_test_utils.py @@ -12,8 +12,7 @@ import napari from loguru import logger from PIL import Image - -from napari_spatialdata.utils._utils import NDArrayA +from spatialdata._types import ArrayLike def get_center_pos_listitem(widget: QListWidget, text: str) -> QPoint: @@ -72,7 +71,7 @@ def click_list_widget_item( raise ValueError(f"{click} is not a valid click") -def take_screenshot(viewer: napari.Viewer, canvas_only: bool = False) -> NDArrayA | Any: +def take_screenshot(viewer: napari.Viewer, canvas_only: bool = False) -> ArrayLike | Any: """Take screenshot of the Napari viewer. Parameters @@ -84,7 +83,7 @@ def take_screenshot(viewer: napari.Viewer, canvas_only: bool = False) -> NDArray Returns ------- - The screenshot as an NDArray + The screenshot as an array. """ logger.info("Taking screenshot of viewer") # to distinguish between the black of the image background and the black of the napari background (now white) @@ -98,7 +97,7 @@ def take_screenshot(viewer: napari.Viewer, canvas_only: bool = False) -> NDArray return interactive_screenshot -def save_image(image_np: NDArrayA, file_path: str) -> None: +def save_image(image_np: ArrayLike, file_path: str) -> None: """Save image to file. Parameters diff --git a/src/napari_spatialdata/utils/_utils.py b/src/napari_spatialdata/utils/_utils.py index 56c503a5..74d583af 100644 --- a/src/napari_spatialdata/utils/_utils.py +++ b/src/napari_spatialdata/utils/_utils.py @@ -47,16 +47,9 @@ from napari_spatialdata._sdata_widgets import CoordinateSystemWidget, ElementWidget -try: - from numpy.typing import NDArray +from spatialdata._types import ArrayLike - NDArrayA = NDArray[Any] -except (ImportError, TypeError): - NDArray = np.ndarray # type: ignore[misc] - NDArrayA = np.ndarray # type: ignore[misc] - - -Vector_name_t = tuple[Optional[Union[pd.Series, NDArrayA]], Optional[str]] +Vector_name_t = tuple[Optional[Union[pd.Series, ArrayLike]], Optional[str]] def _ensure_dense_vector(fn: Callable[..., Vector_name_t]) -> Callable[..., Vector_name_t]: @@ -137,7 +130,7 @@ def _get_categorical( vec: pd.Series | None = None, palette: str | None = None, colordict: pd.Series | dict[Any, Any] | None = None, -) -> NDArrayA: +) -> ArrayLike: categorical = vec if vec is not None else adata.obs[key] if not isinstance(colordict, dict): col_dict = _set_palette(adata, key, palette, colordict) @@ -155,7 +148,7 @@ def _get_categorical( return np.array([col_dict[v] for v in categorical]) -def _position_cluster_labels(coords: NDArrayA, clusters: pd.Series) -> dict[str, NDArrayA]: +def _position_cluster_labels(coords: ArrayLike, clusters: pd.Series) -> dict[str, ArrayLike]: if clusters is not None and not isinstance(clusters.dtype, pd.CategoricalDtype): raise TypeError(f"Expected `clusters` to be `categorical`, found `{infer_dtype(clusters)}`.") coords = coords[:, 1:] @@ -170,7 +163,7 @@ def _position_cluster_labels(coords: NDArrayA, clusters: pd.Series) -> dict[str, return {"clusters": clusters} -def _min_max_norm(vec: spmatrix | NDArrayA) -> NDArrayA: +def _min_max_norm(vec: spmatrix | ArrayLike) -> ArrayLike: if issparse(vec): if TYPE_CHECKING: assert isinstance(vec, spmatrix) @@ -179,18 +172,17 @@ def _min_max_norm(vec: spmatrix | NDArrayA) -> NDArrayA: if vec.ndim != 1: raise ValueError(f"Expected `1` dimension, found `{vec.ndim}`.") - maxx, minn = np.nanmax(vec), np.nanmin(vec) + maxx: ArrayLike = np.nanmax(vec) + minn: ArrayLike = np.nanmin(vec) - return ( # type: ignore[no-any-return] - np.ones_like(vec) if np.isclose(minn, maxx) else ((vec - minn) / (maxx - minn)) - ) + return np.ones_like(vec) if np.isclose(minn, maxx) else ((vec - minn) / (maxx - minn)) def _transform_coordinates(data: list[Any], f: Callable[..., Any]) -> list[Any]: return [[f(xy) for xy in sublist] for sublist in data] -def _get_transform(element: SpatialElement, coordinate_system_name: str | None = None) -> None | NDArrayA: +def _get_transform(element: SpatialElement, coordinate_system_name: str | None = None) -> None | ArrayLike: if not isinstance(element, (DataArray, DataTree, DaskDataFrame, GeoDataFrame)): raise RuntimeError("Cannot get transform for {type(element)}") @@ -198,12 +190,12 @@ def _get_transform(element: SpatialElement, coordinate_system_name: str | None = cs = transformations.keys().__iter__().__next__() if coordinate_system_name is None else coordinate_system_name ct = transformations.get(cs) if ct: - return ct.to_affine_matrix(input_axes=("y", "x"), output_axes=("y", "x")) # type: ignore + return ct.to_affine_matrix(input_axes=("y", "x"), output_axes=("y", "x")) return None @njit(cache=True, fastmath=True) -def _point_inside_triangles(triangles: NDArrayA) -> np.bool_: +def _point_inside_triangles(triangles: ArrayLike) -> np.bool_: # modified from napari AB = triangles[:, 1, :] - triangles[:, 0, :] AC = triangles[:, 2, :] - triangles[:, 0, :] @@ -217,7 +209,7 @@ def _point_inside_triangles(triangles: NDArrayA) -> np.bool_: @njit(parallel=True) -def _points_inside_triangles(points: NDArrayA, triangles: NDArrayA) -> NDArrayA: +def _points_inside_triangles(points: ArrayLike, triangles: ArrayLike) -> ArrayLike: out = np.empty( len( points, @@ -459,7 +451,7 @@ def generate_random_color_hex() -> str: return f"#{randint(0, 255):02x}{randint(0, 255):02x}{randint(0, 255):02x}ff" -def _get_ellipses_from_circles(yx: NDArrayA, radii: NDArrayA) -> NDArrayA: +def _get_ellipses_from_circles(yx: ArrayLike, radii: ArrayLike) -> ArrayLike: """Convert circles to ellipses. Parameters @@ -471,7 +463,7 @@ def _get_ellipses_from_circles(yx: NDArrayA, radii: NDArrayA) -> NDArrayA: Returns ------- - NDArrayA + ArrayLike Ellipses. """ ndim = yx.shape[1] diff --git a/tests/conftest.py b/tests/conftest.py index 77bf8024..de7732de 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,10 +13,10 @@ from loguru import logger from matplotlib.testing.compare import compare_images from napari_spatialdata.utils._test_utils import save_image, take_screenshot -from napari_spatialdata.utils._utils import NDArrayA from scipy import ndimage as ndi from skimage import data from spatialdata import SpatialData +from spatialdata._types import ArrayLike from spatialdata.datasets import blobs HERE: Path = Path(__file__).parent @@ -118,7 +118,7 @@ def labels(): return blobs -def _get_blobs_galaxy() -> tuple[NDArrayA, NDArrayA]: +def _get_blobs_galaxy() -> tuple[ArrayLike, ArrayLike]: blobs = data.binary_blobs(rng=SEED) blobs = ndi.label(blobs)[0] return blobs, data.hubble_deep_field()[: blobs.shape[0], : blobs.shape[0]] diff --git a/tests/test_widgets.py b/tests/test_widgets.py index 00ee6021..b27bb634 100644 --- a/tests/test_widgets.py +++ b/tests/test_widgets.py @@ -10,8 +10,8 @@ from napari_spatialdata._model import DataModel from napari_spatialdata._sdata_widgets import SdataWidget from napari_spatialdata._view import QtAdataScatterWidget, QtAdataViewWidget -from napari_spatialdata.utils._utils import NDArrayA from spatialdata import SpatialData +from spatialdata._types import ArrayLike # make_napari_viewer is a pytest fixture that returns a napari viewer object @@ -51,7 +51,7 @@ def test_creating_widget_with_no_adata(make_napari_viewer: Any, widget: Any) -> def test_model( make_napari_viewer: Any, widget: Any, - labels: NDArrayA, + labels: ArrayLike, sdata_blobs: SpatialData, ) -> None: # make viewer and add an image layer using our fixture @@ -121,7 +121,7 @@ def test_scatterlistwidget( make_napari_viewer: Any, widget: Any, adata_labels: AnnData, - image: NDArrayA, + image: ArrayLike, attr: str, item: str, text: Union[str, int, None], @@ -162,7 +162,7 @@ def test_categorical_and_error( make_napari_viewer: Any, widget: Any, adata_labels: AnnData, - image: NDArrayA, + image: ArrayLike, attr: str, item: str, ) -> None: @@ -200,7 +200,7 @@ def test_component_widget( make_napari_viewer: Any, widget: Any, adata_labels: AnnData, - image: NDArrayA, + image: ArrayLike, ) -> None: viewer = make_napari_viewer() layer_name = "labels" @@ -243,7 +243,7 @@ def test_component_widget( @pytest.mark.parametrize("widget", [QtAdataViewWidget, QtAdataScatterWidget]) -def test_layer_selection(make_napari_viewer: Any, image: NDArrayA, widget: Any, sdata_blobs: SpatialData): +def test_layer_selection(make_napari_viewer: Any, image: ArrayLike, widget: Any, sdata_blobs: SpatialData): viewer = make_napari_viewer() sdata_widget = SdataWidget(viewer, EventedList([sdata_blobs])) sdata_widget.viewer_model.add_sdata_labels(sdata_blobs, "blobs_labels", "global", False) From f84b79bcbc2616ecf1736720f4ed029724b3d10c Mon Sep 17 00:00:00 2001 From: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Date: Tue, 1 Oct 2024 13:18:17 +0200 Subject: [PATCH 4/7] Fix instance id shift bug; fix obsm viz; fix background color labels (#320) --- README.md | 2 +- docs/index.md | 1 + docs/limitations.md | 8 +++++ src/napari_spatialdata/_model.py | 38 ++++++++++++++--------- src/napari_spatialdata/_scatterwidgets.py | 10 +++--- src/napari_spatialdata/_widgets.py | 22 ++++++++----- src/napari_spatialdata/utils/_utils.py | 20 ++++++------ tests/conftest.py | 11 +++++-- tests/test_spatialdata.py | 2 +- tests/test_widgets.py | 6 ++-- 10 files changed, 75 insertions(+), 45 deletions(-) create mode 100644 docs/limitations.md diff --git a/README.md b/README.md index bcbd85c2..8ff45493 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ [![napari hub](https://img.shields.io/endpoint?url=https://api.napari-hub.org/shields/napari-spatialdata)](https://napari-hub.org/plugins/napari-spatialdata) [![DOI](https://zenodo.org/badge/477021400.svg)](https://zenodo.org/badge/latestdoi/477021400) -This repository contains a napari plugin for interactively exploring and annotating SpatialData objects. `napari-spatialdata` is part of the `SpatialData` ecosystem. To learn more about SpatialData, please see the [documentation](https://spatialdata.scverse.org/). +This repository contains a napari plugin for interactively exploring and annotating SpatialData objects. Here you can find the [napari-spatialdata documentation](https://spatialdata.scverse.org/projects/napari/en/latest/notebooks/spatialdata.html). `napari-spatialdata` is part of the `SpatialData` ecosystem. To learn more about SpatialData, please see the [spatialdata documentation](https://spatialdata.scverse.org/). ## Installation diff --git a/docs/index.md b/docs/index.md index 431a7edc..dad71f37 100644 --- a/docs/index.md +++ b/docs/index.md @@ -9,6 +9,7 @@ api.md cli.md contributing.md +limitations.md references.md ``` diff --git a/docs/limitations.md b/docs/limitations.md new file mode 100644 index 00000000..e211ef3e --- /dev/null +++ b/docs/limitations.md @@ -0,0 +1,8 @@ +# Limitations + +Contributions are welcomed! + +- On macOS, the maximum point size is constrained ([details in the Vispy issue tracker](https://github.com/vispy/vispy/issues/2078)); + it is important to keep this in mind when zooming in on points to [explore the points property on mouse hover](https://github.com/scverse/napari-spatialdata/issues/35#issuecomment-2383792431). +- 3D data representation is supported in `spatialdata`; 3D data visualization is supported in `napari`. + Still, in `napari-spatialdata` we currently don't support 3D data visualization. diff --git a/src/napari_spatialdata/_model.py b/src/napari_spatialdata/_model.py index c7a0ca83..835b90e3 100644 --- a/src/napari_spatialdata/_model.py +++ b/src/napari_spatialdata/_model.py @@ -72,7 +72,9 @@ def get_items(self, attr: str) -> tuple[str, ...] | None: return None @_ensure_dense_vector - def get_obs(self, name: str, **_: Any) -> tuple[pd.Series | ArrayLike | None, str]: # TODO(giovp): fix docstring + def get_obs( + self, name: str, **_: Any + ) -> tuple[pd.Series | ArrayLike | None, str, pd.Index]: # TODO(giovp): fix docstring """ Return an observation. @@ -83,7 +85,7 @@ def get_obs(self, name: str, **_: Any) -> tuple[pd.Series | ArrayLike | None, st Returns ------- - The values and the formatted ``name``. + The values, the formatted ``name`` and the `instance_key` values. """ if name not in self.adata.obs.columns: raise KeyError(f"Key `{name}` not found in `adata.obs`.") @@ -93,10 +95,10 @@ def get_obs(self, name: str, **_: Any) -> tuple[pd.Series | ArrayLike | None, st else: obs_column = self.adata.obs[name].copy() obs_column.index = self.adata.obs[self.instance_key] - return obs_column, self._format_key(name) + return obs_column, self._format_key(name), obs_column.index @_ensure_dense_vector - def get_columns_df(self, name: str | int, **_: Any) -> tuple[ArrayLike | None, str]: + def get_columns_df(self, name: str | int, **_: Any) -> tuple[ArrayLike | None, str, pd.Index]: """ Return a column of the dataframe of the SpatialElement. @@ -107,14 +109,17 @@ def get_columns_df(self, name: str | int, **_: Any) -> tuple[ArrayLike | None, s Returns ------- - The dataframe column of interest and the formatted name of the column. + The dataframe column of interest, the formatted name of the column and the `instance_key` valus. """ if self.layer is None: raise ValueError("Layer must be present") - return self.layer.metadata["_columns_df"][name], self._format_key(name) + column = self.layer.metadata["_columns_df"][name] + return column, self._format_key(name), column.index @_ensure_dense_vector - def get_var(self, name: str | int, **_: Any) -> tuple[ArrayLike | None, str]: # TODO(giovp): fix docstring + def get_var( + self, name: str | int, **_: Any + ) -> tuple[ArrayLike | None, str, pd.Index]: # TODO(giovp): fix docstring """ Return a column in anndata.var_names. @@ -126,17 +131,19 @@ def get_var(self, name: str | int, **_: Any) -> tuple[ArrayLike | None, str]: # Returns ------- - The values and the formatted ``name``. + The values, the formatted ``name`` and the `instance_key` values. """ try: ix = self.adata._normalize_indices((slice(None), name)) except KeyError: raise KeyError(f"Key `{name}` not found in `adata.var_names`.") from None - return self.adata._get_X(layer=self.adata_layer)[ix], self._format_key(name, adata_layer=True) + column = self.adata._get_X(layer=self.adata_layer)[ix] + index = self.adata.obs[[self.instance_key]].set_index(self.instance_key).index + return column, self._format_key(name, adata_layer=True), index @_ensure_dense_vector - def get_obsm(self, name: str, index: int | str = 0) -> tuple[ArrayLike | None, str]: + def get_obsm(self, name: str, index: int | str = 0) -> tuple[ArrayLike | None, str, pd.Index]: """ Return a vector from :attr:`anndata.AnnData.obsm`. @@ -149,19 +156,20 @@ def get_obsm(self, name: str, index: int | str = 0) -> tuple[ArrayLike | None, s Returns ------- - The values and the formatted ``name``. + The values, the formatted ``name`` and the `instance_key` values. """ if name not in self.adata.obsm: raise KeyError(f"Unable to find key `{name!r}` in `adata.obsm`.") res = self.adata.obsm[name] pretty_name = self._format_key(name, index=index) + adata_index = self.adata.obs[[self.instance_key]].set_index(self.instance_key).index if isinstance(res, pd.DataFrame): try: if isinstance(index, str): - return res[index], pretty_name + return res[index], pretty_name, adata_index if isinstance(index, int): - return res.iloc[:, index], self._format_key(name, index=res.columns[index]) + return res.iloc[:, index], self._format_key(name, index=res.columns[index]), adata_index except KeyError: raise KeyError(f"Key `{index}` not found in `adata.obsm[{name!r}].`") from None @@ -173,8 +181,8 @@ def get_obsm(self, name: str, index: int | str = 0) -> tuple[ArrayLike | None, s f"Unable to convert `{index}` to an integer when accessing `adata.obsm[{name!r}]`." ) from None res = np.asarray(res) - - return (res if res.ndim == 1 else res[:, index]), pretty_name + column = res if res.ndim == 1 else res[:, index] + return column, pretty_name, adata_index def _format_key(self, key: str | int, index: int | str | None = None, adata_layer: bool = False) -> str: if index is not None: diff --git a/src/napari_spatialdata/_scatterwidgets.py b/src/napari_spatialdata/_scatterwidgets.py index 4c188e79..1800bf25 100644 --- a/src/napari_spatialdata/_scatterwidgets.py +++ b/src/napari_spatialdata/_scatterwidgets.py @@ -152,11 +152,11 @@ def _onChange(self) -> None: def _onAction(self, items: Iterable[str]) -> None: for item in sorted(set(items)): - try: - vec, _ = self._getter(item, index=self.getIndex()) - except Exception as e: # noqa: BLE001 - logger.error(e) - continue + # try: + vec, _, index = self._getter(item, index=self.getIndex()) + # except Exception as e: # noqa: BLE001 + # logger.error(e) + # continue self.chosen = item if isinstance(vec, np.ndarray): self.data = vec diff --git a/src/napari_spatialdata/_widgets.py b/src/napari_spatialdata/_widgets.py index 87daf177..3b07face 100644 --- a/src/napari_spatialdata/_widgets.py +++ b/src/napari_spatialdata/_widgets.py @@ -127,14 +127,22 @@ def _onAction(self, items: Iterable[str]) -> None: i = self.model.layer.metadata["adata"].var.index.get_loc(item) self.viewer.dims.set_point(0, i) else: - vec, name = self._getter(item, index=self.getIndex()) + vec, name, index = self._getter(item, index=self.getIndex()) if self.model.layer is not None: + # update the features (properties for each instance displayed on mouse hover in the bottom bar) + self.getIndex() + features_name = f"{item}_{self.getIndex()}" if self._attr == "obsm" else item + features = pd.DataFrame({features_name: vec}) + # we need this secret column "index", as explained here + # https://forum.image.sc/t/napari-labels-layer-properties/57649/2 + features["index"] = index + self.model.layer.features = features + properties = self._get_points_properties(vec, key=item, layer=self.model.layer) self.model.color_by = "" if self.model.system_name is None else item if isinstance(self.model.layer, (Points, Shapes)): self.model.layer.text = None # needed because of the text-feature order of updates - # self.model.layer.features = properties.get("features", None) self.model.layer.face_color = properties["face_color"] # self.model.layer.edge_color = properties["face_color"] self.model.layer.text = properties["text"] @@ -200,7 +208,7 @@ def _(self, vec: pd.Series, **kwargs: Any) -> dict[str, Any]: if isinstance(layer, Labels): element_indices = element_indices[element_indices != 0] # When merging if the row is not present in the other table it will be nan so we can give it a default color - vec_color_name = vec.name + "_color" + vec_color_name = vec.name + "_colors" if self._attr != "columns_df": if vec_color_name not in self.model.adata.uns: colorer = AnnData(shape=(len(vec), 0), obs=pd.DataFrame(index=vec.index, data={"vec": vec})) @@ -228,7 +236,7 @@ def _(self, vec: pd.Series, **kwargs: Any) -> dict[str, Any]: f"The {vec_color_name} column must have unique values for the each {vec.name} level. Found:\n" f"{unique_colors}" ) - color_dict = unique_colors.to_dict()[f"{vec.name}_color"] + color_dict = unique_colors.to_dict()[f"{vec.name}_colors"] if self.model.instance_key is not None and self.model.instance_key == vec.index.name: merge_df = pd.merge( @@ -240,7 +248,6 @@ def _(self, vec: pd.Series, **kwargs: Any) -> dict[str, Any]: merge_df["color"] = merge_df[vec.name].map(color_dict) if layer is not None and isinstance(layer, Labels): index_color_mapping = dict(zip(merge_df["element_indices"], merge_df["color"])) - index_color_mapping[0] = "#000000ff" return { "color": index_color_mapping, "properties": {"value": vec}, @@ -262,6 +269,7 @@ def _(self, vec: ArrayLike, **kwargs: Any) -> dict[str, Any]: (adata := self.model.adata) is not None and kwargs["key"] not in adata.obs.columns and kwargs["key"] not in adata.var.index + and kwargs["key"] not in adata.obsm ) or adata is None: merge_vec = layer.metadata["_columns_df"][kwargs["key"]] element_indices = merge_vec.index @@ -290,8 +298,6 @@ def _(self, vec: ArrayLike, **kwargs: Any) -> dict[str, Any]: element_indices_list = element_indices.to_list() change_index = element_indices_list.index(i) color_vec[change_index] = np.array([0.5, 0.5, 0.5, 1.0]) - if isinstance(layer, Labels): - color_vec[0] = np.array([0.0, 0.0, 0.0, 1.0]) if layer is not None and isinstance(layer, Labels): return { @@ -441,7 +447,7 @@ def __init_UI(self) -> None: clim=self.getClim(), border_width=1.0, border_color="black", - padding=(0.33, 0.167), + padding=(0.3, 0.167), axis_ratio=0.05, ) diff --git a/src/napari_spatialdata/utils/_utils.py b/src/napari_spatialdata/utils/_utils.py index 74d583af..f8b664c9 100644 --- a/src/napari_spatialdata/utils/_utils.py +++ b/src/napari_spatialdata/utils/_utils.py @@ -49,29 +49,29 @@ from spatialdata._types import ArrayLike -Vector_name_t = tuple[Optional[Union[pd.Series, ArrayLike]], Optional[str]] +Vector_name_index_t = tuple[Optional[Union[pd.Series, ArrayLike]], Optional[str], Optional[pd.Index]] -def _ensure_dense_vector(fn: Callable[..., Vector_name_t]) -> Callable[..., Vector_name_t]: +def _ensure_dense_vector(fn: Callable[..., Vector_name_index_t]) -> Callable[..., Vector_name_index_t]: @wraps(fn) - def decorator(self: Any, *args: Any, **kwargs: Any) -> Vector_name_t: + def decorator(self: Any, *args: Any, **kwargs: Any) -> Vector_name_index_t: normalize = kwargs.pop("normalize", False) - res, fmt = fn(self, *args, **kwargs) + res, name, index = fn(self, *args, **kwargs) if res is None: - return None, None + return None, None, None if isinstance(res, pd.Series): if isinstance(res.dtype, pd.CategoricalDtype): - return res, fmt + return res, name, index if is_string_dtype(res) or is_object_dtype(res) or is_bool_dtype(res): - return res.astype("category"), fmt + return res.astype("category"), name, index if is_integer_dtype(res): unique = res.unique() n_uniq = len(unique) if n_uniq <= 2 and (set(unique) & {0, 1}): - return res.astype(bool).astype("category"), fmt + return res.astype(bool).astype("category"), name, index if len(unique) <= len(res) // 100: - return res.astype("category"), fmt + return res.astype("category"), name, index elif not is_numeric_dtype(res): raise TypeError(f"Unable to process `pandas.Series` of type `{infer_dtype(res)}`.") res = res.to_numpy() @@ -86,7 +86,7 @@ def decorator(self: Any, *args: Any, **kwargs: Any) -> Vector_name_t: if res.ndim != 1: raise ValueError(f"Expected 1-dimensional array, found `{res.ndim}`.") - return (_min_max_norm(res) if normalize else res), fmt + return (_min_max_norm(res) if normalize else res), name, index return decorator diff --git a/tests/conftest.py b/tests/conftest.py index de7732de..de166f82 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,7 @@ from spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata.datasets import blobs +from spatialdata.models import TableModel HERE: Path = Path(__file__).parent @@ -42,7 +43,8 @@ def adata_labels() -> AnnData: { "a": rng.normal(size=(n_obs_labels,)), "categorical": pd.Categorical(rng.integers(0, 2, size=(n_obs_labels,))), - "cell_id": pd.Categorical(seg), + "cell_id": seg, + "region": ["labels" for _ in range(n_obs_labels)], }, index=np.arange(n_obs_labels), ) @@ -58,7 +60,12 @@ def adata_labels() -> AnnData: } } obsm_labels = {"spatial": rng.integers(0, blobs.shape[0], size=(n_obs_labels, 2))} - return generate_adata(n_var, obs_labels, obsm_labels, uns_labels) + return TableModel.parse( + generate_adata(n_var, obs_labels, obsm_labels, uns_labels), + region="labels", + region_key="region", + instance_key="cell_id", + ) @pytest.fixture diff --git a/tests/test_spatialdata.py b/tests/test_spatialdata.py index ce6d5cf7..ba17d5d1 100644 --- a/tests/test_spatialdata.py +++ b/tests/test_spatialdata.py @@ -73,7 +73,7 @@ def test_sdatawidget_images(make_napari_viewer: Any, blobs_extra_cs: SpatialData del blobs_extra_cs.images["image"] -def test_sdatawidget_labels(make_napari_viewer: Any, blobs_extra_cs: SpatialData): +def test_sdatawidget_labels(qtbot, make_napari_viewer: Any, blobs_extra_cs: SpatialData): viewer = make_napari_viewer() widget = SdataWidget(viewer, EventedList([blobs_extra_cs])) assert len(widget.viewer_model.viewer.layers) == 0 diff --git a/tests/test_widgets.py b/tests/test_widgets.py index b27bb634..157c9f81 100644 --- a/tests/test_widgets.py +++ b/tests/test_widgets.py @@ -132,7 +132,7 @@ def test_scatterlistwidget( viewer.add_labels( image, name=layer_name, - metadata={"adata": adata_labels, "region_key": "cell_id"}, + metadata={"adata": adata_labels}, ) model = DataModel() widget = widget(viewer, model) @@ -173,7 +173,7 @@ def test_categorical_and_error( viewer.add_labels( image, name=layer_name, - metadata={"adata": adata_labels, "region_key": "cell_id"}, + metadata={"adata": adata_labels}, ) # widget._select_layer() @@ -208,7 +208,7 @@ def test_component_widget( viewer.add_labels( image, name=layer_name, - metadata={"adata": adata_labels, "region_key": "cell_id"}, + metadata={"adata": adata_labels}, ) model = DataModel() widget = widget(viewer, model) From 9b8bb4a9465c5b0ce8b1d4d69c269cd8d3f9ae37 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 2 Oct 2024 20:31:32 +0200 Subject: [PATCH 5/7] [pre-commit.ci] pre-commit autoupdate (#300) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [pre-commit.ci] pre-commit autoupdate updates: - [github.com/pre-commit/mirrors-mypy: v1.11.1 → v1.11.2](https://github.com/pre-commit/mirrors-mypy/compare/v1.11.1...v1.11.2) - [github.com/astral-sh/ruff-pre-commit: v0.5.7 → v0.6.8](https://github.com/astral-sh/ruff-pre-commit/compare/v0.5.7...v0.6.8) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 +-- benchmarks/benchmark_qt_widget.py | 1 + docs/notebooks/mibitof_analysis.ipynb | 21 ++++++------- docs/notebooks/nanostring_analysis.ipynb | 22 ++++++------- docs/notebooks/scatterwidget.ipynb | 39 ++++++++++++------------ docs/notebooks/spatialdata.ipynb | 23 +++++++------- examples/spatialdata_visium.py | 3 +- tests/conftest.py | 3 +- tests/test_cli.py | 3 +- tests/test_interactive.py | 2 +- tests/test_scatterwidgets.py | 1 + tests/test_spatialdata.py | 9 +++--- tests/test_utils.py | 3 +- tests/test_view.py | 5 +-- tests/test_viewer.py | 9 +++--- tests/test_widgets.py | 5 +-- 16 files changed, 82 insertions(+), 71 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 96a9f72a..899506b7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ default_stages: minimum_pre_commit_version: 2.9.3 repos: - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.1 + rev: v1.11.2 hooks: - id: mypy additional_dependencies: [numpy>=1.23] @@ -49,7 +49,7 @@ repos: hooks: - id: blacken-docs - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.7 + rev: v0.6.8 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/benchmarks/benchmark_qt_widget.py b/benchmarks/benchmark_qt_widget.py index d09973db..684e23ff 100644 --- a/benchmarks/benchmark_qt_widget.py +++ b/benchmarks/benchmark_qt_widget.py @@ -2,6 +2,7 @@ from napari import Viewer from napari.layers import Image from napari.utils.events import EventedList + from napari_spatialdata._sdata_widgets import SdataWidget diff --git a/docs/notebooks/mibitof_analysis.ipynb b/docs/notebooks/mibitof_analysis.ipynb index 624faa71..d5981515 100644 --- a/docs/notebooks/mibitof_analysis.ipynb +++ b/docs/notebooks/mibitof_analysis.ipynb @@ -46,12 +46,11 @@ "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", - "from napari_spatialdata import Interactive\n", "from spatialdata import SpatialData\n", - "import squidpy as sq\n", - "import scanpy as sc\n", "\n", - "plt.rcParams['figure.figsize'] = (20, 20)" + "from napari_spatialdata import Interactive\n", + "\n", + "plt.rcParams[\"figure.figsize\"] = (20, 20)" ] }, { @@ -205,7 +204,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -254,7 +253,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -295,7 +294,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -347,7 +346,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -398,7 +397,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -439,7 +438,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -480,7 +479,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] } ], diff --git a/docs/notebooks/nanostring_analysis.ipynb b/docs/notebooks/nanostring_analysis.ipynb index 8b1ec202..3923b775 100644 --- a/docs/notebooks/nanostring_analysis.ipynb +++ b/docs/notebooks/nanostring_analysis.ipynb @@ -46,12 +46,12 @@ "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", - "from napari_spatialdata import Interactive\n", - "from spatialdata import SpatialData\n", - "import squidpy as sq\n", "import scanpy as sc\n", + "from spatialdata import SpatialData\n", + "\n", + "from napari_spatialdata import Interactive\n", "\n", - "plt.rcParams['figure.figsize'] = (20, 20)" + "plt.rcParams[\"figure.figsize\"] = (20, 20)" ] }, { @@ -278,7 +278,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -327,7 +327,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -390,7 +390,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -431,7 +431,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -483,7 +483,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -572,7 +572,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -613,7 +613,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] } ], diff --git a/docs/notebooks/scatterwidget.ipynb b/docs/notebooks/scatterwidget.ipynb index bbe8611a..d0205346 100644 --- a/docs/notebooks/scatterwidget.ipynb +++ b/docs/notebooks/scatterwidget.ipynb @@ -41,12 +41,13 @@ "cell_type": "code", "execution_count": 10, "metadata": {}, + "outputs": [], "source": [ - "from napari_spatialdata import QtAdataScatterWidget\n", + "import matplotlib.pyplot as plt\n", "import squidpy as sq\n", - "import matplotlib.pyplot as plt" - ], - "outputs": [] + "\n", + "from napari_spatialdata import QtAdataScatterWidget" + ] }, { "attachments": {}, @@ -60,10 +61,10 @@ "cell_type": "code", "execution_count": 11, "metadata": {}, + "outputs": [], "source": [ "adata = sq.datasets.visium_hne_adata()" - ], - "outputs": [] + ] }, { "attachments": {}, @@ -77,10 +78,10 @@ "cell_type": "code", "execution_count": 13, "metadata": {}, + "outputs": [], "source": [ "%gui qt5" - ], - "outputs": [] + ] }, { "attachments": {}, @@ -94,11 +95,11 @@ "cell_type": "code", "execution_count": 14, "metadata": {}, + "outputs": [], "source": [ "widget = QtAdataScatterWidget(adata)\n", "widget.show()" - ], - "outputs": [] + ] }, { "attachments": {}, @@ -112,11 +113,11 @@ "cell_type": "code", "execution_count": 15, "metadata": {}, + "outputs": [], "source": [ "plt.imshow(widget.screenshot())\n", - "plt.axis('off')" - ], - "outputs": [] + "plt.axis(\"off\")" + ] }, { "attachments": {}, @@ -130,21 +131,21 @@ "cell_type": "code", "execution_count": 19, "metadata": {}, + "outputs": [], "source": [ "plt.imshow(widget.screenshot())\n", - "plt.axis('off')" - ], - "outputs": [] + "plt.axis(\"off\")" + ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, + "outputs": [], "source": [ "plt.imshow(widget.screenshot())\n", - "plt.axis('off')" - ], - "outputs": [] + "plt.axis(\"off\")" + ] } ], "metadata": { diff --git a/docs/notebooks/spatialdata.ipynb b/docs/notebooks/spatialdata.ipynb index 540714a2..30dcbace 100644 --- a/docs/notebooks/spatialdata.ipynb +++ b/docs/notebooks/spatialdata.ipynb @@ -66,10 +66,11 @@ "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", - "from napari_spatialdata import Interactive\n", "from spatialdata import SpatialData\n", "\n", - "plt.rcParams['figure.figsize'] = (20, 20)\n", + "from napari_spatialdata import Interactive\n", + "\n", + "plt.rcParams[\"figure.figsize\"] = (20, 20)\n", "\n", "FILE_PATH = \"../../../data/cosmx/data.zarr\" # Change this\n", "sdata = SpatialData.read(FILE_PATH)" @@ -156,7 +157,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -246,7 +247,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -299,7 +300,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -358,7 +359,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -401,7 +402,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -459,7 +460,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -502,7 +503,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -543,7 +544,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] }, { @@ -584,7 +585,7 @@ ], "source": [ "plt.imshow(interactive.screenshot())\n", - "plt.axis('off')" + "plt.axis(\"off\")" ] } ], diff --git a/examples/spatialdata_visium.py b/examples/spatialdata_visium.py index 52e97db8..1b0cb6af 100644 --- a/examples/spatialdata_visium.py +++ b/examples/spatialdata_visium.py @@ -3,9 +3,10 @@ # The dataset can be downloaded from https://spatialdata.scverse.org/en/latest/tutorials/notebooks/datasets/README.html -from napari_spatialdata import Interactive from spatialdata import SpatialData +from napari_spatialdata import Interactive + if __name__ == "__main__": sdata = SpatialData.read("../data/visium/data.zarr") # Change this path! i = Interactive(sdata) diff --git a/tests/conftest.py b/tests/conftest.py index de166f82..531d8d10 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,6 @@ from anndata import AnnData from loguru import logger from matplotlib.testing.compare import compare_images -from napari_spatialdata.utils._test_utils import save_image, take_screenshot from scipy import ndimage as ndi from skimage import data from spatialdata import SpatialData @@ -20,6 +19,8 @@ from spatialdata.datasets import blobs from spatialdata.models import TableModel +from napari_spatialdata.utils._test_utils import save_image, take_screenshot + HERE: Path = Path(__file__).parent SEED = 42 diff --git a/tests/test_cli.py b/tests/test_cli.py index 6f889ae5..617a122d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2,9 +2,10 @@ from click.testing import CliRunner from napari.viewer import Viewer -from napari_spatialdata.__main__ import cli from spatialdata.datasets import blobs +from napari_spatialdata.__main__ import cli + def test_view_exists(): runner = CliRunner() diff --git a/tests/test_interactive.py b/tests/test_interactive.py index 9a465d15..88ed8d18 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -1,7 +1,7 @@ -from napari_spatialdata._interactive import Interactive from spatialdata import SpatialData from spatialdata.models import Image2DModel +from napari_spatialdata._interactive import Interactive from tests.conftest import PlotTester, PlotTesterMeta diff --git a/tests/test_scatterwidgets.py b/tests/test_scatterwidgets.py index 1cbbff3e..0fe80652 100644 --- a/tests/test_scatterwidgets.py +++ b/tests/test_scatterwidgets.py @@ -2,6 +2,7 @@ import numpy as np import pandas as pd + from napari_spatialdata._model import DataModel from napari_spatialdata._scatterwidgets import MatplotlibWidget diff --git a/tests/test_spatialdata.py b/tests/test_spatialdata.py index ba17d5d1..18d68ce0 100644 --- a/tests/test_spatialdata.py +++ b/tests/test_spatialdata.py @@ -12,10 +12,6 @@ from multiscale_spatial_image import to_multiscale from napari.layers import Image, Labels, Points from napari.utils.events import EventedList -from napari_spatialdata import QtAdataViewWidget -from napari_spatialdata._sdata_widgets import CoordinateSystemWidget, ElementWidget, SdataWidget -from napari_spatialdata.constants import config -from napari_spatialdata.utils._test_utils import click_list_widget_item, get_center_pos_listitem from numpy import int64 from spatialdata import SpatialData, deepcopy from spatialdata._core.query.relational_query import get_element_instances @@ -25,6 +21,11 @@ from spatialdata.transformations.operations import set_transformation from xarray import DataArray +from napari_spatialdata import QtAdataViewWidget +from napari_spatialdata._sdata_widgets import CoordinateSystemWidget, ElementWidget, SdataWidget +from napari_spatialdata.constants import config +from napari_spatialdata.utils._test_utils import click_list_widget_item, get_center_pos_listitem + RNG = np.random.default_rng(seed=0) diff --git a/tests/test_utils.py b/tests/test_utils.py index d4f16f05..428aedc5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,6 +4,8 @@ import numpy as np import pytest from anndata import AnnData +from spatialdata.datasets import blobs + from napari_spatialdata.utils._utils import ( _adjust_channels_order, _get_categorical, @@ -13,7 +15,6 @@ _position_cluster_labels, _set_palette, ) -from spatialdata.datasets import blobs def test_get_categorical(adata_labels: AnnData): diff --git a/tests/test_view.py b/tests/test_view.py index 61b9f3d5..6d373d92 100644 --- a/tests/test_view.py +++ b/tests/test_view.py @@ -3,11 +3,12 @@ import numpy as np import pytest from napari.utils.events import EventedList +from spatialdata.datasets import blobs +from spatialdata.transformations import Affine, set_transformation + from napari_spatialdata._sdata_widgets import SdataWidget from napari_spatialdata._view import QtAdataViewWidget from napari_spatialdata.utils._test_utils import click_list_widget_item, get_center_pos_listitem -from spatialdata.datasets import blobs -from spatialdata.transformations import Affine, set_transformation @pytest.mark.parametrize("widget", [QtAdataViewWidget]) diff --git a/tests/test_viewer.py b/tests/test_viewer.py index d5405bc4..d66a5b7f 100644 --- a/tests/test_viewer.py +++ b/tests/test_viewer.py @@ -4,15 +4,16 @@ import numpy as np import pytest from napari.utils.events import EventedList -from napari_spatialdata import QtAdataViewWidget -from napari_spatialdata._sdata_widgets import SdataWidget -from napari_spatialdata.utils._test_utils import click_list_widget_item, get_center_pos_listitem -from napari_spatialdata.utils._utils import _get_transform from qtpy.QtCore import Qt from spatialdata.datasets import blobs from spatialdata.models import Image2DModel from spatialdata.transformations import Scale, Translation, set_transformation +from napari_spatialdata import QtAdataViewWidget +from napari_spatialdata._sdata_widgets import SdataWidget +from napari_spatialdata.utils._test_utils import click_list_widget_item, get_center_pos_listitem +from napari_spatialdata.utils._utils import _get_transform + sdata = blobs(extra_coord_system="space") diff --git a/tests/test_widgets.py b/tests/test_widgets.py index 157c9f81..82b2779d 100644 --- a/tests/test_widgets.py +++ b/tests/test_widgets.py @@ -7,11 +7,12 @@ from anndata.tests.helpers import assert_equal from napari.layers import Image, Labels from napari.utils.events import EventedList +from spatialdata import SpatialData +from spatialdata._types import ArrayLike + from napari_spatialdata._model import DataModel from napari_spatialdata._sdata_widgets import SdataWidget from napari_spatialdata._view import QtAdataScatterWidget, QtAdataViewWidget -from spatialdata import SpatialData -from spatialdata._types import ArrayLike # make_napari_viewer is a pytest fixture that returns a napari viewer object From 7d49c2388850b1aff7d2bac2ab1ef3f22c45b88b Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Thu, 3 Oct 2024 18:07:11 +0200 Subject: [PATCH 6/7] Add teardown steep for example benchamark (#311) Co-authored-by: Luca Marconato --- benchmarks/benchmark_qt_widget.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/benchmarks/benchmark_qt_widget.py b/benchmarks/benchmark_qt_widget.py index 684e23ff..b3c91fec 100644 --- a/benchmarks/benchmark_qt_widget.py +++ b/benchmarks/benchmark_qt_widget.py @@ -18,3 +18,6 @@ def time_create_widget(self) -> None: def time_layer_added(self) -> None: self.viewer.add_layer(self.image) + + def teardown(self) -> None: + self.viewer.close() From a38e2866c9195866fecf1eca0f911d52587b163a Mon Sep 17 00:00:00 2001 From: Minh Trinh <159905267+minhtien-trinh@users.noreply.github.com> Date: Wed, 9 Oct 2024 15:39:34 +0200 Subject: [PATCH 7/7] Add functions to interactive (#315) * Add functions to interactive In this commit two functions are added to the interactive class. get_layer grabs the layer name and returns the layer if it matches the user input. add_text_to_polygons adds annotations to the chosen polygons/shapes layer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add tests for new functions in interactive * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Loosen tests for interactive * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Edit interactive polygon tests to use correct example dataset * Fix interactive test * Fix mixed line endings * trigger ci * Apply suggestions from code review improve readability Co-authored-by: Grzegorz Bokota * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Marconato Co-authored-by: Grzegorz Bokota --- CHANGELOG.md | 7 +++++++ src/napari_spatialdata/_interactive.py | 25 ++++++++++++++++++++++ tests/test_interactive.py | 29 ++++++++++++++++++++++++++ 3 files changed, 61 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a5e6320..b2cb8710 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,13 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.0.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html +## [0.5.5] - 2024-10-07 + +### Added + +- New function to grab layer by name #315 @minhtien-trinh +- New annotation function to add text to polygons #315 @minhtien-trinh + ## [0.5.4] - 2024-xx-xx ### Fixed diff --git a/src/napari_spatialdata/_interactive.py b/src/napari_spatialdata/_interactive.py index e515bd61..91bd5a8d 100644 --- a/src/napari_spatialdata/_interactive.py +++ b/src/napari_spatialdata/_interactive.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any import napari +from napari.layers import Image, Labels, Points, Shapes from napari.utils.events import EventedList from spatialdata._types import ArrayLike @@ -100,3 +101,27 @@ def run(self) -> None: def screenshot(self) -> ArrayLike | Any: """Take a screenshot of the viewer in its current state.""" return self._viewer.screenshot(canvas_only=False) + + def get_layer(self, layer_name: str) -> Image | Labels | Points | Shapes | None: + """Get a layer by name.""" + try: + return self._viewer.layers[layer_name] + except KeyError: + return None + + def add_text_to_polygons(self, layer_name: str, text_annotations: list[str]) -> None: + """Add text annotations to a polygon layer.""" + polygon_layer = self.get_layer(layer_name) + if not polygon_layer: + raise ValueError(f"Polygon layer '{layer_name}' not found.") + if len(text_annotations) != len(polygon_layer.data): + raise ValueError( + f"The number of text annotations must match the number of polygons. " + f"Polygons: {len(polygon_layer.data)}, Text: {len(text_annotations)}." + ) + polygon_layer.text = { + "string": text_annotations, + "size": 10, + "color": "red", + "anchor": "center", + } diff --git a/tests/test_interactive.py b/tests/test_interactive.py index 88ed8d18..ecb74fd6 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -44,3 +44,32 @@ def test_plot_can_add_element_switch_cs(sdata_blobs: SpatialData): assert i._sdata_widget.coordinate_system_widget._system == "global" assert i._viewer.layers[-1].visible i._viewer.close() + + +class TestInteractive(PlotTester, metaclass=PlotTesterMeta): + def test_get_layer_existing(self, sdata_blobs: SpatialData): + i = Interactive(sdata=sdata_blobs, headless=True) + i.add_element(element="blobs_image", element_coordinate_system="global") + layer = i.get_layer("blobs_image") + assert layer is not None, "Expected to retrieve the blobs_image layer, but got None" + assert layer.name == "blobs_image", f"Expected layer name 'blobs_image', got {layer.name}" + i._viewer.close() + + def test_get_layer_non_existing(self, sdata_blobs: SpatialData): + i = Interactive(sdata=sdata_blobs, headless=True) + layer = i.get_layer("non_existing_layer") + assert layer is None, "Expected None for a non-existing layer, but got a layer" + i._viewer.close() + + def test_add_text_to_polygons(self, sdata_blobs: SpatialData): + i = Interactive(sdata=sdata_blobs, headless=True) + i.add_element(element="blobs_polygons", element_coordinate_system="global") + + # Mock polygon layer with some polygon data + text_annotations = ["Label 1", "Label 2", "Label 3", "Label 4", "Label 5"] + polygon_layer = i.get_layer("blobs_polygons") + + # Verify that text is added + i.add_text_to_polygons(layer_name="blobs_polygons", text_annotations=text_annotations) + assert polygon_layer.text is not None, "Text annotations were not added to the polygon layer" + i._viewer.close()