Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4x Faster LUT via StringZilla #36

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions albucore/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import cv2
import numpy as np
import stringzilla as sz

from albucore.decorators import contiguous, preserve_channel_dim
from albucore.utils import (
Expand All @@ -26,7 +27,9 @@


def create_lut_array(
dtype: type[np.number], value: float | np.ndarray, operation: Literal["add", "multiply", "power"]
dtype: type[np.number],
value: float | np.ndarray,
operation: Literal["add", "multiply", "power"],
) -> np.ndarray:
max_value = MAX_VALUES_BY_DTYPE[dtype]

Expand All @@ -42,16 +45,27 @@ def create_lut_array(
raise ValueError(f"Unsupported operation: {operation}")


def apply_lut(img: np.ndarray, value: float | np.ndarray, operation: Literal["add", "multiply", "power"]) -> np.ndarray:
def apply_lut(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider refactoring the implementation to improve code organization and clarity.

While the new implementation using stringzilla may offer performance benefits, it does increase code complexity. Consider the following suggestions to balance performance and readability:

  1. Move serialize_lookup_recover outside of apply_lut:
def serialize_lookup_recover(img: np.ndarray, lut: np.ndarray) -> np.ndarray:
    img_bytes = img.tobytes()
    lut_bytes = lut.tobytes()
    sz.translate(img_bytes, lut_bytes)
    return np.frombuffer(img_bytes, dtype=img.dtype).reshape(img.shape)

def apply_lut(
    img: np.ndarray,
    value: float | np.ndarray,
    operation: Literal["add", "multiply", "power"],
) -> np.ndarray:
    dtype = img.dtype
    if isinstance(value, (int, float)):
        lut = create_lut_array(dtype, value, operation)
        return serialize_lookup_recover(img, clip(lut, dtype))
    num_channels = img.shape[-1]
    luts = create_lut_array(dtype, value, operation)
    return cv2.merge([serialize_lookup_recover(img[:, :, i], clip(luts[i], dtype)) for i in range(num_channels)])
  1. Add comments explaining the performance benefits:
def serialize_lookup_recover(img: np.ndarray, lut: np.ndarray) -> np.ndarray:
    # This function uses stringzilla for efficient byte-level LUT application,
    # which can be faster than cv2.LUT for large images or frequent calls.
    img_bytes = img.tobytes()
    lut_bytes = lut.tobytes()
    sz.translate(img_bytes, lut_bytes)
    return np.frombuffer(img_bytes, dtype=img.dtype).reshape(img.shape)
  1. Consider adding a benchmark comparison between this method and cv2.LUT to justify the added complexity. If the performance gain is minimal, you might want to revert to the simpler cv2.LUT implementation.

  2. If you keep this implementation, add a note in the function docstring explaining why this approach was chosen over cv2.LUT.

These changes will help maintain the potential performance benefits while improving code readability and maintainability.

img: np.ndarray,
value: float | np.ndarray,
operation: Literal["add", "multiply", "power"],
) -> np.ndarray:
dtype = img.dtype

def serialize_lookup_recover(img: np.ndarray, lut: np.ndarray) -> np.ndarray:
# Encode image into bytes, perform the lookups and then decode the bytes back to numpy array
img_bytes = img.tobytes()
lut_bytes = lut.tobytes()
sz.translate(img_bytes, lut_bytes)
return np.frombuffer(img_bytes, dtype=img.dtype).reshape(img.shape)
Comment on lines +55 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question (performance): Can you provide context for replacing cv2.LUT with serialize_lookup_recover?

This change seems significant. Could you share any performance benchmarks or explain the rationale behind this new approach? It would be helpful to understand the benefits over the previous cv2.LUT method.


if isinstance(value, (int, float)):
lut = create_lut_array(dtype, value, operation)
return cv2.LUT(img, clip(lut, dtype))
return serialize_lookup_recover(img, clip(lut, dtype))

num_channels = img.shape[-1]
luts = create_lut_array(dtype, value, operation)
return cv2.merge([cv2.LUT(img[:, :, i], clip(luts[i], dtype)) for i in range(num_channels)])
return cv2.merge([serialize_lookup_recover(img[:, :, i], clip(luts[i], dtype)) for i in range(num_channels)])


def prepare_value_opencv(
Expand Down Expand Up @@ -84,7 +98,9 @@ def prepare_value_opencv(


def apply_numpy(
img: np.ndarray, value: float | np.ndarray, operation: Literal["add", "multiply", "power"]
img: np.ndarray,
value: float | np.ndarray,
operation: Literal["add", "multiply", "power"],
) -> np.ndarray:
if operation == "add" and img.dtype == np.uint8:
value = np.int16(value)
Expand Down
1 change: 1 addition & 0 deletions benchmark/requirements.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
-f https://download.pytorch.org/whl/torch_stable.html

numpy
stringzilla
opencv-python-headless
tqdm
pandas
Expand Down
4 changes: 3 additions & 1 deletion benchmark/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ pytz==2024.2
setuptools==74.1.2
# via pytablewriter
six==1.16.0
# via pandas
stringzilla==5.10.0
# via python-dateutil
sympy==1.13.2
# via torch
Expand All @@ -75,4 +77,4 @@ typepy==1.3.2
typing-extensions==4.12.2
# via torch
tzdata==2024.1
# via pandas
# via -r requirements.in
Loading