Skip to content

Commit

Permalink
Update documentation (add to index and fix alphabetical order) and fi…
Browse files Browse the repository at this point in the history
…x imports
  • Loading branch information
jcqcai committed Jun 1, 2024
1 parent 75c3f13 commit 7ade15d
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
4 changes: 4 additions & 0 deletions docs/api/index.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# API

### Extended Kalman filter (EKF)
- [`ekf.dense_fisher`](ekf/dense_fisher.md) applies an online Bayesian update based
on a Taylor approximation of the log-likelihood. Uses the empirical Fisher
information matrix as a positive-definite alternative to the Hessian.
Natural gradient descent equivalence following [Ollivier, 2019](https://arxiv.org/abs/1703.00209).
- [`ekf.diag_fisher`](ekf/diag_fisher.md) applies an online Bayesian update based
on a Taylor approximation of the log-likelihood. Uses the diagonal empirical Fisher
information matrix as a positive-definite alternative to the Hessian.
Expand Down
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ nav:
- API:
- api/index.md
- EKF:
- Diagonal Fisher: api/ekf/diag_fisher.md
- Dense Fisher: api/ekf/dense_fisher.md
- Diagonal Fisher: api/ekf/diag_fisher.md
- Laplace:
- Dense Fisher: api/laplace/dense_fisher.md
- Dense GGN: api/laplace/dense_ggn.md
Expand Down
4 changes: 2 additions & 2 deletions posteriors/ekf/dense_fisher.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any
from functools import partial
import torch
from torch.func import jacrev
from optree import tree_map
from torch.func import grad_and_value
from dataclasses import dataclass
from optree.integration.torch import tree_ravel

Expand Down Expand Up @@ -149,6 +148,7 @@ def update(
log_likelihood = per_samplify(log_likelihood)

with torch.no_grad(), CatchAuxError():

def log_likelihood_reduced(params, batch):
per_samp_log_lik, internal_aux = log_likelihood(params, batch)
return per_samp_log_lik.mean(), internal_aux
Expand Down

0 comments on commit 7ade15d

Please sign in to comment.