From e0b6c77f86d1a070c041ba362c60adbd8e1d4dfc Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Fri, 15 Nov 2024 12:03:31 -0800 Subject: [PATCH] add support for fixed parameters (#457) Summary: Pull Request resolved: https://github.com/facebookresearch/aepsych/pull/457 Add support for fixed parameters. These parameters are just set in the config and will be removed before any model or generator is aware of them (then added back in whenever a model or generator is asked for an output. Differential Revision: D66012863 --- aepsych/config.py | 6 ++ aepsych/transforms/ops/__init__.py | 3 +- aepsych/transforms/ops/fixed.py | 121 +++++++++++++++++++++++++++++ aepsych/transforms/parameters.py | 36 ++++++++- docs/parameters.md | 13 ++++ tests/test_transforms.py | 115 ++++++++++++++++++++++++++- 6 files changed, 288 insertions(+), 6 deletions(-) create mode 100644 aepsych/transforms/ops/fixed.py diff --git a/aepsych/config.py b/aepsych/config.py index 46be218d1..d5328e7b6 100644 --- a/aepsych/config.py +++ b/aepsych/config.py @@ -291,6 +291,12 @@ def _check_param_settings(self, param_name: str) -> None: f"Parameter {param_name} is binary and shouldn't have bounds." ) + elif param_block["par_type"] == "fixed": + if "value" not in param_block: + raise ParameterConfigError( + f"Parameter {param_name} is fixed and needs to have value set." + ) + else: raise ParameterConfigError( f"Parameter {param_name} has an unsupported parameter type {param_block['par_type']}." diff --git a/aepsych/transforms/ops/__init__.py b/aepsych/transforms/ops/__init__.py index bdf10602b..8a6034fcb 100644 --- a/aepsych/transforms/ops/__init__.py +++ b/aepsych/transforms/ops/__init__.py @@ -5,8 +5,9 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from .fixed import Fixed from .log10_plus import Log10Plus from .normalize_scale import NormalizeScale from .round import Round -__all__ = ["Log10Plus", "NormalizeScale", "Round"] +__all__ = ["Log10Plus", "NormalizeScale", "Round", "Fixed"] diff --git a/aepsych/transforms/ops/fixed.py b/aepsych/transforms/ops/fixed.py new file mode 100644 index 000000000..159d7d91f --- /dev/null +++ b/aepsych/transforms/ops/fixed.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from typing import Any, Dict, List, Optional, Union + +import torch +from aepsych.config import Config +from aepsych.transforms.ops.base import Transform + + +class Fixed(Transform, torch.nn.Module): + def __init__( + self, + indices: List[int], + values: List[Union[float, int]], + transform_on_train: bool = True, + transform_on_eval: bool = True, + transform_on_fantasize: bool = True, + reverse: bool = False, + **kwargs, + ) -> None: + """Initialize a fixed transform. It will add and remove fixed values from + tensors. + + Args: + indices (List[int]): The indices of the parameters to be fixed. + values (List[Union[float, int]]): The values of the fixed parameters. + transform_on_train (bool): A boolean indicating whether to apply the + transforms in train() mode. Default: True. + transform_on_eval (bool): A boolean indicating whether to apply the + transform in eval() mode. Default: True. + transform_on_fantasize (bool): A boolean indicating whether to apply the + transform when called from within a `fantasize` call. Default: True. + reverse (bool): A boolean indicating whether the forward pass should + untransform the inputs. Default: False. + **kwargs: Accepted to conform to API. + """ + # Turn indices and values into tensors and sort + indices_ = torch.tensor(indices, dtype=torch.long) + values_ = torch.tensor(values, dtype=torch.float64) + + # Sort indices and values + sort_idx = torch.argsort(indices_) + indices_ = indices_[sort_idx] + values_ = values_[sort_idx] + + super().__init__() + self.register_buffer("indices", indices_) + self.register_buffer("values", values_) + self.transform_on_train = transform_on_train + self.transform_on_eval = transform_on_eval + self.transform_on_fantasize = transform_on_fantasize + self.reverse = reverse + + def _transform(self, X: torch.Tensor) -> torch.Tensor: + r"""Transform the input Tensor by popping out the fixed parameters at the + specified indices. + + Args: + X (torch.Tensor): A `batch_shape x n x d`-dim tensor of inputs. + + Returns: + torch.Tensor: The input tensor with fixed parameters removed. + """ + X = X.clone() + + mask = ~torch.isin(torch.arange(X.shape[1]), self.indices) + + X = X[:, mask] + + return X + + def _untransform(self, X: torch.Tensor) -> torch.Tensor: + r"""Transform the input tensor by adding back in the fixed parameters at the + specified indices. + + Args: + X (torch.Tensor): A `batch_shape x n x d`-dim tensor of transformed inputs. + + Returns: + torch.Tensor: The same tensor as the input with the fixed parameters added + back in. + """ + X = X.clone() + + for i, idx in enumerate(self.indices): + pre_fixed = X[:, :idx] + post_fixed = X[:, idx:] + fixed = torch.tile(self.values[i], (X.shape[0], 1)) + X = torch.cat((pre_fixed, fixed, post_fixed), dim=1) + + return X + + @classmethod + def get_config_options( + cls, + config: Config, + name: Optional[str] = None, + options: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Return a dictionary of the relevant options to initialize a Fixed parameter + transform for the named parameter within the config. + + Args: + config (Config): Config to look for options in. + name (str, optional): Parameter to find options for. + options (Dict[str, Any], optional): Options to override from the config. + + Returns: + Dict[str, Any]: A dictionary of options to initialize this class with, + including the transformed bounds. + """ + options = super().get_config_options(config=config, name=name, options=options) + + if "values" not in options: + options["values"] = [config.getfloat(name, "value")] + + return options diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index 29e6648c1..5f09915b5 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -26,7 +26,8 @@ from aepsych.config import Config, ConfigurableMixin from aepsych.generators.base import AEPsychGenerator from aepsych.models.base import AEPsychMixin, ModelProtocol -from aepsych.transforms.ops import Log10Plus, NormalizeScale, Round +from aepsych.transforms.ops import Fixed, Log10Plus, NormalizeScale, Round +from aepsych.transforms.ops.base import Transform from aepsych.utils import get_bounds from botorch.acquisition import AcquisitionFunction from botorch.models.transforms.input import ChainedInputTransform @@ -49,6 +50,27 @@ class ParameterTransforms(ChainedInputTransform, ConfigurableMixin): space back into raw space. """ + def __init__( + self, + **transforms: Transform, + ) -> None: + fixed_values = [] + fixed_indices = [] + transform_keys = list(transforms.keys()) + for key in transform_keys: + if isinstance(transforms[key], Fixed): + transform = transforms.pop(key) + fixed_values += transform.values.tolist() + fixed_indices += transform.indices.tolist() + + if len(fixed_values) > 0: + # Combine Fixed parameters + transforms["_CombinedFixed"] = Fixed( + indices=fixed_indices, values=fixed_values + ) + + super().__init__(**transforms) + def _temporary_reshape(func: Callable) -> Callable: # Decorator to reshape tensors to the expected 2D shape, even if the input was # 1D or 3D and after the transform reshape it back to the original. @@ -183,6 +205,14 @@ def get_config_options( ) transform_dict[f"{par}_Round"] = round + if par_type == "fixed": + fixed = Fixed.from_config( + config=config, name=par, options=transform_options + ) + + # We don't mess with bounds since we don't want to modify indices + transform_dict[f"{par}_Fixed"] = fixed + # Log scale if config.getboolean(par, "log_scale", fallback=False): log10 = Log10Plus.from_config( @@ -198,7 +228,7 @@ def get_config_options( # Normalize scale (defaults true) if config.getboolean( par, "normalize_scale", fallback=True - ) and par_type not in ["discrete", "binary"]: + ) and par_type not in ["binary", "fixed"]: normalize = NormalizeScale.from_config( config=config, name=par, options=transform_options ) @@ -391,7 +421,7 @@ def __init__( transforms: ChainedInputTransform = ChainedInputTransform(**{}), **kwargs: Any, ) -> None: - f"""Wraps a Model with parameter transforms. This will transform any relevant + """Wraps a Model with parameter transforms. This will transform any relevant model arguments (e.g., bounds) and model data (e.g., training data, x) to be transformed into the transformed space. The wrapper surfaces the API of the raw model such that the wrapper can be used like a raw model. diff --git a/docs/parameters.md b/docs/parameters.md index 9bdd0a539..4015309f4 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -61,6 +61,19 @@ lower_bound = 0 upper_bound = 1 ``` +

Fixed

+ +```ini +[parameter] +par_type = fixed +value = 4.5 +``` + +Fixed parameters will never be passed to the model or the generators but will be removed +and added to tells and asks, respectively. Fixed parameters when running multiple +conditions with certain parameters fixed; instead of removing the parameter entirely, +a parameter can be set to fixed at a certain value in certain configs. +

Parameter Transformations

Currently, we only support a log scale transformation to parameters. More parameter transformations to come! In general, you can define your parameters in the raw diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 0bcaa71ad..248af7a79 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -17,7 +17,7 @@ ParameterTransformedModel, ParameterTransforms, ) -from aepsych.transforms.ops import Log10Plus, NormalizeScale, Round +from aepsych.transforms.ops import Fixed, Log10Plus, NormalizeScale, Round class TransformsConfigTest(unittest.TestCase): @@ -362,7 +362,7 @@ def test_normalize_scale(self): self.assertTrue(torch.allclose(transforms.untransform(transformed), values)) -class TransformInteger(unittest.TestCase): +class TransformsInteger(unittest.TestCase): def test_integer_bounds(self): config_str = """ [common] @@ -525,3 +525,114 @@ def test_binary(self): with self.assertRaises(ParameterConfigError): config.update(config_str=bad_config_str) + + +class TransformsFixed(unittest.TestCase): + def test_fixed_from_config(self): + np.random.seed(1) + torch.manual_seed(1) + + config_str = """ + [common] + parnames = [signal1, signal2, signal3] + stimuli_per_trial = 1 + outcome_types = [binary] + strategy_names = [init_strat, opt_strat] + + [signal1] + par_type = binary + + [signal2] + par_type = fixed + value = 4.5 + + [signal3] + par_type = continuous + lower_bound = 1 + upper_bound = 100 + log_scale = True + + [init_strat] + generator = SobolGenerator + min_asks = 1 + + [opt_strat] + generator = OptimizeAcqfGenerator + acqf = MCLevelSetEstimation + model = GPClassificationModel + min_asks = 1 + """ + config = Config() + config.update(config_str=config_str) + + strat = SequentialStrategy.from_config(config) + + while not strat.finished: + points = strat.gen() + self.assertTrue(points[0][1].item() == 4.5) + strat.add_data(points, int(np.random.rand() > 0.5)) + + self.assertTrue(len(strat.strat_list[0].generator.lb) == 2) + self.assertTrue(len(strat.strat_list[0].generator.ub) == 2) + + bad_config_str = """ + [common] + parnames = [signal1, signal2, signal3] + stimuli_per_trial = 1 + outcome_types = [binary] + strategy_names = [init_strat, opt_strat] + + [signal1] + par_type = binary + + [signal2] + par_type = fixed + + [signal3] + par_type = continuous + lower_bound = 1 + upper_bound = 100 + log_scale = True + + [init_strat] + generator = SobolGenerator + min_asks = 1 + + [opt_strat] + generator = OptimizeAcqfGenerator + acqf = MCLevelSetEstimation + model = GPClassificationModel + min_asks = 1 + """ + config = Config() + with self.assertRaises(ParameterConfigError): + config.update(config_str=bad_config_str) + + def test_fixed_standalone(self): + fixed1 = Fixed([3], values=[0.3]) + fixed2 = Fixed([1, 2], values=[0.1, 0.2]) + + transforms = ParameterTransforms(fixed1=fixed1, fixed2=fixed2) + + self.assertTrue(len(transforms) == 1) + self.assertTrue( + torch.all(transforms["_CombinedFixed"].indices == torch.tensor([1, 2, 3])) + ) + self.assertTrue( + torch.all( + transforms["_CombinedFixed"].values == torch.tensor([0.1, 0.2, 0.3]) + ) + ) + + input = torch.tensor([[1, 100, 100, 100, 1], [2, 100, 100, 100, 2]]) + transformed = transforms.transform(input) + untransformed = transforms.untransform(transformed) + + self.assertTrue(transformed.shape[0] == 2) + self.assertTrue(torch.all(transformed[:, 0] == torch.tensor([1, 2]))) + self.assertTrue( + torch.all(torch.tensor([1, 0.1, 0.2, 0.3, 1]) == untransformed[0]) + ) + self.assertTrue( + torch.all(torch.tensor([2, 0.1, 0.2, 0.3, 2]) == untransformed[1]) + )