Skip to content

Commit

Permalink
fully revert tensor changes in dim_grid (#460)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #460

dim_grid tensor changes were only half reverted for some reason. This fully reverts it and fixes a bug where dim_grid was producing double the number of points when slicing.

Reviewed By: crasanders

Differential Revision: D66374663

fbshipit-source-id: c2226b59c6320bedf74c3374f6ae1347f47e4e26
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Nov 25, 2024
1 parent deb8ce1 commit aa36126
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
10 changes: 3 additions & 7 deletions aepsych/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,11 @@ def dim_grid(

for i in range(dim):
if i in slice_dims.keys():
mesh_vals.append(
torch.tensor([slice_dims[i] - 1e-10, slice_dims[i] + 1e-10])
)
mesh_vals.append(slice(slice_dims[i] - 1e-10, slice_dims[i] + 1e-10, 1))
else:
mesh_vals.append(torch.linspace(lower[i].item(), upper[i].item(), gridsize))
mesh_vals.append(slice(lower[i].item(), upper[i].item(), gridsize * 1j))

return torch.stack(torch.meshgrid(*mesh_vals, indexing="ij"), dim=-1).reshape(
-1, dim
)
return torch.Tensor(np.mgrid[mesh_vals].reshape(dim, -1).T)


def _process_bounds(
Expand Down
9 changes: 8 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import torch
from aepsych.models import GPClassificationModel
from aepsych.utils import _process_bounds, make_scaled_sobol
from aepsych.utils import _process_bounds, dim_grid, make_scaled_sobol


class UtilsTestCase(unittest.TestCase):
Expand All @@ -35,6 +35,13 @@ def test_dim_grid_model_size(self):
grid = GPClassificationModel.dim_grid(mb, gridsize=gridsize)
self.assertEqual(grid.shape, torch.Size([10, 1]))

def test_dim_grid_slice(self):
lb = torch.tensor([0, 0, 0])
ub = torch.tensor([1, 1, 1])
grid = dim_grid(lb, ub, slice_dims={1: 0.5})

self.assertTrue(np.all(grid.shape == (900, 3)))

def test_process_bounds(self):
lb, ub, dim = _process_bounds(np.r_[0, 1], np.r_[2, 3], None)
self.assertTrue(torch.all(lb == torch.tensor([0.0, 1.0])))
Expand Down

0 comments on commit aa36126

Please sign in to comment.