Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added decorator to show way to process data #31

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ repos:
additional_dependencies: ["tomli"]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.5
rev: v0.6.6
hooks:
# Run the linter.
- id: ruff
Expand Down
3 changes: 2 additions & 1 deletion albucore/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__version__ = "0.0.16"
__version__ = "0.0.17"

from .decorators import *
from .functions import *
from .utils import *
51 changes: 51 additions & 0 deletions albucore/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import sys
from functools import wraps
from typing import Callable

import numpy as np

from albucore.utils import MONO_CHANNEL_DIMENSIONS, NUM_MULTI_CHANNEL_DIMENSIONS, P

if sys.version_info >= (3, 10):
from typing import Concatenate
else:
from typing_extensions import Concatenate


def contiguous(
func: Callable[Concatenate[np.ndarray, P], np.ndarray],
) -> Callable[Concatenate[np.ndarray, P], np.ndarray]:
"""Ensure that input img is contiguous and the output array is also contiguous."""

@wraps(func)
def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray:
# Ensure the input array is contiguous
img = np.require(img, requirements=["C_CONTIGUOUS"])
# Call the original function with the contiguous input
result = func(img, *args, **kwargs)
# Ensure the output array is contiguous
if not result.flags["C_CONTIGUOUS"]:
return np.require(result, requirements=["C_CONTIGUOUS"])

return result

return wrapped_function


def preserve_channel_dim(
func: Callable[Concatenate[np.ndarray, P], np.ndarray],
) -> Callable[Concatenate[np.ndarray, P], np.ndarray]:
"""Preserve dummy channel dim."""

@wraps(func)
def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray:
shape = img.shape
result = func(img, *args, **kwargs)
if len(shape) == NUM_MULTI_CHANNEL_DIMENSIONS and shape[-1] == 1 and result.ndim == MONO_CHANNEL_DIMENSIONS:
return np.expand_dims(result, axis=-1)

if len(shape) == MONO_CHANNEL_DIMENSIONS and result.ndim == NUM_MULTI_CHANNEL_DIMENSIONS:
return result[:, :, 0]
return result

return wrapped_function
80 changes: 77 additions & 3 deletions albucore/functions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from typing import Literal
from functools import wraps
from typing import Any, Callable, Literal

import cv2
import numpy as np

from albucore.decorators import contiguous, preserve_channel_dim
from albucore.utils import (
MAX_OPENCV_WORKING_CHANNELS,
MAX_VALUES_BY_DTYPE,
Expand All @@ -13,11 +15,9 @@
ValueType,
clip,
clipped,
contiguous,
convert_value,
get_max_value,
get_num_channels,
preserve_channel_dim,
)

np_operations = {"multiply": np.multiply, "add": np.add, "power": np.power}
Expand Down Expand Up @@ -570,6 +570,10 @@ def to_float_lut(img: np.ndarray, max_value: float | None = None) -> np.ndarray:


def to_float(img: np.ndarray, max_value: float | None = None) -> np.ndarray:
if img.dtype == np.float64:
return img.astype(np.float32)
if img.dtype == np.float32:
return img
if img.dtype == np.uint8:
return to_float_lut(img, max_value)
return to_float_numpy(img, max_value)
Expand Down Expand Up @@ -620,6 +624,12 @@ def from_float(img: np.ndarray, target_dtype: np.dtype, max_value: float | None
- For other input types, it falls back to a numpy-based implementation.
- The function clips values to ensure they fit within the range of the target data type.
"""
if target_dtype == np.float32:
return img

if target_dtype == np.float64:
return img.astype(np.float32)

if img.dtype == np.float32:
return from_float_opencv(img, target_dtype, max_value)

Expand Down Expand Up @@ -652,3 +662,67 @@ def vflip_numpy(img: np.ndarray) -> np.ndarray:

def vflip(img: np.ndarray) -> np.ndarray:
return vflip_cv2(img)


def float32_io(func: Callable[..., np.ndarray]) -> Callable[..., np.ndarray]:
ternaus marked this conversation as resolved.
Show resolved Hide resolved
"""Decorator to ensure float32 input/output for image processing functions.

This decorator converts the input image to float32 before passing it to the wrapped function,
and then converts the result back to the original dtype if it wasn't float32.

Args:
func (Callable[..., np.ndarray]): The image processing function to be wrapped.

Returns:
Callable[..., np.ndarray]: A wrapped function that handles float32 conversion.

Example:
@float32_io
def some_image_function(img: np.ndarray) -> np.ndarray:
# Function implementation
return processed_img
"""

@wraps(func)
def float32_wrapper(img: np.ndarray, *args: Any, **kwargs: Any) -> np.ndarray:
input_dtype = img.dtype
if input_dtype != np.float32:
img = to_float(img)
result = func(img, *args, **kwargs)

return from_float(result, target_dtype=input_dtype) if input_dtype != np.float32 else result

return float32_wrapper


def uint8_io(func: Callable[..., np.ndarray]) -> Callable[..., np.ndarray]:
"""Decorator to ensure uint8 input/output for image processing functions.

This decorator converts the input image to uint8 before passing it to the wrapped function,
and then converts the result back to the original dtype if it wasn't uint8.

Args:
func (Callable[..., np.ndarray]): The image processing function to be wrapped.

Returns:
Callable[..., np.ndarray]: A wrapped function that handles uint8 conversion.

Example:
@uint8_io
def some_image_function(img: np.ndarray) -> np.ndarray:
# Function implementation
return processed_img
"""

@wraps(func)
def uint8_wrapper(img: np.ndarray, *args: Any, **kwargs: Any) -> np.ndarray:
input_dtype = img.dtype

if input_dtype != np.uint8:
img = from_float(img, target_dtype=np.uint8)

result = func(img, *args, **kwargs)

return to_float(result) if input_dtype != np.uint8 else result

return uint8_wrapper
39 changes: 0 additions & 39 deletions albucore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,6 @@ def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.n
return wrapped_function


def preserve_channel_dim(
func: Callable[Concatenate[np.ndarray, P], np.ndarray],
) -> Callable[Concatenate[np.ndarray, P], np.ndarray]:
"""Preserve dummy channel dim."""

@wraps(func)
def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray:
shape = img.shape
result = func(img, *args, **kwargs)
if len(shape) == NUM_MULTI_CHANNEL_DIMENSIONS and shape[-1] == 1 and result.ndim == MONO_CHANNEL_DIMENSIONS:
return np.expand_dims(result, axis=-1)

if len(shape) == MONO_CHANNEL_DIMENSIONS and result.ndim == NUM_MULTI_CHANNEL_DIMENSIONS:
return result[:, :, 0]
return result

return wrapped_function


def get_num_channels(image: np.ndarray) -> int:
return image.shape[2] if image.ndim == NUM_MULTI_CHANNEL_DIMENSIONS else 1

Expand All @@ -151,26 +132,6 @@ def is_multispectral_image(image: np.ndarray) -> bool:
return num_channels not in {1, 3}


def contiguous(
func: Callable[Concatenate[np.ndarray, P], np.ndarray],
) -> Callable[Concatenate[np.ndarray, P], np.ndarray]:
"""Ensure that input img is contiguous and the output array is also contiguous."""

@wraps(func)
def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray:
# Ensure the input array is contiguous
img = np.require(img, requirements=["C_CONTIGUOUS"])
# Call the original function with the contiguous input
result = func(img, *args, **kwargs)
# Ensure the output array is contiguous
if not result.flags["C_CONTIGUOUS"]:
return np.require(result, requirements=["C_CONTIGUOUS"])

return result

return wrapped_function


def convert_value(value: np.ndarray | float, num_channels: int) -> float | np.ndarray:
"""Convert a multiplier to a float / int or a numpy array.

Expand Down
20 changes: 20 additions & 0 deletions tests/test_to_from_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,23 @@ def test_from_float_opencv_input_unchanged(dtype, channels):
img_copy = img.copy()
_ = from_float_opencv(img, dtype, max_value)
np.testing.assert_array_equal(img, img_copy)


def test_to_float_returns_same_object_for_float32():
float32_image = np.random.rand(10, 10, 3).astype(np.float32)
result = to_float(float32_image)
assert result is float32_image # Check if it's the same object


@pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.float32])
def test_to_float_from_float_roundtrip(dtype):
if dtype == np.float32:
original = np.random.rand(10, 10, 3).astype(dtype)
else:
original = np.random.randint(0, 256, (10, 10, 3)).astype(dtype)
ternaus marked this conversation as resolved.
Show resolved Hide resolved

float_version = to_float(original)
roundtrip = from_float(float_version, dtype)

assert roundtrip.dtype == dtype
np.testing.assert_allclose(original, roundtrip, rtol=1e-5, atol=1e-8)
Comment on lines +335 to +346
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Consider adding more dtypes to the roundtrip test

The test covers important dtypes, but consider adding int32, int64, and float64 to ensure comprehensive coverage of potential input types.

@pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.int32, np.int64, np.float32, np.float64])
def test_to_float_from_float_roundtrip(dtype):
    if np.issubdtype(dtype, np.floating):
        original = np.random.rand(10, 10, 3).astype(dtype)
    else:
        original = np.random.randint(0, 256, (10, 10, 3)).astype(dtype)

    float_version = to_float(original)
    roundtrip = from_float(float_version, dtype)

    assert roundtrip.dtype == dtype
    np.testing.assert_allclose(original, roundtrip, rtol=1e-5, atol=1e-8)

69 changes: 68 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
import pytest
import cv2
from albucore.utils import NPDTYPE_TO_OPENCV_DTYPE, clip, convert_value, get_opencv_dtype_from_numpy, contiguous
from albucore.decorators import contiguous
from albucore.functions import float32_io, from_float, to_float, uint8_io
from albucore.utils import NPDTYPE_TO_OPENCV_DTYPE, clip, convert_value, get_opencv_dtype_from_numpy


@pytest.mark.parametrize("input_img, dtype, expected", [
Expand Down Expand Up @@ -88,3 +90,68 @@ def test_contiguous_decorator(input_array):
# Check if the content is correct (same as reversing the original array)
expected_output = input_array[::-1, ::-1]
np.testing.assert_array_equal(output_array, expected_output), "Output array content is not as expected"


# Sample functions to be wrapped
@float32_io
def dummy_float32_func(img):
return img * 2

@uint8_io
def dummy_uint8_func(img):
return np.clip(img + 10, 0, 255).astype(np.uint8)

# Test data
@pytest.fixture(params=[
np.uint8, np.float32
])
ternaus marked this conversation as resolved.
Show resolved Hide resolved
def test_image(request):
dtype = request.param
if np.issubdtype(dtype, np.integer):
return np.random.randint(0, 256, (10, 10, 3), dtype=dtype)
else:
return np.random.rand(10, 10, 3).astype(dtype)
ternaus marked this conversation as resolved.
Show resolved Hide resolved

# Tests
@pytest.mark.parametrize("wrapper,func, image", [
(float32_io, dummy_float32_func, np.random.randint(0, 256, (10, 10, 3), dtype=np.uint8)),
(uint8_io, dummy_uint8_func, np.random.rand(10, 10, 3).astype(np.float32))
])
def test_io_wrapper(wrapper, func, image):
input_dtype = image.dtype
result = func(image)

# Check if the output dtype matches the input dtype
assert result.dtype == input_dtype

# Check if the function was actually applied
if wrapper == float32_io:
expected = from_float(to_float(image) * 2, input_dtype)
else: # uint8_io
expected = to_float(from_float(image, np.uint8) + 10)
ternaus marked this conversation as resolved.
Show resolved Hide resolved

np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5)
ternaus marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.parametrize("wrapper,func,expected_intermediate_dtype", [
(float32_io, dummy_float32_func, np.float32),
(uint8_io, dummy_uint8_func, np.uint8)
])
def test_intermediate_dtype(wrapper, func, expected_intermediate_dtype, test_image):
original_func = func.__wrapped__ # Access the original function

def check_dtype(img):
assert img.dtype == expected_intermediate_dtype
return original_func(img)

wrapped_func = wrapper(check_dtype)
wrapped_func(test_image) # This will raise an assertion error if the intermediate dtype is incorrect

def test_float32_io_preserves_float32(test_image):
if test_image.dtype == np.float32:
result = dummy_float32_func(test_image)
assert result.dtype == np.float32
ternaus marked this conversation as resolved.
Show resolved Hide resolved
ternaus marked this conversation as resolved.
Show resolved Hide resolved

def test_uint8_io_preserves_uint8(test_image):
if test_image.dtype == np.uint8:
result = dummy_uint8_func(test_image)
assert result.dtype == np.uint8
ternaus marked this conversation as resolved.
Show resolved Hide resolved
ternaus marked this conversation as resolved.
Show resolved Hide resolved
Loading