Skip to content

Commit

Permalink
move utilsforecast imports to performance test
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed Oct 12, 2023
1 parent 80f6d50 commit be8580e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ wheel.py-api = "cp37"

[tool.cibuildwheel]
test-requires = "pytest"
test-command = "pytest {project}/tests"
test-command = "pytest {project}/tests/test_correctness.py"
25 changes: 13 additions & 12 deletions tests/test_scalers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
import numpy as np
import pytest
from utilsforecast.target_transforms import (
LocalStandardScaler as UtilsStandardScaler,
LocalMinMaxScaler as UtilsMinMaxScaler,
LocalRobustScaler as UtilsRobustScaler,
)

from coreforecast.grouped_array import GroupedArray
from coreforecast.scalers import (
LocalMinMaxScaler,
Expand Down Expand Up @@ -67,12 +61,6 @@ def scaler_inverse_transform(x, stats):
"robust-iqr": LocalRobustScaler("iqr"),
"robust-mad": LocalRobustScaler("mad"),
}
scaler2utils = {
"standard": UtilsStandardScaler(),
"minmax": UtilsMinMaxScaler(),
"robust-iqr": UtilsRobustScaler("iqr"),
"robust-mad": UtilsRobustScaler("mad"),
}
scalers = list(scaler2fns.keys())


Expand Down Expand Up @@ -121,6 +109,19 @@ def test_correctness(data, indptr, scaler_name, dtype):
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("lib", ["core", "utils"])
def test_performance(benchmark, data, indptr, scaler_name, dtype, lib):
from utilsforecast.target_transforms import (
LocalStandardScaler as UtilsStandardScaler,
LocalMinMaxScaler as UtilsMinMaxScaler,
LocalRobustScaler as UtilsRobustScaler,
)

scaler2utils = {
"standard": UtilsStandardScaler(),
"minmax": UtilsMinMaxScaler(),
"robust-iqr": UtilsRobustScaler("iqr"),
"robust-mad": UtilsRobustScaler("mad"),
}

ga = GroupedArray(data.astype(dtype), indptr)
if lib == "core":
scaler = scaler2core[scaler_name]
Expand Down

0 comments on commit be8580e

Please sign in to comment.