Skip to content

Commit

Permalink
add support for binary parameters (facebookresearch#454)
Browse files Browse the repository at this point in the history
Summary:

Binary parameters are just secretly discrete parameters bounded at [0, 1].  Config will accept binary as a par_type and do the necessary work to support it in modeling.

Differential Revision: D65954134
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Nov 15, 2024
1 parent 412e6ce commit 7127015
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 6 deletions.
9 changes: 7 additions & 2 deletions aepsych/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ def update(
# Validate the parameter-specific block
self._check_param_settings(par_name)

lb[i] = self[par_name]["lower_bound"]
ub[i] = self[par_name]["upper_bound"]
lb[i] = self[par_name].get("lower_bound", fallback="0")
ub[i] = self[par_name].get("upper_bound", fallback="1")

self["common"]["lb"] = f"[{', '.join(lb)}]"
self["common"]["ub"] = f"[{', '.join(ub)}]"
Expand Down Expand Up @@ -276,6 +276,11 @@ def _check_param_settings(self, param_name: str) -> None:
and self.getint(param_name, "upper_bound") % 1 == 0
):
raise ValueError(f"Parameter {param_name} has non-integer bounds.")
elif param_block["par_type"] == "binary":
if "lower_bound" in param_block or "upper_bound" in param_block:
raise ValueError(
f"Parameter {param_name} is binary and shouldn't have bounds."
)
else:
raise ValueError(
f"Parameter {param_name} has an unsupported parameter type {param_block['par_type']}."
Expand Down
8 changes: 5 additions & 3 deletions aepsych/transforms/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ def get_config_options(
except KeyError: # Probably because par doesn't have its own section
par_type = "continuous"

# Integer variable
if par_type == "integer":
# Integer or binary variable
if par_type in ["integer", "binary"]:
round = Round.from_config(
config=config, name=par, options=transform_options
)
Expand All @@ -196,7 +196,9 @@ def get_config_options(
transform_dict[f"{par}_Log10Plus"] = log10

# Normalize scale (defaults true)
if config.getboolean(par, "normalize_scale", fallback=True):
if config.getboolean(
par, "normalize_scale", fallback=True
) and par_type not in ["discrete", "binary"]:
normalize = NormalizeScale.from_config(
config=config, name=par, options=transform_options
)
Expand Down
21 changes: 21 additions & 0 deletions docs/parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ what parameter types are used and whatever transformations are used.
Currently, we only support continuous parameters. More parameter types soon to come!

<h3>Continuous<h3>

```ini
[parameter]
par_type = continuous
Expand All @@ -28,6 +29,7 @@ include negative values (e.g., lower bound = -1, upper bound = 1) or have very l
ranges (e.g., lower bound = 0, upper bound = 1,000,000).

<h3>Integer<h3>

```ini
[parameter]
par_type = integer
Expand All @@ -40,6 +42,25 @@ and necessity of bounds. However, integer parameters will use continuous relaxat
allow the models and generators to handle integer input/outputs. For example, this could
represent the number of lights are on for a detection threshold experiment.

<h3>Binary</h3>

```ini
[parameter]
par_type = binary
```

Binary parameters could be useful for modeling parameters like whether a distractor is
on or not during a detection task. Binary parameters are implemented as a special case
of a integer parameter. No bounds should be set. It will be treated as a integer
parameter that will either be 0 or 1. Binary parameters are equivalent to this:

```ini
[parameter]
par_type = discrete
lower_bound = 0
upper_bound = 1
```

<h2>Parameter Transformations</h2>
Currently, we only support a log scale transformation to parameters. More parameter
transformations to come! In general, you can define your parameters in the raw
Expand Down
54 changes: 53 additions & 1 deletion tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ParameterTransformedModel,
ParameterTransforms,
)
from aepsych.transforms.ops import Log10Plus, NormalizeScale
from aepsych.transforms.ops import Log10Plus, NormalizeScale, Round


class TransformsConfigTest(unittest.TestCase):
Expand Down Expand Up @@ -448,3 +448,55 @@ def test_integer_model(self):
est_max = x[np.argmin((zhat - target) ** 2)]
diff = np.abs(est_max / 100 - target)
self.assertTrue(diff < 0.15, f"Diff = {diff}")

def test_binary(self):
config_str = """
[common]
parnames = [signal1]
stimuli_per_trial = 1
outcome_types = [binary]
strategy_names = [init_strat]
[signal1]
par_type = binary
[init_strat]
generator = SobolGenerator
min_asks = 1
"""
config = Config()
config.update(config_str=config_str)

strat = SequentialStrategy.from_config(config)

transforms = strat.transforms

self.assertTrue(len(transforms) == 1)
self.assertTrue(isinstance(list(transforms.values())[0], Round))
self.assertTrue(
torch.all(config.gettensor("common", "lb") == torch.tensor([0]))
)
self.assertTrue(
torch.all(config.gettensor("common", "ub") == torch.tensor([1]))
)

bad_config_str = """
[common]
parnames = [signal1]
stimuli_per_trial = 1
outcome_types = [binary]
strategy_names = [init_strat]
[signal1]
par_type = binary
lower_bound = 0
upper_bound = 1
[init_strat]
generator = SobolGenerator
min_asks = 1
"""
config = Config()

with self.assertRaises(ValueError):
config.update(config_str=bad_config_str)

0 comments on commit 7127015

Please sign in to comment.