Skip to content

Commit

Permalink
Added decorator to show way to process data (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus authored Sep 20, 2024
1 parent 5f9611f commit c40fb69
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ repos:
additional_dependencies: ["tomli"]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.5
rev: v0.6.6
hooks:
# Run the linter.
- id: ruff
Expand Down
3 changes: 2 additions & 1 deletion albucore/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__version__ = "0.0.16"
__version__ = "0.0.17"

from .decorators import *
from .functions import *
from .utils import *
51 changes: 51 additions & 0 deletions albucore/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import sys
from functools import wraps
from typing import Callable

import numpy as np

from albucore.utils import MONO_CHANNEL_DIMENSIONS, NUM_MULTI_CHANNEL_DIMENSIONS, P

if sys.version_info >= (3, 10):
from typing import Concatenate
else:
from typing_extensions import Concatenate


def contiguous(
func: Callable[Concatenate[np.ndarray, P], np.ndarray],
) -> Callable[Concatenate[np.ndarray, P], np.ndarray]:
"""Ensure that input img is contiguous and the output array is also contiguous."""

@wraps(func)
def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray:
# Ensure the input array is contiguous
img = np.require(img, requirements=["C_CONTIGUOUS"])
# Call the original function with the contiguous input
result = func(img, *args, **kwargs)
# Ensure the output array is contiguous
if not result.flags["C_CONTIGUOUS"]:
return np.require(result, requirements=["C_CONTIGUOUS"])

return result

return wrapped_function


def preserve_channel_dim(
func: Callable[Concatenate[np.ndarray, P], np.ndarray],
) -> Callable[Concatenate[np.ndarray, P], np.ndarray]:
"""Preserve dummy channel dim."""

@wraps(func)
def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray:
shape = img.shape
result = func(img, *args, **kwargs)
if len(shape) == NUM_MULTI_CHANNEL_DIMENSIONS and shape[-1] == 1 and result.ndim == MONO_CHANNEL_DIMENSIONS:
return np.expand_dims(result, axis=-1)

if len(shape) == MONO_CHANNEL_DIMENSIONS and result.ndim == NUM_MULTI_CHANNEL_DIMENSIONS:
return result[:, :, 0]
return result

return wrapped_function
80 changes: 77 additions & 3 deletions albucore/functions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from typing import Literal
from functools import wraps
from typing import Any, Callable, Literal

import cv2
import numpy as np

from albucore.decorators import contiguous, preserve_channel_dim
from albucore.utils import (
MAX_OPENCV_WORKING_CHANNELS,
MAX_VALUES_BY_DTYPE,
Expand All @@ -13,11 +15,9 @@
ValueType,
clip,
clipped,
contiguous,
convert_value,
get_max_value,
get_num_channels,
preserve_channel_dim,
)

np_operations = {"multiply": np.multiply, "add": np.add, "power": np.power}
Expand Down Expand Up @@ -570,6 +570,10 @@ def to_float_lut(img: np.ndarray, max_value: float | None = None) -> np.ndarray:


def to_float(img: np.ndarray, max_value: float | None = None) -> np.ndarray:
if img.dtype == np.float64:
return img.astype(np.float32)
if img.dtype == np.float32:
return img
if img.dtype == np.uint8:
return to_float_lut(img, max_value)
return to_float_numpy(img, max_value)
Expand Down Expand Up @@ -620,6 +624,12 @@ def from_float(img: np.ndarray, target_dtype: np.dtype, max_value: float | None
- For other input types, it falls back to a numpy-based implementation.
- The function clips values to ensure they fit within the range of the target data type.
"""
if target_dtype == np.float32:
return img

if target_dtype == np.float64:
return img.astype(np.float32)

if img.dtype == np.float32:
return from_float_opencv(img, target_dtype, max_value)

Expand Down Expand Up @@ -652,3 +662,67 @@ def vflip_numpy(img: np.ndarray) -> np.ndarray:

def vflip(img: np.ndarray) -> np.ndarray:
return vflip_cv2(img)


def float32_io(func: Callable[..., np.ndarray]) -> Callable[..., np.ndarray]:
"""Decorator to ensure float32 input/output for image processing functions.
This decorator converts the input image to float32 before passing it to the wrapped function,
and then converts the result back to the original dtype if it wasn't float32.
Args:
func (Callable[..., np.ndarray]): The image processing function to be wrapped.
Returns:
Callable[..., np.ndarray]: A wrapped function that handles float32 conversion.
Example:
@float32_io
def some_image_function(img: np.ndarray) -> np.ndarray:
# Function implementation
return processed_img
"""

@wraps(func)
def float32_wrapper(img: np.ndarray, *args: Any, **kwargs: Any) -> np.ndarray:
input_dtype = img.dtype
if input_dtype != np.float32:
img = to_float(img)
result = func(img, *args, **kwargs)

return from_float(result, target_dtype=input_dtype) if input_dtype != np.float32 else result

return float32_wrapper


def uint8_io(func: Callable[..., np.ndarray]) -> Callable[..., np.ndarray]:
"""Decorator to ensure uint8 input/output for image processing functions.
This decorator converts the input image to uint8 before passing it to the wrapped function,
and then converts the result back to the original dtype if it wasn't uint8.
Args:
func (Callable[..., np.ndarray]): The image processing function to be wrapped.
Returns:
Callable[..., np.ndarray]: A wrapped function that handles uint8 conversion.
Example:
@uint8_io
def some_image_function(img: np.ndarray) -> np.ndarray:
# Function implementation
return processed_img
"""

@wraps(func)
def uint8_wrapper(img: np.ndarray, *args: Any, **kwargs: Any) -> np.ndarray:
input_dtype = img.dtype

if input_dtype != np.uint8:
img = from_float(img, target_dtype=np.uint8)

result = func(img, *args, **kwargs)

return to_float(result) if input_dtype != np.uint8 else result

return uint8_wrapper
39 changes: 0 additions & 39 deletions albucore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,6 @@ def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.n
return wrapped_function


def preserve_channel_dim(
func: Callable[Concatenate[np.ndarray, P], np.ndarray],
) -> Callable[Concatenate[np.ndarray, P], np.ndarray]:
"""Preserve dummy channel dim."""

@wraps(func)
def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray:
shape = img.shape
result = func(img, *args, **kwargs)
if len(shape) == NUM_MULTI_CHANNEL_DIMENSIONS and shape[-1] == 1 and result.ndim == MONO_CHANNEL_DIMENSIONS:
return np.expand_dims(result, axis=-1)

if len(shape) == MONO_CHANNEL_DIMENSIONS and result.ndim == NUM_MULTI_CHANNEL_DIMENSIONS:
return result[:, :, 0]
return result

return wrapped_function


def get_num_channels(image: np.ndarray) -> int:
return image.shape[2] if image.ndim == NUM_MULTI_CHANNEL_DIMENSIONS else 1

Expand All @@ -151,26 +132,6 @@ def is_multispectral_image(image: np.ndarray) -> bool:
return num_channels not in {1, 3}


def contiguous(
func: Callable[Concatenate[np.ndarray, P], np.ndarray],
) -> Callable[Concatenate[np.ndarray, P], np.ndarray]:
"""Ensure that input img is contiguous and the output array is also contiguous."""

@wraps(func)
def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray:
# Ensure the input array is contiguous
img = np.require(img, requirements=["C_CONTIGUOUS"])
# Call the original function with the contiguous input
result = func(img, *args, **kwargs)
# Ensure the output array is contiguous
if not result.flags["C_CONTIGUOUS"]:
return np.require(result, requirements=["C_CONTIGUOUS"])

return result

return wrapped_function


def convert_value(value: np.ndarray | float, num_channels: int) -> float | np.ndarray:
"""Convert a multiplier to a float / int or a numpy array.
Expand Down
20 changes: 20 additions & 0 deletions tests/test_to_from_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,23 @@ def test_from_float_opencv_input_unchanged(dtype, channels):
img_copy = img.copy()
_ = from_float_opencv(img, dtype, max_value)
np.testing.assert_array_equal(img, img_copy)


def test_to_float_returns_same_object_for_float32():
float32_image = np.random.rand(10, 10, 3).astype(np.float32)
result = to_float(float32_image)
assert result is float32_image # Check if it's the same object


@pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.float32])
def test_to_float_from_float_roundtrip(dtype):
if dtype == np.float32:
original = np.random.rand(10, 10, 3).astype(dtype)
else:
original = np.random.randint(0, 256, (10, 10, 3)).astype(dtype)

float_version = to_float(original)
roundtrip = from_float(float_version, dtype)

assert roundtrip.dtype == dtype
np.testing.assert_allclose(original, roundtrip, rtol=1e-5, atol=1e-8)
69 changes: 68 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
import pytest
import cv2
from albucore.utils import NPDTYPE_TO_OPENCV_DTYPE, clip, convert_value, get_opencv_dtype_from_numpy, contiguous
from albucore.decorators import contiguous
from albucore.functions import float32_io, from_float, to_float, uint8_io
from albucore.utils import NPDTYPE_TO_OPENCV_DTYPE, clip, convert_value, get_opencv_dtype_from_numpy


@pytest.mark.parametrize("input_img, dtype, expected", [
Expand Down Expand Up @@ -88,3 +90,68 @@ def test_contiguous_decorator(input_array):
# Check if the content is correct (same as reversing the original array)
expected_output = input_array[::-1, ::-1]
np.testing.assert_array_equal(output_array, expected_output), "Output array content is not as expected"


# Sample functions to be wrapped
@float32_io
def dummy_float32_func(img):
return img * 2

@uint8_io
def dummy_uint8_func(img):
return np.clip(img + 10, 0, 255).astype(np.uint8)

# Test data
@pytest.fixture(params=[
np.uint8, np.float32
])
def test_image(request):
dtype = request.param
if np.issubdtype(dtype, np.integer):
return np.random.randint(0, 256, (10, 10, 3), dtype=dtype)
else:
return np.random.rand(10, 10, 3).astype(dtype)

# Tests
@pytest.mark.parametrize("wrapper,func, image", [
(float32_io, dummy_float32_func, np.random.randint(0, 256, (10, 10, 3), dtype=np.uint8)),
(uint8_io, dummy_uint8_func, np.random.rand(10, 10, 3).astype(np.float32))
])
def test_io_wrapper(wrapper, func, image):
input_dtype = image.dtype
result = func(image)

# Check if the output dtype matches the input dtype
assert result.dtype == input_dtype

# Check if the function was actually applied
if wrapper == float32_io:
expected = from_float(to_float(image) * 2, input_dtype)
else: # uint8_io
expected = to_float(from_float(image, np.uint8) + 10)

np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5)

@pytest.mark.parametrize("wrapper,func,expected_intermediate_dtype", [
(float32_io, dummy_float32_func, np.float32),
(uint8_io, dummy_uint8_func, np.uint8)
])
def test_intermediate_dtype(wrapper, func, expected_intermediate_dtype, test_image):
original_func = func.__wrapped__ # Access the original function

def check_dtype(img):
assert img.dtype == expected_intermediate_dtype
return original_func(img)

wrapped_func = wrapper(check_dtype)
wrapped_func(test_image) # This will raise an assertion error if the intermediate dtype is incorrect

def test_float32_io_preserves_float32(test_image):
if test_image.dtype == np.float32:
result = dummy_float32_func(test_image)
assert result.dtype == np.float32

def test_uint8_io_preserves_uint8(test_image):
if test_image.dtype == np.uint8:
result = dummy_uint8_func(test_image)
assert result.dtype == np.uint8

0 comments on commit c40fb69

Please sign in to comment.