diff --git a/aepsych/config.py b/aepsych/config.py index afbc323be..e468069fb 100644 --- a/aepsych/config.py +++ b/aepsych/config.py @@ -8,6 +8,7 @@ import ast import configparser import json +import logging import re import warnings from types import ModuleType @@ -166,24 +167,36 @@ def update( # Warn if ub/lb is defined in common section if "ub" in self["common"] and "lb" in self["common"]: - warnings.warn( - "ub and lb have been defined in common section, ignoring parameter specific blocks, be very careful!" - ) - elif "parnames" in self["common"]: # it's possible to pass no parnames - par_names = self.getlist( - "common", "parnames", element_type=str, fallback=[] + logging.warning( + "ub and lb have been defined in common section, parameter-specific bounds take precendence over these." ) - lb = [None] * len(par_names) - ub = [None] * len(par_names) - for i, par_name in enumerate(par_names): - # Validate the parameter-specific block - self._check_param_settings(par_name) - lb[i] = self[par_name]["lower_bound"] - ub[i] = self[par_name]["upper_bound"] - - self["common"]["lb"] = f"[{', '.join(lb)}]" - self["common"]["ub"] = f"[{', '.join(ub)}]" + if "parnames" in self["common"]: # it's possible to pass no parnames + try: + par_names = self.getlist( + "common", "parnames", element_type=str, fallback=[] + ) + lb = [None] * len(par_names) + ub = [None] * len(par_names) + for i, par_name in enumerate(par_names): + # Validate the parameter-specific block + self._check_param_settings(par_name) + + lb[i] = self[par_name]["lower_bound"] + ub[i] = self[par_name]["upper_bound"] + + self["common"]["lb"] = f"[{', '.join(lb)}]" + self["common"]["ub"] = f"[{', '.join(ub)}]" + except ValueError: + # Check if ub/lb exists in common + if "ub" in self["common"] and "lb" in self["common"]: + logging.warning( + "Parameter-specific bounds are incomplete, falling back to ub/lb in [common]" + ) + else: + raise ValueError( + "Missing ub or lb in [common] with incomplete parameter-specific bounds, cannot fallback!" + ) # Deprecation warning for "experiment" section if "experiment" in self: diff --git a/aepsych/models/base.py b/aepsych/models/base.py index 67f2af75a..4d554be81 100644 --- a/aepsych/models/base.py +++ b/aepsych/models/base.py @@ -337,10 +337,6 @@ def set_train_data( if targets is not None: self.train_targets = targets - def normalize_inputs(self, x: torch.Tensor) -> torch.Tensor: - scale = self.ub - self.lb - return (x - self.lb) / scale - def forward(self, x: torch.Tensor) -> gpytorch.distributions.MultivariateNormal: """Evaluate GP @@ -351,9 +347,8 @@ def forward(self, x: torch.Tensor) -> gpytorch.distributions.MultivariateNormal: gpytorch.distributions.MultivariateNormal: Distribution object holding mean and covariance at x. """ - transformed_x = self.normalize_inputs(x) - mean_x = self.mean_module(transformed_x) - covar_x = self.covar_module(transformed_x) + mean_x = self.mean_module(x) + covar_x = self.covar_module(x) pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x) return pred diff --git a/aepsych/models/monotonic_rejection_gp.py b/aepsych/models/monotonic_rejection_gp.py index ec391ff50..9bf761abe 100644 --- a/aepsych/models/monotonic_rejection_gp.py +++ b/aepsych/models/monotonic_rejection_gp.py @@ -341,11 +341,7 @@ def forward(self, x: torch.Tensor) -> gpytorch.distributions.MultivariateNormal: gpytorch.distributions.MultivariateNormal: Distribution object holding mean and covariance at x. """ - - # final dim is deriv index, we only normalize the "real" dims - transformed_x = x.clone() - transformed_x[..., :-1] = self.normalize_inputs(transformed_x[..., :-1]) - mean_x = self.mean_module(transformed_x) - covar_x = self.covar_module(transformed_x) + mean_x = self.mean_module(x) + covar_x = self.covar_module(x) latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x) return latent_pred diff --git a/aepsych/models/multitask_regression.py b/aepsych/models/multitask_regression.py index cc73d3967..aab0b396c 100644 --- a/aepsych/models/multitask_regression.py +++ b/aepsych/models/multitask_regression.py @@ -79,9 +79,8 @@ def __init__( def forward( self, x: torch.Tensor ) -> gpytorch.distributions.MultitaskMultivariateNormal: - transformed_x = self.normalize_inputs(x) - mean_x = self.mean_module(transformed_x) - covar_x = self.covar_module(transformed_x) + mean_x = self.mean_module(x) + covar_x = self.covar_module(x) return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x) @classmethod diff --git a/aepsych/models/semi_p.py b/aepsych/models/semi_p.py index e1aa518e3..da8c98dfb 100644 --- a/aepsych/models/semi_p.py +++ b/aepsych/models/semi_p.py @@ -529,18 +529,17 @@ def forward(self, x: torch.Tensor) -> MultivariateNormal: Returns: MVN object evaluated at samples """ - transformed_x = self.normalize_inputs(x) # TODO: make slope prop to intensity width. - slope_mean = self.slope_mean_module(transformed_x) + slope_mean = self.slope_mean_module(x) # kc mvn - offset_mean = self.offset_mean_module(transformed_x) + offset_mean = self.offset_mean_module(x) - slope_cov = self.slope_covar_module(transformed_x) - offset_cov = self.offset_covar_module(transformed_x) + slope_cov = self.slope_covar_module(x) + offset_cov = self.offset_covar_module(x) mean_x, cov_x = _hadamard_mvn_approx( - x_intensity=transformed_x[..., self.stim_dim], + x_intensity=x[..., self.stim_dim], slope_mean=slope_mean, slope_cov=slope_cov, offset_mean=offset_mean, diff --git a/aepsych/models/utils.py b/aepsych/models/utils.py index 71fe55781..25aa21f0a 100644 --- a/aepsych/models/utils.py +++ b/aepsych/models/utils.py @@ -179,6 +179,9 @@ def get_extremum( timeout_sec=max_time, ) + if hasattr(model, "transforms"): + best_point = model.transforms.untransform(best_point) + # PosteriorMean flips the sign on minimize, we flip it back if extremum_type == "min": best_val = -best_val diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index 432ac22e1..a96b6e68f 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -17,7 +17,7 @@ from aepsych.generators.base import AEPsychGenerator from aepsych.models.base import AEPsychMixin, ModelProtocol from botorch.acquisition import AcquisitionFunction -from botorch.models.transforms.input import ChainedInputTransform, Log10 +from botorch.models.transforms.input import ChainedInputTransform, Log10, Normalize from botorch.models.transforms.utils import subset_transform from botorch.posteriors import Posterior from torch import Tensor @@ -128,16 +128,34 @@ def get_config_options( parnames = config.getlist("common", "parnames", element_type=str) # This is the "options" dictionary, transform options is only for maintaining the right transforms - transform_dict = {} + transform_dict: Dict[str, ChainedInputTransform] = {} for par in parnames: # This is the order that transforms are potentially applied, order matters # Log scale if config.getboolean(par, "log_scale", fallback=False): - transform_dict[f"{par}_Log10Plus"] = Log10Plus.from_config( + log10 = Log10Plus.from_config( config=config, name=par, options=transform_options ) + # Transform bounds + transform_options["bounds"] = log10.transform( + transform_options["bounds"] + ) + transform_dict[f"{par}_Log10Plus"] = log10 + + # Normalize scale (defaults true) + if config.getboolean(par, "normalize_scale", fallback=True): + normalize = NormalizeScale.from_config( + config=config, name=par, options=transform_options + ) + + # Transform bounds + transform_options["bounds"] = normalize.transform( + transform_options["bounds"] + ) + transform_dict[f"{par}_NormalizeScale"] = normalize + return transform_dict @@ -202,9 +220,9 @@ def __init__( # Figure out what we need to do with generator if isinstance(generator, type): if "lb" in kwargs: - kwargs["lb"] = transforms.transform(kwargs["lb"].float()) + kwargs["lb"] = transforms.transform(kwargs["lb"].to(torch.float64)) if "ub" in kwargs: - kwargs["ub"] = transforms.transform(kwargs["ub"].float()) + kwargs["ub"] = transforms.transform(kwargs["ub"].to(torch.float64)) _base_obj = generator(**kwargs) else: _base_obj = generator @@ -349,9 +367,9 @@ def __init__( # Alternative instantiation method for analysis (and not live) if isinstance(model, type): if "lb" in kwargs: - kwargs["lb"] = transforms.transform(kwargs["lb"].float()) + kwargs["lb"] = transforms.transform(kwargs["lb"].to(torch.float64)) if "ub" in kwargs: - kwargs["ub"] = transforms.transform(kwargs["ub"].float()) + kwargs["ub"] = transforms.transform(kwargs["ub"].to(torch.float64)) _base_obj = model(**kwargs) else: _base_obj = model @@ -818,6 +836,109 @@ def get_config_options( return options +class NormalizeScale(Normalize, ConfigurableMixin): + def __init__( + self, + d: int, + indices: Optional[Union[list[int], Tensor]] = None, + bounds: Optional[Tensor] = None, + batch_shape: torch.Size = torch.Size(), + transform_on_train: bool = True, + transform_on_eval: bool = True, + transform_on_fantasize: bool = True, + reverse: bool = False, + min_range: float = 1e-8, + learn_bounds: Optional[bool] = None, + almost_zero: float = 1e-12, + **kwargs, + ) -> None: + r"""Normalizes the scale of the parameters. + + Args: + d: Total number of parameters (dimensions). + indices: The indices of the inputs to normalize. If omitted, + take all dimensions of the inputs into account. + bounds: If provided, use these bounds to normalize the parameters. If + omitted, learn the bounds in train mode. + batch_shape: The batch shape of the inputs (assuming input tensors + of shape `batch_shape x n x d`). If provided, perform individual + normalization per batch, otherwise uses a single normalization. + transform_on_train: A boolean indicating whether to apply the + transforms in train() mode. Default: True. + transform_on_eval: A boolean indicating whether to apply the + transform in eval() mode. Default: True. + transform_on_fantasize: A boolean indicating whether to apply the + transform when called from within a `fantasize` call. Default: True. + reverse: A boolean indicating whether the forward pass should untransform + the parameters. + min_range: If the range of a parameter is smaller than `min_range`, + that parameter will not be normalized. This is equivalent to + using bounds of `[0, 1]` for this dimension, and helps avoid division + by zero errors and related numerical issues. See the example below. + NOTE: This only applies if `learn_bounds=True`. + learn_bounds: Whether to learn the bounds in train mode. Defaults + to False if bounds are provided, otherwise defaults to True. + **kwargs: Accepted to conform to API. + """ + super().__init__( + d=d, + indices=indices, + bounds=bounds, + batch_shape=batch_shape, + transform_on_train=transform_on_train, + transform_on_eval=transform_on_eval, + transform_on_fantasize=transform_on_fantasize, + reverse=reverse, + min_range=min_range, + learn_bounds=learn_bounds, + almost_zero=almost_zero, + ) + + @classmethod + def get_config_options( + cls, + config: Config, + name: Optional[str] = None, + options: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """ + Return a dictionary of the relevant options to initialize a NormalizeScale + transform for the named parameter within the config. + + Args: + config (Config): Config to look for options in. + name (str): Parameter to find options for. + options (Dict[str, Any]): Options to override from the config. + + Return: + Dict[str, Any]: A diciontary of options to initialize this class with, + including the transformed bounds. + """ + if name is None: + raise ValueError(f"{name} must be set to initialize a transform.") + + if options is None: + options = {} + else: + options = deepcopy(options) + + # Figure out the index of this parameter + parnames = config.getlist("common", "parnames", element_type=str) + idx = parnames.index(name) + + # Make sure we have bounds ready + if "bounds" not in options: + options["bounds"] = get_bounds(config) + + if "indices" not in options: + options["indices"] = [idx] + + if "d" not in options: + options["d"] = len(parnames) + + return options + + def transform_options( config: Config, transforms: Optional[ChainedInputTransform] = None ) -> Config: diff --git a/docs/parameters.md b/docs/parameters.md index 4aa0157d1..ef7dfaf14 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -15,7 +15,6 @@ what parameter types are used and whatever transformations are used. Currently, we only support continuous parameters. More parameter types soon to come!

Continuous

- ```ini [parameter] par_type = continuous @@ -58,3 +57,31 @@ For parameters with lower bounds that are positive but still less 1, we will alw a constant value of 1 (i.e., `Log10(x + 1)` and `10 ^ (x - 1)`). For parameters with lower bounds that are negative, we will use a constant value of the absolute value of the lower bound + 1 (i.e., `Log10(x + |lb| + 1)` and `10 ^ (x - |lb| - 1)`). + +

Normalize scale

+By default, all parameters will have their scale min-max normalized to the range of +[0, 1]. This prevents any particular parameter with a large scale to completely dominate +the other parameters. Very rarely, this behavior may not be desired and can be turned +off for specific parameters. + +```ini +[parameter] +par_type = continuous +lower_bound = 1 +upper_bound = 100 +normalize_scale = False # turn it on with any of true/yes/on, turn it off with any of false/no/off; case insensitive +``` + +By setting the `normalize_scale` option to False, this parameter will not be scaled +before being given to the model and therefore maintain its original magnitude. This is +very rarely necessary and should be used with caution. + +

Order of operations

+Parameter types and parameter-specific transforms are all handled by the +`ParameterTransform` API. Transforms built from config files will have a specific order +of operation, regardless of how the options were set in the config file. Each parameter +is transformed entirely separately. + +Currently, the order is as follows: +* Log scale +* Normalize scale \ No newline at end of file diff --git a/tests/generators/test_manual_generator.py b/tests/generators/test_manual_generator.py index d72daf274..e517b2808 100644 --- a/tests/generators/test_manual_generator.py +++ b/tests/generators/test_manual_generator.py @@ -51,10 +51,9 @@ def test_manual_generator(self): """ config = Config() config.update(config_str=config_str) - # gen = ManualGenerator.from_config(config) gen = ParameterTransformedGenerator.from_config(config, "init_strat") - npt.assert_equal(gen.lb.numpy(), np.array([10, 10])) - npt.assert_equal(gen.ub.numpy(), np.array([11, 11])) + npt.assert_equal(gen.lb.numpy(), np.array([0, 0])) + npt.assert_equal(gen.ub.numpy(), np.array([1, 1])) self.assertFalse(gen.finished) p1 = list(gen.gen()[0]) diff --git a/tests/models/test_gp_classification.py b/tests/models/test_gp_classification.py index d2d84045a..465bc6c2b 100644 --- a/tests/models/test_gp_classification.py +++ b/tests/models/test_gp_classification.py @@ -23,6 +23,8 @@ from aepsych.generators import OptimizeAcqfGenerator, SobolGenerator from aepsych.models import GPClassificationModel from aepsych.strategy import SequentialStrategy, Strategy +from aepsych.transforms import ParameterTransformedModel, ParameterTransforms +from aepsych.transforms.parameters import Normalize from botorch.acquisition import qUpperConfidenceBound from botorch.optim.fit import fit_gpytorch_mll_torch from botorch.optim.stopping import ExpMAStoppingCriterion @@ -211,11 +213,19 @@ def test_1d_classification_different_scales(self): X, y = torch.Tensor(X), torch.Tensor(y) X[:, 0] = X[:, 0] * 1000 X[:, 1] = X[:, 1] / 1000 - lb = [-3000, -0.003] - ub = [3000, 0.003] - - model = GPClassificationModel(lb=lb, ub=ub, inducing_size=20) + lb = torch.tensor([-3000, -0.003]) + ub = torch.tensor([3000, 0.003]) + transforms = ParameterTransforms( + normalize=Normalize(2, bounds=torch.stack((lb, ub))) + ) + model = ParameterTransformedModel( + model=GPClassificationModel, + lb=lb, + ub=ub, + inducing_size=20, + transforms=transforms, + ) model.fit(X[:50], y[:50]) # pspace diff --git a/tests/models/test_gp_regression.py b/tests/models/test_gp_regression.py index 41e3b5ff8..263a1e58c 100644 --- a/tests/models/test_gp_regression.py +++ b/tests/models/test_gp_regression.py @@ -88,8 +88,8 @@ def test_extremum(self): def test_from_config(self): model = self.server.strat.model - npt.assert_allclose(model.lb, [-1.0]) - npt.assert_allclose(model.ub, [3.0]) + npt.assert_allclose(model.transforms.untransform(model.lb), [-1.0]) + npt.assert_allclose(model.transforms.untransform(model.ub), [3.0]) self.assertEqual(model.dim, 1) self.assertIsInstance(model.likelihood, GaussianLikelihood) self.assertEqual(model.max_fit_time, 1) diff --git a/tests/models/test_pairwise_probit.py b/tests/models/test_pairwise_probit.py index 4b4f2eedf..d299e3f23 100644 --- a/tests/models/test_pairwise_probit.py +++ b/tests/models/test_pairwise_probit.py @@ -22,6 +22,12 @@ from aepsych.server.message_handlers.handle_setup import configure from aepsych.server.message_handlers.handle_tell import tell from aepsych.strategy import SequentialStrategy, Strategy +from aepsych.transforms import ( + ParameterTransformedGenerator, + ParameterTransformedModel, + ParameterTransforms, +) +from aepsych.transforms.parameters import Normalize from botorch.acquisition import qUpperConfidenceBound from botorch.acquisition.active_learning import PairwiseMCPosteriorVariance from scipy.stats import bernoulli, norm, pearsonr @@ -192,30 +198,49 @@ def test_1d_pairwise_probit(self): np.random.seed(seed) n_init = 50 n_opt = 1 - lb = -4.0 - ub = 4.0 + lb = torch.tensor([-4.0]) + ub = torch.tensor([4.0]) extra_acqf_args = {"beta": 3.84} + transforms = ParameterTransforms( + normalize=Normalize(d=1, bounds=torch.stack([lb, ub])) + ) + sobol_gen = ParameterTransformedGenerator( + generator=SobolGenerator, + lb=lb, + ub=ub, + seed=seed, + stimuli_per_trial=2, + transforms=transforms, + ) + acqf_gen = ParameterTransformedGenerator( + generator=OptimizeAcqfGenerator, + acqf=qUpperConfidenceBound, + acqf_kwargs=extra_acqf_args, + stimuli_per_trial=2, + transforms=transforms, + ) + probit_model = ParameterTransformedModel( + model=PairwiseProbitModel, lb=lb, ub=ub, transforms=transforms + ) model_list = [ Strategy( lb=lb, ub=ub, - generator=SobolGenerator(lb=lb, ub=ub, seed=seed, stimuli_per_trial=2), + generator=sobol_gen, min_asks=n_init, stimuli_per_trial=2, outcome_types=["binary"], + transforms=transforms, ), Strategy( lb=lb, ub=ub, - model=PairwiseProbitModel(lb=lb, ub=ub), - generator=OptimizeAcqfGenerator( - acqf=qUpperConfidenceBound, - acqf_kwargs=extra_acqf_args, - stimuli_per_trial=2, - ), + model=probit_model, + generator=acqf_gen, min_asks=n_opt, stimuli_per_trial=2, outcome_types=["binary"], + transforms=transforms, ), ] diff --git a/tests/test_config.py b/tests/test_config.py index 1ad61772d..87ef7d209 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -31,7 +31,7 @@ from aepsych.server.message_handlers.handle_setup import configure from aepsych.strategy import SequentialStrategy, Strategy from aepsych.transforms import ParameterTransforms, transform_options -from aepsych.transforms.parameters import Log10Plus +from aepsych.transforms.parameters import Log10Plus, NormalizeScale from aepsych.version import __version__ from botorch.acquisition import qLogNoisyExpectedImprovement from botorch.acquisition.active_learning import PairwiseMCPosteriorVariance @@ -50,12 +50,12 @@ def test_single_probit_config(self): [par1] par_type = continuous - lower_bound = 0 - upper_bound = 1 + lower_bound = 1 + upper_bound = 10 [par2] par_type = continuous - lower_bound = 0 + lower_bound = -1 upper_bound = 1 [init_strat] @@ -123,9 +123,19 @@ def test_single_probit_config(self): self.assertTrue(strat.strat_list[0].outcome_types == ["binary"]) self.assertTrue(strat.strat_list[1].min_asks == 20) self.assertTrue(torch.all(strat.strat_list[0].lb == strat.strat_list[1].lb)) - self.assertTrue(torch.all(strat.strat_list[1].model.lb == torch.Tensor([0, 0]))) + self.assertTrue( + torch.all( + strat.transforms.untransform(strat.strat_list[1].model.lb) + == torch.Tensor([1, -1]) + ) + ) self.assertTrue(torch.all(strat.strat_list[0].ub == strat.strat_list[1].ub)) - self.assertTrue(torch.all(strat.strat_list[1].model.ub == torch.Tensor([1, 1]))) + self.assertTrue( + torch.all( + strat.transforms.untransform(strat.strat_list[1].model.ub) + == torch.Tensor([10, 1]) + ) + ) self.assertEqual(strat.strat_list[0].min_total_outcome_occurrences, 5) self.assertEqual(strat.strat_list[0].min_post_range, None) @@ -1038,6 +1048,7 @@ def test_derived_bounds(self): par_type = continuous lower_bound = -10 upper_bound = 10 + normalize_scale = False [init_strat] min_total_tells = 10 @@ -1061,7 +1072,7 @@ def test_derived_bounds(self): self.assertTrue(torch.all(model.lb == torch.Tensor([0, -10]))) self.assertTrue(torch.all(model.ub == torch.Tensor([1, 10]))) - def test_ignore_specific_bounds(self): + def test_ignore_common_bounds(self): config_str = """ [common] parnames = [par1, par2] @@ -1081,6 +1092,7 @@ def test_ignore_specific_bounds(self): par_type = continuous lower_bound = -5 upper_bound = 1 + normalize_scale = False [init_strat] min_total_tells = 10 @@ -1101,9 +1113,53 @@ def test_ignore_specific_bounds(self): opt_strat = strat.strat_list[1] model = opt_strat.model - self.assertTrue(torch.all(model.lb == torch.Tensor([0, 0]))) + self.assertTrue(torch.all(model.lb == torch.Tensor([0, -5]))) self.assertTrue(torch.all(model.ub == torch.Tensor([1, 1]))) + def test_common_fallback_bounds(self): + config_str = """ + [common] + parnames = [par1, par2] + lb = [0, 0] + ub = [1, 100] + stimuli_per_trial = 1 + outcome_types = [binary] + target = 0.75 + strategy_names = [init_strat, opt_strat] + + [par1] + par_type = continuous + lower_bound = 1 + upper_bound = 100 + + [par2] + par_type = continuous + # lower_bound = -5 + # upper_bound = 1 + normalize_scale = False + + [init_strat] + min_total_tells = 10 + generator = SobolGenerator + + [opt_strat] + min_total_tells = 20 + refit_every = 5 + generator = OptimizeAcqfGenerator + acqf = MCLevelSetEstimation + model = GPClassificationModel + """ + + config = Config() + config.update(config_str=config_str) + + strat = SequentialStrategy.from_config(config) + opt_strat = strat.strat_list[1] + model = opt_strat.model + + self.assertTrue(torch.all(model.lb == torch.Tensor([0, 0]))) + self.assertTrue(torch.all(model.ub == torch.Tensor([1, 100]))) + def test_parameter_setting_block_validation(self): config_str = """ [common] @@ -1213,8 +1269,8 @@ def test_clone_transform_options(self): self.assertFalse(torch.all(config_points == xformed_points)) self.assertFalse(torch.all(config_window == xformed_window)) - self.assertTrue(torch.allclose(xformed_lb, torch.tensor([0.0, 1.0]))) - self.assertTrue(torch.allclose(xformed_ub, torch.tensor([1.0, 2.0]))) + self.assertTrue(torch.allclose(xformed_lb, torch.tensor([0.0, 0.0]))) + self.assertTrue(torch.allclose(xformed_ub, torch.tensor([1.0, 1.0]))) transforms = ParameterTransforms.from_config(config) reversed_points = transforms.untransform(xformed_points) @@ -1245,11 +1301,20 @@ def test_build_transform(self): transforms = ParameterTransforms.from_config(config) - self.assertTrue(len(transforms.values()) == 1) + self.assertTrue(len(transforms.values()) == 3) tf = list(transforms.items())[0] - self.assertTrue(tf[0] == "signal2_Log10Plus") - self.assertTrue(isinstance(tf[1], Log10Plus)) + expected_names = [ + "signal1_NormalizeScale", + "signal2_Log10Plus", + "signal2_NormalizeScale", + ] + expected_transforms = [NormalizeScale, Log10Plus, NormalizeScale] + for tf, name, transform in zip( + transforms.items(), expected_names, expected_transforms + ): + self.assertTrue(tf[0] == name) + self.assertTrue(isinstance(tf[1], transform)) if __name__ == "__main__": diff --git a/tests/test_strategy.py b/tests/test_strategy.py index 2bbb12adc..e826e2897 100644 --- a/tests/test_strategy.py +++ b/tests/test_strategy.py @@ -15,6 +15,12 @@ from aepsych.models.gp_classification import GPClassificationModel from aepsych.models.monotonic_rejection_gp import MonotonicRejectionGP from aepsych.strategy import SequentialStrategy, Strategy +from aepsych.transforms import ( + ParameterTransformedGenerator, + ParameterTransformedModel, + ParameterTransforms, +) +from aepsych.transforms.parameters import Normalize class TestSequenceGenerators(unittest.TestCase): @@ -22,20 +28,29 @@ def setUp(self): seed = 1 torch.manual_seed(seed) np.random.seed(seed) - lb = [-1, -1] - ub = [1, 1] + lb = torch.tensor([-1, -1]) + ub = torch.tensor([1, 1]) extra_acqf_args = {"target": 0.75, "beta": 1.96} + transforms = ParameterTransforms( + normalize=Normalize(d=2, bounds=torch.stack([lb, ub])) + ) + self.strat = Strategy( - model=MonotonicRejectionGP( + model=ParameterTransformedModel( + MonotonicRejectionGP, + transforms=transforms, + dim=2, lb=lb, ub=ub, - dim=2, monotonic_idxs=[1], ), - generator=MonotonicRejectionGenerator( - acqf=MonotonicMCLSE, acqf_kwargs=extra_acqf_args + generator=ParameterTransformedGenerator( + MonotonicRejectionGenerator, + transforms=transforms, + acqf=MonotonicMCLSE, + acqf_kwargs=extra_acqf_args, ), min_asks=50, lb=lb, diff --git a/tests/test_transforms.py b/tests/test_transforms.py index c6fd9a35d..7beed7a5c 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -17,7 +17,7 @@ ParameterTransformedModel, ParameterTransforms, ) -from aepsych.transforms.parameters import Log10Plus +from aepsych.transforms.parameters import Log10Plus, NormalizeScale class TransformsConfigTest(unittest.TestCase): @@ -45,7 +45,7 @@ def setUp(self): [init_strat] min_total_tells = 10 generator = SobolGenerator - + [SobolGenerator] seed = 12345 @@ -121,10 +121,15 @@ def test_transforms_in_strategy(self): class TransformsLog10Test(unittest.TestCase): - def test_transform_reshape(self): - x = torch.rand(4, 3, 2) + 1.0 - - transforms = ParameterTransforms(Log10Plus=Log10Plus(indices=[0, 1, 2])) + def test_transform_reshape3D(self): + lb = torch.tensor([-1, 0, 10]) + ub = torch.tensor([-1e-6, 9, 99]) + x = SobolGenerator(lb=lb, ub=ub, stimuli_per_trial=2).gen(4) + + transforms = ParameterTransforms( + log10=Log10Plus(indices=[0, 1, 2], constant=2), + normalize=NormalizeScale(d=3, bounds=torch.stack([lb, ub])), + ) transformed_x = transforms.transform(x) untransformed_x = transforms.untransform(transformed_x) @@ -143,12 +148,14 @@ def test_log_transform(self): lower_bound = -10 upper_bound = 10 log_scale = false + normalize_scale = no [signal2] par_type = continuous lower_bound = 1 upper_bound = 100 log_scale = true + normalize_scale = off """ config = Config() config.update(config_str=config_str) @@ -239,3 +246,35 @@ def test_log_model(self): est_max = x[np.argmin((zhat - target) ** 2)] diff = np.abs(est_max / 100 - target) self.assertTrue(diff < 0.15, f"Diff = {diff}") + + +class TransformsNormalize(unittest.TestCase): + def test_normalize_scale(self): + config_str = """ + [common] + parnames = [signal1, signal2] + stimuli_per_trial = 1 + outcome_types = [binary] + + [signal1] + par_type = continuous + lower_bound = -10 + upper_bound = 10 + normalize_scale = false + + [signal2] + par_type = continuous + lower_bound = 0 + upper_bound = 100 + """ + config = Config() + config.update(config_str=config_str) + + transforms = ParameterTransforms.from_config(config) + + values = torch.tensor([[-5.0, 20.0], [20.0, 1.0]]) + expected = torch.tensor([[-5.0, 0.2], [20.0, 0.01]]) + transformed = transforms.transform(values) + + self.assertTrue(torch.allclose(transformed, expected)) + self.assertTrue(torch.allclose(transforms.untransform(transformed), values))