Skip to content

Commit

Permalink
Merge pull request #54 from ggmarshall/main
Browse files Browse the repository at this point in the history
FFT factory funcs fix and bugfix for interpolated time point thresh
  • Loading branch information
iguinn authored Feb 20, 2024
2 parents e7fb840 + 293ef8c commit 00318ed
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 16 deletions.
15 changes: 3 additions & 12 deletions src/dspeed/processing_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,20 +1406,11 @@ def __init__(
)

# reshape just in case there are some missing dimensions
arshape = list(param.shape)
arshape = list(param.buffer.shape)
for idim in range(-1, -1 - len(shape), -1):
if (
len(arshape) < len(shape) + 1 + idim
or arshape[idim] != shape[idim]
):
if len(arshape) < -idim or arshape[idim] != shape[idim]:
arshape.insert(len(arshape) + idim + 1, 1)

if param.is_const:
param = param.get_buffer(grid).reshape(arshape)
else:
param = param.get_buffer(grid).reshape(
tuple([self.proc_chain._block_width] + arshape)
)
param = param.get_buffer(grid).reshape(arshape)

elif isinstance(param, str):
# Convert string into integer buffer if appropriate
Expand Down
6 changes: 3 additions & 3 deletions src/dspeed/processors/time_point_thresh.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,19 @@ def interpolated_time_point_thresh(
for i in range(int(t_start), len(w_in) - 1, 1):
if w_in[i] <= a_threshold < w_in[i + 1]:
i_cross = i
return
break
else:
for i in range(int(t_start), 1, -1):
if w_in[i - 1] < a_threshold <= w_in[i]:
i_cross = i - 1
return
break

if i_cross == -1:
return

if mode_in == ord("i"): # return index before crossing
t_out[0] = i_cross
elif mode_in == ord("f"): # return index before crossing
elif mode_in == ord("f"): # return index after crossing
t_out[0] = i_cross + 1
elif mode_in == ord("c"): # return index before crossing
t_out[0] = i_cross
Expand Down
9 changes: 8 additions & 1 deletion tests/configs/icpc-dsp-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
"cuspEmax",
"zacEmax",
"zacEftp",
"cuspEftp"
"cuspEftp",
"wf_psd"
],
"processors": {
"tp_min, tp_max, wf_min, wf_max": {
Expand Down Expand Up @@ -357,6 +358,12 @@
"module": "numpy",
"args": ["tp_0_est", "tp_aoe_max/16", "tp_aoe_samp"],
"unit": "ns"
},
"wf_psd": {
"function": "psd",
"module": "dspeed.processors",
"args": ["wf_blsub", "wf_psd"],
"init_args": ["wf_blsub", "wf_psd"]
}
}
}
184 changes: 184 additions & 0 deletions tests/processors/test_time_point_thresh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import numpy as np
import pytest

from dspeed.errors import DSPFatal
from dspeed.processors import interpolated_time_point_thresh, time_point_thresh


def test_time_point_thresh(compare_numba_vs_python):
"""Testing function for the time_point_thresh processor."""

# test for nan if w_in has a nan
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
w_in[4] = np.nan
assert np.isnan(
compare_numba_vs_python(
time_point_thresh,
w_in,
1,
11,
0,
)
)

# test for nan if nan is passed to a_threshold
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert np.isnan(
compare_numba_vs_python(
time_point_thresh,
w_in,
np.nan,
11,
0,
)
)

# test for nan if nan is passed to t_start
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert np.isnan(
compare_numba_vs_python(
time_point_thresh,
w_in,
1,
np.nan,
0,
)
)

# test for nan if nan is passed to walk_forward
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert np.isnan(
compare_numba_vs_python(
time_point_thresh,
w_in,
1,
11,
np.nan,
)
)

# test for error if t_start non integer
with pytest.raises(DSPFatal):
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
time_point_thresh(w_in, 1, 10.5, 0, np.array([0.0]))

# test for error if walk_forward non integer
with pytest.raises(DSPFatal):
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
time_point_thresh(w_in, 1, 11, 0.5, np.array([0.0]))

# test for error if t_start out of range
with pytest.raises(DSPFatal):
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
time_point_thresh(w_in, 1, 12, 0, np.array([0.0]))

# test walk backward
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert compare_numba_vs_python(time_point_thresh, w_in, 1, 11, 0) == 8.0

# test walk forward
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert compare_numba_vs_python(time_point_thresh, w_in, 3, 0, 1) == 4.0


def test_interpolated_time_point_thresh(compare_numba_vs_python):
"""Testing function for the interpolated_time_point_thresh processor."""

# test for nan if w_in has a nan
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
w_in[4] = np.nan
assert np.isnan(
compare_numba_vs_python(interpolated_time_point_thresh, w_in, 1.0, 11.0, 0, 105)
)

# test for nan if nan is passed to a_threshold
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert np.isnan(
compare_numba_vs_python(
interpolated_time_point_thresh, w_in, np.nan, 11.0, 0, 105
)
)

# test for nan if nan is passed to t_start
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert np.isnan(
compare_numba_vs_python(
interpolated_time_point_thresh, w_in, 1.0, np.nan, 0, 105
)
)

# test for nan if t_start out of range
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert np.isnan(
compare_numba_vs_python(interpolated_time_point_thresh, w_in, 1.0, 12, 0, 105)
)

# test walk backward mode 'i'
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert (
compare_numba_vs_python(interpolated_time_point_thresh, w_in, 1, 11, 0, 105)
== 7.0
)

# test walk forward mode 'i'
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert (
compare_numba_vs_python(interpolated_time_point_thresh, w_in, 3, 0, 1, 105)
== 4.0
)

# test walk backward mode 'f'
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert (
compare_numba_vs_python(interpolated_time_point_thresh, w_in, 1, 11, 0, 102)
== 8.0
)

# test walk forward mode 'f'
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert (
compare_numba_vs_python(interpolated_time_point_thresh, w_in, 3, 0, 1, 102)
== 5.0
)

# test walk backward mode 'f'
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert (
compare_numba_vs_python(interpolated_time_point_thresh, w_in, 1, 11, 0, 99)
== 7.0
)

# test walk forward mode 'f'
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert (
compare_numba_vs_python(interpolated_time_point_thresh, w_in, 3, 0, 1, 99)
== 4.0
)

# test walk backward mode 'n'
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert (
compare_numba_vs_python(interpolated_time_point_thresh, w_in, 1, 11, 0, 110)
== 7.5
)

# test walk forward mode 'n'
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert (
compare_numba_vs_python(interpolated_time_point_thresh, w_in, 3, 0, 1, 110)
== 4.5
)

# test walk backward mode 'l'
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert (
compare_numba_vs_python(interpolated_time_point_thresh, w_in, 1.5, 11, 0, 108)
== 8.5
)

# test walk forward mode 'l'
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert (
compare_numba_vs_python(interpolated_time_point_thresh, w_in, 3.5, 0, 1, 108)
== 4.5
)

0 comments on commit 00318ed

Please sign in to comment.