Skip to content

Commit

Permalink
pass transforms around instead of making duplicates (#416)
Browse files Browse the repository at this point in the history
Summary:

Instead of creating duplicate transforms whenever we need one, we create a single transform from the config and initialize the wrapped model and wrapped generators with that one transform. This passes the same transform object around and allows the transformations to learn parameters and still be synced up across wrapped objects.

Differential Revision: D65155103
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Nov 8, 2024
1 parent 3b6bae0 commit 89f8bce
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 8 deletions.
8 changes: 6 additions & 2 deletions aepsych/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,11 +468,15 @@ def from_config(cls, config: Config, name: str) -> Strategy:
stimuli_per_trial = config.getint(name, "stimuli_per_trial", fallback=1)
outcome_types = config.getlist(name, "outcome_types", element_type=str)

generator = ParameterTransformedGenerator.from_config(config, name)
generator = ParameterTransformedGenerator.from_config(
config, name, options={"transforms": transforms}
)

model_cls = config.getobj(name, "model", fallback=None)
if model_cls is not None:
model = ParameterTransformedModel.from_config(config, name)
model = ParameterTransformedModel.from_config(
config, name, options={"transforms": transforms}
)
use_gpu_modeling = config.getboolean(
model._base_obj.__class__.__name__, "use_gpu", fallback=False
)
Expand Down
26 changes: 20 additions & 6 deletions aepsych/transforms/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def __init__(
being set in init (e.g., `Wrapper(Type, lb=lb, ub=ub)`, `lb` and `ub`
should be in the raw parameter space.
The object's name will be ParameterTransformed<Generator.__name__>.
Args:
model (Type | AEPsychGenerator): Generator to wrap, this could either be a
completely initialized generator or just the generator class. An
Expand Down Expand Up @@ -277,17 +279,22 @@ def get_config_options(
"""
if options is None:
options = {}
options["transforms"] = ParameterTransforms.from_config(config)
else:
# Check if there's a transform already if so save it to it persists over copying
if "transforms" in options:
transforms = options["transforms"]
else:
transforms = ParameterTransforms.from_config(config)

options = deepcopy(options)
options["transforms"] = transforms

if name is None:
raise ValueError("name of strategy must be set to initialize a generator")
else:
gen_cls = config.getobj(name, "generator")

if "transforms" not in options:
options["transforms"] = ParameterTransforms.from_config(config)

# Transform config
transformed_config = transform_options(config, options["transforms"])

Expand Down Expand Up @@ -323,6 +330,8 @@ def __init__(
being set in init (e.g., `Wrapper(Type, lb=lb, ub=ub)`, `lb` and `ub`
should be in the raw parameter space.
The object's name will be ParameterTransformed<Model.__name__>.
Args:
model (Type | ModelProtocol): Model to wrap, this could either be a
completely initialized model or just the model class. An initialized
Expand Down Expand Up @@ -669,17 +678,22 @@ def get_config_options(
"""
if options is None:
options = {}
options["transforms"] = ParameterTransforms.from_config(config)
else:
# Check if there's a transform already if so save it to it persists over copying
if "transforms" in options:
transforms = options["transforms"]
else:
transforms = ParameterTransforms.from_config(config)

options = deepcopy(options)
options["transforms"] = transforms

if name is None:
raise ValueError("name of strategy must be set to initialize a model")
else:
model_cls = config.getobj(name, "model")

if "transforms" not in options:
options["transforms"] = ParameterTransforms.from_config(config)

# Transform config
transformed_config = transform_options(config, options["transforms"])

Expand Down
8 changes: 8 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ def test_model_init_equivalent(self):

def test_transforms_in_strategy(self):
for _strat in self.strat.strat_list:
# Check if the same transform is passed around everywhere
self.assertTrue(id(_strat.transforms) == id(_strat.generator.transforms))
if _strat.model is not None:
self.assertTrue(
id(_strat.generator.transforms) == id(_strat.model.transforms)
)

# Check all the transform bits are the same
for strat_transform, gen_transform in zip(
_strat.transforms.items(), _strat.generator.transforms.items()
):
Expand Down

0 comments on commit 89f8bce

Please sign in to comment.