-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add additional test functions and psychophysics task and dataset from…
… Letham et al. 2022 Summary: Additional high-dimensional test functions and real psychophysics task are added to problem.py for benchmarking performance of acquistions functions or GP models. The code and dataset are obtained from https://github.com/facebookresearch/bernoulli_lse/blob/main/problems.py. Reviewed By: crasanders Differential Revision: D58164941
- Loading branch information
1 parent
722123e
commit 3f48106
Showing
4 changed files
with
1,155 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. | ||
import os | ||
import numpy as np | ||
import torch | ||
from aepsych.models import GPClassificationModel | ||
from aepsych.benchmark.test_functions import ( | ||
modified_hartmann6, | ||
discrim_highdim, | ||
novel_discrimination_testfun, | ||
) | ||
from aepsych.benchmark.problem import LSEProblemWithEdgeLogging | ||
|
||
|
||
"""The DiscrimLowDim, DiscrimHighDim, ContrastSensitivity6d, and Hartmann6Binary classes | ||
are copied from bernoulli_lse github repository (https://github.com/facebookresearch/bernoulli_lse) | ||
by Letham et al. 2022.""" | ||
|
||
|
||
class DiscrimLowDim(LSEProblemWithEdgeLogging): | ||
name = "discrim_lowdim" | ||
bounds = torch.tensor([[-1, 1], [-1, 1]], dtype=torch.double).T | ||
threshold = 0.75 | ||
|
||
def f(self, x: torch.Tensor) -> torch.Tensor: | ||
return torch.tensor(novel_discrimination_testfun(x), dtype=torch.double) | ||
|
||
|
||
class DiscrimHighDim(LSEProblemWithEdgeLogging): | ||
name = "discrim_highdim" | ||
threshold = 0.75 | ||
bounds = torch.tensor( | ||
[ | ||
[-1, 1], | ||
[-1, 1], | ||
[0.5, 1.5], | ||
[0.05, 0.15], | ||
[0.05, 0.2], | ||
[0, 0.9], | ||
[0, 3.14 / 2], | ||
[0.5, 2], | ||
], | ||
dtype=torch.double, | ||
).T | ||
|
||
def f(self, x: torch.Tensor) -> torch.Tensor: | ||
return torch.tensor(discrim_highdim(x), dtype=torch.double) | ||
|
||
|
||
class Hartmann6Binary(LSEProblemWithEdgeLogging): | ||
name = "hartmann6_binary" | ||
threshold = 0.5 | ||
bounds = torch.stack( | ||
( | ||
torch.zeros(6, dtype=torch.double), | ||
torch.ones(6, dtype=torch.double), | ||
) | ||
) | ||
|
||
def f(self, X: torch.Tensor) -> torch.Tensor: | ||
y = torch.tensor([modified_hartmann6(x) for x in X], dtype=torch.double) | ||
f = 3 * y - 2.0 | ||
return f | ||
|
||
|
||
class ContrastSensitivity6d(LSEProblemWithEdgeLogging): | ||
""" | ||
Uses a surrogate model fit to real data from a constrast sensitivity study. | ||
""" | ||
|
||
name = "contrast_sensitivity_6d" | ||
threshold = 0.75 | ||
bounds = torch.tensor( | ||
[[-1.5, 0], [-1.5, 0], [0, 20], [0.5, 7], [1, 10], [0, 10]], | ||
dtype=torch.double, | ||
).T | ||
|
||
def __init__(self): | ||
|
||
# Load the data | ||
self.data = np.loadtxt( | ||
os.path.join("..", "..", "dataset", "csf_dataset.csv"), | ||
delimiter=",", | ||
skiprows=1, | ||
) | ||
y = torch.LongTensor(self.data[:, 0]) | ||
x = torch.Tensor(self.data[:, 1:]) | ||
|
||
# Fit a model, with a large number of inducing points | ||
self.m = GPClassificationModel( | ||
lb=self.bounds[0], | ||
ub=self.bounds[1], | ||
inducing_size=100, | ||
inducing_point_method="kmeans++", | ||
) | ||
|
||
self.m.fit( | ||
x, | ||
y, | ||
) | ||
|
||
def f(self, X: torch.Tensor) -> torch.Tensor: | ||
# clamp f to 0 since we expect p(x) to be lower-bounded at 0.5 | ||
return torch.clamp(self.m.predict(torch.tensor(X))[0], min=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.