Skip to content

Commit

Permalink
implement parameter transforms as generator/model wrappers (facebookr…
Browse files Browse the repository at this point in the history
…esearch#401)

Summary:
Pull Request resolved: facebookresearch#401

Parameter transforms will be handled by wrapping generator and model objects.

The wrappers surfaces the base object API completely and even appears to be the wrapped object upon type inspection. Methods that requires the transformations are overridden by the wrapper to apply the required (un)transforms.

The wrappers expects transforms from BoTorch and new transforms should follow BoTorch's InputTransforms.

As a baseline a log10 transform is implemented.

Differential Revision: D64129439
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Oct 31, 2024
1 parent eebd6d7 commit c1a45e4
Show file tree
Hide file tree
Showing 18 changed files with 992 additions and 78 deletions.
12 changes: 11 additions & 1 deletion aepsych/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@

from gpytorch.likelihoods import BernoulliLikelihood, GaussianLikelihood

from . import acquisition, config, factory, generators, models, strategy, utils
from . import (
acquisition,
config,
factory,
generators,
models,
strategy,
transforms,
utils,
)
from .config import Config
from .likelihoods import BernoulliObjectiveLikelihood
from .models import GPClassificationModel
Expand All @@ -26,6 +35,7 @@
"factory",
"models",
"strategy",
"transforms",
"utils",
"generators",
# classes
Expand Down
4 changes: 4 additions & 0 deletions aepsych/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def make_benchmark_list(self, **bench_config) -> List[Dict[str, float]]:
List[dict[str, float]]: List of dictionaries, each of which can be passed
to aepsych.config.Config.
"""

# This could be a generator but then we couldn't
# know how many params we have, tqdm wouldn't work, etc,
# so we materialize the full list.
Expand Down Expand Up @@ -154,6 +155,9 @@ def run_experiment(
np.random.seed(seed)
config_dict["common"]["lb"] = str(problem.lb.tolist())
config_dict["common"]["ub"] = str(problem.ub.tolist())
config_dict["common"]["parnames"] = str(
[f"par{i}" for i in range(len(problem.ub.tolist()))]
)
config_dict["problem"] = problem.metadata
materialized_config = self.materialize_config(config_dict)

Expand Down
51 changes: 36 additions & 15 deletions aepsych/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,23 @@
# LICENSE file in the root directory of this source tree.
import abc
import ast
import re
import configparser
import json
import re
import warnings
from types import ModuleType
from typing import Any, ClassVar, Dict, List, Mapping, Optional, Sequence, TypeVar
from typing import (
Any,
Callable,
ClassVar,
Dict,
Dict,
List,
Mapping,
Optional,
Sequence,
TypeVar,
)
import botorch
import gpytorch
import numpy as np
Expand All @@ -21,6 +32,7 @@

_T = TypeVar("_T")


class Config(configparser.ConfigParser):

# names in these packages can be referred to by string name
Expand Down Expand Up @@ -75,7 +87,6 @@ def _get(
fallback=configparser._UNSET,
**kwargs,
):

"""
Override configparser to:
1. Return from common if a section doesn't exist. This comes
Expand Down Expand Up @@ -107,8 +118,8 @@ def _get(
)

# Convert config into a dictionary (eliminate duplicates from defaulted 'common' section.)
def to_dict(self, deduplicate: bool = True) -> dict:
_dict = {}
def to_dict(self, deduplicate: bool = True) -> Dict[str, Any]:
_dict: Dict[str, Any] = {}
for section in self:
_dict[section] = {}
for setting in self[section]:
Expand Down Expand Up @@ -160,8 +171,10 @@ def update(
warnings.warn(
"ub and lb have been defined in common section, ignoring parameter specific blocks, be very careful!"
)
elif "parnames" in self["common"]: # it's possible to pass no parnames
par_names = self.getlist("common", "parnames", element_type=str, fallback = [])
elif "parnames" in self["common"]: # it's possible to pass no parnames
par_names = self.getlist(
"common", "parnames", element_type=str, fallback=[]
)
lb = [None] * len(par_names)
ub = [None] * len(par_names)
for i, par_name in enumerate(par_names):
Expand All @@ -174,14 +187,15 @@ def update(
self["common"]["lb"] = f"[{', '.join(lb)}]"
self["common"]["ub"] = f"[{', '.join(ub)}]"


# Deprecation warning for "experiment" section
if "experiment" in self:
for i in self["experiment"]:
self["common"][i] = self["experiment"][i]
del self["experiment"]

def _str_to_list(self, v: str, element_type: _T = float) -> List[_T]:
def _str_to_list(
self, v: str, element_type: Callable[[_T], _T] = float
) -> List[_T]:
v = re.sub(r"\n ", ",", v)
v = re.sub(r"(?<!,)\s+", ",", v)
v = re.sub(r",]", "]", v)
Expand Down Expand Up @@ -223,18 +237,25 @@ def _check_param_settings(self, param_name: str) -> None:

# Checking if param_type is set
if "par_type" not in param_block:
raise ValueError(f"Parameter {param_name} is missing the param_type setting.")
raise ValueError(
f"Parameter {param_name} is missing the param_type setting."
)

# Each parameter type has a different set of required settings
if param_block['par_type'] == "continuous":
if param_block["par_type"] == "continuous":
# Check if bounds exist
if "lower_bound" not in param_block:
raise ValueError(f"Parameter {param_name} is missing the lower_bound setting.")
raise ValueError(
f"Parameter {param_name} is missing the lower_bound setting."
)
if "upper_bound" not in param_block:
raise ValueError(f"Parameter {param_name} is missing the upper_bound setting.")
raise ValueError(
f"Parameter {param_name} is missing the upper_bound setting."
)
else:
raise ValueError(f"Parameter {param_name} has an unsupported parameter type {param_block['par_type']}.")

raise ValueError(
f"Parameter {param_name} has an unsupported parameter type {param_block['par_type']}."
)

def __repr__(self) -> str:
return f"Config at {hex(id(self))}: \n {str(self)}"
Expand Down
37 changes: 35 additions & 2 deletions aepsych/config.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,24 @@

import abc
import configparser
from typing import Any, ClassVar, Dict, List, Mapping, Optional, TypeVar, Union
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
Mapping,
Optional,
TypeVar,
Union,
)

import numpy as np
import torch
from botorch.models.transforms.input import (
ChainedInputTransform,
ReversibleInputTransform,
)

_T = TypeVar("_T")
_ET = TypeVar("_ET")
Expand Down Expand Up @@ -50,7 +64,7 @@ class Config(configparser.ConfigParser):
raw: bool = ...,
vars: Optional[Mapping[str, str]] = ...,
fallback: _T = ...,
element_type: _ET = ...,
element_type: Callable[[_ET], _ET] = ...,
) -> Union[_T, List[_ET]]: ...
def getarray(
self,
Expand All @@ -61,10 +75,29 @@ class Config(configparser.ConfigParser):
vars: Optional[Mapping[str, str]] = ...,
fallback: _T = ...,
) -> Union[np.ndarray, _T]: ...
def getboolean(
self,
section: str,
option: str,
*,
raw: bool = ...,
vars: Mapping[str, str] | None = ...,
fallback: _T = ...,
) -> bool | _T: ...
def getfloat(
self,
section: str,
option: str,
*,
raw: bool = ...,
vars: Mapping[str, str] | None = ...,
fallback: _T = ...,
) -> float | _T: ...
@classmethod
def register_module(cls: _T, module): ...
def jsonifyMetadata(self) -> str: ...
def jsonifyAll(self) -> str: ...
def to_dict(self, deduplicate: bool = ...) -> Dict[str, Any]: ...

class ConfigurableMixin(abc.ABC):
@classmethod
Expand Down
17 changes: 11 additions & 6 deletions aepsych/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import abc
from inspect import signature
from typing import Any, Dict, Generic, Protocol, runtime_checkable, TypeVar, Optional
import re
from inspect import signature
from typing import Any, Dict, Generic, Optional, Protocol, runtime_checkable, TypeVar

import torch
from aepsych.config import Config
from aepsych.models.base import AEPsychMixin
from botorch.acquisition import (
AcquisitionFunction,
NoisyExpectedImprovement,
qNoisyExpectedImprovement,
LogNoisyExpectedImprovement,
NoisyExpectedImprovement,
qLogNoisyExpectedImprovement,
qNoisyExpectedImprovement,
)


Expand All @@ -43,6 +43,9 @@ class AEPsychGenerator(abc.ABC, Generic[AEPsychModelType]):
stimuli_per_trial = 1
max_asks: Optional[int] = None

acqf: AcquisitionFunction
acqf_kwargs: Dict[str, Any]

def __init__(
self,
) -> None:
Expand All @@ -58,7 +61,9 @@ def from_config(cls, config: Config) -> Any:
pass

@classmethod
def _get_acqf_options(cls, acqf: AcquisitionFunction, config: Config) -> Dict[str, Any]:
def _get_acqf_options(
cls, acqf: AcquisitionFunction, config: Config
) -> Dict[str, Any]:
if acqf is not None:
acqf_name = acqf.__name__

Expand All @@ -81,7 +86,7 @@ def _get_acqf_options(cls, acqf: AcquisitionFunction, config: Config) -> Dict[st
elif re.search(
r"^\[.*\]$", v, flags=re.DOTALL
): # use regex to check if the value is a list
extra_acqf_args[k] = config._str_to_list(v) # type: ignore
extra_acqf_args[k] = config._str_to_list(v) # type: ignore
else:
# otherwise try a float
try:
Expand Down
6 changes: 4 additions & 2 deletions aepsych/generators/manual_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

import numpy as np
import torch
from torch.quasirandom import SobolEngine

from aepsych.config import Config
from aepsych.generators.base import AEPsychGenerator
from aepsych.models.base import AEPsychMixin
from aepsych.utils import _process_bounds
from torch.quasirandom import SobolEngine


class ManualGenerator(AEPsychGenerator):
Expand Down Expand Up @@ -70,7 +70,9 @@ def gen(
return points

@classmethod
def from_config(cls, config: Config, name: Optional[str] = None) -> 'ManualGenerator':
def from_config(
cls, config: Config, name: Optional[str] = None
) -> "ManualGenerator":
return cls(**cls.get_config_options(config, name))

@classmethod
Expand Down
12 changes: 7 additions & 5 deletions aepsych/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def bounds(self) -> torch.Tensor:
def dim(self) -> int:
pass

def posterior(self, x: torch.Tensor) -> GPyTorchPosterior:
def posterior(self, X: torch.Tensor) -> GPyTorchPosterior:
pass

def predict(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
Expand Down Expand Up @@ -103,7 +103,9 @@ def update(
) -> None:
pass

def p_below_threshold(self, x, f_thresh) -> torch.Tensor:
def p_below_threshold(
self, x: torch.Tensor, f_thresh: torch.Tensor
) -> torch.Tensor:
pass


Expand Down Expand Up @@ -374,11 +376,11 @@ def _fit_mll(
)
return res

def p_below_threshold(self, x: torch.Tensor, f_thresh: torch.Tensor) -> torch.Tensor:
def p_below_threshold(self, x: torch.Tensor, f_thresh: torch.Tensor) -> torch.Tensor:
f, var = self.predict(x)
f_thresh = f_thresh.reshape(-1, 1)
f = f.reshape(1, -1)
var = var.reshape(1, -1)

z = (f_thresh - f) / var.sqrt()
return torch.distributions.Normal(0, 1).cdf(z) # Use PyTorch's CDF equivalent
return torch.distributions.Normal(0, 1).cdf(z) # Use PyTorch's CDF equivalent
Loading

0 comments on commit c1a45e4

Please sign in to comment.