Skip to content

Commit

Permalink
WIP fix bug where variational GPs wouldn't use correct likelihoods; u…
Browse files Browse the repository at this point in the history
…se botorch inducing point selection; use botorch default mean/covar; update botorch/ax versions (#323)

Summary:

This ensures variational GPs will use correct likelihoods, and it moves some logic out of AEPsych and into botorch. Anecdotally, the botorch priors seem to produce better models, and since we are so resource-strapped on the AEPsych side, we should rely on botorch logic as much as possible.

Differential Revision: D48891019
  • Loading branch information
Craig Sanders authored and facebook-github-bot committed Dec 8, 2023
1 parent e79ef71 commit c77e206
Show file tree
Hide file tree
Showing 27 changed files with 530 additions and 277 deletions.
1 change: 1 addition & 0 deletions aepsych/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,5 @@ def from_config(cls, config: Config, name: Optional[str] = None):
Config.register_module(gpytorch.likelihoods)
Config.register_module(gpytorch.kernels)
Config.register_module(botorch.acquisition)
Config.register_module(botorch.acquisition.multi_objective)
Config.registered_names["None"] = None
2 changes: 0 additions & 2 deletions aepsych/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from .manual_generator import ManualGenerator
from .monotonic_rejection_generator import MonotonicRejectionGenerator
from .monotonic_thompson_sampler_generator import MonotonicThompsonSamplerGenerator
from .multi_outcome_generator import MultiOutcomeOptimizationGenerator
from .optimize_acqf_generator import AxOptimizeAcqfGenerator, OptimizeAcqfGenerator
from .pairwise_optimize_acqf_generator import PairwiseOptimizeAcqfGenerator
from .pairwise_sobol_generator import PairwiseSobolGenerator
Expand All @@ -33,7 +32,6 @@
"AxOptimizeAcqfGenerator",
"AxSobolGenerator",
"IntensityAwareSemiPGenerator",
"MultiOutcomeOptimizationGenerator",
"AxRandomGenerator",
]

Expand Down
27 changes: 0 additions & 27 deletions aepsych/generators/multi_outcome_generator.py

This file was deleted.

7 changes: 2 additions & 5 deletions aepsych/generators/optimize_acqf_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from aepsych.config import Config, ConfigurableMixin
from aepsych.generators.base import AEPsychGenerationStep, AEPsychGenerator
from aepsych.models.base import ModelProtocol
from aepsych.models.surrogate import AEPsychSurrogate
from ax.models.torch.botorch_modular.surrogate import Surrogate
from aepsych.utils_logging import getLogger
from ax.modelbridge import Models
from ax.modelbridge.registry import Cont_X_trans
Expand Down Expand Up @@ -152,14 +152,11 @@ def get_config_options(cls, config: Config, name: str) -> Dict:
acqf_options = cls._get_acqf_options(acqf_cls, config)
gen_options = cls._get_gen_options(config)

max_fit_time = model_options["max_fit_time"]

model_kwargs = {
"surrogate": AEPsychSurrogate(
"surrogate": Surrogate(
botorch_model_class=model_class,
mll_class=model_class.get_mll_class(),
model_options=model_options,
max_fit_time=max_fit_time,
),
"acquisition_class": AEPsychAcquisition,
"botorch_acqf_class": acqf_cls,
Expand Down
10 changes: 8 additions & 2 deletions aepsych/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .exact_gp import ContinuousRegressionGP, ExactGP
from .gp_classification import GPBetaRegressionModel, GPClassificationModel
from .gp_regression import GPRegressionModel
from .model_list import AEPsychModelListGP
from .monotonic_projection_gp import MonotonicProjectionGP
from .monotonic_rejection_gp import MonotonicRejectionGP
from .multitask_regression import IndependentMultitaskGPRModel, MultitaskGPRModel
Expand All @@ -21,8 +22,12 @@
semi_p_posterior_transform,
SemiParametricGPModel,
)
from .variational_gp import BetaRegressionGP, BinaryClassificationGP, OrdinalGP, VariationalGP

from .variational_gp import (
BetaRegressionGP,
BinaryClassificationGP,
OrdinalGP,
VariationalGP,
)


__all__ = [
Expand All @@ -44,6 +49,7 @@
"semi_p_posterior_transform",
"OrdinalGP",
"GPBetaRegressionModel",
"AEPsychModelListGP",
]

Config.register_module(sys.modules[__name__])
Loading

0 comments on commit c77e206

Please sign in to comment.