From 1b09650a21b1c65ee4d0d8dbcb12fd853af1153a Mon Sep 17 00:00:00 2001 From: maffettone Date: Thu, 3 Oct 2024 11:19:06 -0400 Subject: [PATCH] enh: add normaliztion and standardization to workflow --- pdf_agents/scientific_value.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/pdf_agents/scientific_value.py b/pdf_agents/scientific_value.py index a394fd5..f881601 100644 --- a/pdf_agents/scientific_value.py +++ b/pdf_agents/scientific_value.py @@ -9,6 +9,7 @@ from botorch.acquisition import UpperConfidenceBound, qUpperConfidenceBound from botorch.models import SingleTaskGP from botorch.optim import optimize_acqf # noqa: F401 +from botorch.utils.transforms import normalize, standardize, unnormalize from gpytorch.mlls import ExactMarginalLogLikelihood from scipy.spatial import distance_matrix @@ -132,7 +133,9 @@ def ask(self, batch_size: int = 1): train_x = torch.tensor(self.independent_cache, dtype=torch.double, device=self.device) if train_x.dim() == 1: train_x = train_x.view(-1, 1) - train_y = torch.tensor(value, dtype=torch.double, device=self.device) + norm_bounds = torch.stack([train_x.min(dim=0).values, train_x.max(dim=0).values]) + train_x = normalize(train_x, norm_bounds) + train_y = standardize(torch.tensor(value, dtype=torch.double, device=self.device)) gp = SingleTaskGP(train_x, train_y).to(self.device) mll = ExactMarginalLogLikelihood(gp.likelihood, gp).to(self.device) fit_gpytorch_mll(mll) @@ -146,12 +149,15 @@ def ask(self, batch_size: int = 1): # candidates, acq_value = optimize_acqf( # acq, bounds=self.bounds, q=batch_size, num_restarts=self.num_restarts, raw_samples=self.raw_samples # ) - grid = torch.tensor(make_wafer_grid_list(*self.bounds.cpu().numpy().ravel(), step=self.motor_resolution))[ - :, None, : - ] + grid = normalize( + torch.tensor(make_wafer_grid_list(*self.bounds.cpu().numpy().ravel(), step=self.motor_resolution))[ + :, None, : + ], + norm_bounds, + ) acq_grid = acq(grid) top_indicies = torch.argsort(acq_grid, descending=True, dim=0)[:batch_size] - candidates = grid[top_indicies].squeeze(1) + candidates = unnormalize(grid, norm_bounds)[top_indicies].squeeze(1) acq_value = acq_grid[top_indicies] if batch_size == 1: