Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

linearized_forward_diag API #57

Open
SamDuffield opened this issue Apr 16, 2024 · 1 comment
Open

linearized_forward_diag API #57

SamDuffield opened this issue Apr 16, 2024 · 1 comment
Labels
enhancement New feature or request (beyond just a new method)

Comments

@SamDuffield
Copy link
Contributor

Currently we have an API like

vals, chol, aux = linearized_forward_diag(f, params, batch, sd_diag)

but perhaps and API like

vals, chol, aux = linearized_forward_diag(f, sd_diag)(params, batch)

might be cleaner as it provides a new function that retains the required signature of f. It's also better fitting with the torch.func API

@SamDuffield SamDuffield added the enhancement New feature or request (beyond just a new method) label Apr 16, 2024
@SamDuffield
Copy link
Contributor Author

Additionally considerations:

  • Should we add an additional observation_noise_sd (could be a different name) that corresponds to $\sigma$ in the predictive distribution $N( f \mid f^*, J^T \Sigma J + \sigma^2 \mathbb{I})$, details in https://arxiv.org/abs/1906.11537
  • Should we remove the internal no_grad to allow user more flexibility? Although majority of cases the function would still need to be wrapped in no_grad to avoid memory usage from undesired gradients

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request (beyond just a new method)
Projects
None yet
Development

No branches or pull requests

1 participant