Skip to content

Commit

Permalink
Merge pull request #15 from NCAR/deep
Browse files Browse the repository at this point in the history
Numba versions of T-Digest quantile and cdf
  • Loading branch information
djgagne authored Apr 22, 2024
2 parents c93666e + 55f4bb1 commit 1c787f3
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 13 deletions.
123 changes: 114 additions & 9 deletions bridgescaler/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from numpy.lib.recfunctions import structured_to_unstructured, unstructured_to_structured
from copy import deepcopy
from crick import TDigest as CTDigest
from scipy.special import ndtr, ndtri
from numba_stats import norm
import pandas as pd
import xarray as xr
from pytdigest import TDigest
from functools import partial
from scipy.stats import logistic
from warnings import warn

from numba import guvectorize, float32, float64, void
CENTROID_DTYPE = np.dtype([('mean', np.float64), ('weight', np.float64)])

class DBaseScaler(object):
Expand Down Expand Up @@ -355,27 +356,131 @@ def fit_variable(var_index, xv_shared=None, compression=None, channels_last=None

def transform_variable(td_obj, xv,
min_val=0.000001, max_val=0.9999999, distribution="normal"):
x_transformed = td_obj.cdf(xv)
x_transformed[:] = np.minimum(x_transformed, max_val)
x_transformed[:] = np.maximum(x_transformed, min_val)
td_centroids = td_obj.centroids()
x_transformed = np.zeros_like(xv)
tdigest_cdf(xv, td_centroids["mean"], td_centroids["weight"],
td_obj.min(), td_obj.max(), x_transformed)
x_transformed = np.minimum(x_transformed, max_val)
x_transformed = np.maximum(x_transformed, min_val)
if distribution == "normal":
x_transformed[:] = norm.ppf(x_transformed, loc=0, scale=1)
x_transformed = ndtri(x_transformed)
elif distribution == "logistic":
x_transformed[:] = logistic.ppf(x_transformed)
x_transformed = logistic.ppf(x_transformed)
return x_transformed


def inv_transform_variable(td_obj, xv,
distribution="normal"):
x_transformed = np.zeros(xv.shape, dtype=xv.dtype)
td_centroids = td_obj.centroids()
x_transformed = np.zeros_like(xv)
if distribution == "normal":
x_transformed = norm.cdf(xv, loc=0, scale=1)
x_transformed = ndtr(xv)
elif distribution == "logistic":
x_transformed = logistic.cdf(xv)
x_transformed[:] = td_obj.quantile(x_transformed)
tdigest_quantile(xv, td_centroids["mean"], td_centroids["weight"],
td_obj.min(), td_obj.max(), x_transformed)
return x_transformed


@guvectorize([void(float64[:], float64[:], float64[:], float64, float64, float64[:]),
void(float32[:], float64[:], float64[:], float64, float64, float32[:])], "(m),(n),(n),(),()->(m)")
def tdigest_cdf(xv, cent_mean, cent_weight, t_min, t_max, out):
cent_merged_weight = np.zeros_like(cent_weight)
cumulative_weight = 0
for i in range(cent_weight.size):
cent_merged_weight[i] = cumulative_weight + cent_weight[i] / 2.0
cumulative_weight += cent_weight[i]
total_weight = cent_weight.sum()
for i, x in enumerate(xv):
if cent_mean.size == 0:
out[i] = np.nan
continue
# Single centroid
if cent_mean.size == 1:
if x < t_min:
out[i] = 0.0
elif x > t_max:
out[i] = 1.0
elif t_max - t_min < np.finfo(np.float64).eps:
out[i] = 0.5
else:
out[i] = (x - t_min) / (t_max - t_min)
continue
# Equality checks only apply if > 1 centroid
if x >= t_max:
out[i] = 1.0
continue
elif x <= t_min:
out[i] = 0.0
continue

# i_l = bisect_left_mean(T->merge_centroids, x, 0, T->ncentroids);
i_l = np.searchsorted(cent_mean, x, side="left")
if x < cent_mean[0]:
# min < x < first centroid
x0 = t_min
x1 = cent_mean[0]
dw = cent_merged_weight[0] / 2.0
out[i] = dw * (x - x0) / (x1 - x0) / total_weight
elif i_l == cent_mean.size:
# last centroid < x < max
x0 = cent_mean[i_l - 1]
x1 = t_max
dw = cent_weight[i_l - 1] / 2.0
out[i] = 1.0 - dw * (x1 - x) / (x1 - x0) / total_weight
elif cent_mean[i_l] == x:
# x is equal to one or more centroids
i_r = np.searchsorted(cent_mean, x, side="right")
out[i] = cent_merged_weight[i_r] / total_weight
else:
assert cent_mean[i_l] > x
x0 = cent_mean[i_l - 1]
x1 = cent_mean[i_l]
dw = 0.5 * (cent_weight[i_l - 1] + cent_weight[i_l])
out[i] = (cent_merged_weight[i_l - 1] + dw * (x - x0) / (x1 - x0)) / total_weight


@guvectorize([void(float64[:], float64[:], float64[:], float64, float64, float64[:]),
void(float32[:], float64[:], float64[:], float64, float64, float32[:])], "(m),(n),(n),(),()->(m)")
def tdigest_quantile(qv, cent_mean, cent_weight, t_min, t_max, out):
cent_merged_weight = np.zeros_like(cent_weight)
cumulative_weight = 0
for i in range(cent_weight.size):
cent_merged_weight[i] = cumulative_weight + cent_weight[i] / 2.0
cumulative_weight += cent_weight[i]
total_weight = cent_weight.sum()
for i, q in enumerate(qv):
if total_weight == 0:
out[i] = np.nan
continue
if q <= 0:
out[i] = t_min
continue
if q >= 1:
out[i] = t_max
continue
if cent_mean.size == 1:
out[i] = cent_mean[0]
continue

index = q * total_weight
b = np.searchsorted(cent_merged_weight, index, side="left")
if b == 0:
x0 = 0
y0 = t_min
else:
x0 = cent_merged_weight[b - 1]
y0 = cent_mean[b - 1]

if b == cent_mean.size:
x1 = total_weight
y1 = t_max
else:
x1 = cent_merged_weight[b]
y1 = cent_mean[b]
out[i] = y0 + (index - x0) * (y1 - y0) / (x1 - x0)


class DQuantileScaler(DBaseScaler):
"""
Distributed Quantile Scaler that uses the crick TDigest Cython library to compute quantiles across multiple
Expand Down
39 changes: 35 additions & 4 deletions scripts/eval_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
import xarray as xr
import os
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import psutil
from scipy.special import ndtri
from scipy.stats import norm
from memory_profiler import profile

def make_test_data():
np.random.seed(34325)
test_data = dict()
col_names = ["a", "b", "c", "d", "e"]
test_data["means"] = np.array([0, 5.3, -2.421, 21456.3, 1.e-5])
test_data["sds"] = np.array([5, 352.2, 1e-4, 20000.3, 5.3e-2])
test_data["n_examples"] = np.array([1000000, 500, 88])
test_data["n_examples"] = np.array([100000, 500, 88])
test_data["numpy_2d"] = []
test_data["numpy_4d"] = []
test_data["pandas"] = []
Expand Down Expand Up @@ -98,11 +101,39 @@ def eval_dquantile_scaler(test_data):
pool.join()
return

def small_eval(test_data):
process = psutil.Process()

# Record initial memory usage

test_data_c_first = test_data["xarray"][0].transpose("batch", "variable", "y", "x").astype("float32")
xr_dss_f = DQuantileScaler(distribution="normal", channels_last=False)
xr_dss_f.fit(test_data_c_first)
bt_memory = process.memory_info().rss
initial_memory = process.memory_info().rss
print(initial_memory/1e6)
xr_dss_f.distribution = None
test_data_c_first = xr_dss_f.transform(test_data_c_first)
test_data_c_sec = ndtri(test_data_c_first)
output_arr = np.full((1000, 50, 50), 0.5)
output_arr = norm.ppf(output_arr)
output_arr = np.full((1000, 50, 50), 0.5)
output_arr = ndtri(output_arr)
at_memory = process.memory_info().rss
print("final mem:", at_memory / 1e6)

print("mem diff:", (at_memory - bt_memory) / 1e6)
return test_data_c_first


if __name__ == "__main__":
from time import perf_counter, time
from time import time

start = time()
test_data = make_test_data()
eval_dquantile_scaler(test_data)
test_data_c_first = test_data["xarray"][0].transpose("batch", "variable", "y", "x").astype("float32")
print(test_data["xarray"][0])
test_data_c_first[:] = small_eval(test_data)
#eval_dquantile_scaler(test_data)
stop = time()
print(stop - start)
18 changes: 18 additions & 0 deletions scripts/numpy_mem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import matplotlib
matplotlib.use('agg')
import numpy as np
import matplotlib.pyplot as plt
import psutil
import xarray as xr
mem = []
def get_data():
return np.zeros((1000, 50, 50), dtype=np.float32)
data = get_data()
for i in range(data.shape[0]):
data[i] = np.random.random((50, 50))
mem.append(psutil.virtual_memory()[1])
mem.append(psutil.virtual_memory()[1])
xd = xr.DataArray(data)
mem.append(psutil.virtual_memory()[1])
plt.plot(mem)
plt.savefig("mem_profile.png", dpi=200, bbox_inches="tight")
37 changes: 37 additions & 0 deletions scripts/scipy_ppf_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from scipy.stats import norm
from scipy.special import ndtri
import numpy as np
import matplotlib.pyplot as plt
import psutil
import gc

process = psutil.Process()
n_elements = 301
mem_vals = np.zeros(n_elements)
mem_vals[0] = process.memory_info().rss / 1e6
for i in range(1, n_elements):
x = np.random.random(size=(100, 50, 50))
ppf_val = ndtri(x)
mem_vals[i] = process.memory_info().rss / 1e6
gc.collect()
plt.plot(mem_vals[1:] - mem_vals[0], label="ndtri")
mem_vals = np.zeros(n_elements)
mem_vals[0] = process.memory_info().rss / 1e6

for i in range(1, n_elements):
x = np.random.random(size=(100, 50, 50))
ppf_val = norm.ppf(x)
mem_vals[i] = process.memory_info().rss / 1e6
gc.collect()
plt.plot(mem_vals[1:] - mem_vals[0], label="norm.ppf")
mem_vals = np.zeros(n_elements)
mem_vals[0] = process.memory_info().rss / 1e6
for i in range(1, n_elements):
x = np.random.random(size=(100, 50, 50))
mem_vals[i] = process.memory_info().rss / 1e6
gc.collect()
plt.plot(mem_vals[1:] - mem_vals[0], label="control")
plt.xlabel("Iterations")
plt.ylabel("Memory usage (MB)")
plt.legend()
plt.savefig("norm_usage_tracking.png", dpi=200, bbox_inches="tight")

0 comments on commit 1c787f3

Please sign in to comment.