Skip to content

Commit

Permalink
add templates to support float64 (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Oct 12, 2023
1 parent 4ef1c96 commit c3f4fc7
Show file tree
Hide file tree
Showing 11 changed files with 303 additions and 122 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/build-python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ jobs:
CIBW_TEST_SKIP: '*-macosx_arm64'

- uses: actions/upload-artifact@v3
if: github.event_name == 'push'
with:
path: ./wheelhouse/*.whl

build_sdist:
name: Build sdist
if: github.event_name == 'push'
runs-on: ubuntu-latest

steps:
Expand All @@ -59,7 +61,8 @@ jobs:
with:
path: ./dist/*.tar.gz

upload_all:
upload_artifacts:
name: 'Upload to PyPI'
if: github.repository == 'Nixtla/coreforecast' && startsWith(github.ref, 'refs/tags/v')
needs: [build_wheels, build_sdist]
runs-on: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ jobs:
run: pip install --no-build-isolation -v .

- name: Run tests
run: pytest
run: pytest --benchmark-group-by=param:scaler_name --benchmark-sort=fullname
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ if(NOT CMAKE_BUILD_TYPE)
endif()

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_FLAGS_DEBUG "-g")

if(APPLE)
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
Expand Down
41 changes: 30 additions & 11 deletions coreforecast/grouped_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from importlib.resources import files


DTYPE_FLOAT32 = ctypes.c_int(0)
DTYPE_FLOAT64 = ctypes.c_int(1)

if platform.system() in ("Windows", "Microsoft"):
prefix = "Release"
extension = "dll"
Expand All @@ -22,52 +25,68 @@
)


def _data_as_ptr(arr: np.ndarray, dtype):
return arr.ctypes.data_as(ctypes.POINTER(dtype))
def _data_as_void_ptr(arr: np.ndarray):
return arr.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p))


class GroupedArray:
def __init__(self, data: np.ndarray, indptr: np.ndarray):
if data.dtype == np.float32:
self.dtype = DTYPE_FLOAT32
elif data.dtype == np.float64:
self.dtype = DTYPE_FLOAT64
else:
self.dtype = DTYPE_FLOAT32
data = data.astype(np.float32)
self.data = data
if indptr.dtype != np.int32:
indptr = indptr.astype(np.int32)
self.indptr = indptr
self._handle = ctypes.c_void_p()
_LIB.GroupedArray_CreateFromArrays(
_data_as_ptr(data, ctypes.c_float),
_data_as_void_ptr(data),
ctypes.c_int32(data.size),
_data_as_ptr(indptr, ctypes.c_int32),
indptr.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ctypes.c_int32(indptr.size),
self.dtype,
ctypes.byref(self._handle),
)

def __del__(self):
_LIB.GroupedArray_Delete(self._handle)
_LIB.GroupedArray_Delete(self._handle, self.dtype)

def __len__(self):
return self.indptr.size - 1

def __getitem__(self, i):
return self.data[self.indptr[i] : self.indptr[i + 1]]

def scaler_fit(self, stats_fn_name: str) -> np.ndarray:
stats = np.empty((len(self), 2), dtype=np.float64)
stats = np.full((len(self), 2), np.nan, dtype=self.data.dtype)
stats_fn = _LIB[stats_fn_name]
stats_fn(
self._handle,
_data_as_ptr(stats, ctypes.c_double),
self.dtype,
_data_as_void_ptr(stats),
)
return stats

def scaler_transform(self, stats: np.ndarray) -> np.ndarray:
out = np.full_like(self.data, np.nan)
_LIB.GroupedArray_ScalerTransform(
self._handle,
_data_as_ptr(stats, ctypes.c_double),
_data_as_ptr(out, ctypes.c_float),
_data_as_void_ptr(stats),
self.dtype,
_data_as_void_ptr(out),
)
return out

def scaler_inverse_transform(self, stats: np.ndarray) -> np.ndarray:
out = np.empty_like(self.data)
_LIB.GroupedArray_ScalerInverseTransform(
self._handle,
_data_as_ptr(stats, ctypes.c_double),
_data_as_ptr(out, ctypes.c_float),
_data_as_void_ptr(stats),
self.dtype,
_data_as_void_ptr(out),
)
return out
12 changes: 6 additions & 6 deletions coreforecast/scalers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ class LocalStandardScaler(BaseLocalScaler):
stats_fn_name = "GroupedArray_StandardScalerStats"


class LocalRobustScalerIqr(BaseLocalScaler):
stats_fn_name = "GroupedArray_RobustScalerIqrStats"


class LocalRobustScalerMad(BaseLocalScaler):
stats_fn_name = "GroupedArray_RobustScalerMadStats"
class LocalRobustScaler(BaseLocalScaler):
def __init__(self, scale: str):
if scale == "iqr":
self.stats_fn_name = "GroupedArray_RobustScalerIqrStats"
else:
self.stats_fn_name = "GroupedArray_RobustScalerMadStats"
3 changes: 3 additions & 0 deletions dev_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ dependencies:
- cmake
- mypy
- ninja
- numba
- numpy
- pytest
- pytest-benchmark
- scikit-build-core
- utilsforecast>=0.0.7
3 changes: 3 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ channels:
dependencies:
- cmake
- ninja
- numba
- numpy
- pytest
- pytest-benchmark
- scikit-build-core
- utilsforecast>=0.0.7
22 changes: 14 additions & 8 deletions include/coreforecast.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,34 @@

typedef void *GroupedArrayHandle;

#define DTYPE_FLOAT32 (0)
#define DTYPE_FLOAT64 (1)

extern "C" {
DLL_EXPORT int GroupedArray_CreateFromArrays(float *data, int32_t n_data,
DLL_EXPORT int GroupedArray_CreateFromArrays(const void *data, int32_t n_data,
int32_t *indptr, int32_t n_groups,
int data_type,
GroupedArrayHandle *out);

DLL_EXPORT int GroupedArray_Delete(GroupedArrayHandle handle);
DLL_EXPORT int GroupedArray_Delete(GroupedArrayHandle handle, int data_type);

DLL_EXPORT int GroupedArray_MinMaxScalerStats(GroupedArrayHandle handle,
double *out);
int data_type, void *out);

DLL_EXPORT int GroupedArray_StandardScalerStats(GroupedArrayHandle handle,
double *out);
int data_type, void *out);

DLL_EXPORT int GroupedArray_RobustScalerIqrStats(GroupedArrayHandle handle,
double *out);
int data_type, void *out);

DLL_EXPORT int GroupedArray_RobustScalerMadStats(GroupedArrayHandle handle,
double *out);
int data_type, void *out);

DLL_EXPORT int GroupedArray_ScalerTransform(GroupedArrayHandle handle,
double *stats, float *out);
const void *stats, int data_type,
void *out);

DLL_EXPORT int GroupedArray_ScalerInverseTransform(GroupedArrayHandle handle,
double *stats, float *out);
const void *stats,
int data_type, void *out);
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ wheel.py-api = "cp37"

[tool.cibuildwheel]
test-requires = "pytest"
test-command = "pytest {project}/tests"
test-command = "pytest {project}/tests -k correct"
Loading

0 comments on commit c3f4fc7

Please sign in to comment.