Skip to content

Commit

Permalink
derivativeGP gpu support (#444)
Browse files Browse the repository at this point in the history
Summary:

Add gpu support for derivative GP.

I noticed that this model isn’t actually like a normal model that can show up in a live experiment with a config, but we should still make it work for GPU. I did most of that but it did require some pretty arcane shenanigans with overriding GPyTorch’s underlying handling of train_inputs. This in turn made me do some arcane mypy stuff.

Differential Revision: D65515631
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Nov 9, 2024
1 parent 3b6bae0 commit 3fe7a5b
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 15 deletions.
2 changes: 1 addition & 1 deletion aepsych/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
Callable,
ClassVar,
Dict,
Dict,
List,
Mapping,
Optional,
Sequence,
TypeVar,
)

import botorch
import gpytorch
import numpy as np
Expand Down
8 changes: 4 additions & 4 deletions aepsych/likelihoods/bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class BernoulliObjectiveLikelihood(_OneDimensionalLikelihood):

def __init__(self, objective: Callable) -> None:
"""Initialize BernoulliObjectiveLikelihood.
Args:
objective (Callable): Objective function that maps function samples to probabilities."""
super().__init__()
Expand All @@ -42,13 +42,13 @@ def forward(
@classmethod
def from_config(cls, config: Config) -> "BernoulliObjectiveLikelihood":
"""Create an instance from a configuration object.
Args:
config (Config): Configuration object.
Returns:
BernoulliObjectiveLikelihood: BernoulliObjectiveLikelihood instance.
"""
objective_cls = config.getobj(cls.__name__, "objective")
objective = objective_cls.from_config(config)
return cls(objective=objective)
return cls(objective=objective)
6 changes: 3 additions & 3 deletions aepsych/likelihoods/semi_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def expected_log_prob(
# modified, TODO fixme upstream (cc @bletham)
def log_prob_lambda(function_samples: torch.Tensor) -> torch.Tensor:
"""Lambda function to compute the log probability.
Args:
function_samples (torch.Tensor): Function samples.
Returns:
torch.Tensor: Log probability.
"""
Expand Down Expand Up @@ -142,4 +142,4 @@ def from_config(cls, config: Config) -> "LinearBernoulliLikelihood":
else:
objective = objective

return cls(objective=objective)
return cls(objective=objective)
2 changes: 1 addition & 1 deletion aepsych/means/constant_partial_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
idx = input[..., -1].to(dtype=torch.long) > 0
mean_fit = super(ConstantMeanPartialObsGrad, self).forward(input[..., ~idx, :])
sz = mean_fit.shape[:-1] + torch.Size([input.shape[-2]])
mean = torch.zeros(sz)
mean = torch.zeros(sz).to(input)
mean[~idx] = mean_fit
return mean
12 changes: 8 additions & 4 deletions aepsych/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class AEPsychMixin(GPyTorchModel):

extremum_solver = "Nelder-Mead"
outcome_types: List[str] = []
train_inputs: Optional[Tuple[torch.Tensor]]
train_inputs: Optional[Tuple[torch.Tensor, ...]]
train_targets: Optional[torch.Tensor]

@property
Expand Down Expand Up @@ -398,7 +398,7 @@ def p_below_threshold(


class AEPsychModelDeviceMixin(AEPsychMixin):
_train_inputs: Optional[Tuple[torch.Tensor]]
_train_inputs: Optional[Tuple[torch.Tensor, ...]]
_train_targets: Optional[torch.Tensor]

def set_train_data(self, inputs=None, targets=None, strict=False):
Expand All @@ -423,13 +423,17 @@ def device(self) -> torch.device:
return next(self.parameters()).device

@property
def train_inputs(self) -> Optional[Tuple[torch.Tensor]]:
def train_inputs(self) -> Optional[Tuple[torch.Tensor, ...]]:
if self._train_inputs is None:
return None

# makes sure the tensors are on the right device, move in place
_train_inputs = []
for input in self._train_inputs:
input.to(self.device)
_train_inputs.append(input.to(self.device))

_tuple_inputs: Tuple[torch.Tensor, ...] = tuple(_train_inputs)
self._train_inputs = _tuple_inputs

return self._train_inputs

Expand Down
5 changes: 4 additions & 1 deletion aepsych/models/derivative_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch
from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad
from aepsych.means.constant_partial_grad import ConstantMeanPartialObsGrad
from aepsych.models.base import AEPsychModelDeviceMixin
from botorch.models.gpytorch import GPyTorchModel
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import Kernel
Expand All @@ -22,7 +23,9 @@
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy


class MixedDerivativeVariationalGP(gpytorch.models.ApproximateGP, GPyTorchModel):
class MixedDerivativeVariationalGP(
gpytorch.models.ApproximateGP, AEPsychModelDeviceMixin, GPyTorchModel
):
"""A variational GP with mixed derivative observations.
For more on GPs with derivative observations, see e.g. Riihimaki & Vehtari 2010.
Expand Down
1 change: 0 additions & 1 deletion aepsych/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import matplotlib.pyplot as plt
import numpy as np

import torch
from aepsych.strategy import Strategy
from aepsych.utils import get_lse_contour, get_lse_interval, make_scaled_sobol
Expand Down
39 changes: 39 additions & 0 deletions tests_gpu/models/test_derivative_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
from aepsych import Config, SequentialStrategy
from aepsych.models.derivative_gp import MixedDerivativeVariationalGP
from botorch.fit import fit_gpytorch_mll
from botorch.utils.testing import BotorchTestCase
from gpytorch.likelihoods import BernoulliLikelihood
from gpytorch.mlls.variational_elbo import VariationalELBO


class TestDerivativeGP(BotorchTestCase):
def test_MixedDerivativeVariationalGP_gpu(self):
train_x = torch.cat(
(torch.tensor([1.0, 2.0, 3.0, 4.0]).unsqueeze(1), torch.zeros(4, 1)), dim=1
)
train_y = torch.tensor([1.0, 2.0, 3.0, 4.0])
m = MixedDerivativeVariationalGP(
train_x=train_x,
train_y=train_y,
inducing_points=train_x,
fixed_prior_mean=0.5,
).cuda()

self.assertEqual(m.mean_module.constant.item(), 0.5)
self.assertEqual(
m.covar_module.base_kernel.raw_lengthscale.shape, torch.Size([1, 1])
)
mll = VariationalELBO(
likelihood=BernoulliLikelihood(), model=m, num_data=train_y.numel()
).cuda()
mll = fit_gpytorch_mll(mll)
test_x = torch.tensor([[1.0, 0], [3.0, 1.0]]).cuda()
m(test_x)

0 comments on commit 3fe7a5b

Please sign in to comment.