Skip to content

Commit

Permalink
Remove normalize_inputs and replace with parameter transform (#431)
Browse files Browse the repository at this point in the history
Summary:

`normalize_inputs` (the one that min-max scales paraemters) is confusingly named (there's another `normalize_inputs` that concatenates data and ensures they're all the right types) and is a hard-coded transformation that is applied to all parameters. This means that there's no way to turn the behavior off selectively nor is it obvious that it is happening.

This diff removes the normalize_inputs method and replaces it with an parameter transform that will also allow selective application of the transform via an index.

Differential Revision: D65069497
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Nov 8, 2024
1 parent 6de8f0f commit 941499b
Show file tree
Hide file tree
Showing 15 changed files with 395 additions and 89 deletions.
45 changes: 29 additions & 16 deletions aepsych/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ast
import configparser
import json
import logging
import re
import warnings
from types import ModuleType
Expand Down Expand Up @@ -166,24 +167,36 @@ def update(

# Warn if ub/lb is defined in common section
if "ub" in self["common"] and "lb" in self["common"]:
warnings.warn(
"ub and lb have been defined in common section, ignoring parameter specific blocks, be very careful!"
)
elif "parnames" in self["common"]: # it's possible to pass no parnames
par_names = self.getlist(
"common", "parnames", element_type=str, fallback=[]
logging.warning(
"ub and lb have been defined in common section, parameter-specific bounds take precendence over these."
)
lb = [None] * len(par_names)
ub = [None] * len(par_names)
for i, par_name in enumerate(par_names):
# Validate the parameter-specific block
self._check_param_settings(par_name)

lb[i] = self[par_name]["lower_bound"]
ub[i] = self[par_name]["upper_bound"]

self["common"]["lb"] = f"[{', '.join(lb)}]"
self["common"]["ub"] = f"[{', '.join(ub)}]"
if "parnames" in self["common"]: # it's possible to pass no parnames
try:
par_names = self.getlist(
"common", "parnames", element_type=str, fallback=[]
)
lb = [None] * len(par_names)
ub = [None] * len(par_names)
for i, par_name in enumerate(par_names):
# Validate the parameter-specific block
self._check_param_settings(par_name)

lb[i] = self[par_name]["lower_bound"]
ub[i] = self[par_name]["upper_bound"]

self["common"]["lb"] = f"[{', '.join(lb)}]"
self["common"]["ub"] = f"[{', '.join(ub)}]"
except ValueError:
# Check if ub/lb exists in common
if "ub" in self["common"] and "lb" in self["common"]:
logging.warning(
"Parameter-specific bounds are incomplete, falling back to ub/lb in [common]"
)
else:
raise ValueError(
"Missing ub or lb in [common] with incomplete parameter-specific bounds, cannot fallback!"
)

# Deprecation warning for "experiment" section
if "experiment" in self:
Expand Down
9 changes: 2 additions & 7 deletions aepsych/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,6 @@ def set_train_data(
if targets is not None:
self.train_targets = targets

def normalize_inputs(self, x: torch.Tensor) -> torch.Tensor:
scale = self.ub - self.lb
return (x - self.lb) / scale

def forward(self, x: torch.Tensor) -> gpytorch.distributions.MultivariateNormal:
"""Evaluate GP
Expand All @@ -351,9 +347,8 @@ def forward(self, x: torch.Tensor) -> gpytorch.distributions.MultivariateNormal:
gpytorch.distributions.MultivariateNormal: Distribution object
holding mean and covariance at x.
"""
transformed_x = self.normalize_inputs(x)
mean_x = self.mean_module(transformed_x)
covar_x = self.covar_module(transformed_x)
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
return pred

Expand Down
8 changes: 2 additions & 6 deletions aepsych/models/monotonic_rejection_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,11 +341,7 @@ def forward(self, x: torch.Tensor) -> gpytorch.distributions.MultivariateNormal:
gpytorch.distributions.MultivariateNormal: Distribution object
holding mean and covariance at x.
"""

# final dim is deriv index, we only normalize the "real" dims
transformed_x = x.clone()
transformed_x[..., :-1] = self.normalize_inputs(transformed_x[..., :-1])
mean_x = self.mean_module(transformed_x)
covar_x = self.covar_module(transformed_x)
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
return latent_pred
5 changes: 2 additions & 3 deletions aepsych/models/multitask_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ def __init__(
def forward(
self, x: torch.Tensor
) -> gpytorch.distributions.MultitaskMultivariateNormal:
transformed_x = self.normalize_inputs(x)
mean_x = self.mean_module(transformed_x)
covar_x = self.covar_module(transformed_x)
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)

@classmethod
Expand Down
11 changes: 5 additions & 6 deletions aepsych/models/semi_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,18 +529,17 @@ def forward(self, x: torch.Tensor) -> MultivariateNormal:
Returns:
MVN object evaluated at samples
"""
transformed_x = self.normalize_inputs(x)
# TODO: make slope prop to intensity width.
slope_mean = self.slope_mean_module(transformed_x)
slope_mean = self.slope_mean_module(x)

# kc mvn
offset_mean = self.offset_mean_module(transformed_x)
offset_mean = self.offset_mean_module(x)

slope_cov = self.slope_covar_module(transformed_x)
offset_cov = self.offset_covar_module(transformed_x)
slope_cov = self.slope_covar_module(x)
offset_cov = self.offset_covar_module(x)

mean_x, cov_x = _hadamard_mvn_approx(
x_intensity=transformed_x[..., self.stim_dim],
x_intensity=x[..., self.stim_dim],
slope_mean=slope_mean,
slope_cov=slope_cov,
offset_mean=offset_mean,
Expand Down
3 changes: 3 additions & 0 deletions aepsych/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ def get_extremum(
timeout_sec=max_time,
)

if hasattr(model, "transforms"):
best_point = model.transforms.untransform(best_point)

# PosteriorMean flips the sign on minimize, we flip it back
if extremum_type == "min":
best_val = -best_val
Expand Down
135 changes: 128 additions & 7 deletions aepsych/transforms/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from aepsych.generators.base import AEPsychGenerator
from aepsych.models.base import AEPsychMixin, ModelProtocol
from botorch.acquisition import AcquisitionFunction
from botorch.models.transforms.input import ChainedInputTransform, Log10
from botorch.models.transforms.input import ChainedInputTransform, Log10, Normalize
from botorch.models.transforms.utils import subset_transform
from botorch.posteriors import Posterior
from torch import Tensor
Expand Down Expand Up @@ -128,16 +128,34 @@ def get_config_options(
parnames = config.getlist("common", "parnames", element_type=str)

# This is the "options" dictionary, transform options is only for maintaining the right transforms
transform_dict = {}
transform_dict: Dict[str, ChainedInputTransform] = {}
for par in parnames:
# This is the order that transforms are potentially applied, order matters

# Log scale
if config.getboolean(par, "log_scale", fallback=False):
transform_dict[f"{par}_Log10Plus"] = Log10Plus.from_config(
log10 = Log10Plus.from_config(
config=config, name=par, options=transform_options
)

# Transform bounds
transform_options["bounds"] = log10.transform(
transform_options["bounds"]
)
transform_dict[f"{par}_Log10Plus"] = log10

# Normalize scale (defaults true)
if config.getboolean(par, "normalize_scale", fallback=True):
normalize = NormalizeScale.from_config(
config=config, name=par, options=transform_options
)

# Transform bounds
transform_options["bounds"] = normalize.transform(
transform_options["bounds"]
)
transform_dict[f"{par}_NormalizeScale"] = normalize

return transform_dict


Expand Down Expand Up @@ -202,9 +220,9 @@ def __init__(
# Figure out what we need to do with generator
if isinstance(generator, type):
if "lb" in kwargs:
kwargs["lb"] = transforms.transform(kwargs["lb"].float())
kwargs["lb"] = transforms.transform(kwargs["lb"].to(torch.float64))
if "ub" in kwargs:
kwargs["ub"] = transforms.transform(kwargs["ub"].float())
kwargs["ub"] = transforms.transform(kwargs["ub"].to(torch.float64))
_base_obj = generator(**kwargs)
else:
_base_obj = generator
Expand Down Expand Up @@ -349,9 +367,9 @@ def __init__(
# Alternative instantiation method for analysis (and not live)
if isinstance(model, type):
if "lb" in kwargs:
kwargs["lb"] = transforms.transform(kwargs["lb"].float())
kwargs["lb"] = transforms.transform(kwargs["lb"].to(torch.float64))
if "ub" in kwargs:
kwargs["ub"] = transforms.transform(kwargs["ub"].float())
kwargs["ub"] = transforms.transform(kwargs["ub"].to(torch.float64))
_base_obj = model(**kwargs)
else:
_base_obj = model
Expand Down Expand Up @@ -818,6 +836,109 @@ def get_config_options(
return options


class NormalizeScale(Normalize, ConfigurableMixin):
def __init__(
self,
d: int,
indices: Optional[Union[list[int], Tensor]] = None,
bounds: Optional[Tensor] = None,
batch_shape: torch.Size = torch.Size(),
transform_on_train: bool = True,
transform_on_eval: bool = True,
transform_on_fantasize: bool = True,
reverse: bool = False,
min_range: float = 1e-8,
learn_bounds: Optional[bool] = None,
almost_zero: float = 1e-12,
**kwargs,
) -> None:
r"""Normalizes the scale of the parameters.
Args:
d: Total number of parameters (dimensions).
indices: The indices of the inputs to normalize. If omitted,
take all dimensions of the inputs into account.
bounds: If provided, use these bounds to normalize the parameters. If
omitted, learn the bounds in train mode.
batch_shape: The batch shape of the inputs (assuming input tensors
of shape `batch_shape x n x d`). If provided, perform individual
normalization per batch, otherwise uses a single normalization.
transform_on_train: A boolean indicating whether to apply the
transforms in train() mode. Default: True.
transform_on_eval: A boolean indicating whether to apply the
transform in eval() mode. Default: True.
transform_on_fantasize: A boolean indicating whether to apply the
transform when called from within a `fantasize` call. Default: True.
reverse: A boolean indicating whether the forward pass should untransform
the parameters.
min_range: If the range of a parameter is smaller than `min_range`,
that parameter will not be normalized. This is equivalent to
using bounds of `[0, 1]` for this dimension, and helps avoid division
by zero errors and related numerical issues. See the example below.
NOTE: This only applies if `learn_bounds=True`.
learn_bounds: Whether to learn the bounds in train mode. Defaults
to False if bounds are provided, otherwise defaults to True.
**kwargs: Accepted to conform to API.
"""
super().__init__(
d=d,
indices=indices,
bounds=bounds,
batch_shape=batch_shape,
transform_on_train=transform_on_train,
transform_on_eval=transform_on_eval,
transform_on_fantasize=transform_on_fantasize,
reverse=reverse,
min_range=min_range,
learn_bounds=learn_bounds,
almost_zero=almost_zero,
)

@classmethod
def get_config_options(
cls,
config: Config,
name: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
Return a dictionary of the relevant options to initialize a NormalizeScale
transform for the named parameter within the config.
Args:
config (Config): Config to look for options in.
name (str): Parameter to find options for.
options (Dict[str, Any]): Options to override from the config.
Return:
Dict[str, Any]: A diciontary of options to initialize this class with,
including the transformed bounds.
"""
if name is None:
raise ValueError(f"{name} must be set to initialize a transform.")

if options is None:
options = {}
else:
options = deepcopy(options)

# Figure out the index of this parameter
parnames = config.getlist("common", "parnames", element_type=str)
idx = parnames.index(name)

# Make sure we have bounds ready
if "bounds" not in options:
options["bounds"] = get_bounds(config)

if "indices" not in options:
options["indices"] = [idx]

if "d" not in options:
options["d"] = len(parnames)

return options


def transform_options(
config: Config, transforms: Optional[ChainedInputTransform] = None
) -> Config:
Expand Down
29 changes: 28 additions & 1 deletion docs/parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ what parameter types are used and whatever transformations are used.
Currently, we only support continuous parameters. More parameter types soon to come!

<h3>Continuous<h3>

```ini
[parameter]
par_type = continuous
Expand Down Expand Up @@ -58,3 +57,31 @@ For parameters with lower bounds that are positive but still less 1, we will alw
a constant value of 1 (i.e., `Log10(x + 1)` and `10 ^ (x - 1)`). For parameters with
lower bounds that are negative, we will use a constant value of the absolute value of
the lower bound + 1 (i.e., `Log10(x + |lb| + 1)` and `10 ^ (x - |lb| - 1)`).

<h3>Normalize scale</h3>
By default, all parameters will have their scale min-max normalized to the range of
[0, 1]. This prevents any particular parameter with a large scale to completely dominate
the other parameters. Very rarely, this behavior may not be desired and can be turned
off for specific parameters.

```ini
[parameter]
par_type = continuous
lower_bound = 1
upper_bound = 100
normalize_scale = False # turn it on with any of true/yes/on, turn it off with any of false/no/off; case insensitive
```

By setting the `normalize_scale` option to False, this parameter will not be scaled
before being given to the model and therefore maintain its original magnitude. This is
very rarely necessary and should be used with caution.

<h2>Order of operations</h2>
Parameter types and parameter-specific transforms are all handled by the
`ParameterTransform` API. Transforms built from config files will have a specific order
of operation, regardless of how the options were set in the config file. Each parameter
is transformed entirely separately.

Currently, the order is as follows:
* Log scale
* Normalize scale
5 changes: 2 additions & 3 deletions tests/generators/test_manual_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,9 @@ def test_manual_generator(self):
"""
config = Config()
config.update(config_str=config_str)
# gen = ManualGenerator.from_config(config)
gen = ParameterTransformedGenerator.from_config(config, "init_strat")
npt.assert_equal(gen.lb.numpy(), np.array([10, 10]))
npt.assert_equal(gen.ub.numpy(), np.array([11, 11]))
npt.assert_equal(gen.lb.numpy(), np.array([0, 0]))
npt.assert_equal(gen.ub.numpy(), np.array([1, 1]))
self.assertFalse(gen.finished)

p1 = list(gen.gen()[0])
Expand Down
Loading

0 comments on commit 941499b

Please sign in to comment.