diff --git a/aepsych/models/monotonic_rejection_gp.py b/aepsych/models/monotonic_rejection_gp.py index 86eceb304..a541c60fe 100644 --- a/aepsych/models/monotonic_rejection_gp.py +++ b/aepsych/models/monotonic_rejection_gp.py @@ -29,7 +29,6 @@ from gpytorch.models import ApproximateGP from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy from scipy.stats import norm -from torch import Tensor class MonotonicRejectionGP(AEPsychMixin, ApproximateGP): @@ -148,12 +147,12 @@ def __init__( self.fixed_prior_mean = fixed_prior_mean self.inducing_points = inducing_points - def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None: + def fit(self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs) -> None: """Fit the model Args: - train_x (Tensor): Training x points - train_y (Tensor): Training y points. Should be (n x 1). + train_x (torch.Tensor): Training x points + train_y (torch.Tensor): Training y points. Should be (n x 1). """ self.set_train_data(train_x, train_y) @@ -168,18 +167,18 @@ def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None: def _set_model( self, - train_x: Tensor, - train_y: Tensor, - model_state_dict: Optional[Dict[str, Tensor]] = None, - likelihood_state_dict: Optional[Dict[str, Tensor]] = None, + train_x: torch.Tensor, + train_y: torch.Tensor, + model_state_dict: Optional[Dict[str, torch.Tensor]] = None, + likelihood_state_dict: Optional[Dict[str, torch.Tensor]] = None, ) -> None: """Sets the model with the given data and state dicts. Args: - train_x (Tensor): Training x points - train_y (Tensor): Training y points. Should be (n x 1). - model_state_dict (Dict[str, Tensor], optional): State dict for the model - likelihood_state_dict (Dict[str, Tensor], optional): State dict for the likelihood + train_x (torch.Tensor): Training x points + train_y (torch.Tensor): Training y points. Should be (n x 1). + model_state_dict (Dict[str, torch.Tensor], optional): State dict for the model + likelihood_state_dict (Dict[str, torch.Tensor], optional): State dict for the likelihood """ train_x_aug = self._augment_with_deriv_index(train_x, 0) self.set_train_data(train_x_aug, train_y) @@ -195,15 +194,15 @@ def _set_model( ) mll = fit_gpytorch_mll(mll) - def update(self, train_x: Tensor, train_y: Tensor, warmstart: bool = True) -> None: + def update(self, train_x: torch.Tensor, train_y: torch.Tensor, warmstart: bool = True) -> None: """ Update the model with new data. Expects the full set of data, not the incremental new data. Args: - train_x (Tensor): Train X. - train_y (Tensor): Train Y. Should be (n x 1). + train_x (torch.Tensor): Train X. + train_y (torch.Tensor): Train Y. Should be (n x 1). warmstart (bool): If True, warm-start model fitting with current parameters. Defaults to True. """ if warmstart: @@ -221,14 +220,14 @@ def update(self, train_x: Tensor, train_y: Tensor, warmstart: bool = True) -> No def sample( self, - x: Tensor, + x: torch.Tensor, num_samples: Optional[int] = None, num_rejection_samples: Optional[int] = None, ) -> torch.Tensor: """Sample from monotonic GP Args: - x (Tensor): tensor of n points at which to sample + x (torch.Tensor): tensor of n points at which to sample num_samples (int, optional): how many points to sample. Default is self.num_samples. num_rejection_samples (int): how many samples to use for rejection sampling. Default is self.num_rejection_samples. @@ -268,15 +267,16 @@ def sample( return samples_f def predict( - self, x: Tensor, probability_space: bool = False - ) -> Tuple[Tensor, Tensor]: + self, x: torch.Tensor, probability_space: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: """Predict Args: x (torch.Tensor): tensor of n points at which to predict. probability_space (bool): whether to return in probability space. Defaults to False. - Returns: tuple (f, var) where f is (n,) and var is (n,) + Returns: + Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at query points. """ samples_f = self.sample(x) mean = torch.mean(samples_f, dim=0).squeeze() @@ -303,22 +303,22 @@ def predict_probability( """ return self.predict(x, probability_space=True) - def _augment_with_deriv_index(self, x: Tensor, indx: int) -> Tensor: + def _augment_with_deriv_index(self, x: torch.Tensor, indx: int) -> torch.Tensor: """Augment input with derivative index Args: - x (Tensor): Input tensor + x (torch.Tensor): Input tensor indx (int): Derivative index Returns: - Tensor: Augmented tensor + torch.Tensor: Augmented tensor """ return torch.cat( (x, indx * torch.ones(x.shape[0], 1)), dim=1, ) - def _get_deriv_constraint_points(self) -> Tensor: + def _get_deriv_constraint_points(self) -> torch.Tensor: """Get derivative constraint points""" deriv_cp = torch.tensor([]) for i in self.monotonic_idxs: