Skip to content

Commit

Permalink
refactor: generalize event calculation functions with improved structure
Browse files Browse the repository at this point in the history
  • Loading branch information
Beforerr committed Nov 9, 2024
1 parent e8d1f9c commit f0adc0b
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 33 deletions.
30 changes: 14 additions & 16 deletions notebooks/02_ids_properties.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,31 +97,29 @@
"outputs": [],
"source": [
"# | export\n",
"def calc_events_duration(df: pl.DataFrame, data, tr_cols=[\"tstart\", \"tstop\"], **kwargs):\n",
" # TODO: Add support for parallel processing\n",
" results = [\n",
" calc_duration(select_data_by_timerange(data, row[0], row[1]), **kwargs)\n",
" for row in df.select(tr_cols).iter_rows()\n",
" ]\n",
" return df.with_columns(**ld2dl(results)).drop_nulls()\n",
"\n",
"\n",
"def calc_events_mva_features(\n",
" df: pl.DataFrame,\n",
" data: xr.DataArray,\n",
" tr_cols=[\"t.d_start\", \"t.d_end\"],\n",
" **kwargs,\n",
"def calc_events_features(\n",
" df: pl.DataFrame, data, tr_cols=[\"tstart\", \"tstop\"], func=None, **kwargs\n",
"):\n",
" tranges = df.select(tr_cols).to_numpy()\n",
" data_ref = ray.put(data)\n",
"\n",
" @ray.remote\n",
" def remote(tr, **kwargs):\n",
" data = select_data_by_timerange(ray.get(data_ref), tr[0], tr[1])\n",
" return calc_mva_features_all(data, **kwargs)\n",
" return func(data, **kwargs)\n",
"\n",
" results = ray.get([remote.remote(tr, **kwargs) for tr in tranges])\n",
" return df.with_columns(**ld2dl(results))"
" return df.with_columns(**ld2dl(results))\n",
"\n",
"\n",
"def calc_events_duration(df, data, tr_cols=[\"tstart\", \"tstop\"], **kwargs):\n",
" return calc_events_features(\n",
" df, data, tr_cols, func=calc_duration, **kwargs\n",
" ).drop_nulls()\n",
"\n",
"\n",
"def calc_events_mva_features(df, data, tr_cols=[\"t.d_start\", \"t.d_end\"], **kwargs):\n",
" return calc_events_features(df, data, tr_cols, func=calc_mva_features_all, **kwargs)"
]
},
{
Expand Down
2 changes: 2 additions & 0 deletions src/discontinuitypy/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
'discontinuitypy/core/propeties.py'),
'discontinuitypy.core.propeties.calc_events_duration': ( 'ids_properties.html#calc_events_duration',
'discontinuitypy/core/propeties.py'),
'discontinuitypy.core.propeties.calc_events_features': ( 'ids_properties.html#calc_events_features',
'discontinuitypy/core/propeties.py'),
'discontinuitypy.core.propeties.calc_events_mva_features': ( 'ids_properties.html#calc_events_mva_features',
'discontinuitypy/core/propeties.py'),
'discontinuitypy.core.propeties.calc_events_normal_direction': ( 'ids_properties.html#calc_events_normal_direction',
Expand Down
32 changes: 15 additions & 17 deletions src/discontinuitypy/core/propeties.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

# %% auto 0
__all__ = ['get_data_at_times', 'select_data_by_timerange', 'get_candidate_data', 'calc_candidate_duration',
'calc_events_duration', 'calc_events_mva_features', 'calc_normal_direction', 'calc_events_normal_direction',
'calc_events_vec_change', 'process_events']
'calc_events_features', 'calc_events_duration', 'calc_events_mva_features', 'calc_normal_direction',
'calc_events_normal_direction', 'calc_events_vec_change', 'process_events']

# %% ../../../notebooks/02_ids_properties.ipynb 1
# | code-summary: "Import all the packages needed for the project"
Expand Down Expand Up @@ -50,32 +50,30 @@ def calc_candidate_duration(candidate, data, **kwargs):
return calc_duration(candidate_data, **kwargs)

# %% ../../../notebooks/02_ids_properties.ipynb 6
def calc_events_duration(df: pl.DataFrame, data, tr_cols=["tstart", "tstop"], **kwargs):
# TODO: Add support for parallel processing
results = [
calc_duration(select_data_by_timerange(data, row[0], row[1]), **kwargs)
for row in df.select(tr_cols).iter_rows()
]
return df.with_columns(**ld2dl(results)).drop_nulls()


def calc_events_mva_features(
df: pl.DataFrame,
data: xr.DataArray,
tr_cols=["t.d_start", "t.d_end"],
**kwargs,
def calc_events_features(
df: pl.DataFrame, data, tr_cols=["tstart", "tstop"], func=None, **kwargs
):
tranges = df.select(tr_cols).to_numpy()
data_ref = ray.put(data)

@ray.remote
def remote(tr, **kwargs):
data = select_data_by_timerange(ray.get(data_ref), tr[0], tr[1])
return calc_mva_features_all(data, **kwargs)
return func(data, **kwargs)

results = ray.get([remote.remote(tr, **kwargs) for tr in tranges])
return df.with_columns(**ld2dl(results))


def calc_events_duration(df, data, tr_cols=["tstart", "tstop"], **kwargs):
return calc_events_features(
df, data, tr_cols, func=calc_duration, **kwargs
).drop_nulls()


def calc_events_mva_features(df, data, tr_cols=["t.d_start", "t.d_end"], **kwargs):
return calc_events_features(df, data, tr_cols, func=calc_mva_features_all, **kwargs)

# %% ../../../notebooks/02_ids_properties.ipynb 8
def calc_normal_direction(v1, v2, normalize=True) -> np.ndarray:
"""
Expand Down

0 comments on commit f0adc0b

Please sign in to comment.