Skip to content

Commit

Permalink
Merge pull request #17 from stardist/tests
Browse files Browse the repository at this point in the history
Refactor code and add tests
  • Loading branch information
uschmidt83 authored Apr 8, 2022
2 parents 9736d7e + 52898f4 commit 5846abd
Show file tree
Hide file tree
Showing 12 changed files with 356 additions and 72 deletions.
44 changes: 44 additions & 0 deletions .github/actions/test/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
name: 'Run tests'

inputs:
python-version:
required: true
install-packages:
required: true
default: "'.[test]'"

runs:
using: "composite"
steps:
- name: Set up Python ${{ inputs.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ inputs.python-version }}

# these libraries enable testing on qt on linux
- uses: tlambert03/setup-qt-libs@v1

# strategy borrowed from vispy for installing opengl libs on windows
- name: Install Windows OpenGL
if: runner.os == 'Windows'
shell: pwsh
run: |
git clone --depth 1 https://github.com/pyvista/gl-ci-helpers.git
powershell gl-ci-helpers/appveyor/install_opengl.ps1
- name: Install package
shell: bash
run: |
python -m pip install --upgrade pip wheel setuptools
python -m pip install ${{ inputs.install-packages }}
- name: Find test directory
shell: bash
run: |
python -W ignore -c "import stardist_napari; print(f'pytest_dir={stardist_napari.__path__[0]}')" >> $GITHUB_ENV
# run tests inside the installed stardist_napari package
- name: Test with pytest
uses: GabrielBB/xvfb-action@v1
with:
run: python -m pytest -v --color=yes --durations=0 ${{ env.pytest_dir }}
24 changes: 24 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Test

on:
push:
pull_request:

jobs:
test:
name: ${{ matrix.platform }} py${{ matrix.python-version }}
runs-on: ${{ matrix.platform }}
strategy:
fail-fast: false
matrix:
platform: [ubuntu-latest, windows-latest, macos-latest]
python-version: ['3.7', '3.8', '3.9', '3.10']
exclude:
# TODO: no stardist wheels yet
- python-version: '3.10'

steps:
- uses: actions/checkout@v2
- uses: ./.github/actions/test
with:
python-version: ${{ matrix.python-version }}
25 changes: 25 additions & 0 deletions .github/workflows/test_pypi.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: Test (PyPI)

on:
schedule:
- cron: "0 18 * * *"

jobs:
test:
name: ${{ matrix.platform }} py${{ matrix.python-version }}
runs-on: ${{ matrix.platform }}
strategy:
fail-fast: false
matrix:
platform: [ubuntu-latest, windows-latest, macos-latest]
python-version: ['3.7', '3.8', '3.9', '3.10']
exclude:
# TODO: no stardist wheels yet
- python-version: '3.10'

steps:
- uses: actions/checkout@v2
- uses: ./.github/actions/test
with:
python-version: ${{ matrix.python-version }}
install-packages: "'stardist-napari[test]'"
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 22.1.0
rev: 22.3.0
hooks:
- id: black
language_version: python3.9
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Framework :: napari",
],
install_requires=[
Expand All @@ -50,4 +51,7 @@
"napari>=0.4.13",
"magicgui>=0.4.0",
],
extras_require={
"test": ["pytest", "pytest-qt", "napari[pyqt]>=0.4.13"],
},
)
13 changes: 13 additions & 0 deletions stardist_napari/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,14 @@
import os

DEBUG = os.environ.get("STARDIST_NAPARI_DEBUG", "").lower() in (
"y",
"yes",
"t",
"true",
"on",
"1",
)
del os

from ._dock_widget import plugin_wrapper as make_dock_widget
from ._version import __version__
164 changes: 94 additions & 70 deletions stardist_napari/_dock_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@

import napari
import numpy as np
from csbdeep.utils import (
_raise,
axes_check_and_normalize,
axes_dict,
load_json,
normalize,
)
from magicgui import magicgui
from magicgui import widgets as mw
from magicgui.application import use_app
Expand All @@ -28,6 +35,74 @@
from psygnal import Signal
from qtpy.QtWidgets import QSizePolicy

from . import DEBUG

# -------------------------------------------------------------------------

CUSTOM_MODEL = "CUSTOM_MODEL"


class Output(Enum):
Labels = "Label Image"
Polys = "Polygons / Polyhedra"
Both = "Both"


output_choices = [Output.Labels.value, Output.Polys.value, Output.Both.value]


class TimelapseLabels(Enum):
Match = "Match to previous frame (via overlap)"
Unique = "Unique through time"
Separate = "Separate per frame (no processing)"


timelapse_opts = [
TimelapseLabels.Match.value,
TimelapseLabels.Unique.value,
TimelapseLabels.Separate.value,
]


# -------------------------------------------------------------------------


def get_model_config_and_thresholds(path):
config = load_json(str(path / "config.json"))
thresholds = None
try:
# not all models have associated thresholds
thresholds = load_json(str(path / "thresholds.json"))
except FileNotFoundError:
pass
return config, thresholds


def get_data(image):
image = image.data[0] if image.multiscale else image.data
# enforce dense numpy array in case we are given a dask array etc
return np.asarray(image)


def change_handler(*widgets, init=True, debug=DEBUG):
def decorator_change_handler(handler):
@functools.wraps(handler)
def wrapper(*args):
source = Signal.sender()
emitter = Signal.current_emitter()
if debug:
# print(f"{emitter}: {source} = {args!r}")
print(f"{str(emitter.name).upper()}: {source.name} = {args!r}")
return handler(*args)

for widget in widgets:
widget.changed.connect(wrapper)
if init:
widget.changed(widget.value)
return wrapper

return decorator_change_handler


def surface_from_polys(polys):
from stardist.geometry import dist_to_coord3D
Expand All @@ -51,54 +126,17 @@ def surface_from_polys(polys):
return [np.array(vertices), np.array(faces), np.array(values)]


# -------------------------------------------------------------------------


def plugin_wrapper():
# delay imports until plugin is requested by user
# -> especially those importing tensorflow (csbdeep.models.*, stardist.models.*)
from csbdeep.models.pretrained import get_model_folder, get_registered_models
from csbdeep.utils import (
_raise,
axes_check_and_normalize,
axes_dict,
load_json,
normalize,
)
from stardist.matching import group_matching_labels
from stardist.models import StarDist2D, StarDist3D
from stardist.utils import abspath

DEBUG = os.environ.get("STARDIST_NAPARI_DEBUG", "").lower() in (
"y",
"yes",
"t",
"true",
"on",
"1",
)

def get_data(image):
image = image.data[0] if image.multiscale else image.data
# enforce dense numpy array in case we are given a dask array etc
return np.asarray(image)

def change_handler(*widgets, init=True, debug=DEBUG):
def decorator_change_handler(handler):
@functools.wraps(handler)
def wrapper(*args):
source = Signal.sender()
emitter = Signal.current_emitter()
if debug:
# print(f"{emitter}: {source} = {args!r}")
print(f"{str(emitter.name).upper()}: {source.name} = {args!r}")
return handler(*args)

for widget in widgets:
widget.changed.connect(wrapper)
if init:
widget.changed(widget.value)
return wrapper

return decorator_change_handler

# -------------------------------------------------------------------------

_models, _aliases = {}, {}
Expand All @@ -116,7 +154,6 @@ def wrapper(*args):
model_threshs = dict()
model_selected = None

CUSTOM_MODEL = "CUSTOM_MODEL"
model_type_choices = [
("2D", StarDist2D),
("3D", StarDist3D),
Expand All @@ -128,34 +165,16 @@ def get_model(model_type, model):
if model_type == CUSTOM_MODEL:
path = Path(model)
path.is_dir() or _raise(FileNotFoundError(f"{path} is not a directory"))
config = model_configs[(model_type, model)]
config = model_configs.get(
(model_type, model), get_model_config_and_thresholds(path)[0]
)
model_class = StarDist2D if config["n_dim"] == 2 else StarDist3D
return model_class(None, name=path.name, basedir=str(path.parent))
else:
return model_type.from_pretrained(model)

# -------------------------------------------------------------------------

class Output(Enum):
Labels = "Label Image"
Polys = "Polygons / Polyhedra"
Both = "Both"

output_choices = [Output.Labels.value, Output.Polys.value, Output.Both.value]

class TimelapseLabels(Enum):
Match = "Match to previous frame (via overlap)"
Unique = "Unique through time"
Separate = "Separate per frame (no processing)"

timelapse_opts = [
TimelapseLabels.Match.value,
TimelapseLabels.Unique.value,
TimelapseLabels.Separate.value,
]

# -------------------------------------------------------------------------

DEFAULTS = dict(
model_type=StarDist2D,
model2d=models_reg[StarDist2D][0][1],
Expand Down Expand Up @@ -315,7 +334,12 @@ def plugin(
progress_bar: mw.ProgressBar,
) -> List[napari.types.LayerDataTuple]:

model = get_model(*model_selected)
model = get_model(
model_type,
{StarDist2D: model2d, StarDist3D: model3d, CUSTOM_MODEL: model_folder}[
model_type
],
)
if model._is_multiclass():
warn("multi-class mode not supported yet, ignoring classification output")

Expand Down Expand Up @@ -571,8 +595,10 @@ def progress(it, **kwargs):
)
)
if output_type in (Output.Polys.value, Output.Both.value):
n_objects = len(polys["points"])
if isinstance(model, StarDist3D):
if "T" in axes:
raise NotImplementedError("Polyhedra output for 3D timelapse")
n_objects = len(polys["points"])
surface = surface_from_polys(polys)
layers.append(
(
Expand Down Expand Up @@ -1008,12 +1034,10 @@ def _get_model_folder():

def _process_model_folder(path):
try:
model_configs[key] = load_json(str(path / "config.json"))
try:
# not all models have associated thresholds
model_threshs[key] = load_json(str(path / "thresholds.json"))
except FileNotFoundError:
pass
_config, _thresholds = get_model_config_and_thresholds(path)
model_configs[key] = _config
if _thresholds is not None:
model_threshs[key] = _thresholds
finally:
select_model(key)
plugin.progress_bar.hide()
Expand Down
Empty file.
Loading

0 comments on commit 5846abd

Please sign in to comment.