From b4aa79535d8a45342428ad3a3db9d88bf315c531 Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Thu, 14 Nov 2024 22:38:57 -0800 Subject: [PATCH] make manual generators require tensors and support higher-dimensional tensors (#455) Summary: This removes the numpy dependency from the manual generators and changes the logic slightly so that we can support higher dimensional tensors of points, allowing us to support pairwise experiments Differential Revision: D64607853 --- aepsych/generators/manual_generator.py | 57 +++++++++++++++-------- aepsych/strategy.py | 2 +- aepsych/transforms/parameters.py | 18 +++++++ tests/generators/test_manual_generator.py | 43 +++++++++++++++++ tests/test_config.py | 2 - tests/test_transforms.py | 50 ++++++++++++++++++++ 6 files changed, 150 insertions(+), 22 deletions(-) diff --git a/aepsych/generators/manual_generator.py b/aepsych/generators/manual_generator.py index 2a3195717..2d6e03210 100644 --- a/aepsych/generators/manual_generator.py +++ b/aepsych/generators/manual_generator.py @@ -6,9 +6,8 @@ # LICENSE file in the root directory of this source tree. import warnings -from typing import Dict, Optional, Union +from typing import Dict, Optional -import numpy as np import torch from aepsych.config import Config from aepsych.generators.base import AEPsychGenerator @@ -26,7 +25,7 @@ def __init__( self, lb: torch.Tensor, ub: torch.Tensor, - points: Union[np.ndarray, torch.Tensor], + points: torch.Tensor, dim: Optional[int] = None, shuffle: bool = True, seed: Optional[int] = None, @@ -35,16 +34,18 @@ def __init__( Args: lb torch.Tensor: Lower bounds of each parameter. ub torch.Tensor: Upper bounds of each parameter. - points (Union[np.ndarray, torch.Tensor]): The points that will be generated. + points torch.Tensor: The points that will be generated. dim (int, optional): Dimensionality of the parameter space. If None, it is inferred from lb and ub. shuffle (bool): Whether or not to shuffle the order of the points. True by default. """ self.seed = seed self.lb, self.ub, self.dim = _process_bounds(lb, ub, dim) + self.points = points if shuffle: - np.random.seed(self.seed) - np.random.shuffle(points) - self.points = torch.tensor(points) + if seed is not None: + torch.manual_seed(seed) + self.points = points[torch.randperm(len(points))] + self.max_asks = len(self.points) self._idx = 0 @@ -82,10 +83,14 @@ def get_config_options(cls, config: Config, name: Optional[str] = None) -> Dict: lb = config.gettensor(name, "lb") ub = config.gettensor(name, "ub") dim = config.getint(name, "dim", fallback=None) - points = config.getarray(name, "points") + points = config.gettensor(name, "points") shuffle = config.getboolean(name, "shuffle", fallback=True) seed = config.getint(name, "seed", fallback=None) + if len(points.shape) == 3: + # Configs have a reasonable natural input method that produces incorrect tensors + points = points.swapaxes(-1, -2) + options = { "lb": lb, "ub": ub, @@ -107,10 +112,10 @@ class SampleAroundPointsGenerator(ManualGenerator): def __init__( self, - lb: Union[np.ndarray, torch.Tensor], - ub: Union[np.ndarray, torch.Tensor], - window: Union[np.ndarray, torch.Tensor], - points: Union[np.ndarray, torch.Tensor], + lb: torch.Tensor, + ub: torch.Tensor, + window: torch.Tensor, + points: torch.Tensor, samples_per_point: int, dim: Optional[int] = None, shuffle: bool = True, @@ -120,7 +125,10 @@ def __init__( Args: lb (Union[np.ndarray, torch.Tensor]): Lower bounds of each parameter. ub (Union[np.ndarray, torch.Tensor]): Upper bounds of each parameter. - window (Union[np.ndarray, torch.Tensor]): How far away to sample from the reference point along each dimension. + window (Union[np.ndarray, torch.Tensor]): How far away to sample from the + reference point along each dimension. If the parameters are transformed, + the proportion of the range (based on ub/lb given) covered by the window + will be preserved (and not the absolute distance from the reference points). points (Union[np.ndarray, torch.Tensor]): The points that will be generated. samples_per_point (int): How many samples around each point to take. dim (int, optional): Dimensionality of the parameter space. If None, it is inferred from lb and ub. @@ -128,16 +136,27 @@ def __init__( seed (int, optional): Random seed. """ lb, ub, dim = _process_bounds(lb, ub, dim) - points = torch.Tensor(points) self.engine = SobolEngine(dimension=dim, scramble=True, seed=seed) - generated = [] + gen_points = [] + if len(points.shape) > 2: + # We need to determine how many stimuli there are per trial to maintain the proper tensor shape + n_draws = points.shape[-1] + else: + n_draws = 1 for point in points: + if len(points.shape) > 2: + point = point.T p_lb = torch.max(point - window, lb) p_ub = torch.min(point + window, ub) - grid = self.engine.draw(samples_per_point) - grid = p_lb + (p_ub - p_lb) * grid - generated.append(grid) - generated = torch.Tensor(np.vstack(generated)) # type: ignore + for _ in range(samples_per_point): + grid = self.engine.draw(n_draws) + grid = p_lb + (p_ub - p_lb) * grid + gen_points.append(grid) + if len(points.shape) > 2: + generated = torch.stack(gen_points) + generated = generated.swapaxes(-2, -1) + else: + generated = torch.vstack(gen_points) super().__init__(lb, ub, generated, dim, shuffle, seed) # type: ignore diff --git a/aepsych/strategy.py b/aepsych/strategy.py index aaa5da335..5e85d02d6 100644 --- a/aepsych/strategy.py +++ b/aepsych/strategy.py @@ -238,7 +238,7 @@ def __init__( if self.stimuli_per_trial == 1: self.event_shape: Tuple[int, ...] = (self.dim,) - if self.stimuli_per_trial == 2: + if self.stimuli_per_trial > 1: self.event_shape = (self.dim, self.stimuli_per_trial) self.model = model diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index 664c84405..0aad41d85 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -807,6 +807,24 @@ def transform_options( if option in ["ub", "lb"]: value = transforms.transform_bounds(value, bound=option) + elif option == "points" and len(value.shape) == 3: + value = value.swapaxes(-2, -1) + value = transforms.transform(value) + value = value.swapaxes(-1, -2) + elif option == "window": + # Get proportion of range covered in raw space + raw_lb = config.gettensor("common", "lb") + raw_ub = config.gettensor("common", "ub") + raw_range = raw_ub - raw_lb + window_prop = (value * 2) / raw_range + + # Calculate transformed range + transformed_lb = transforms.transform_bounds(raw_lb, bound="lb") + transformed_ub = transforms.transform_bounds(raw_ub, bound="ub") + transformed_range = transformed_ub - transformed_lb + + # Transformed window covers the same proportion of range + value = window_prop * transformed_range / 2 else: value = transforms.transform(value) diff --git a/tests/generators/test_manual_generator.py b/tests/generators/test_manual_generator.py index e517b2808..41f1dd9b4 100644 --- a/tests/generators/test_manual_generator.py +++ b/tests/generators/test_manual_generator.py @@ -102,6 +102,49 @@ def test_sample_around_points_generator(self): self.assertTrue(gen.finished) + def test_sample_around_points_generator_high_dim(self): + points = [ + [[-1.5, 1], [-1, 1.25], [-2, 1.75]], + [[-1.25, 1.25], [-1.75, 1.5], [-1.0, 2]], + ] + window = [0.25, 0.1] + samples_per_point = 2 + lb = [-2, 1] + ub = [-1, 2] + config_str = f""" + [common] + lb = {lb} + ub = {ub} + parnames = [par1, par2] + stimuli_per_trial = 3 + + [SampleAroundPointsGenerator] + points = {points} + window = {window} + samples_per_point = {samples_per_point} + seed = 123 + """ + config = Config() + config.update(config_str=config_str) + gen = SampleAroundPointsGenerator.from_config(config) + npt.assert_equal(gen.lb.numpy(), np.array(lb)) + npt.assert_equal(gen.ub.numpy(), np.array(ub)) + self.assertEqual(gen.max_asks, len(points * samples_per_point)) + self.assertEqual(gen.seed, 123) + self.assertFalse(gen.finished) + + points = gen.gen(gen.max_asks) + for i in range(len(window)): + npt.assert_array_less(points[:, i, :], points[:, i, :] + window[i]) + npt.assert_array_less( + np.ones(points[:, i, :].shape) * lb[i], points[:, i, :] + ) + npt.assert_array_less( + points[:, i, :], np.ones(points[:, i, :].shape) * ub[i] + ) + + self.assertTrue(gen.finished) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_config.py b/tests/test_config.py index 056778f08..397fb07a2 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1274,10 +1274,8 @@ def test_clone_transform_options(self): transforms = ParameterTransforms.from_config(config) reversed_points = transforms.untransform(xformed_points) - reversed_window = transforms.untransform(xformed_window) self.assertTrue(torch.allclose(reversed_points, torch.tensor(points))) - self.assertTrue(torch.allclose(reversed_window, torch.tensor(window))) def test_build_transform(self): config_str = """ diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 4a1ca50ec..146da3a48 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -151,6 +151,56 @@ def test_options_override(self): self.assertTrue(transform.constant == 5) self.assertTrue(transform.indices[0] == 0) + def test_transform_manual_generator(self): + base_points = [ + [[-1.5, 1], [-1, 1.25], [-2, 1.75]], + [[-1.25, 1.25], [-1.75, 1.5], [-1.0, 2]], + ] + window = [0.25, 0.1] + samples_per_point = 2 + lb = [-3, 1] + ub = [-1, 3] + config_str = f""" + [common] + parnames = [par1, par2] + stimuli_per_trial = 3 + outcome_types = [binary] + strategy_names = [init_strat] + + [par1] + par_type = continuous + lower_bound = {lb[0]} + upper_bound = {ub[0]} + + [par2] + par_type = continuous + lower_bound = {lb[1]} + upper_bound = {ub[1]} + + [init_strat] + generator = SampleAroundPointsGenerator + + [SampleAroundPointsGenerator] + points = {base_points} + window = {window} + samples_per_point = {samples_per_point} + seed = 123 + """ + config = Config() + config.update(config_str=config_str) + + strat = SequentialStrategy.from_config(config) + + nPoints = 0 + while not strat.finished: + points = strat.gen() + strat.add_data(points, torch.tensor(1.0)) + self.assertTrue(torch.all(points[0, 0, :] < 0)) + self.assertTrue(torch.all(points[0, 1, :] > 0)) + nPoints += 1 + + self.assertTrue(nPoints == len(base_points) * samples_per_point) + class TransformsLog10Test(unittest.TestCase): def test_transform_reshape3D(self):