From b300da4fc82ee24162af5b9b03c9c1fc14409ab8 Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Wed, 30 Oct 2024 17:51:45 -0700 Subject: [PATCH] pass transforms around instead of making duplicates (#416) Summary: Pull Request resolved: https://github.com/facebookresearch/aepsych/pull/416 Instead of creating duplicate transforms whenever we need one, we create a single transform from the config and initialize the wrapped model and wrapped generators with that one transform. This passes the same transform object around and allows the transformations to learn parameters and still be synced up across wrapped objects. Differential Revision: D65155103 --- aepsych/strategy.py | 4 ++-- aepsych/transforms/parameters.py | 19 +++++++++++++------ tests/test_transforms.py | 8 ++++++++ 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/aepsych/strategy.py b/aepsych/strategy.py index 5e9495ae3..7c9bcab26 100644 --- a/aepsych/strategy.py +++ b/aepsych/strategy.py @@ -410,9 +410,9 @@ def from_config(cls, config: Config, name: str) -> Strategy: stimuli_per_trial = config.getint(name, "stimuli_per_trial", fallback=1) outcome_types = config.getlist(name, "outcome_types", element_type=str) - generator = GeneratorWrapper.from_config(name, config) + generator = GeneratorWrapper.from_config(name, config, transforms) - model = ModelWrapper.from_config(name, config) + model = ModelWrapper.from_config(name, config, transforms) acqf_cls = config.getobj(name, "acqf", fallback=None) if acqf_cls is not None and hasattr(generator, "acqf"): diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index 195437fe2..2b4184b3e 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -173,12 +173,14 @@ def from_config( cls, name: str, config: Config, + transforms: Optional[ChainedInputTransform] = None, ): gen_cls = config.getobj(name, "generator", fallback=SobolGenerator) - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) # We need transformed values from config but we don't want to edit config - transformed_config = transform_options(config) + transformed_config = transform_options(config, transforms) gen = gen_cls.from_config(transformed_config) @@ -298,27 +300,32 @@ def from_config( cls, name: str, config: Config, + transforms: Optional[ChainedInputTransform] = None, ): # We don't always have models model_cls = config.getobj(name, "model", fallback=None) if model_cls is None: return None - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) # Need transformed values - transformed_config = transform_options(config) + transformed_config = transform_options(config, transforms) model = model_cls.from_config(transformed_config) return cls(model, transforms) -def transform_options(config: Config) -> Config: +def transform_options( + config: Config, transforms: Optional[ChainedInputTransform] = None +) -> Config: """ Return a copy of the config with the options transformed. The config """ - transforms = ParameterTransforms.from_config(config) + if transforms is None: + transforms = ParameterTransforms.from_config(config) configClone = deepcopy(config) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 3514ce62f..134190117 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -100,6 +100,14 @@ def test_model_init_equivalent(self): def test_transforms_in_strategy(self): for _strat in self.strat.strat_list: + # Check if the same transform is passed around everywhere + self.assertTrue(id(_strat.transforms) == id(_strat.generator.transforms)) + if _strat.model is not None: + self.assertTrue( + id(_strat.generator.transforms) == id(_strat.model.transforms) + ) + + # Check all the transform bits are the same for strat_transform, gen_transform in zip( _strat.transforms.items(), _strat.generator.transforms.items() ):