Skip to content

Commit

Permalink
add tests and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed Jan 26, 2024
1 parent 1191e83 commit 640a6cf
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 39 deletions.
40 changes: 39 additions & 1 deletion coreforecast/grouped_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,16 @@ def _data_as_void_ptr(arr: np.ndarray):

def _ensure_float(x: np.ndarray) -> np.ndarray:
if x.dtype not in (np.float32, np.float64):
return x.astype(np.float32)
x = x.astype(np.float32)
return x


def _pyfloat_to_np_c(x: float, t: np.dtype) -> Union[ctypes.c_float, ctypes.c_double]:
if t == np.float32:
return ctypes.c_float(x)
return ctypes.c_double(x)


class GroupedArray:
"""Array of grouped data
Expand Down Expand Up @@ -313,3 +319,35 @@ def _exponentially_weighted_transform(
_data_as_void_ptr(out),
)
return out

def _boxcox_fit(
self, season_length: int, lower: float, upper: float, method: str
) -> np.ndarray:
out = np.empty_like(self.data, shape=(len(self), 1))
_LIB[f"{self.prefix}_BoxCoxLambda{method}"](
self._handle,
ctypes.c_int(season_length),
_pyfloat_to_np_c(lower, self.data.dtype),
_pyfloat_to_np_c(upper, self.data.dtype),
_data_as_void_ptr(out),
)
# dummy scales to be compatible with GroupedArray's ScalerTransform
return np.hstack([out, np.ones_like(out)])

def _boxcox_transform(self, stats: np.ndarray) -> np.ndarray:
out = np.empty_like(self.data)
_LIB[f"{self.prefix}_BoxCoxTransform"](
self._handle,
_data_as_void_ptr(stats),
_data_as_void_ptr(out),
)
return out

def _boxcox_inverse_transform(self, stats: np.ndarray) -> np.ndarray:
out = np.empty_like(self.data)
_LIB[f"{self.prefix}_BoxCoxInverseTransform"](
self._handle,
_data_as_void_ptr(stats),
_data_as_void_ptr(out),
)
return out
93 changes: 59 additions & 34 deletions coreforecast/scalers.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
import ctypes
from typing import Union

import numpy as np

from .grouped_array import _LIB, _data_as_void_ptr, _ensure_float, GroupedArray
from .grouped_array import (
_LIB,
_data_as_void_ptr,
_ensure_float,
_pyfloat_to_np_c,
GroupedArray,
)


__all__ = [
"boxcox_lambda",
"LocalBoxCoxScaler",
"LocalMinMaxScaler",
"LocalRobustScaler",
"LocalStandardScaler",
"boxcox",
"boxcox_lambda",
"inv_boxcox"
]


_LIB.Float32_BoxCoxLambdaGuerrero.restype = ctypes.c_float
_LIB.Float64_BoxCoxLambdaGuerrero.restype = ctypes.c_double


def _pyfloat_to_np_c(x: float, t: np.dtype) -> Union[ctypes.c_float, ctypes.c_double]:
if t == np.float32:
return ctypes.c_float(x)
return ctypes.c_double(x)


def boxcox_lambda(
x: np.ndarray,
season_length: int,
Expand All @@ -45,8 +46,6 @@ def boxcox_lambda(
float: Optimum lambda."""
if method != "guerrero":
raise NotImplementedError(f"Method {method} not implemented")
if any(x <= 0):
raise ValueError("All values in x must be positive")
if lower >= upper:
raise ValueError("lower must be less than upper")
x = _ensure_float(x)
Expand All @@ -65,6 +64,14 @@ def boxcox_lambda(


def boxcox(x: np.ndarray, lmbda: float) -> np.ndarray:
"""Apply the Box-Cox transformation
Args:
x (np.ndarray): Array with data to transform.
lmbda (float): Lambda value to use.
Returns:
np.ndarray: Array with the transformed data."""
x = _ensure_float(x)
if x.dtype == np.float32:
fn = "Float32_BoxCoxTransform"
Expand All @@ -81,6 +88,14 @@ def boxcox(x: np.ndarray, lmbda: float) -> np.ndarray:


def inv_boxcox(x: np.ndarray, lmbda: float) -> np.ndarray:
"""Invert the Box-Cox transformation
Args:
x (np.ndarray): Array with data to transform.
lmbda (float): Lambda value to use.
Returns:
np.ndarray: Array with the inverted transformation."""
x = _ensure_float(x)
if x.dtype == np.float32:
fn = "Float32_BoxCoxInverseTransform"
Expand Down Expand Up @@ -159,6 +174,14 @@ def __init__(self, scale: str):


class LocalBoxCoxScaler(_BaseLocalScaler):
"""Find the optimum lambda for the Box-Cox transformation by group and apply it
Args:
season_length (int): Length of the seasonal period.
lower (float): Lower bound for the lambda.
upper (float): Upper bound for the lambda.
method (str): Method to use. Valid options are 'guerrero'."""

def __init__(
self,
season_length: int,
Expand All @@ -174,32 +197,34 @@ def __init__(
self.method = method.capitalize()

def fit(self, ga: GroupedArray) -> "_BaseLocalScaler":
self.stats = np.empty_like(ga.data, shape=(len(ga), 1))
_LIB[f"{ga.prefix}_BoxCoxLambda{self.method}"](
ga._handle,
ctypes.c_int(self.season_length),
_pyfloat_to_np_c(self.lower, ga.data.dtype),
_pyfloat_to_np_c(self.upper, ga.data.dtype),
_data_as_void_ptr(self.stats),
"""Compute the statistics for each group.
Args:
ga (GroupedArray): Array with grouped data.
Returns:
self: The fitted scaler object."""
self.stats = ga._boxcox_fit(
self.season_length, self.lower, self.upper, self.method
)
# this is to use ones as scale. I know.
self.stats = np.hstack([self.stats, np.ones_like(self.stats)])
return self

def transform(self, ga: GroupedArray) -> np.ndarray:
out = np.empty_like(ga.data)
_LIB[f"{ga.prefix}_BoxCoxTransform"](
ga._handle,
_data_as_void_ptr(self.stats),
_data_as_void_ptr(out),
)
return out
"""Use the computed lambdas to apply the transformation.
Args:
ga (GroupedArray): Array with grouped data.
Returns:
np.ndarray: Array with the transformed data."""
return ga._boxcox_transform(self.stats)

def inverse_transform(self, ga: GroupedArray) -> np.ndarray:
out = np.empty_like(ga.data)
_LIB[f"{ga.prefix}_BoxCoxInverseTransform"](
ga._handle,
_data_as_void_ptr(self.stats),
_data_as_void_ptr(out),
)
return out
"""Use the computed lambdas to invert the transformation.
Args:
ga (GroupedArray): Array with grouped data.
Returns:
np.ndarray: Array with the inverted transformation."""
return ga._boxcox_inverse_transform(self.stats)
2 changes: 1 addition & 1 deletion include/grouped_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ template <class T> class GroupedArray {
scale = static_cast<T>(1.0);
}
for (indptr_t j = start; j < end; ++j) {
out[j] = f(data_[j], scale, offset);
out[j] = f(data_[j], offset, scale);
}
}
}
Expand Down
16 changes: 13 additions & 3 deletions src/scalers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
#include <vector>

template <typename T>
inline T CommonScalerTransform(T data, T scale, T offset) {
inline T CommonScalerTransform(T data, T offset, T scale) {
return (data - offset) / scale;
}

template <typename T>
inline T CommonScalerInverseTransform(T data, T scale, T offset) {
inline T CommonScalerInverseTransform(T data, T offset, T scale) {
return data * scale + offset;
}

Expand Down Expand Up @@ -88,6 +88,16 @@ T GuerreroCV(T lambda, const std::vector<T> &x_mean,
template <typename T>
void BoxCoxLambda_Guerrero(const T *x, int n, T *out, int period, T lower,
T upper) {
if (n <= 2 * period) {
*out = static_cast<T>(1.0);
return;
}
for (int i = 0; i < n; ++i) {
if (x[i] <= 0.0) {
lower = std::max(lower, static_cast<T>(0.0));
break;
}
}
int n_seasons = n / period;
int n_full = n_seasons * period;
// build matrix with subseries having full periods
Expand Down Expand Up @@ -136,7 +146,7 @@ template <typename T> inline T BoxCoxTransform(T x, T lambda, T /*unused*/) {
if (x > 0) {
return std::expm1(lambda * std::log(x)) / lambda;
}
return -std::expm1(-lambda * std::log(-x)) / lambda;
return (-std::exp(lambda * std::log(-x)) - 1) / lambda;
}

template <typename T>
Expand Down
22 changes: 22 additions & 0 deletions tests/test_scalers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@

from coreforecast.grouped_array import GroupedArray
from coreforecast.scalers import (
LocalBoxCoxScaler,
LocalMinMaxScaler,
LocalRobustScaler,
LocalStandardScaler,
boxcox_lambda,
boxcox,
inv_boxcox,
)


Expand Down Expand Up @@ -107,6 +111,24 @@ def test_correctness(data, indptr, scaler_name, dtype):
np.testing.assert_allclose(restored, expected_restored, atol=1e-6, rtol=1e-6)


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_boxcox_correctness(data, indptr, dtype):
ga = GroupedArray(data.astype(dtype), indptr)
sc = LocalBoxCoxScaler(season_length=10)
sc.fit(ga)
transformed = sc.transform(ga)
restored = sc.inverse_transform(GroupedArray(transformed, ga.indptr))
atol = 5e-4 if dtype == np.float32 else 1e-8
np.testing.assert_allclose(ga.data, restored, atol=atol)
lmbda = boxcox_lambda(ga[0], 10)
np.testing.assert_allclose(lmbda, sc.stats[0, 0])
first_grp = slice(indptr[0], indptr[1])
first_tfm = boxcox(ga[0], lmbda)
first_restored = inv_boxcox(first_tfm, lmbda)
np.testing.assert_allclose(first_tfm, transformed[first_grp])
np.testing.assert_allclose(first_restored, restored[first_grp])


@pytest.mark.parametrize("scaler_name", scalers)
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("lib", ["core", "utils"])
Expand Down

0 comments on commit 640a6cf

Please sign in to comment.