Skip to content

Commit

Permalink
fix porposed changes to docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
yalsaffar committed Nov 8, 2024
1 parent 7c9752a commit 0f383f8
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions aepsych/models/monotonic_rejection_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy
from scipy.stats import norm
from torch import Tensor


class MonotonicRejectionGP(AEPsychMixin, ApproximateGP):
Expand Down Expand Up @@ -148,12 +147,12 @@ def __init__(
self.fixed_prior_mean = fixed_prior_mean
self.inducing_points = inducing_points

def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None:
def fit(self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs) -> None:
"""Fit the model
Args:
train_x (Tensor): Training x points
train_y (Tensor): Training y points. Should be (n x 1).
train_x (torch.Tensor): Training x points
train_y (torch.Tensor): Training y points. Should be (n x 1).
"""
self.set_train_data(train_x, train_y)

Expand All @@ -168,18 +167,18 @@ def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None:

def _set_model(
self,
train_x: Tensor,
train_y: Tensor,
model_state_dict: Optional[Dict[str, Tensor]] = None,
likelihood_state_dict: Optional[Dict[str, Tensor]] = None,
train_x: torch.Tensor,
train_y: torch.Tensor,
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
likelihood_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
"""Sets the model with the given data and state dicts.
Args:
train_x (Tensor): Training x points
train_y (Tensor): Training y points. Should be (n x 1).
model_state_dict (Dict[str, Tensor], optional): State dict for the model
likelihood_state_dict (Dict[str, Tensor], optional): State dict for the likelihood
train_x (torch.Tensor): Training x points
train_y (torch.Tensor): Training y points. Should be (n x 1).
model_state_dict (Dict[str, torch.Tensor], optional): State dict for the model
likelihood_state_dict (Dict[str, torch.Tensor], optional): State dict for the likelihood
"""
train_x_aug = self._augment_with_deriv_index(train_x, 0)
self.set_train_data(train_x_aug, train_y)
Expand All @@ -195,15 +194,15 @@ def _set_model(
)
mll = fit_gpytorch_mll(mll)

def update(self, train_x: Tensor, train_y: Tensor, warmstart: bool = True) -> None:
def update(self, train_x: torch.Tensor, train_y: torch.Tensor, warmstart: bool = True) -> None:
"""
Update the model with new data.
Expects the full set of data, not the incremental new data.
Args:
train_x (Tensor): Train X.
train_y (Tensor): Train Y. Should be (n x 1).
train_x (torch.Tensor): Train X.
train_y (torch.Tensor): Train Y. Should be (n x 1).
warmstart (bool): If True, warm-start model fitting with current parameters. Defaults to True.
"""
if warmstart:
Expand All @@ -221,14 +220,14 @@ def update(self, train_x: Tensor, train_y: Tensor, warmstart: bool = True) -> No

def sample(
self,
x: Tensor,
x: torch.Tensor,
num_samples: Optional[int] = None,
num_rejection_samples: Optional[int] = None,
) -> torch.Tensor:
"""Sample from monotonic GP
Args:
x (Tensor): tensor of n points at which to sample
x (torch.Tensor): tensor of n points at which to sample
num_samples (int, optional): how many points to sample. Default is self.num_samples.
num_rejection_samples (int): how many samples to use for rejection sampling. Default is self.num_rejection_samples.
Expand Down Expand Up @@ -268,15 +267,16 @@ def sample(
return samples_f

def predict(
self, x: Tensor, probability_space: bool = False
) -> Tuple[Tensor, Tensor]:
self, x: torch.Tensor, probability_space: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predict
Args:
x (torch.Tensor): tensor of n points at which to predict.
probability_space (bool): whether to return in probability space. Defaults to False.
Returns: tuple (f, var) where f is (n,) and var is (n,)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at query points.
"""
samples_f = self.sample(x)
mean = torch.mean(samples_f, dim=0).squeeze()
Expand All @@ -303,22 +303,22 @@ def predict_probability(
"""
return self.predict(x, probability_space=True)

def _augment_with_deriv_index(self, x: Tensor, indx: int) -> Tensor:
def _augment_with_deriv_index(self, x: torch.Tensor, indx: int) -> torch.Tensor:
"""Augment input with derivative index
Args:
x (Tensor): Input tensor
x (torch.Tensor): Input tensor
indx (int): Derivative index
Returns:
Tensor: Augmented tensor
torch.Tensor: Augmented tensor
"""
return torch.cat(
(x, indx * torch.ones(x.shape[0], 1)),
dim=1,
)

def _get_deriv_constraint_points(self) -> Tensor:
def _get_deriv_constraint_points(self) -> torch.Tensor:
"""Get derivative constraint points"""
deriv_cp = torch.tensor([])
for i in self.monotonic_idxs:
Expand Down

0 comments on commit 0f383f8

Please sign in to comment.