Skip to content

Commit

Permalink
Monotonic rejection model and generator (#458)
Browse files Browse the repository at this point in the history
Summary:

monotonic rejection model GPU support, since they're tied to the generator, we also ensure the generators are gpu ready as well.

Differential Revision: D65638150
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Nov 18, 2024
1 parent d096c6a commit 26ec708
Show file tree
Hide file tree
Showing 11 changed files with 283 additions and 47 deletions.
2 changes: 1 addition & 1 deletion aepsych/generators/monotonic_rejection_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def gen(
)

# Augment bounds with deriv indicator
bounds = torch.cat((model.bounds_, torch.zeros(2, 1)), dim=1)
bounds = torch.cat((model.bounds_, torch.zeros(2, 1).to(model.device)), dim=1)
# Fix deriv indicator to 0 during optimization
fixed_features = {(bounds.shape[1] - 1): 0.0}
# Fix explore features to random values
Expand Down
2 changes: 1 addition & 1 deletion aepsych/means/constant_partial_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
idx = input[..., -1].to(dtype=torch.long) > 0
mean_fit = super(ConstantMeanPartialObsGrad, self).forward(input[..., ~idx, :])
sz = mean_fit.shape[:-1] + torch.Size([input.shape[-2]])
mean = torch.zeros(sz)
mean = torch.zeros(sz).to(input)
mean[~idx] = mean_fit
return mean
7 changes: 6 additions & 1 deletion aepsych/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,12 @@ def set_train_data(self, inputs=None, targets=None, strict=False):
def device(self) -> torch.device:
# We assume all models have some parameters and all models will only use one device
# notice that this has no setting, don't let users set device, use .to().
return next(self.parameters()).device
try:
return next(self.parameters()).device
except (
AttributeError
): # Fallback for cases where we need device before we have params
return torch.device("cpu")

@property
def train_inputs(self) -> Optional[Tuple[torch.Tensor]]:
Expand Down
4 changes: 3 additions & 1 deletion aepsych/models/derivative_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch
from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad
from aepsych.means.constant_partial_grad import ConstantMeanPartialObsGrad
from aepsych.models.base import AEPsychModelDeviceMixin
from botorch.models.gpytorch import GPyTorchModel
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import Kernel
Expand All @@ -22,7 +23,7 @@
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy


class MixedDerivativeVariationalGP(gpytorch.models.ApproximateGP, GPyTorchModel):
class MixedDerivativeVariationalGP(gpytorch.models.ApproximateGP, AEPsychModelDeviceMixin, GPyTorchModel):
"""A variational GP with mixed derivative observations.
For more on GPs with derivative observations, see e.g. Riihimaki & Vehtari 2010.
Expand Down Expand Up @@ -99,6 +100,7 @@ def __init__(
self._num_outputs = 1
self.train_inputs = (train_x,)
self.train_targets = train_y
self.to(self.device)
self(train_x) # Necessary for CholeskyVariationalDistribution

def forward(self, x: torch.Tensor) -> MultivariateNormal:
Expand Down
16 changes: 10 additions & 6 deletions aepsych/models/monotonic_projection_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,17 @@ def posterior(
# using numpy because torch doesn't support vectorized linspace,
# pytorch/issues/61292
grid: Union[np.ndarray, torch.Tensor] = np.linspace(
self.lb[dim],
X[:, dim].numpy(),
self.lb[dim].cpu().numpy(),
X[:, dim].cpu().numpy(),
s + 1,
) # (s+1 x n)
grid = torch.tensor(grid[:-1, :], dtype=X.dtype) # Drop x; (s x n)
X_aug[(1 + i * s) : (1 + (i + 1) * s), :, dim] = grid
# X_aug[0, :, :] is X, and then subsequent indices are points in the grids
# Predict marginal distributions on X_aug

X = X.to(self.device)
X_aug = X_aug.to(self.device)
with torch.no_grad():
post_aug = super().posterior(X=X_aug)
mu_aug = post_aug.mean.squeeze() # (m*s+1 x n)
Expand All @@ -158,12 +161,13 @@ def posterior(
# Adjust the whole covariance matrix to accomadate the projected marginals
with torch.no_grad():
post = super().posterior(X=X)
R = cov2corr(post.distribution.covariance_matrix.squeeze().numpy())
S_proj = torch.tensor(corr2cov(R, sigma_proj.numpy()), dtype=X.dtype)
R = cov2corr(post.distribution.covariance_matrix.squeeze().cpu().numpy())
S_proj = torch.tensor(corr2cov(R, sigma_proj.cpu().numpy()), dtype=X.dtype)
mvn_proj = gpytorch.distributions.MultivariateNormal(
mu_proj.unsqueeze(0),
S_proj.unsqueeze(0),
mu_proj.unsqueeze(0).to(self.device),
S_proj.unsqueeze(0).to(self.device),
)

return GPyTorchPosterior(mvn_proj)

def sample(self, x: torch.Tensor, num_samples: int) -> torch.Tensor:
Expand Down
25 changes: 14 additions & 11 deletions aepsych/models/monotonic_rejection_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from aepsych.factory.monotonic import monotonic_mean_covar_factory
from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad
from aepsych.means.constant_partial_grad import ConstantMeanPartialObsGrad
from aepsych.models.base import AEPsychMixin
from aepsych.models.base import AEPsychModelDeviceMixin
from aepsych.models.utils import select_inducing_points
from aepsych.utils import _process_bounds, promote_0d
from botorch.fit import fit_gpytorch_mll
Expand All @@ -32,7 +32,7 @@
from torch import Tensor


class MonotonicRejectionGP(AEPsychMixin, ApproximateGP):
class MonotonicRejectionGP(AEPsychModelDeviceMixin, ApproximateGP):
"""A monotonic GP using rejection sampling.
This takes the same insight as in e.g. Riihimäki & Vehtari 2010 (that the derivative of a GP
Expand Down Expand Up @@ -83,15 +83,15 @@ def __init__(
objective (Optional[MCAcquisitionObjective], optional): Transformation of GP to apply before computing acquisition function. Defaults to identity transform for gaussian likelihood, probit transform for probit-bernoulli.
extra_acqf_args (Optional[Dict[str, object]], optional): Additional arguments to pass into the acquisition function. Defaults to None.
"""
self.lb, self.ub, self.dim = _process_bounds(lb, ub, dim)
lb, ub, self.dim = _process_bounds(lb, ub, dim)
if likelihood is None:
likelihood = BernoulliLikelihood()

self.inducing_size = num_induc
self.inducing_point_method = inducing_point_method
inducing_points = select_inducing_points(
inducing_size=self.inducing_size,
bounds=self.bounds,
bounds=torch.stack((lb, ub)),
method="sobol",
)

Expand Down Expand Up @@ -134,7 +134,9 @@ def __init__(

super().__init__(variational_strategy)

self.bounds_ = torch.stack([self.lb, self.ub])
self.register_buffer("lb", lb)
self.register_buffer("ub", ub)
self.register_buffer("bounds_", torch.stack([self.lb, self.ub]))
self.mean_module = mean_module
self.covar_module = covar_module
self.likelihood = likelihood
Expand All @@ -144,7 +146,7 @@ def __init__(
self.num_samples = num_samples
self.num_rejection_samples = num_rejection_samples
self.fixed_prior_mean = fixed_prior_mean
self.inducing_points = inducing_points
self.register_buffer("inducing_points", inducing_points)

def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None:
"""Fit the model
Expand All @@ -161,7 +163,7 @@ def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None:
X=self.train_inputs[0],
bounds=self.bounds,
method=self.inducing_point_method,
)
).to(self.device)
self._set_model(train_x, train_y)

def _set_model(
Expand Down Expand Up @@ -284,13 +286,14 @@ def predict_probability(
return self.predict(x, probability_space=True)

def _augment_with_deriv_index(self, x: Tensor, indx) -> Tensor:
x = x.to(self.device)
return torch.cat(
(x, indx * torch.ones(x.shape[0], 1)),
(x, indx * torch.ones(x.shape[0], 1).to(self.device)),
dim=1,
)

def _get_deriv_constraint_points(self) -> Tensor:
deriv_cp = torch.tensor([])
deriv_cp = torch.tensor([]).to(self.device)
for i in self.monotonic_idxs:
induc_i = self._augment_with_deriv_index(self.inducing_points, i + 1)
deriv_cp = torch.cat((deriv_cp, induc_i), dim=0)
Expand All @@ -299,8 +302,8 @@ def _get_deriv_constraint_points(self) -> Tensor:
@classmethod
def from_config(cls, config: Config) -> MonotonicRejectionGP:
classname = cls.__name__
num_induc = config.gettensor(classname, "num_induc", fallback=25)
num_samples = config.gettensor(classname, "num_samples", fallback=250)
num_induc = config.getint(classname, "num_induc", fallback=25)
num_samples = config.getint(classname, "num_samples", fallback=250)
num_rejection_samples = config.getint(
classname, "num_rejection_samples", fallback=5000
)
Expand Down
17 changes: 3 additions & 14 deletions aepsych/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,6 @@ class Strategy(object):

_n_eval_points: int = 1000

no_gpu_acqfs = (
MonotonicMCAcquisition,
MonotonicBernoulliMCMutualInformation,
MonotonicMCPosteriorVariance,
MonotonicMCLSE,
)

def __init__(
self,
generator: Union[AEPsychGenerator, ParameterTransformedGenerator],
Expand Down Expand Up @@ -182,13 +175,7 @@ def __init__(
)
self.generator_device = torch.device("cpu")
else:
if hasattr(generator, "acqf") and generator.acqf in self.no_gpu_acqfs:
warnings.warn(
f"GPU requested for acquistion function {type(generator.acqf).__name__}, but this acquisiton function does not support GPU! Using CPU instead.",
UserWarning,
)
self.generator_device = torch.device("cpu")
elif not torch.cuda.is_available():
if not torch.cuda.is_available():
warnings.warn(
f"GPU requested for generator {type(generator).__name__}, but no GPU found! Using CPU instead.",
UserWarning,
Expand Down Expand Up @@ -283,9 +270,11 @@ def normalize_inputs(
x = x[None, :]

if self.x is not None:
x = x.to(self.x)
x = torch.cat((self.x, x), dim=0)

if self.y is not None:
y = y.to(self.y)
y = torch.cat((self.y, y), dim=0)

# Ensure the correct dtype
Expand Down
6 changes: 6 additions & 0 deletions tests_gpu/acquisition/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) Meta, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
51 changes: 51 additions & 0 deletions tests_gpu/acquisition/test_monotonic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
from aepsych.acquisition.monotonic_rejection import MonotonicMCLSE
from aepsych.acquisition.objective import ProbitObjective
from aepsych.models.derivative_gp import MixedDerivativeVariationalGP
from botorch.acquisition.objective import IdentityMCObjective
from botorch.utils.testing import BotorchTestCase


class TestMonotonicAcq(BotorchTestCase):
def test_monotonic_acq_gpu(self):
# Init
train_X_aug = torch.tensor(
[[0.0, 0.0, 0.0], [1.0, 1.0, 0.0], [2.0, 2.0, 0.0]]
).cuda()
deriv_constraint_points = torch.tensor(
[[0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [2.0, 2.0, 1.0]]
).cuda()
train_Y = torch.tensor([[1.0], [2.0], [3.0]]).cuda()

m = MixedDerivativeVariationalGP(
train_x=train_X_aug, train_y=train_Y, inducing_points=train_X_aug
).cuda()
acq = MonotonicMCLSE(
model=m,
deriv_constraint_points=deriv_constraint_points,
num_samples=5,
num_rejection_samples=8,
target=1.9,
)
self.assertTrue(isinstance(acq.objective, IdentityMCObjective))
acq = MonotonicMCLSE(
model=m,
deriv_constraint_points=deriv_constraint_points,
num_samples=5,
num_rejection_samples=8,
target=1.9,
objective=ProbitObjective(),
).cuda()
# forward
acq(train_X_aug)
Xfull = torch.cat((train_X_aug, acq.deriv_constraint_points), dim=0)
posterior = m.posterior(Xfull)
samples = acq.sampler(posterior)
self.assertEqual(samples.shape, torch.Size([5, 6, 1]))
Loading

0 comments on commit 26ec708

Please sign in to comment.