Skip to content

Commit

Permalink
make manual generators require tensors and support higher-dimensional…
Browse files Browse the repository at this point in the history
… tensors (facebookresearch#455)

Summary:


This removes the numpy dependency from the manual generators and changes the logic slightly so that we can support higher dimensional tensors of points, allowing us to support pairwise experiments

Reviewed By: crasanders

Differential Revision: D64607853
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Nov 20, 2024
1 parent 9615419 commit 3e90e45
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 22 deletions.
57 changes: 38 additions & 19 deletions aepsych/generators/manual_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
# LICENSE file in the root directory of this source tree.

import warnings
from typing import Dict, Optional, Union
from typing import Dict, Optional

import numpy as np
import torch
from aepsych.config import Config
from aepsych.generators.base import AEPsychGenerator
Expand All @@ -26,7 +25,7 @@ def __init__(
self,
lb: torch.Tensor,
ub: torch.Tensor,
points: Union[np.ndarray, torch.Tensor],
points: torch.Tensor,
dim: Optional[int] = None,
shuffle: bool = True,
seed: Optional[int] = None,
Expand All @@ -35,16 +34,18 @@ def __init__(
Args:
lb torch.Tensor: Lower bounds of each parameter.
ub torch.Tensor: Upper bounds of each parameter.
points (Union[np.ndarray, torch.Tensor]): The points that will be generated.
points torch.Tensor: The points that will be generated.
dim (int, optional): Dimensionality of the parameter space. If None, it is inferred from lb and ub.
shuffle (bool): Whether or not to shuffle the order of the points. True by default.
"""
self.seed = seed
self.lb, self.ub, self.dim = _process_bounds(lb, ub, dim)
self.points = points
if shuffle:
np.random.seed(self.seed)
np.random.shuffle(points)
self.points = torch.tensor(points)
if seed is not None:
torch.manual_seed(seed)
self.points = points[torch.randperm(len(points))]

self.max_asks = len(self.points)
self._idx = 0

Expand Down Expand Up @@ -82,10 +83,14 @@ def get_config_options(cls, config: Config, name: Optional[str] = None) -> Dict:
lb = config.gettensor(name, "lb")
ub = config.gettensor(name, "ub")
dim = config.getint(name, "dim", fallback=None)
points = config.getarray(name, "points")
points = config.gettensor(name, "points")
shuffle = config.getboolean(name, "shuffle", fallback=True)
seed = config.getint(name, "seed", fallback=None)

if len(points.shape) == 3:
# Configs have a reasonable natural input method that produces incorrect tensors
points = points.swapaxes(-1, -2)

options = {
"lb": lb,
"ub": ub,
Expand All @@ -107,10 +112,10 @@ class SampleAroundPointsGenerator(ManualGenerator):

def __init__(
self,
lb: Union[np.ndarray, torch.Tensor],
ub: Union[np.ndarray, torch.Tensor],
window: Union[np.ndarray, torch.Tensor],
points: Union[np.ndarray, torch.Tensor],
lb: torch.Tensor,
ub: torch.Tensor,
window: torch.Tensor,
points: torch.Tensor,
samples_per_point: int,
dim: Optional[int] = None,
shuffle: bool = True,
Expand All @@ -120,24 +125,38 @@ def __init__(
Args:
lb (Union[np.ndarray, torch.Tensor]): Lower bounds of each parameter.
ub (Union[np.ndarray, torch.Tensor]): Upper bounds of each parameter.
window (Union[np.ndarray, torch.Tensor]): How far away to sample from the reference point along each dimension.
window (Union[np.ndarray, torch.Tensor]): How far away to sample from the
reference point along each dimension. If the parameters are transformed,
the proportion of the range (based on ub/lb given) covered by the window
will be preserved (and not the absolute distance from the reference points).
points (Union[np.ndarray, torch.Tensor]): The points that will be generated.
samples_per_point (int): How many samples around each point to take.
dim (int, optional): Dimensionality of the parameter space. If None, it is inferred from lb and ub.
shuffle (bool): Whether or not to shuffle the order of the points. True by default.
seed (int, optional): Random seed.
"""
lb, ub, dim = _process_bounds(lb, ub, dim)
points = torch.Tensor(points)
self.engine = SobolEngine(dimension=dim, scramble=True, seed=seed)
generated = []
gen_points = []
if len(points.shape) > 2:
# We need to determine how many stimuli there are per trial to maintain the proper tensor shape
n_draws = points.shape[-1]
else:
n_draws = 1
for point in points:
if len(points.shape) > 2:
point = point.T
p_lb = torch.max(point - window, lb)
p_ub = torch.min(point + window, ub)
grid = self.engine.draw(samples_per_point)
grid = p_lb + (p_ub - p_lb) * grid
generated.append(grid)
generated = torch.Tensor(np.vstack(generated)) # type: ignore
for _ in range(samples_per_point):
grid = self.engine.draw(n_draws)
grid = p_lb + (p_ub - p_lb) * grid
gen_points.append(grid)
if len(points.shape) > 2:
generated = torch.stack(gen_points)
generated = generated.swapaxes(-2, -1)
else:
generated = torch.vstack(gen_points)

super().__init__(lb, ub, generated, dim, shuffle, seed) # type: ignore

Expand Down
2 changes: 1 addition & 1 deletion aepsych/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def __init__(
if self.stimuli_per_trial == 1:
self.event_shape: Tuple[int, ...] = (self.dim,)

if self.stimuli_per_trial == 2:
if self.stimuli_per_trial > 1:
self.event_shape = (self.dim, self.stimuli_per_trial)

self.model = model
Expand Down
18 changes: 18 additions & 0 deletions aepsych/transforms/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,24 @@ def transform_options(

if option in ["ub", "lb"]:
value = transforms.transform_bounds(value, bound=option)
elif option == "points" and len(value.shape) == 3:
value = value.swapaxes(-2, -1)
value = transforms.transform(value)
value = value.swapaxes(-1, -2)
elif option == "window":
# Get proportion of range covered in raw space
raw_lb = config.gettensor("common", "lb")
raw_ub = config.gettensor("common", "ub")
raw_range = raw_ub - raw_lb
window_prop = (value * 2) / raw_range

# Calculate transformed range
transformed_lb = transforms.transform_bounds(raw_lb, bound="lb")
transformed_ub = transforms.transform_bounds(raw_ub, bound="ub")
transformed_range = transformed_ub - transformed_lb

# Transformed window covers the same proportion of range
value = window_prop * transformed_range / 2
else:
value = transforms.transform(value)

Expand Down
43 changes: 43 additions & 0 deletions tests/generators/test_manual_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,49 @@ def test_sample_around_points_generator(self):

self.assertTrue(gen.finished)

def test_sample_around_points_generator_high_dim(self):
points = [
[[-1.5, 1], [-1, 1.25], [-2, 1.75]],
[[-1.25, 1.25], [-1.75, 1.5], [-1.0, 2]],
]
window = [0.25, 0.1]
samples_per_point = 2
lb = [-2, 1]
ub = [-1, 2]
config_str = f"""
[common]
lb = {lb}
ub = {ub}
parnames = [par1, par2]
stimuli_per_trial = 3
[SampleAroundPointsGenerator]
points = {points}
window = {window}
samples_per_point = {samples_per_point}
seed = 123
"""
config = Config()
config.update(config_str=config_str)
gen = SampleAroundPointsGenerator.from_config(config)
npt.assert_equal(gen.lb.numpy(), np.array(lb))
npt.assert_equal(gen.ub.numpy(), np.array(ub))
self.assertEqual(gen.max_asks, len(points * samples_per_point))
self.assertEqual(gen.seed, 123)
self.assertFalse(gen.finished)

points = gen.gen(gen.max_asks)
for i in range(len(window)):
npt.assert_array_less(points[:, i, :], points[:, i, :] + window[i])
npt.assert_array_less(
np.ones(points[:, i, :].shape) * lb[i], points[:, i, :]
)
npt.assert_array_less(
points[:, i, :], np.ones(points[:, i, :].shape) * ub[i]
)

self.assertTrue(gen.finished)


if __name__ == "__main__":
unittest.main()
2 changes: 0 additions & 2 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,10 +1274,8 @@ def test_clone_transform_options(self):

transforms = ParameterTransforms.from_config(config)
reversed_points = transforms.untransform(xformed_points)
reversed_window = transforms.untransform(xformed_window)

self.assertTrue(torch.allclose(reversed_points, torch.tensor(points)))
self.assertTrue(torch.allclose(reversed_window, torch.tensor(window)))

def test_build_transform(self):
config_str = """
Expand Down
50 changes: 50 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,56 @@ def test_options_override(self):
self.assertTrue(transform.constant == 5)
self.assertTrue(transform.indices[0] == 0)

def test_transform_manual_generator(self):
base_points = [
[[-1.5, 1], [-1, 1.25], [-2, 1.75]],
[[-1.25, 1.25], [-1.75, 1.5], [-1.0, 2]],
]
window = [0.25, 0.1]
samples_per_point = 2
lb = [-3, 1]
ub = [-1, 3]
config_str = f"""
[common]
parnames = [par1, par2]
stimuli_per_trial = 3
outcome_types = [binary]
strategy_names = [init_strat]
[par1]
par_type = continuous
lower_bound = {lb[0]}
upper_bound = {ub[0]}
[par2]
par_type = continuous
lower_bound = {lb[1]}
upper_bound = {ub[1]}
[init_strat]
generator = SampleAroundPointsGenerator
[SampleAroundPointsGenerator]
points = {base_points}
window = {window}
samples_per_point = {samples_per_point}
seed = 123
"""
config = Config()
config.update(config_str=config_str)

strat = SequentialStrategy.from_config(config)

nPoints = 0
while not strat.finished:
points = strat.gen()
strat.add_data(points, torch.tensor(1.0))
self.assertTrue(torch.all(points[0, 0, :] < 0))
self.assertTrue(torch.all(points[0, 1, :] > 0))
nPoints += 1

self.assertTrue(nPoints == len(base_points) * samples_per_point)


class TransformsLog10Test(unittest.TestCase):
def test_transform_reshape3D(self):
Expand Down

0 comments on commit 3e90e45

Please sign in to comment.