Skip to content

Commit

Permalink
Support optimizer options to be modified in models, starting from con…
Browse files Browse the repository at this point in the history
…fig (#448)

Summary:
Pull Request resolved: #448

Optimizer options can be used to control the SciPy minimize function during model fit. We now allow these options to be set during model initialization (including from config).

Every model includes these options.

Reviewed By: crasanders

Differential Revision: D65641684

fbshipit-source-id: 4328889313d05c129aaf9d7e32b881c8a24b0807
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Nov 20, 2024
1 parent 10100e4 commit 572374f
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 20 deletions.
6 changes: 5 additions & 1 deletion aepsych/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,14 +364,18 @@ def _fit_mll(
optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs.copy()
max_fit_time = kwargs.pop("max_fit_time", self.max_fit_time)
if max_fit_time is not None:
if "options" not in optimizer_kwargs:
optimizer_kwargs["options"] = {}

# figure out how long evaluating a single samp
starttime = time.time()
_ = mll(self(train_x), train_y)
single_eval_time = (
time.time() - starttime + 1e-6
) # add an epsilon to avoid divide by zero
n_eval = int(max_fit_time / single_eval_time)
optimizer_kwargs["options"] = {"maxfun": n_eval}

optimizer_kwargs["options"]["maxfun"] = n_eval
logger.info(f"fit maxfun is {n_eval}")

starttime = time.time()
Expand Down
22 changes: 18 additions & 4 deletions aepsych/models/gp_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import warnings
from copy import deepcopy
from typing import Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple

import gpytorch
import numpy as np
Expand All @@ -17,7 +17,7 @@
from aepsych.factory.default import default_mean_covar_factory
from aepsych.models.base import AEPsychModelDeviceMixin
from aepsych.models.utils import select_inducing_points
from aepsych.utils import _process_bounds, promote_0d
from aepsych.utils import _process_bounds, get_optimizer_options, promote_0d
from aepsych.utils_logging import getLogger
from gpytorch.likelihoods import BernoulliLikelihood, BetaLikelihood, Likelihood
from gpytorch.models import ApproximateGP
Expand Down Expand Up @@ -57,6 +57,7 @@ def __init__(
inducing_size: Optional[int] = None,
max_fit_time: Optional[float] = None,
inducing_point_method: str = "auto",
optimizer_options: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the GP Classification model
Expand All @@ -78,12 +79,17 @@ def __init__(
If "pivoted_chol", selects points based on the pivoted Cholesky heuristic.
If "kmeans++", selects points by performing kmeans++ clustering on the training data.
If "auto", tries to determine the best method automatically.
optimizer_options (Dict[str, Any], optional): Optimizer options to pass to the SciPy optimizer during
fitting. Assumes we are using L-BFGS-B.
"""
lb, ub, self.dim = _process_bounds(lb, ub, dim)

self.max_fit_time = max_fit_time
self.inducing_size = inducing_size or 99

self.optimizer_options = (
{"options": optimizer_options} if optimizer_options else {"options": {}}
)

if self.inducing_size >= 100:
logger.warning(
(
Expand Down Expand Up @@ -174,6 +180,8 @@ def from_config(cls, config: Config) -> GPClassificationModel:
else:
likelihood = None # fall back to __init__ default

optimizer_options = get_optimizer_options(config, classname)

return cls(
lb=lb,
ub=ub,
Expand All @@ -184,6 +192,7 @@ def from_config(cls, config: Config) -> GPClassificationModel:
max_fit_time=max_fit_time,
inducing_point_method=inducing_point_method,
likelihood=likelihood,
optimizer_options=optimizer_options,
)

def _reset_hyperparameters(self) -> None:
Expand Down Expand Up @@ -251,7 +260,10 @@ def fit(
n = train_y.shape[0]
mll = gpytorch.mlls.VariationalELBO(self.likelihood, self, n)

self._fit_mll(mll, **kwargs)
if "optimizer_kwargs" in kwargs:
self._fit_mll(mll, **kwargs)
else:
self._fit_mll(mll, optimizer_kwargs=self.optimizer_options, **kwargs)

def sample(self, x: torch.Tensor, num_samples: int) -> torch.Tensor:
"""Sample from underlying model.
Expand Down Expand Up @@ -335,6 +347,7 @@ def __init__(
inducing_size: Optional[int] = None,
max_fit_time: Optional[float] = None,
inducing_point_method: str = "auto",
optimizer_options: Optional[Dict[str, Any]] = None,
) -> None:
if likelihood is None:
likelihood = BetaLikelihood()
Expand All @@ -348,4 +361,5 @@ def __init__(
inducing_size=inducing_size,
max_fit_time=max_fit_time,
inducing_point_method=inducing_point_method,
optimizer_options=optimizer_options,
)
16 changes: 13 additions & 3 deletions aepsych/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from __future__ import annotations

from copy import deepcopy
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import gpytorch
import numpy as np
import torch
from aepsych.config import Config
from aepsych.factory.default import default_mean_covar_factory
from aepsych.models.base import AEPsychModelDeviceMixin
from aepsych.utils import _process_bounds, promote_0d
from aepsych.utils import _process_bounds, get_optimizer_options, promote_0d
from aepsych.utils_logging import getLogger
from gpytorch.likelihoods import GaussianLikelihood, Likelihood
from gpytorch.models import ExactGP
Expand All @@ -40,6 +40,7 @@ def __init__(
covar_module: Optional[gpytorch.kernels.Kernel] = None,
likelihood: Optional[Likelihood] = None,
max_fit_time: Optional[float] = None,
optimizer_options: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the GP regression model
Expand All @@ -55,6 +56,8 @@ def __init__(
Gaussian likelihood.
max_fit_time (float, optional): The maximum amount of time, in seconds, to spend fitting the model. If None,
there is no limit to the fitting time.
optimizer_options (Dict[str, Any], optional): Optimizer options to pass to the SciPy optimizer during
fitting. Assumes we are using L-BFGS-B.
"""
if likelihood is None:
likelihood = GaussianLikelihood()
Expand All @@ -64,6 +67,10 @@ def __init__(
lb, ub, self.dim = _process_bounds(lb, ub, dim)
self.max_fit_time = max_fit_time

self.optimizer_options = (
{"options": optimizer_options} if optimizer_options else {"options": {}}
)

if mean_module is None or covar_module is None:
default_mean, default_covar = default_mean_covar_factory(
dim=self.dim, stimuli_per_trial=self.stimuli_per_trial
Expand Down Expand Up @@ -105,6 +112,8 @@ def construct_inputs(cls, config: Config) -> Dict:

max_fit_time = config.getfloat(classname, "max_fit_time", fallback=None)

optimizer_options = get_optimizer_options(config, classname)

return {
"lb": lb,
"ub": ub,
Expand All @@ -113,6 +122,7 @@ def construct_inputs(cls, config: Config) -> Dict:
"covar_module": covar,
"likelihood": likelihood,
"max_fit_time": max_fit_time,
"optimizer_options": optimizer_options,
}

@classmethod
Expand Down Expand Up @@ -142,7 +152,7 @@ def fit(self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs) -> None:
"""
self.set_train_data(train_x, train_y)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self)
return self._fit_mll(mll, **kwargs)
return self._fit_mll(mll, self.optimizer_options, **kwargs)

def sample(self, x: torch.Tensor, num_samples: int) -> torch.Tensor:
"""Sample from underlying model.
Expand Down
8 changes: 7 additions & 1 deletion aepsych/models/monotonic_projection_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

from __future__ import annotations

from typing import Any, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import gpytorch
import numpy as np
import torch
from aepsych.config import Config
from aepsych.factory.default import default_mean_covar_factory
from aepsych.models.gp_classification import GPClassificationModel
from aepsych.utils import get_optimizer_options
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.likelihoods import Likelihood
from statsmodels.stats.moment_helpers import corr2cov, cov2corr
Expand Down Expand Up @@ -104,6 +105,7 @@ def __init__(
inducing_size: Optional[int] = None,
max_fit_time: Optional[float] = None,
inducing_point_method: str = "auto",
optimizer_options: Optional[Dict[str, Any]] = None,
) -> None:
assert len(monotonic_dims) > 0
self.monotonic_dims = [int(d) for d in monotonic_dims]
Expand All @@ -119,6 +121,7 @@ def __init__(
inducing_size=inducing_size,
max_fit_time=max_fit_time,
inducing_point_method=inducing_point_method,
optimizer_options=optimizer_options,
)

def posterior(
Expand Down Expand Up @@ -222,6 +225,8 @@ def from_config(cls, config: Config) -> MonotonicProjectionGP:
)
min_f_val = config.getfloat(classname, "min_f_val", fallback=None)

optimizer_options = get_optimizer_options(config, classname)

return cls(
lb=lb,
ub=ub,
Expand All @@ -235,4 +240,5 @@ def from_config(cls, config: Config) -> MonotonicProjectionGP:
monotonic_dims=monotonic_dims,
monotonic_grid_size=monotonic_grid_size,
min_f_val=min_f_val,
optimizer_options=optimizer_options,
)
15 changes: 12 additions & 3 deletions aepsych/models/monotonic_rejection_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from __future__ import annotations

import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import gpytorch
import numpy as np
Expand All @@ -20,7 +20,7 @@
from aepsych.means.constant_partial_grad import ConstantMeanPartialObsGrad
from aepsych.models.base import AEPsychMixin
from aepsych.models.utils import select_inducing_points
from aepsych.utils import _process_bounds, promote_0d
from aepsych.utils import _process_bounds, get_optimizer_options, promote_0d
from botorch.fit import fit_gpytorch_mll
from gpytorch.kernels import Kernel
from gpytorch.likelihoods import BernoulliLikelihood, Likelihood
Expand Down Expand Up @@ -63,6 +63,7 @@ def __init__(
num_samples: int = 250,
num_rejection_samples: int = 5000,
inducing_point_method: str = "auto",
optimizer_options: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize MonotonicRejectionGP.
Expand All @@ -82,6 +83,8 @@ def __init__(
acqf (MonotonicMCAcquisition, optional): Acquisition function to use for querying points. Defaults to MonotonicMCLSE.
objective (Optional[MCAcquisitionObjective], optional): Transformation of GP to apply before computing acquisition function. Defaults to identity transform for gaussian likelihood, probit transform for probit-bernoulli.
extra_acqf_args (Optional[Dict[str, object]], optional): Additional arguments to pass into the acquisition function. Defaults to None.
optimizer_options (Dict[str, Any], optional): Optimizer options to pass to the SciPy optimizer during
fitting. Assumes we are using L-BFGS-B.
"""
self.lb, self.ub, self.dim = _process_bounds(lb, ub, dim)
if likelihood is None:
Expand Down Expand Up @@ -145,6 +148,9 @@ def __init__(
self.num_rejection_samples = num_rejection_samples
self.fixed_prior_mean = fixed_prior_mean
self.inducing_points = inducing_points
self.optimizer_options = (
{"options": optimizer_options} if optimizer_options else {"options": {}}
)

def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None:
"""Fit the model
Expand Down Expand Up @@ -183,7 +189,7 @@ def _set_model(
mll = VariationalELBO(
likelihood=self.likelihood, model=self, num_data=train_y.numel()
)
mll = fit_gpytorch_mll(mll)
mll = fit_gpytorch_mll(mll, optimizer_kwargs=self.optimizer_options)

def update(self, train_x: Tensor, train_y: Tensor, warmstart: bool = True) -> None:
"""
Expand Down Expand Up @@ -319,6 +325,8 @@ def from_config(cls, config: Config) -> MonotonicRejectionGP:
classname, "monotonic_idxs", fallback=[-1]
)

optimizer_options = get_optimizer_options(config, classname)

return cls(
monotonic_idxs=monotonic_idxs,
lb=lb,
Expand All @@ -329,6 +337,7 @@ def from_config(cls, config: Config) -> MonotonicRejectionGP:
num_rejection_samples=num_rejection_samples,
mean_module=mean,
covar_module=covar,
optimizer_options=optimizer_options,
)

def forward(self, x: torch.Tensor) -> gpytorch.distributions.MultivariateNormal:
Expand Down
34 changes: 28 additions & 6 deletions aepsych/models/pairwise_probit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import time
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple

import gpytorch
import numpy as np
import torch
from aepsych.config import Config
from aepsych.factory import default_mean_covar_factory
from aepsych.models.base import AEPsychMixin
from aepsych.utils import _process_bounds, promote_0d
from aepsych.utils import _process_bounds, get_optimizer_options, promote_0d
from aepsych.utils_logging import getLogger
from botorch.fit import fit_gpytorch_mll
from botorch.models import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood
Expand Down Expand Up @@ -64,6 +63,7 @@ def __init__(
dim: Optional[int] = None,
covar_module: Optional[gpytorch.kernels.Kernel] = None,
max_fit_time: Optional[float] = None,
optimizer_options: Optional[Dict[str, Any]] = None,
) -> None:
self.lb, self.ub, dim = _process_bounds(lb, ub, dim)

Expand Down Expand Up @@ -93,6 +93,9 @@ def __init__(
)

self.dim = dim # The Pairwise constructor sets self.dim = None.
self.optimizer_options = (
{"options": optimizer_options} if optimizer_options else {"options": {}}
)

def fit(
self,
Expand All @@ -101,6 +104,12 @@ def fit(
optimizer_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
if optimizer_kwargs is not None:
if not "optimizer_kwargs" in optimizer_kwargs:
optimizer_kwargs = optimizer_kwargs.copy()
optimizer_kwargs.update(self.optimizer_options)
else:
optimizer_kwargs = {"options": self.optimizer_options}
self.train()
mll = PairwiseLaplaceMarginalLogLikelihood(self.likelihood, self)
datapoints, comparisons = self._pairs_to_comparisons(train_x, train_y)
Expand All @@ -109,17 +118,21 @@ def fit(
optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs.copy()
max_fit_time = kwargs.pop("max_fit_time", self.max_fit_time)
if max_fit_time is not None:
if "options" not in optimizer_kwargs:
optimizer_kwargs["options"] = {}

# figure out how long evaluating a single samp
starttime = time.time()
_ = mll(self(datapoints), comparisons)
single_eval_time = time.time() - starttime
n_eval = int(max_fit_time / single_eval_time)
optimizer_kwargs["maxfun"] = n_eval

optimizer_kwargs["options"]["maxfun"] = n_eval
logger.info(f"fit maxfun is {n_eval}")

logger.info("Starting fit...")
starttime = time.time()
fit_gpytorch_mll(mll, **kwargs, **optimizer_kwargs)
fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs, **kwargs)
logger.info(f"Fit done, time={time.time()-starttime}")

def update(
Expand Down Expand Up @@ -209,4 +222,13 @@ def from_config(cls, config: Config) -> "PairwiseProbitModel":

max_fit_time = config.getfloat(classname, "max_fit_time", fallback=None)

return cls(lb=lb, ub=ub, dim=dim, covar_module=covar, max_fit_time=max_fit_time)
optimizer_options = get_optimizer_options(config, classname)

return cls(
lb=lb,
ub=ub,
dim=dim,
covar_module=covar,
max_fit_time=max_fit_time,
optimizer_options=optimizer_options,
)
Loading

0 comments on commit 572374f

Please sign in to comment.