forked from facebookresearch/aepsych
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Split out transform classes/functions to new files (facebookresearch#453
) Summary: parameters.py was getting too big. Transforms themselves were moved to ops.py, leaving the base class and wrapper classes in parameters.py. Generic utility function that could be used elsewhere moved to base utils.py New parameter handling overwrites old ax support, so we remove ax related functions/tests Differential Revision: D65898366
- Loading branch information
1 parent
cabd911
commit ecf8b4b
Showing
9 changed files
with
385 additions
and
650 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,347 @@ | ||
from copy import deepcopy | ||
from typing import Any, Dict, Literal, Optional, Union | ||
|
||
import numpy as np | ||
import torch | ||
from aepsych.config import Config, ConfigurableMixin | ||
from aepsych.utils import get_bounds | ||
from botorch.models.transforms.input import Log10, Normalize, ReversibleInputTransform | ||
from botorch.models.transforms.utils import subset_transform | ||
from torch import Tensor | ||
|
||
|
||
def _get_parameter_options( | ||
config: Config, name: Optional[str] = None, options: Optional[Dict[str, Any]] = None | ||
) -> Dict[str, Any]: | ||
"""Return options for a parameter in a config. | ||
Args: | ||
config (Config): Config to search for parameter. | ||
name (str): Name of parameter. | ||
options (Dict[str, Any], optional): dictionary of options to overwrite config | ||
options, defaults to an empty dictionary. | ||
Returns: | ||
Dict[str, Any]: Dictionary of options to initialize a transform from config. | ||
""" | ||
if name is None: | ||
raise ValueError(f"{name} must be set to initialize a transform.") | ||
|
||
if options is None: | ||
options = {} | ||
else: | ||
options = deepcopy(options) | ||
|
||
# Figure out the index of this parameter | ||
parnames = config.getlist("common", "parnames", element_type=str) | ||
idx = parnames.index(name) | ||
|
||
if "indices" not in options: | ||
options["indices"] = [idx] | ||
|
||
return options | ||
|
||
|
||
class Log10Plus(Log10, ConfigurableMixin): | ||
"""Base-10 log transform that we add a constant to the values""" | ||
|
||
def __init__( | ||
self, | ||
indices: list[int], | ||
constant: float = 0.0, | ||
transform_on_train: bool = True, | ||
transform_on_eval: bool = True, | ||
transform_on_fantasize: bool = True, | ||
reverse: bool = False, | ||
**kwargs, | ||
) -> None: | ||
"""Initalize transform | ||
Args: | ||
indices: The indices of the parameters to log transform. | ||
constant: The constant to add to inputs before log transforming. Defaults to | ||
0.0. | ||
transform_on_train: A boolean indicating whether to apply the | ||
transforms in train() mode. Default: True. | ||
transform_on_eval: A boolean indicating whether to apply the | ||
transform in eval() mode. Default: True. | ||
transform_on_fantasize: A boolean indicating whether to apply the | ||
transform when called from within a `fantasize` call. Default: True. | ||
reverse: A boolean indicating whether the forward pass should untransform | ||
the inputs. | ||
**kwargs: Accepted to conform to API. | ||
""" | ||
super().__init__( | ||
indices=indices, | ||
transform_on_train=transform_on_train, | ||
transform_on_eval=transform_on_eval, | ||
transform_on_fantasize=transform_on_fantasize, | ||
reverse=reverse, | ||
) | ||
self.register_buffer("constant", torch.tensor(constant, dtype=torch.long)) | ||
|
||
@subset_transform | ||
def _transform(self, X: Tensor) -> Tensor: | ||
r"""Add the constant then log transform the inputs. | ||
Args: | ||
X: A `batch_shape x n x d`-dim tensor of inputs. | ||
Returns: | ||
A `batch_shape x n x d`-dim tensor of transformed inputs. | ||
""" | ||
X = X + (torch.ones_like(X) * self.constant) | ||
return X.log10() | ||
|
||
@subset_transform | ||
def _untransform(self, X: Tensor) -> Tensor: | ||
r"""Reverse the log transformation then subtract the constant. | ||
Args: | ||
X: A `batch_shape x n x d`-dim tensor of transformed inputs. | ||
Returns: | ||
A `batch_shape x n x d`-dim tensor of untransformed inputs. | ||
""" | ||
X = 10.0**X | ||
return X - (torch.ones_like(X) * self.constant) | ||
|
||
@classmethod | ||
def get_config_options( | ||
cls, | ||
config: Config, | ||
name: Optional[str] = None, | ||
options: Optional[Dict[str, Any]] = None, | ||
) -> Dict[str, Any]: | ||
"""Return a dictionary of the relevant options to initialize a Log10Plus | ||
transform for the named parameter within the config. | ||
Args: | ||
config (Config): Config to look for options in. | ||
name (str): Parameter to find options for. | ||
options (Dict[str, Any]): Options to override from the config. | ||
Returns: | ||
Dict[str, Any]: A diciontary of options to initialize this class with, | ||
including the transformed bounds. | ||
""" | ||
options = _get_parameter_options(config, name, options) | ||
|
||
# Make sure we have bounds ready | ||
if "bounds" not in options: | ||
options["bounds"] = get_bounds(config) | ||
|
||
if "constant" not in options: | ||
lb = options["bounds"][0, options["indices"]] | ||
if lb < 0.0: | ||
constant = np.abs(lb) + 1.0 | ||
elif lb < 1.0: | ||
constant = 1.0 | ||
else: | ||
constant = 0.0 | ||
|
||
options["constant"] = constant | ||
|
||
return options | ||
|
||
|
||
class NormalizeScale(Normalize, ConfigurableMixin): | ||
def __init__( | ||
self, | ||
d: int, | ||
indices: Optional[Union[list[int], Tensor]] = None, | ||
bounds: Optional[Tensor] = None, | ||
batch_shape: torch.Size = torch.Size(), | ||
transform_on_train: bool = True, | ||
transform_on_eval: bool = True, | ||
transform_on_fantasize: bool = True, | ||
reverse: bool = False, | ||
min_range: float = 1e-8, | ||
learn_bounds: Optional[bool] = None, | ||
almost_zero: float = 1e-12, | ||
**kwargs, | ||
) -> None: | ||
r"""Normalizes the scale of the parameters. | ||
Args: | ||
d: Total number of parameters (dimensions). | ||
indices: The indices of the inputs to normalize. If omitted, | ||
take all dimensions of the inputs into account. | ||
bounds: If provided, use these bounds to normalize the parameters. If | ||
omitted, learn the bounds in train mode. | ||
batch_shape: The batch shape of the inputs (assuming input tensors | ||
of shape `batch_shape x n x d`). If provided, perform individual | ||
normalization per batch, otherwise uses a single normalization. | ||
transform_on_train: A boolean indicating whether to apply the | ||
transforms in train() mode. Default: True. | ||
transform_on_eval: A boolean indicating whether to apply the | ||
transform in eval() mode. Default: True. | ||
transform_on_fantasize: A boolean indicating whether to apply the | ||
transform when called from within a `fantasize` call. Default: True. | ||
reverse: A boolean indicating whether the forward pass should untransform | ||
the parameters. | ||
min_range: If the range of a parameter is smaller than `min_range`, | ||
that parameter will not be normalized. This is equivalent to | ||
using bounds of `[0, 1]` for this dimension, and helps avoid division | ||
by zero errors and related numerical issues. See the example below. | ||
NOTE: This only applies if `learn_bounds=True`. | ||
learn_bounds: Whether to learn the bounds in train mode. Defaults | ||
to False if bounds are provided, otherwise defaults to True. | ||
**kwargs: Accepted to conform to API. | ||
""" | ||
super().__init__( | ||
d=d, | ||
indices=indices, | ||
bounds=bounds, | ||
batch_shape=batch_shape, | ||
transform_on_train=transform_on_train, | ||
transform_on_eval=transform_on_eval, | ||
transform_on_fantasize=transform_on_fantasize, | ||
reverse=reverse, | ||
min_range=min_range, | ||
learn_bounds=learn_bounds, | ||
almost_zero=almost_zero, | ||
) | ||
|
||
@classmethod | ||
def get_config_options( | ||
cls, | ||
config: Config, | ||
name: Optional[str] = None, | ||
options: Optional[Dict[str, Any]] = None, | ||
) -> Dict[str, Any]: | ||
""" | ||
Return a dictionary of the relevant options to initialize a NormalizeScale | ||
transform for the named parameter within the config. | ||
Args: | ||
config (Config): Config to look for options in. | ||
name (str): Parameter to find options for. | ||
options (Dict[str, Any]): Options to override from the config. | ||
Return: | ||
Dict[str, Any]: A diciontary of options to initialize this class with, | ||
including the transformed bounds. | ||
""" | ||
options = _get_parameter_options(config, name, options) | ||
|
||
# Make sure we have bounds ready | ||
if "bounds" not in options: | ||
options["bounds"] = get_bounds(config) | ||
|
||
if "d" not in options: | ||
options["d"] = options["bounds"].shape[1] | ||
|
||
return options | ||
|
||
|
||
class Discretize(ReversibleInputTransform, torch.nn.Module, ConfigurableMixin): | ||
def __init__( | ||
self, | ||
indices: list[int], | ||
transform_on_train: bool = True, | ||
transform_on_eval: bool = True, | ||
transform_on_fantasize: bool = True, | ||
reverse: bool = False, | ||
**kwargs, | ||
) -> None: | ||
"""Initialize a discretize transform. | ||
Args: | ||
indices: The indices of the inputs to round. | ||
transform_on_train: A boolean indicating whether to apply the | ||
transforms in train() mode. Default: True. | ||
transform_on_eval: A boolean indicating whether to apply the | ||
transform in eval() mode. Default: True. | ||
transform_on_fantasize: Currently will not do anything, here to conform to | ||
API. | ||
reverse: Whether to round in forward or backward passes. | ||
**kwargs: Accepted to conform to API. | ||
""" | ||
super().__init__() | ||
self.register_buffer("indices", torch.tensor(indices, dtype=torch.long)) | ||
self.transform_on_train = transform_on_train | ||
self.transform_on_eval = transform_on_eval | ||
self.transform_on_fantasize = transform_on_fantasize | ||
self.reverse = reverse | ||
|
||
@subset_transform | ||
def _transform(self, X: torch.Tensor) -> torch.Tensor: | ||
r"""Round the inputs to a model to be discrete. | ||
Args: | ||
X (torch.Tensor): A `batch_shape x n x d`-dim tensor of inputs. | ||
Returns: | ||
torch.Tensor: The input tensor with values rounded. | ||
""" | ||
return X.round() | ||
|
||
@subset_transform | ||
def _untransform(self, X: Tensor) -> Tensor: | ||
r"""This does nothing as rounding cannot be reversed. Typically, Discretize will | ||
be initialized with reverse=True, such that only values coming out of a model/ | ||
generator is discretized (the model needs to work in a continuous space). | ||
Args: | ||
X (torch.Tensor): A `batch_shape x n x d`-dim tensor of transformed inputs. | ||
Returns: | ||
torch.Tensor: The same tensor as the input with no change. | ||
""" | ||
return X | ||
|
||
def transform_bounds( | ||
self, X: torch.Tensor, bound: Optional[Literal["lb", "ub"]] = None | ||
) -> torch.Tensor: | ||
r"""Return the bounds X transformed. | ||
Args: | ||
X (torch.Tensor): Either a `[1, dim]` or `[2, dim]` tensor of parameter | ||
bounds. | ||
bound (Literal["lb", "ub"], optional): The bound that this is, if None, we | ||
will assume the input is both bounds with a `[2, dim]` X. | ||
Returns: | ||
torch.Tensor: A transformed set of parameter bounds. | ||
""" | ||
X = X.clone() | ||
|
||
if bound == "lb": | ||
X[0, self.indices] -= torch.tensor([0.5] * len(self.indices)) | ||
elif bound == "ub": | ||
X[0, self.indices] += torch.tensor([0.5 - 1e-6] * len(self.indices)) | ||
else: # Both bounds | ||
X[0, self.indices] -= torch.tensor([0.5] * len(self.indices)) | ||
X[1, self.indices] += torch.tensor([0.5 - 1e-6] * len(self.indices)) | ||
|
||
return X | ||
|
||
@classmethod | ||
def get_config_options( | ||
cls, | ||
config: Config, | ||
name: Optional[str] = None, | ||
options: Optional[Dict[str, Any]] = None, | ||
) -> Dict[str, Any]: | ||
""" | ||
Return a dictionary of the relevant options to initialize the discretize transform | ||
from the config for the named transform. | ||
Args: | ||
config (Config): Config to look for options in. | ||
name (str, optional): The parameter to find options for. | ||
options (Dict[str, Any], optional): Options to override from the config, | ||
defaults to None. | ||
Return: | ||
Dict[str, Any]: A dictionary of options to initialize this class. | ||
""" | ||
options = _get_parameter_options(config, name, options) | ||
|
||
# When Round is made from config, we need it to untransform (inputs are expected | ||
# to be discrete) values coming out of models/generators need to be discretized. | ||
if "reverse" not in options: | ||
options["reverse"] = True | ||
|
||
return options |
Oops, something went wrong.