Skip to content

Commit

Permalink
implement pairwisekernel (facebookresearch#371)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#371

This implements Houlsby et al. (2011)'s Pairwise kernel, which can turn any other model into a pairwise one.

Differential Revision: D59697559
  • Loading branch information
crasanders authored and facebook-github-bot committed Aug 16, 2024
1 parent 683ad8a commit e994d05
Show file tree
Hide file tree
Showing 18 changed files with 641 additions and 149 deletions.
12 changes: 10 additions & 2 deletions aepsych/factory/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from aepsych.config import Config
from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad
from aepsych.means.constant_partial_grad import ConstantMeanPartialObsGrad
from aepsych.utils import get_dim
from scipy.stats import norm

from .pairwisekernel import PairwiseKernel

"""AEPsych factory functions.
These functions generate a gpytorch Mean and Kernel objects from
aepsych.config.Config configurations, including setting lengthscale
Expand All @@ -36,7 +37,9 @@


def default_mean_covar_factory(
config: Optional[Config] = None, dim: Optional[int] = None
config: Optional[Config] = None,
dim: Optional[int] = None,
stimuli_per_trial: int = 1,
) -> Tuple[gpytorch.means.ConstantMean, gpytorch.kernels.ScaleKernel]:
"""Default factory for generic GP models
Expand All @@ -55,6 +58,8 @@ def default_mean_covar_factory(
dim is not None
), "Either config or dim must be provided!"

assert stimuli_per_trial in (1, 2), "stimuli_per_trial must be 1 or 2!"

fixed_mean = False
lengthscale_prior = "gamma"
outputscale_prior = "box"
Expand Down Expand Up @@ -136,6 +141,9 @@ def default_mean_covar_factory(
outputscale_prior=os_prior,
)

if stimuli_per_trial == 2:
covar = PairwiseKernel(covar)

return mean, covar


Expand Down
85 changes: 85 additions & 0 deletions aepsych/factory/pairwisekernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import torch
from gpytorch.kernels import Kernel
from gpytorch.lazy import lazify


class PairwiseKernel(Kernel):
"""
Wrapper to convert a kernel K on R^k to a kernel K' on R^{2k}, modeling
functions of the form g(a, b) = f(a) - f(b), where f ~ GP(mu, K).
Since g is a linear combination of Gaussians, it follows that g ~ GP(0, K')
where K'((a,b), (c,d)) = K(a,c) - K(a, d) - K(b, c) + K(b, d).
"""

def __init__(self, latent_kernel, is_partial_obs=False, **kwargs):
super(PairwiseKernel, self).__init__(**kwargs)

self.latent_kernel = latent_kernel
self.is_partial_obs = is_partial_obs

def forward(self, x1, x2, diag=False, **params):
r"""
TODO: make last_batch_dim work properly
d must be 2*k for integer k, k is the dimension of the latent space
Args:
:attr:`x1` (Tensor `n x d` or `b x n x d`):
First set of data
:attr:`x2` (Tensor `m x d` or `b x m x d`):
Second set of data
:attr:`diag` (bool):
Should the Kernel compute the whole kernel, or just the diag?
Returns:
:class:`Tensor` or :class:`gpytorch.lazy.LazyTensor`.
The exact size depends on the kernel's evaluation mode:
* `full_covar`: `n x m` or `b x n x m`
* `diag`: `n` or `b x n`
"""
if self.is_partial_obs:
d = x1.shape[-1] - 1
assert d == x2.shape[-1] - 1, "tensors not the same dimension"
assert d % 2 == 0, "dimension must be even"

k = int(d / 2)

# special handling for kernels that (also) do funky
# things with the input dimension
deriv_idx_1 = x1[..., -1][:, None]
deriv_idx_2 = x2[..., -1][:, None]

a = torch.cat((x1[..., :k], deriv_idx_1), dim=1)
b = torch.cat((x1[..., k:-1], deriv_idx_1), dim=1)
c = torch.cat((x2[..., :k], deriv_idx_2), dim=1)
d = torch.cat((x2[..., k:-1], deriv_idx_2), dim=1)

else:
d = x1.shape[-1]

assert d == x2.shape[-1], "tensors not the same dimension"
assert d % 2 == 0, "dimension must be even"

k = int(d / 2)

a = x1[..., :k]
b = x1[..., k:]
c = x2[..., :k]
d = x2[..., k:]

if not diag:
return (
lazify(self.latent_kernel(a, c, diag=diag, **params))
+ lazify(self.latent_kernel(b, d, diag=diag, **params))
- lazify(self.latent_kernel(b, c, diag=diag, **params))
- lazify(self.latent_kernel(a, d, diag=diag, **params))
)
else:
return (
self.latent_kernel(a, c, diag=diag, **params)
+ self.latent_kernel(b, d, diag=diag, **params)
- self.latent_kernel(b, c, diag=diag, **params)
- self.latent_kernel(a, d, diag=diag, **params)
)
4 changes: 0 additions & 4 deletions aepsych/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from .monotonic_rejection_generator import MonotonicRejectionGenerator
from .monotonic_thompson_sampler_generator import MonotonicThompsonSamplerGenerator
from .optimize_acqf_generator import OptimizeAcqfGenerator
from .pairwise_optimize_acqf_generator import PairwiseOptimizeAcqfGenerator
from .pairwise_sobol_generator import PairwiseSobolGenerator
from .random_generator import RandomGenerator
from .semi_p import IntensityAwareSemiPGenerator
from .sobol_generator import SobolGenerator
Expand All @@ -28,8 +26,6 @@
"SobolGenerator",
"EpsilonGreedyGenerator",
"ManualGenerator",
"PairwiseOptimizeAcqfGenerator",
"PairwiseSobolGenerator",
"IntensityAwareSemiPGenerator",
"AcqfThompsonSamplerGenerator"
]
Expand Down
7 changes: 3 additions & 4 deletions aepsych/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
# LICENSE file in the root directory of this source tree.
import abc
from inspect import signature
from typing import Any, Dict, Generic, Protocol, runtime_checkable, TypeVar, Optional
from typing import Any, Dict, Generic, Optional, Protocol, runtime_checkable, TypeVar
import re

import torch
from aepsych.config import Config
from aepsych.models.base import AEPsychMixin
from botorch.acquisition import (
AcquisitionFunction,
NoisyExpectedImprovement,
qNoisyExpectedImprovement,
LogNoisyExpectedImprovement,
NoisyExpectedImprovement,
qLogNoisyExpectedImprovement,
qNoisyExpectedImprovement,
)


Expand All @@ -40,7 +40,6 @@ class AEPsychGenerator(abc.ABC, Generic[AEPsychModelType]):
qLogNoisyExpectedImprovement,
LogNoisyExpectedImprovement,
]
stimuli_per_trial = 1
max_asks: Optional[int] = None

def __init__(
Expand Down
24 changes: 5 additions & 19 deletions aepsych/generators/optimize_acqf_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
from aepsych.generators.base import AEPsychGenerator
from aepsych.models.base import ModelProtocol
from aepsych.utils_logging import getLogger
from botorch.acquisition.preference import AnalyticExpectedUtilityOfBestOption
from botorch.optim import optimize_acqf
from botorch.acquisition import (
AcquisitionFunction,
NoisyExpectedImprovement,
qNoisyExpectedImprovement,
LogNoisyExpectedImprovement,
NoisyExpectedImprovement,
qLogNoisyExpectedImprovement,
qNoisyExpectedImprovement,
)
from botorch.acquisition.preference import AnalyticExpectedUtilityOfBestOption
from botorch.optim import optimize_acqf

logger = getLogger()

Expand All @@ -44,7 +44,6 @@ def __init__(
restarts: int = 10,
samps: int = 1000,
max_gen_time: Optional[float] = None,
stimuli_per_trial: int = 1,
) -> None:
"""Initialize OptimizeAcqfGenerator.
Args:
Expand All @@ -63,7 +62,6 @@ def __init__(
self.restarts = restarts
self.samps = samps
self.max_gen_time = max_gen_time
self.stimuli_per_trial = stimuli_per_trial

def _instantiate_acquisition_fn(self, model: ModelProtocol):
if self.acqf == AnalyticExpectedUtilityOfBestOption:
Expand All @@ -83,17 +81,7 @@ def gen(self, num_points: int, model: ModelProtocol, **gen_options) -> torch.Ten
np.ndarray: Next set of point(s) to evaluate, [num_points x dim].
"""

if self.stimuli_per_trial == 2:
qbatch_points = self._gen(
num_points=num_points * 2, model=model, **gen_options
)

# output of super() is (q, dim) but the contract is (num_points, dim, 2)
# so we need to split q into q and pairs and then move the pair dim to the end
return qbatch_points.reshape(num_points, 2, -1).swapaxes(-1, -2)

else:
return self._gen(num_points=num_points, model=model, **gen_options)
return self._gen(num_points=num_points, model=model, **gen_options)

def _gen(
self, num_points: int, model: ModelProtocol, **gen_options
Expand Down Expand Up @@ -124,7 +112,6 @@ def from_config(cls, config: Config):
classname = cls.__name__
acqf = config.getobj(classname, "acqf", fallback=None)
extra_acqf_args = cls._get_acqf_options(acqf, config)
stimuli_per_trial = config.getint(classname, "stimuli_per_trial")
restarts = config.getint(classname, "restarts", fallback=10)
samps = config.getint(classname, "samps", fallback=1000)
max_gen_time = config.getfloat(classname, "max_gen_time", fallback=None)
Expand All @@ -135,5 +122,4 @@ def from_config(cls, config: Config):
restarts=restarts,
samps=samps,
max_gen_time=max_gen_time,
stimuli_per_trial=stimuli_per_trial,
)
25 changes: 0 additions & 25 deletions aepsych/generators/pairwise_optimize_acqf_generator.py

This file was deleted.

26 changes: 0 additions & 26 deletions aepsych/generators/pairwise_sobol_generator.py

This file was deleted.

25 changes: 6 additions & 19 deletions aepsych/generators/sobol_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(
ub: Union[np.ndarray, torch.Tensor],
dim: Optional[int] = None,
seed: Optional[int] = None,
stimuli_per_trial: int = 1,
):
"""Iniatialize SobolGenerator.
Args:
Expand All @@ -38,13 +37,8 @@ def __init__(
seed (int, optional): Random seed.
"""
self.lb, self.ub, self.dim = _process_bounds(lb, ub, dim)
self.lb = self.lb.repeat(stimuli_per_trial)
self.ub = self.ub.repeat(stimuli_per_trial)
self.stimuli_per_trial = stimuli_per_trial
self.seed = seed
self.engine = SobolEngine(
dimension=self.dim * stimuli_per_trial, scramble=True, seed=self.seed
)
self.engine = SobolEngine(dimension=self.dim, scramble=True, seed=self.seed)

def gen(
self,
Expand All @@ -59,16 +53,7 @@ def gen(
"""
grid = self.engine.draw(num_points)
grid = self.lb + (self.ub - self.lb) * grid
if self.stimuli_per_trial == 1:
return grid

return torch.tensor(
np.moveaxis(
grid.reshape(num_points, self.stimuli_per_trial, -1).numpy(),
-1,
-self.stimuli_per_trial,
)
)
return grid

@classmethod
def from_config(cls, config: Config):
Expand All @@ -78,8 +63,10 @@ def from_config(cls, config: Config):
ub = config.gettensor(classname, "ub")
dim = config.getint(classname, "dim", fallback=None)
seed = config.getint(classname, "seed", fallback=None)
stimuli_per_trial = config.getint(classname, "stimuli_per_trial")

return cls(
lb=lb, ub=ub, dim=dim, seed=seed, stimuli_per_trial=stimuli_per_trial
lb=lb,
ub=ub,
dim=dim,
seed=seed,
)
Loading

0 comments on commit e994d05

Please sign in to comment.