Skip to content

Commit

Permalink
add support for fixed parameters (facebookresearch#457)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#457

Add support for fixed parameters. These parameters are just set in the config and will be removed before any model or generator is aware of them (then added back in whenever a model or generator is asked for an output.

Differential Revision: D66012863
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Nov 15, 2024
1 parent cb9f00c commit e0b6c77
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 6 deletions.
6 changes: 6 additions & 0 deletions aepsych/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ def _check_param_settings(self, param_name: str) -> None:
f"Parameter {param_name} is binary and shouldn't have bounds."
)

elif param_block["par_type"] == "fixed":
if "value" not in param_block:
raise ParameterConfigError(
f"Parameter {param_name} is fixed and needs to have value set."
)

else:
raise ParameterConfigError(
f"Parameter {param_name} has an unsupported parameter type {param_block['par_type']}."
Expand Down
3 changes: 2 additions & 1 deletion aepsych/transforms/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from .fixed import Fixed
from .log10_plus import Log10Plus
from .normalize_scale import NormalizeScale
from .round import Round

__all__ = ["Log10Plus", "NormalizeScale", "Round"]
__all__ = ["Log10Plus", "NormalizeScale", "Round", "Fixed"]
121 changes: 121 additions & 0 deletions aepsych/transforms/ops/fixed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#!/usr/bin/env python3
# Copyright (c) Meta, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Optional, Union

import torch
from aepsych.config import Config
from aepsych.transforms.ops.base import Transform


class Fixed(Transform, torch.nn.Module):
def __init__(
self,
indices: List[int],
values: List[Union[float, int]],
transform_on_train: bool = True,
transform_on_eval: bool = True,
transform_on_fantasize: bool = True,
reverse: bool = False,
**kwargs,
) -> None:
"""Initialize a fixed transform. It will add and remove fixed values from
tensors.
Args:
indices (List[int]): The indices of the parameters to be fixed.
values (List[Union[float, int]]): The values of the fixed parameters.
transform_on_train (bool): A boolean indicating whether to apply the
transforms in train() mode. Default: True.
transform_on_eval (bool): A boolean indicating whether to apply the
transform in eval() mode. Default: True.
transform_on_fantasize (bool): A boolean indicating whether to apply the
transform when called from within a `fantasize` call. Default: True.
reverse (bool): A boolean indicating whether the forward pass should
untransform the inputs. Default: False.
**kwargs: Accepted to conform to API.
"""
# Turn indices and values into tensors and sort
indices_ = torch.tensor(indices, dtype=torch.long)
values_ = torch.tensor(values, dtype=torch.float64)

# Sort indices and values
sort_idx = torch.argsort(indices_)
indices_ = indices_[sort_idx]
values_ = values_[sort_idx]

super().__init__()
self.register_buffer("indices", indices_)
self.register_buffer("values", values_)
self.transform_on_train = transform_on_train
self.transform_on_eval = transform_on_eval
self.transform_on_fantasize = transform_on_fantasize
self.reverse = reverse

def _transform(self, X: torch.Tensor) -> torch.Tensor:
r"""Transform the input Tensor by popping out the fixed parameters at the
specified indices.
Args:
X (torch.Tensor): A `batch_shape x n x d`-dim tensor of inputs.
Returns:
torch.Tensor: The input tensor with fixed parameters removed.
"""
X = X.clone()

mask = ~torch.isin(torch.arange(X.shape[1]), self.indices)

X = X[:, mask]

return X

def _untransform(self, X: torch.Tensor) -> torch.Tensor:
r"""Transform the input tensor by adding back in the fixed parameters at the
specified indices.
Args:
X (torch.Tensor): A `batch_shape x n x d`-dim tensor of transformed inputs.
Returns:
torch.Tensor: The same tensor as the input with the fixed parameters added
back in.
"""
X = X.clone()

for i, idx in enumerate(self.indices):
pre_fixed = X[:, :idx]
post_fixed = X[:, idx:]
fixed = torch.tile(self.values[i], (X.shape[0], 1))
X = torch.cat((pre_fixed, fixed, post_fixed), dim=1)

return X

@classmethod
def get_config_options(
cls,
config: Config,
name: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Return a dictionary of the relevant options to initialize a Fixed parameter
transform for the named parameter within the config.
Args:
config (Config): Config to look for options in.
name (str, optional): Parameter to find options for.
options (Dict[str, Any], optional): Options to override from the config.
Returns:
Dict[str, Any]: A dictionary of options to initialize this class with,
including the transformed bounds.
"""
options = super().get_config_options(config=config, name=name, options=options)

if "values" not in options:
options["values"] = [config.getfloat(name, "value")]

return options
36 changes: 33 additions & 3 deletions aepsych/transforms/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from aepsych.config import Config, ConfigurableMixin
from aepsych.generators.base import AEPsychGenerator
from aepsych.models.base import AEPsychMixin, ModelProtocol
from aepsych.transforms.ops import Log10Plus, NormalizeScale, Round
from aepsych.transforms.ops import Fixed, Log10Plus, NormalizeScale, Round
from aepsych.transforms.ops.base import Transform
from aepsych.utils import get_bounds
from botorch.acquisition import AcquisitionFunction
from botorch.models.transforms.input import ChainedInputTransform
Expand All @@ -49,6 +50,27 @@ class ParameterTransforms(ChainedInputTransform, ConfigurableMixin):
space back into raw space.
"""

def __init__(
self,
**transforms: Transform,
) -> None:
fixed_values = []
fixed_indices = []
transform_keys = list(transforms.keys())
for key in transform_keys:
if isinstance(transforms[key], Fixed):
transform = transforms.pop(key)
fixed_values += transform.values.tolist()
fixed_indices += transform.indices.tolist()

if len(fixed_values) > 0:
# Combine Fixed parameters
transforms["_CombinedFixed"] = Fixed(
indices=fixed_indices, values=fixed_values
)

super().__init__(**transforms)

def _temporary_reshape(func: Callable) -> Callable:
# Decorator to reshape tensors to the expected 2D shape, even if the input was
# 1D or 3D and after the transform reshape it back to the original.
Expand Down Expand Up @@ -183,6 +205,14 @@ def get_config_options(
)
transform_dict[f"{par}_Round"] = round

if par_type == "fixed":
fixed = Fixed.from_config(
config=config, name=par, options=transform_options
)

# We don't mess with bounds since we don't want to modify indices
transform_dict[f"{par}_Fixed"] = fixed

# Log scale
if config.getboolean(par, "log_scale", fallback=False):
log10 = Log10Plus.from_config(
Expand All @@ -198,7 +228,7 @@ def get_config_options(
# Normalize scale (defaults true)
if config.getboolean(
par, "normalize_scale", fallback=True
) and par_type not in ["discrete", "binary"]:
) and par_type not in ["binary", "fixed"]:
normalize = NormalizeScale.from_config(
config=config, name=par, options=transform_options
)
Expand Down Expand Up @@ -391,7 +421,7 @@ def __init__(
transforms: ChainedInputTransform = ChainedInputTransform(**{}),
**kwargs: Any,
) -> None:
f"""Wraps a Model with parameter transforms. This will transform any relevant
"""Wraps a Model with parameter transforms. This will transform any relevant
model arguments (e.g., bounds) and model data (e.g., training data, x) to be
transformed into the transformed space. The wrapper surfaces the API of the
raw model such that the wrapper can be used like a raw model.
Expand Down
13 changes: 13 additions & 0 deletions docs/parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ lower_bound = 0
upper_bound = 1
```

<h3>Fixed</h3>

```ini
[parameter]
par_type = fixed
value = 4.5
```

Fixed parameters will never be passed to the model or the generators but will be removed
and added to tells and asks, respectively. Fixed parameters when running multiple
conditions with certain parameters fixed; instead of removing the parameter entirely,
a parameter can be set to fixed at a certain value in certain configs.

<h2>Parameter Transformations</h2>
Currently, we only support a log scale transformation to parameters. More parameter
transformations to come! In general, you can define your parameters in the raw
Expand Down
115 changes: 113 additions & 2 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ParameterTransformedModel,
ParameterTransforms,
)
from aepsych.transforms.ops import Log10Plus, NormalizeScale, Round
from aepsych.transforms.ops import Fixed, Log10Plus, NormalizeScale, Round


class TransformsConfigTest(unittest.TestCase):
Expand Down Expand Up @@ -362,7 +362,7 @@ def test_normalize_scale(self):
self.assertTrue(torch.allclose(transforms.untransform(transformed), values))


class TransformInteger(unittest.TestCase):
class TransformsInteger(unittest.TestCase):
def test_integer_bounds(self):
config_str = """
[common]
Expand Down Expand Up @@ -525,3 +525,114 @@ def test_binary(self):

with self.assertRaises(ParameterConfigError):
config.update(config_str=bad_config_str)


class TransformsFixed(unittest.TestCase):
def test_fixed_from_config(self):
np.random.seed(1)
torch.manual_seed(1)

config_str = """
[common]
parnames = [signal1, signal2, signal3]
stimuli_per_trial = 1
outcome_types = [binary]
strategy_names = [init_strat, opt_strat]
[signal1]
par_type = binary
[signal2]
par_type = fixed
value = 4.5
[signal3]
par_type = continuous
lower_bound = 1
upper_bound = 100
log_scale = True
[init_strat]
generator = SobolGenerator
min_asks = 1
[opt_strat]
generator = OptimizeAcqfGenerator
acqf = MCLevelSetEstimation
model = GPClassificationModel
min_asks = 1
"""
config = Config()
config.update(config_str=config_str)

strat = SequentialStrategy.from_config(config)

while not strat.finished:
points = strat.gen()
self.assertTrue(points[0][1].item() == 4.5)
strat.add_data(points, int(np.random.rand() > 0.5))

self.assertTrue(len(strat.strat_list[0].generator.lb) == 2)
self.assertTrue(len(strat.strat_list[0].generator.ub) == 2)

bad_config_str = """
[common]
parnames = [signal1, signal2, signal3]
stimuli_per_trial = 1
outcome_types = [binary]
strategy_names = [init_strat, opt_strat]
[signal1]
par_type = binary
[signal2]
par_type = fixed
[signal3]
par_type = continuous
lower_bound = 1
upper_bound = 100
log_scale = True
[init_strat]
generator = SobolGenerator
min_asks = 1
[opt_strat]
generator = OptimizeAcqfGenerator
acqf = MCLevelSetEstimation
model = GPClassificationModel
min_asks = 1
"""
config = Config()
with self.assertRaises(ParameterConfigError):
config.update(config_str=bad_config_str)

def test_fixed_standalone(self):
fixed1 = Fixed([3], values=[0.3])
fixed2 = Fixed([1, 2], values=[0.1, 0.2])

transforms = ParameterTransforms(fixed1=fixed1, fixed2=fixed2)

self.assertTrue(len(transforms) == 1)
self.assertTrue(
torch.all(transforms["_CombinedFixed"].indices == torch.tensor([1, 2, 3]))
)
self.assertTrue(
torch.all(
transforms["_CombinedFixed"].values == torch.tensor([0.1, 0.2, 0.3])
)
)

input = torch.tensor([[1, 100, 100, 100, 1], [2, 100, 100, 100, 2]])
transformed = transforms.transform(input)
untransformed = transforms.untransform(transformed)

self.assertTrue(transformed.shape[0] == 2)
self.assertTrue(torch.all(transformed[:, 0] == torch.tensor([1, 2])))
self.assertTrue(
torch.all(torch.tensor([1, 0.1, 0.2, 0.3, 1]) == untransformed[0])
)
self.assertTrue(
torch.all(torch.tensor([2, 0.1, 0.2, 0.3, 2]) == untransformed[1])
)

0 comments on commit e0b6c77

Please sign in to comment.