From 8ed95b44b83dde4e04ece91ea87c0cceb70983df Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Thu, 14 Nov 2024 22:38:57 -0800 Subject: [PATCH] add support for binary parameters (#454) Summary: Binary parameters are just secretly discrete parameters bounded at [0, 1]. Config will accept binary as a par_type and do the necessary work to support it in modeling. Differential Revision: D65954134 --- aepsych/config.py | 9 ++++-- aepsych/transforms/parameters.py | 8 +++-- docs/parameters.md | 21 +++++++++++++ tests/test_transforms.py | 54 +++++++++++++++++++++++++++++++- 4 files changed, 86 insertions(+), 6 deletions(-) diff --git a/aepsych/config.py b/aepsych/config.py index e5b0ca18d..45231b012 100644 --- a/aepsych/config.py +++ b/aepsych/config.py @@ -182,8 +182,8 @@ def update( # Validate the parameter-specific block self._check_param_settings(par_name) - lb[i] = self[par_name]["lower_bound"] - ub[i] = self[par_name]["upper_bound"] + lb[i] = self[par_name].get("lower_bound", fallback="0") + ub[i] = self[par_name].get("upper_bound", fallback="1") self["common"]["lb"] = f"[{', '.join(lb)}]" self["common"]["ub"] = f"[{', '.join(ub)}]" @@ -276,6 +276,11 @@ def _check_param_settings(self, param_name: str) -> None: and self.getint(param_name, "upper_bound") % 1 == 0 ): raise ValueError(f"Parameter {param_name} has non-integer bounds.") + elif param_block["par_type"] == "binary": + if "lower_bound" in param_block or "upper_bound" in param_block: + raise ValueError( + f"Parameter {param_name} is binary and shouldn't have bounds." + ) else: raise ValueError( f"Parameter {param_name} has an unsupported parameter type {param_block['par_type']}." diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index 0aad41d85..0d3b29e2b 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -171,8 +171,8 @@ def get_config_options( except KeyError: # Probably because par doesn't have its own section par_type = "continuous" - # Integer variable - if par_type == "integer": + # Integer or binary variable + if par_type in ["integer", "binary"]: round = Round.from_config( config=config, name=par, options=transform_options ) @@ -196,7 +196,9 @@ def get_config_options( transform_dict[f"{par}_Log10Plus"] = log10 # Normalize scale (defaults true) - if config.getboolean(par, "normalize_scale", fallback=True): + if config.getboolean( + par, "normalize_scale", fallback=True + ) and par_type not in ["discrrete", "binary"]: normalize = NormalizeScale.from_config( config=config, name=par, options=transform_options ) diff --git a/docs/parameters.md b/docs/parameters.md index f93526b66..9bdd0a539 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -15,6 +15,7 @@ what parameter types are used and whatever transformations are used. Currently, we only support continuous parameters. More parameter types soon to come!

Continuous

+ ```ini [parameter] par_type = continuous @@ -28,6 +29,7 @@ include negative values (e.g., lower bound = -1, upper bound = 1) or have very l ranges (e.g., lower bound = 0, upper bound = 1,000,000).

Integer

+ ```ini [parameter] par_type = integer @@ -40,6 +42,25 @@ and necessity of bounds. However, integer parameters will use continuous relaxat allow the models and generators to handle integer input/outputs. For example, this could represent the number of lights are on for a detection threshold experiment. +

Binary

+ +```ini +[parameter] +par_type = binary +``` + +Binary parameters could be useful for modeling parameters like whether a distractor is +on or not during a detection task. Binary parameters are implemented as a special case +of a integer parameter. No bounds should be set. It will be treated as a integer +parameter that will either be 0 or 1. Binary parameters are equivalent to this: + +```ini +[parameter] +par_type = discrete +lower_bound = 0 +upper_bound = 1 +``` +

Parameter Transformations

Currently, we only support a log scale transformation to parameters. More parameter transformations to come! In general, you can define your parameters in the raw diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 146da3a48..97d30c498 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -17,7 +17,7 @@ ParameterTransformedModel, ParameterTransforms, ) -from aepsych.transforms.ops import Log10Plus, NormalizeScale +from aepsych.transforms.ops import Discretize, Log10Plus, NormalizeScale class TransformsConfigTest(unittest.TestCase): @@ -448,3 +448,55 @@ def test_integer_model(self): est_max = x[np.argmin((zhat - target) ** 2)] diff = np.abs(est_max / 100 - target) self.assertTrue(diff < 0.15, f"Diff = {diff}") + + def test_binary(self): + config_str = """ + [common] + parnames = [signal1] + stimuli_per_trial = 1 + outcome_types = [binary] + strategy_names = [init_strat] + + [signal1] + par_type = binary + + [init_strat] + generator = SobolGenerator + min_asks = 1 + """ + config = Config() + config.update(config_str=config_str) + + strat = SequentialStrategy.from_config(config) + + transforms = strat.transforms + + self.assertTrue(len(transforms) == 1) + self.assertTrue(isinstance(list(transforms.values())[0], Discretize)) + self.assertTrue( + torch.all(config.gettensor("common", "lb") == torch.tensor([0])) + ) + self.assertTrue( + torch.all(config.gettensor("common", "ub") == torch.tensor([1])) + ) + + bad_config_str = """ + [common] + parnames = [signal1] + stimuli_per_trial = 1 + outcome_types = [binary] + strategy_names = [init_strat] + + [signal1] + par_type = binary + lower_bound = 0 + upper_bound = 1 + + [init_strat] + generator = SobolGenerator + min_asks = 1 + """ + config = Config() + + with self.assertRaises(ValueError): + config.update(config_str=bad_config_str)