-
-
Notifications
You must be signed in to change notification settings - Fork 271
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add cadCAD_tools as a submodule (#320)
* init * add profile run & types * add visualizations * relative viz * update types * update types * add functions for cleaning initial state & params * assign params can now accept lists/sets * add cartesian sweep * add generic suf f * fix for running under cadcad 0.4.18 * minor changes * add support for custom exec_mode * Handle lambda parameters for when assign_parameters is True. * Handle functions as well as lambda. * move cadCAD_tools to cadCAD.tools * add modifications + nb --------- Co-authored-by: Shawn Anderson <[email protected]>
- Loading branch information
1 parent
676f0e0
commit ffe3b65
Showing
10 changed files
with
19,511 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from cadCAD.tools.execution import easy_run | ||
from cadCAD.tools.profiling import profile_run | ||
from cadCAD.tools.utils import generic_suf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from cadCAD.tools.execution.easy_run import easy_run |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import inspect | ||
import types | ||
from typing import Dict, Union | ||
|
||
import pandas as pd # type: ignore | ||
from cadCAD.configuration import Experiment | ||
from cadCAD.configuration.utils import config_sim | ||
from cadCAD.engine import ExecutionContext, ExecutionMode, Executor | ||
|
||
|
||
def describe_or_return(v: object) -> object: | ||
""" | ||
Thanks @LinuxIsCool! | ||
""" | ||
if isinstance(v, types.FunctionType): | ||
return f'function: {v.__name__}' | ||
elif isinstance(v, types.LambdaType) and v.__name__ == '<lambda>': | ||
return f'lambda: {inspect.signature(v)}' | ||
else: | ||
return v | ||
|
||
|
||
def select_M_dict(M_dict: Dict[str, object], keys: set) -> Dict[str, object]: | ||
""" | ||
Thanks @LinuxIsCool! | ||
""" | ||
return {k: describe_or_return(v) for k, v in M_dict.items() if k in keys} | ||
|
||
|
||
def select_config_M_dict(configs: list, i: int, keys: set) -> Dict[str, object]: | ||
return select_M_dict(configs[i].sim_config['M'], keys) | ||
|
||
|
||
def easy_run( | ||
state_variables, | ||
params, | ||
psubs, | ||
N_timesteps, | ||
N_samples, | ||
use_label=False, | ||
assign_params: Union[bool, set] = True, | ||
drop_substeps=True, | ||
exec_mode='local', | ||
) -> pd.DataFrame: | ||
""" | ||
Run cadCAD simulations without headaches. | ||
""" | ||
|
||
# Set-up sim_config | ||
simulation_parameters = {'N': N_samples, 'T': range(N_timesteps), 'M': params} | ||
sim_config = config_sim(simulation_parameters) # type: ignore | ||
|
||
# Create a new experiment | ||
exp = Experiment() | ||
exp.append_configs( | ||
sim_configs=sim_config, | ||
initial_state=state_variables, | ||
partial_state_update_blocks=psubs, | ||
) | ||
configs = exp.configs | ||
|
||
# Set-up cadCAD executor | ||
if exec_mode == 'local': | ||
_exec_mode = ExecutionMode().local_mode | ||
elif exec_mode == 'single': | ||
_exec_mode = ExecutionMode().single_mode | ||
exec_context = ExecutionContext(_exec_mode) | ||
executor = Executor(exec_context=exec_context, configs=configs) | ||
|
||
# Execute the cadCAD experiment | ||
(records, tensor_field, _) = executor.execute() | ||
|
||
# Parse the output as a pandas DataFrame | ||
df = pd.DataFrame(records) | ||
|
||
if drop_substeps == True: | ||
# Drop all intermediate substeps | ||
first_ind = (df.substep == 0) & (df.timestep == 0) | ||
last_ind = df.substep == max(df.substep) | ||
inds_to_drop = first_ind | last_ind | ||
df = df.loc[inds_to_drop].drop(columns=['substep']) | ||
else: | ||
pass | ||
|
||
if assign_params == False: | ||
pass | ||
else: | ||
M_dict = configs[0].sim_config['M'] | ||
params_set = set(M_dict.keys()) | ||
|
||
if assign_params == True: | ||
pass | ||
else: | ||
params_set &= assign_params # type: ignore | ||
|
||
# Logic for getting the assign params criteria | ||
if type(assign_params) is list: | ||
selected_params = set(assign_params) & params_set # type: ignore | ||
elif type(assign_params) is set: | ||
selected_params = assign_params & params_set | ||
else: | ||
selected_params = params_set | ||
|
||
# Attribute parameters to each row | ||
df = df.assign(**select_config_M_dict(configs, 0, selected_params)) | ||
for i, (_, n_df) in enumerate(df.groupby(['simulation', 'subset', 'run'])): | ||
df.loc[n_df.index] = n_df.assign( | ||
**select_config_M_dict(configs, i, selected_params) | ||
) | ||
|
||
# Based on Vitor Marthendal (@marthendalnunes) snippet | ||
if use_label == True: | ||
psub_map = { | ||
order + 1: psub.get('label', '') for (order, psub) in enumerate(psubs) | ||
} | ||
psub_map[0] = 'Initial State' | ||
df['substep_label'] = df.substep.map(psub_map) | ||
|
||
return df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from typing import Mapping | ||
from cadCAD.tools.execution import easy_run | ||
from pandas import DataFrame # type: ignore | ||
from cadCAD.types import * | ||
from cadCAD.tools.types import * | ||
from itertools import product | ||
from dataclasses import dataclass | ||
|
||
|
||
def sweep_cartesian_product(sweep_params: SweepableParameters) -> SweepableParameters: | ||
""" | ||
Makes a cartesian product from dictionary values. | ||
This is useful for plugging inside the sys_params dict, like: | ||
```python | ||
sweep_params = {'a': [0.1, 0.2], 'b': [1, 2]} | ||
product_sweep | ||
sys_params = {**cartesian_product_sweep(sweep_params), | ||
'c': [0.1]} | ||
``` | ||
Usage: | ||
>>> sweep_params = {'a': [0.1, 0.2], 'b': [1, 2]} | ||
>>> cartesian_product_sweep(sweep_params) | ||
{'a': [0.1, 0.1, 0.2, 0.2], 'b': [1, 2, 1, 2]} | ||
""" | ||
cartesian_product = product(*sweep_params.values()) | ||
transpose_cartesian_product = zip(*cartesian_product) | ||
zipped_sweep_params = zip(sweep_params.keys(), transpose_cartesian_product) | ||
sweep_dict = dict(zipped_sweep_params) | ||
sweep_dict = {k: tuple(v) for k, v in sweep_dict.items()} | ||
return sweep_dict | ||
|
||
|
||
def prepare_params(params: SystemParameters, | ||
cartesian_sweep: bool = False) -> Mapping[str, List[object]]: | ||
simple_params = {k: [v.value] | ||
for k, v in params.items() | ||
if type(v) is Param} | ||
|
||
sweep_params: SweepableParameters = {k: v.value | ||
for k, v in params.items() | ||
if type(v) is ParamSweep} | ||
if cartesian_sweep is True: | ||
sweep_params = sweep_cartesian_product(sweep_params) | ||
else: | ||
pass | ||
|
||
cleaned_params = {**simple_params, **sweep_params} | ||
return cleaned_params | ||
|
||
|
||
def prepare_state(state: InitialState) -> Mapping[str, object]: | ||
cleaned_state = {k: v.value | ||
for k, v in state.items()} | ||
return cleaned_state | ||
|
||
|
||
@dataclass | ||
class ConfigurationWrapper(): | ||
initial_state: InitialState | ||
params: SystemParameters | ||
timestep_block: StateUpdateBlocks | ||
timesteps: int | ||
samples: int | ||
|
||
def run(self, *args, **kwargs) -> DataFrame: | ||
output = easy_run(prepare_state(self.initial_state), | ||
prepare_params(self.params), | ||
self.timestep_block, | ||
self.timesteps, | ||
self.samples, | ||
*args, | ||
**kwargs) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from cadCAD.tools.profiling.profile_run import profile_run |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from typing import Dict | ||
from cadCAD.tools import easy_run | ||
from cadCAD.types import StateUpdateBlocks, Parameters, State, StateUpdateBlock | ||
from time import time | ||
import pandas as pd # type: ignore | ||
|
||
|
||
def MEASURE_TIME_SUF(p, s, h, v, p_i): return ('run_time', time()) | ||
|
||
|
||
MEASURING_BLOCK: StateUpdateBlock = { | ||
'label': 'Time Measure', | ||
'policies': {}, | ||
'variables': { | ||
'run_time': MEASURE_TIME_SUF | ||
} | ||
} # type: ignore | ||
|
||
|
||
def profile_psubs(psubs: StateUpdateBlocks, profile_substeps=True) -> StateUpdateBlocks: | ||
""" | ||
Updates a TimestepBlock so that a time measuring function is added. | ||
""" | ||
new_timestep_block: StateUpdateBlocks = [] | ||
new_timestep_block.append(MEASURING_BLOCK) | ||
if profile_substeps is True: | ||
for psub in psubs: | ||
new_timestep_block.append(psub) | ||
new_timestep_block.append(MEASURING_BLOCK) | ||
else: | ||
pass | ||
return new_timestep_block | ||
|
||
|
||
def profile_run(state_variables: State, | ||
params: Parameters, | ||
psubs: StateUpdateBlocks, | ||
*args, | ||
profile_substeps=True, | ||
**kwargs) -> pd.DataFrame: | ||
|
||
if profile_substeps is True: | ||
kwargs.update(drop_substeps=False) | ||
|
||
new_psubs = profile_psubs(psubs, profile_substeps) | ||
state_variables.update({'run_time': None}) | ||
|
||
return easy_run(state_variables, | ||
params, | ||
new_psubs, | ||
*args, | ||
**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from tqdm.auto import tqdm | ||
import pandas as pd | ||
import plotly.express as px | ||
import numpy as np | ||
|
||
|
||
def visualize_elapsed_time_per_ts(df: pd.DataFrame, relative=False) -> None: | ||
indexes = ['simulation', 'run', 'timestep', 'substep'] | ||
|
||
z_df = df.set_index(indexes) | ||
first_time = z_df.query( | ||
'timestep == 1 & substep == 1').reset_index([-1, -2]).run_time | ||
s = (z_df.run_time - first_time) | ||
s.name = 'time_since_start' | ||
|
||
z_df = z_df.join(s) | ||
s = z_df.groupby(indexes[:-1]).time_since_start.max() | ||
|
||
fig_df = s.reset_index() | ||
if relative is True: | ||
s = fig_df.groupby(indexes[:-2]).time_since_start.diff() | ||
s.name = 'psub_duration' | ||
fig_df = fig_df.join(s) | ||
|
||
y_col = 'psub_duration' | ||
else: | ||
y_col = 'time_since_start' | ||
|
||
fig = px.box(fig_df, | ||
x='timestep', | ||
y=y_col) | ||
|
||
return fig | ||
|
||
|
||
def visualize_substep_impact(df: pd.DataFrame, relative=True, **kwargs) -> None: | ||
indexes = ['simulation', 'run', 'timestep', 'substep'] | ||
|
||
new_df = df.copy() | ||
new_df = new_df.assign(psub_time=np.nan).set_index(indexes) | ||
|
||
# Calculate the run time associated with PSUBs | ||
for ind, gg_df in tqdm(df.query('substep > 0').groupby(indexes[:-1])): | ||
g_df = gg_df.reset_index() | ||
N_rows = len(g_df) | ||
substep_rows = list(range(N_rows))[1:-1:2] | ||
|
||
for substep_row in substep_rows: | ||
t1 = g_df.run_time[substep_row - 1] | ||
t2 = g_df.run_time[substep_row + 1] | ||
dt = t2 - t1 | ||
g_df.loc[substep_row, 'psub_time'] = dt | ||
g_df = g_df.set_index(indexes) | ||
new_df.loc[g_df.index, 'psub_time'] = g_df.psub_time | ||
|
||
fig_df = new_df.reset_index().dropna(subset=['psub_time']) | ||
|
||
|
||
if 'substep_label' in fig_df.columns: | ||
x_col = 'substep_label' | ||
else: | ||
x_col = 'substep' | ||
fig_df[x_col] = fig_df[x_col] / 2 | ||
|
||
if relative is True: | ||
fig_df = fig_df.assign(relative_psub_time=fig_df.groupby(indexes[:-1]).psub_time.apply(lambda x: x / x.sum())) | ||
y_col = 'relative_psub_time' | ||
else: | ||
y_col = 'psub_time' | ||
|
||
inds = fig_df[y_col] < fig_df[y_col].quantile(0.95) | ||
inds &= fig_df[y_col] > fig_df[y_col].quantile(0.05) | ||
|
||
fig = px.box(fig_df[inds], | ||
x=x_col, | ||
y=y_col, | ||
**kwargs) | ||
|
||
return fig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from typing import NamedTuple, Tuple, Dict, Union, List | ||
|
||
class InitialValue(NamedTuple): | ||
value: object | ||
type: type | ||
|
||
|
||
class Param(NamedTuple): | ||
value: object | ||
type: type | ||
|
||
|
||
class ParamSweep(NamedTuple): | ||
value: List[object] | ||
type: type | ||
|
||
InitialState = Dict[str, InitialValue] | ||
SystemParameters = Dict[str, Union[Param, ParamSweep]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from cadCAD.types import * | ||
|
||
def generic_suf(variable: str, | ||
signal: str='') -> StateUpdateFunction: | ||
""" | ||
Generate a State Update Function that assigns the signal value to the | ||
given variable. By default, the signal has the same identifier as the | ||
variable. | ||
""" | ||
if signal is '': | ||
signal = variable | ||
else: | ||
pass | ||
|
||
def suf(_1, _2, _3, _4, signals: PolicyOutput) -> StateUpdateTuple: | ||
return (variable, signals[signal]) | ||
return suf |
Oops, something went wrong.