Skip to content

Commit

Permalink
Cleanup & add tests for CSV loading
Browse files Browse the repository at this point in the history
  • Loading branch information
jluethi committed Jul 25, 2024
1 parent 263779b commit b728bdb
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 50 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,7 @@ venv/

# written by setuptools_scm
**/_version.py

# Files produced during testing
*.csv
*.clf
2 changes: 2 additions & 0 deletions example/feature_visualization_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
d = {
"test": [-100, 200, 300, 500, 900, 300],
"label": [1, 2, 3, 4, 5, 6],
"index": [1, 2, 3, 4, 5, 6],
"feature1": [100, 200, 300, 500, 900, 1001],
"feature2": [2200, 2100, 2000, 1500, 1300, 1001],
}
df = pd.DataFrame(data=d)
# df.to_csv("example_data.csv", index=False)

viewer = napari.Viewer()
label_layer = viewer.add_labels(lbl_img_np)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"scikit-image",
"matplotlib",
"pandas",
"packaging",
]

[project.optional-dependencies]
Expand Down
103 changes: 71 additions & 32 deletions src/napari_feature_visualization/_tests/test_feature_vis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import importlib

import napari
import numpy as np
import pandas as pd
import pytest
from packaging import version

from napari_feature_visualization.feature_vis import feature_vis

Expand All @@ -18,55 +19,93 @@ def create_label_img():
return lbl_img_np


def test_feature_vis_widget(make_napari_viewer):
lbl_img = create_label_img()

# Dummy df for this test
def create_feature_df():
d = {
"test": [-100, 200, 300, 500, 900, 300],
"label": [1, 2, 3, 4, 5, 6],
"index": [1, 2, 3, 4, 5, 6],
"feature1": [100, 200, 300, 500, 900, 1001],
"feature2": [2200, 2100, 2000, 1500, 1300, 1001],
}
df = pd.DataFrame(data=d)
# Ensure index is the labels for correct matching
return df


@pytest.mark.parametrize(
"load_features_from", ["CSV File", "Layer Properties"]
)
def test_feature_vis_widget(make_napari_viewer, load_features_from):
lbl_img = create_label_img()
df = create_feature_df()
viewer = make_napari_viewer()
label_layer = viewer.add_labels(lbl_img)
label_layer.features = df

feature_vis_widget = feature_vis()

# FIXME: It appears feature test is used. How do I test other feature
# selection? Setting it in the function appears to have no effect.

# if we "call" this object, it'll execute our function
feature_vis_widget(
label_layer=label_layer,
load_features_from="Layer Properties",
feature="feature1",
)

# Test differently depending on napari version, as colormap class has
# changed
print(
importlib.util.find_spec("napari.utils.colormaps.DirectLabelColormap")
)
colormaps_module = importlib.import_module("napari.utils.colormaps")
DirectLabelColormap = getattr(
colormaps_module, "DirectLabelColormap", None
)
if DirectLabelColormap is not None:
# napari >= 0.4.19 tests
from napari.utils.colormaps import DirectLabelColormap
if load_features_from == "CSV File":
df.to_csv("example_data.csv", index=False)
# if we "call" this object, it'll execute our function
feature_vis_widget(
label_layer=label_layer,
load_features_from=load_features_from,
DataFrame="example_data.csv",
label_column="label",
feature="feature1",
)
elif load_features_from == "Layer Properties":
label_layer.features = df
# if we "call" this object, it'll execute our function
feature_vis_widget(
label_layer=label_layer,
load_features_from=load_features_from,
feature="feature1",
label_column="label",
)

napari_version = version.parse(napari.__version__)
if napari_version >= version.parse("0.4.19"):
assert len(label_layer.colormap.color_dict) == 8
np.testing.assert_array_almost_equal(
label_layer.colormap.color_dict[3],
np.array([0.229739, 0.322361, 0.545706, 1.0]),
)
else:
# napari < 0.4.19 test
assert len(label_layer.colormap.colors) == 6
assert len(label_layer.colormap.colors) == 7
np.testing.assert_array_almost_equal(
label_layer.colormap.colors[2],
label_layer.colormap.colors[3],
np.array([0.229739, 0.322361, 0.545706, 1.0]),
)


# def test_feature_vis_from_csv(make_napari_viewer):
# lbl_img = create_label_img()
# df = create_feature_df()

# viewer = make_napari_viewer()
# label_layer = viewer.add_labels(lbl_img)
# label_layer.features = df

# feature_vis_widget = feature_vis()

# # if we "call" this object, it'll execute our function
# feature_vis_widget(
# label_layer=label_layer,
# load_features_from="CSV File",
# DataFrame="example_data.csv",
# feature="feature1",
# )

# napari_version = version.parse(napari.__version__)
# if napari_version >= version.parse("0.4.19"):
# assert len(label_layer.colormap.color_dict) == 8
# np.testing.assert_array_almost_equal(
# label_layer.colormap.color_dict[3],
# np.array([0.229739, 0.322361, 0.545706, 1.0]),
# )
# else:
# assert len(label_layer.colormap.colors) == 7
# np.testing.assert_array_almost_equal(
# label_layer.colormap.colors[3],
# np.array([0.229739, 0.322361, 0.545706, 1.0]),
# )
50 changes: 32 additions & 18 deletions src/napari_feature_visualization/feature_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,26 @@

import pathlib

import matplotlib.pyplot as plt
import matplotlib
import napari
import numpy as np
import pandas as pd
from magicgui import magic_factory
from packaging import version

from napari_feature_visualization.utils import ColormapChoices, get_df


def check_default_label_column(df):
if "label" in df:
return "label"
elif "Label" in df:
return "Label"
elif "index" in df:
return "index"
return ""


def _init(widget):
def get_feature_choices(*args):
if widget.load_features_from.value == "CSV File":
Expand Down Expand Up @@ -41,12 +52,7 @@ def update_df_columns(event):
widget.feature.reset_choices()
widget.label_column.reset_choices()
features = widget.feature.choices
if "label" in features:
widget.label_column.value = "label"
elif "Label" in features:
widget.label_column.value = "Label"
elif "index" in features:
widget.label_column.value = "index"
widget.label_column.value = check_default_label_column(features)

# if load_features_from is toggled, make the widget.DataFrame disappear
if widget.load_features_from.value == "Layer Properties":
Expand Down Expand Up @@ -109,15 +115,18 @@ def feature_vis(
site_df = get_df(DataFrame)
else:
site_df = pd.DataFrame(label_layer.properties)
label_column = "label"

if label_column == "":
label_column = check_default_label_column(site_df)

site_df.loc[:, "label"] = site_df[str(label_column)].astype(int)
# Check that there is one unique label for every entry in the dataframe
# => It's a site dataframe, not one containing many different sites
# TODO: How to feedback this issue to the user?
assert len(site_df["label"].unique()) == len(
site_df
), "A feature dataframe with non-unique labels was provided. The visualize_feature_on_label_layer function is not designed for this."
if len(site_df["label"].unique()) != len(site_df):
raise ValueError(
"A feature dataframe with non-unique labels was provided. The "
"visualize_feature_on_label_layer function is not designed for "
"this."
)
# Rescale feature between 0 & 1 to make a colormap
site_df["feature_scaled"] = (site_df[feature] - lower_contrast_limit) / (
upper_contrast_limit - lower_contrast_limit
Expand All @@ -126,7 +135,9 @@ def feature_vis(
site_df.loc[site_df["feature_scaled"] < 0, "feature_scaled"] = 0
site_df.loc[site_df["feature_scaled"] > 1, "feature_scaled"] = 1

colors = plt.cm.get_cmap(Colormap.value)(site_df["feature_scaled"])
colors = matplotlib.colormaps.get_cmap(Colormap.value)(
site_df["feature_scaled"]
)

# Create an array where the index is the label value and the value is
# the feature value
Expand All @@ -135,13 +146,16 @@ def feature_vis(
label_properties = {feature: np.round(properties_array, decimals=2)}

colormap = dict(zip(site_df["label"], colors))
colormap[None] = [0.0, 0.0, 0.0, 0.0]
# If in napari >= 0.4.19, use DirectLabelColormap
try:
# Show missing objects as black
colormap[None] = [0.0, 0.0, 0.0, 1.0]

# Handle different colormap APIs depending on the napari version
napari_version = version.parse(napari.__version__)
if napari_version >= version.parse("0.4.19"):
from napari.utils.colormaps import DirectLabelColormap

label_layer.colormap = DirectLabelColormap(color_dict=colormap)
except ImportError:
else:
label_layer.color = colormap

if load_features_from == "CSV File":
Expand Down

0 comments on commit b728bdb

Please sign in to comment.