Skip to content

Commit

Permalink
Enhance Docstrings in aepsych/factory (#428)
Browse files Browse the repository at this point in the history
Summary:
One of several PRs addressing issue #366 to improve docstring coverage.

Improves documentation in `aepsych/factory` for better clarity and consistency.
- Adds missing docstrings to functions and methods across factory modules.
- Updates existing docstrings with refined type hints and a unified structure.

Pull Request resolved: #428

Reviewed By: crasanders

Differential Revision: D65950795

Pulled By: JasonKChow

fbshipit-source-id: 0a18cf0ca76d05574718f11352b8b0587d54f912
  • Loading branch information
yalsaffar authored and facebook-github-bot committed Nov 20, 2024
1 parent d096c6a commit 10100e4
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 3 deletions.
23 changes: 22 additions & 1 deletion aepsych/factory/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def default_mean_covar_factory(
config details).
dim (int, optional): Dimensionality of the parameter space. Must be provided
if config is None.
stimuli_per_trial (int): Number of stimuli per trial. Defaults to 1.
Returns:
Tuple[gpytorch.means.Mean, gpytorch.kernels.Kernel]: Instantiated
Expand Down Expand Up @@ -68,6 +69,14 @@ def default_mean_covar_factory(
def _get_default_mean_function(
config: Optional[Config] = None,
) -> gpytorch.means.ConstantMean:
"""Creates a default mean function for Gaussian Processes.
Args:
config (Config, optional): Configuration object.
Returns:
gpytorch.means.ConstantMean: An instantiated mean function with appropriate priors based on the configuration.
"""
# default priors
fixed_mean = False
mean = gpytorch.means.ConstantMean()
Expand All @@ -93,6 +102,18 @@ def _get_default_cov_function(
stimuli_per_trial: int,
active_dims: Optional[List[int]] = None,
) -> gpytorch.kernels.Kernel:
"""Creates a default covariance function for Gaussian Processes.
Args:
config (Config, optional): Configuration object.
dim (int): Dimensionality of the parameter space.
stimuli_per_trial (int): Number of stimuli per trial.
active_dims (List[int], optional): List of dimensions to use in the covariance function. Defaults to None.
Returns:
gpytorch.kernels.Kernel: An instantiated kernel with appropriate priors based on the configuration.
"""

# default priors
lengthscale_prior = "lognormal" if stimuli_per_trial == 1 else "gamma"
ls_loc = torch.tensor(math.sqrt(2.0), dtype=torch.float64)
Expand Down Expand Up @@ -178,4 +199,4 @@ def _get_default_cov_function(
outputscale_prior=os_prior,
outputscale_constraint=gpytorch.constraints.GreaterThan(1e-4),
)
return covar
return covar
12 changes: 11 additions & 1 deletion aepsych/factory/ordinal.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@
def ordinal_mean_covar_factory(
config: Config,
) -> Tuple[gpytorch.means.ConstantMean, gpytorch.kernels.ScaleKernel]:
""" Create a mean and covariance function for ordinal GPs.
Args:
config (Config): Config object containing bounds.
Returns:
Tuple[gpytorch.means.ConstantMean, gpytorch.kernels.ScaleKernel]: A tuple containing
the mean function (ConstantMean) and the covariance function (ScaleKernel).
"""

try:
base_factory = config.getobj("ordinal_mean_covar_factory", "base_factory")
except NoOptionError:
Expand All @@ -31,4 +41,4 @@ def ordinal_mean_covar_factory(
else:
covar = base_covar

return mean, covar
return mean, covar
10 changes: 9 additions & 1 deletion aepsych/factory/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
def pairwise_mean_covar_factory(
config: Config,
) -> Tuple[gpytorch.means.ConstantMean, gpytorch.kernels.ScaleKernel]:
""" Creates a mean and covariance function for pairwise GPs.
Args:
config (Config): Config object containing bounds.
Returns:
Tuple[gpytorch.means.ConstantMean, gpytorch.kernels.ScaleKernel]: A tuple containing
the mean function (ConstantMean) and the covariance function (ScaleKernel)."""
lb = config.gettensor("common", "lb")
ub = config.gettensor("common", "ub")
assert lb.shape[0] == ub.shape[0], "bounds shape mismatch!"
Expand Down Expand Up @@ -65,4 +73,4 @@ def pairwise_mean_covar_factory(
cov = _get_default_cov_function(config, config_dim // 2, stimuli_per_trial=1)
covar = PairwiseKernel(cov)

return mean, covar
return mean, covar

0 comments on commit 10100e4

Please sign in to comment.