Skip to content

Commit

Permalink
Merge pull request #96 from jcqcai/main
Browse files Browse the repository at this point in the history
Dense Fisher EKF
  • Loading branch information
SamDuffield authored Jun 3, 2024
2 parents 1585144 + 5e6ca0b commit 24d79e8
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 2 deletions.
7 changes: 7 additions & 0 deletions docs/api/ekf/dense_fisher.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
title: EKF Dense Fisher
---

# EKF Dense Fisher

::: posteriors.ekf.dense_fisher
6 changes: 4 additions & 2 deletions docs/api/index.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# API

### Extended Kalman filter (EKF)
- [`ekf.diag_fisher`](ekf/diag_fisher.md) applies an online Bayesian update based
on a Taylor approximation of the log-likelihood. Uses the diagonal empirical Fisher
- [`ekf.dense_fisher`](ekf/dense_fisher.md) applies an online Bayesian update based
on a Taylor approximation of the log-likelihood. Uses the empirical Fisher
information matrix as a positive-definite alternative to the Hessian.
Natural gradient descent equivalence following [Ollivier, 2019](https://arxiv.org/abs/1703.00209).
- [`ekf.diag_fisher`](ekf/diag_fisher.md) same as `ekf.dense_fisher` but
uses the diagonal of the empirical Fisher information matrix instead.

### Laplace approximation
- [`laplace.dense_fisher`](laplace/dense_fisher.md) calculates the empirical Fisher
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ nav:
- API:
- api/index.md
- EKF:
- Dense Fisher: api/ekf/dense_fisher.md
- Diagonal Fisher: api/ekf/diag_fisher.md
- Laplace:
- Dense Fisher: api/laplace/dense_fisher.md
Expand Down
1 change: 1 addition & 0 deletions posteriors/ekf/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from posteriors.ekf import diag_fisher
from posteriors.ekf import dense_fisher
202 changes: 202 additions & 0 deletions posteriors/ekf/dense_fisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
from typing import Any
from functools import partial
import torch
from torch.func import grad_and_value
from dataclasses import dataclass
from optree.integration.torch import tree_ravel

from posteriors.tree_utils import tree_size

from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.utils import (
per_samplify,
empirical_fisher,
is_scalar,
CatchAuxError,
)


def build(
log_likelihood: LogProbFn,
lr: float,
transition_cov: torch.Tensor | float = 0.0,
per_sample: bool = False,
init_cov: torch.Tensor | float = 1.0,
) -> Transform:
"""Builds a transform to implement an extended Kalman Filter update.
EKF applies an online update to a Gaussian posterior over the parameters.
The approximate Bayesian update is based on the linearization
$$
\\log p(θ | y) ≈ \\log p(θ) + ε g(μ)ᵀ(θ - μ) + \\frac12 ε (θ - μ)^T F(μ) (θ - μ)
$$
where $μ$ is the mean of the prior distribution, $ε$ is the learning rate
(or equivalently the likelihood inverse temperature),
$g(μ)$ is the gradient of the log likelihood at μ and $F(μ)$ is the
empirical Fisher information matrix at $μ$ for data $y$.
For more information on extended Kalman filtering as well as an equivalence
to (online) natural gradient descent see [Ollivier, 2019](https://arxiv.org/abs/1703.00209).
Args:
log_likelihood: Function that takes parameters and input batch and
returns the log-likelihood value as well as auxiliary information,
e.g. from the model call.
lr: Inverse temperature of the update, which behaves like a learning rate.
transition_cov: Covariance of the transition noise, to additively
inflate the covariance before the update.
per_sample: If True, then log_likelihood is assumed to return a vector of
log likelihoods for each sample in the batch. If False, then log_likelihood
is assumed to return a scalar log likelihood for the whole batch, in this
case torch.func.vmap will be called, this is typically slower than
directly writing log_likelihood to be per sample.
init_cov: Initial covariance of the Normal distribution. Can be torch.Tensor or scalar.
Returns:
EKF transform instance.
"""
init_fn = partial(init, init_cov=init_cov)
update_fn = partial(
update,
log_likelihood=log_likelihood,
lr=lr,
transition_cov=transition_cov,
per_sample=per_sample,
)
return Transform(init_fn, update_fn)


@dataclass
class EKFDenseState(TransformState):
"""State encoding a Normal distribution over parameters.
Args:
params: Mean of the Normal distribution.
cov: Covariance matrix of the
Normal distribution.
log_likelihood: Log likelihood of the data given the parameters.
aux: Auxiliary information from the log_likelihood call.
"""

params: TensorTree
cov: torch.Tensor
log_likelihood: float = 0
aux: Any = None


def init(
params: TensorTree,
init_cov: torch.Tensor | float = 1.0,
) -> EKFDenseState:
"""Initialise Multivariate Normal distribution over parameters.
Args:
params: Initial mean of the Normal distribution.
init_cov: Initial covariance matrix of the Multivariate Normal distribution.
If it is a float, it is defined as an identity matrix scaled by that float.
Returns:
Initial EKFDenseState.
"""
if is_scalar(init_cov):
num_params = tree_size(params)
init_cov = init_cov * torch.eye(num_params, requires_grad=False)

return EKFDenseState(params, init_cov)


def update(
state: EKFDenseState,
batch: Any,
log_likelihood: LogProbFn,
lr: float,
transition_cov: torch.Tensor | float = 0.0,
per_sample: bool = False,
inplace: bool = False,
) -> EKFDenseState:
"""Applies an extended Kalman Filter update to the Multivariate Normal distribution.
The approximate Bayesian update is based on the linearization
$$
\\log p(θ | y) ≈ \\log p(θ) + ε g(μ)ᵀ(θ - μ) + \\frac12 ε (θ - μ)^T F(μ) (θ - μ)
$$
where $μ$ is the mean of the prior distribution, $ε$ is the learning rate
(or equivalently the likelihood inverse temperature),
$g(μ)$ is the gradient of the log likelihood at μ and $F(μ)$ is the
empirical Fisher information matrix at $μ$ for data $y$.
Args:
state: Current state.
batch: Input data to log_likelihood.
log_likelihood: Function that takes parameters and input batch and
returns the log-likelihood value as well as auxiliary information,
e.g. from the model call.
lr: Inverse temperature of the update, which behaves like a learning rate.
transition_cov: Covariance of the transition noise, to additively
inflate the covariance before the update.
per_sample: If True, then log_likelihood is assumed to return a vector of
log likelihoods for each sample in the batch. If False, then log_likelihood
is assumed to return a scalar log likelihood for the whole batch, in this
case torch.func.vmap will be called, this is typically slower than
directly writing log_likelihood to be per sample.
inplace: Whether to update the state parameters in-place.
Returns:
Updated EKFDenseState.
"""
if not per_sample:
log_likelihood = per_samplify(log_likelihood)

with torch.no_grad(), CatchAuxError():

def log_likelihood_reduced(params, batch):
per_samp_log_lik, internal_aux = log_likelihood(params, batch)
return per_samp_log_lik.mean(), internal_aux

grad, (log_liks, aux) = grad_and_value(log_likelihood_reduced, has_aux=True)(
state.params, batch
)
fisher, _ = empirical_fisher(
lambda p: log_likelihood(p, batch), has_aux=True, normalize=True
)(state.params)

predict_cov = state.cov + transition_cov
predict_cov_inv = torch.cholesky_inverse(torch.linalg.cholesky(predict_cov))
update_cov_inv = predict_cov_inv - lr * fisher
update_cov = torch.cholesky_inverse(torch.linalg.cholesky(update_cov_inv))

mu_raveled, mu_unravel_f = tree_ravel(state.params)
update_mean = mu_raveled + lr * update_cov @ tree_ravel(grad)[0]
update_mean = mu_unravel_f(update_mean)

if inplace:
state.params = update_mean
state.cov = update_cov
state.log_likelihood = log_liks.mean().detach()
state.aux = aux
return state
return EKFDenseState(update_mean, update_cov, log_liks.mean().detach(), aux)


def sample(
state: EKFDenseState, sample_shape: torch.Size = torch.Size([])
) -> TensorTree:
"""Single sample from Multivariate Normal distribution over parameters.
Args:
state: State encoding mean and covariance.
sample_shape: Shape of the desired samples.
Returns:
Sample(s) from Multivariate Normal distribution.
"""
mean_flat, unravel_func = tree_ravel(state.params)

samples = torch.distributions.MultivariateNormal(
loc=mean_flat,
covariance_matrix=state.cov,
validate_args=False,
).sample(sample_shape)

samples = torch.vmap(unravel_func)(samples)
return samples
63 changes: 63 additions & 0 deletions tests/ekf/test_dense_fisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
from optree import tree_map
from torch.distributions import MultivariateNormal
from optree.integration.torch import tree_ravel
from posteriors.tree_utils import tree_size
from posteriors import ekf


def test_ekf_dense():
torch.manual_seed(42)
target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)}
num_params = tree_size(target_mean)
A = torch.randn(num_params, num_params)
target_cov = torch.mm(A.t(), A)

dist = MultivariateNormal(tree_ravel(target_mean)[0], covariance_matrix=target_cov)

def log_prob(p, b):
return dist.log_prob(tree_ravel(p)[0]).sum(), torch.Tensor([])

init_mean = tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean)
batch = torch.arange(3).reshape(-1, 1)
n_steps = 1000
transform = ekf.dense_fisher.build(log_prob, lr=1e-1)

# Test inplace = False
state = transform.init(init_mean)
log_liks = []
for _ in range(n_steps):
state = transform.update(state, batch, inplace=False)
log_liks.append(state.log_likelihood.item())

assert log_liks[0] < log_liks[-1]

for key in state.params:
assert torch.allclose(state.params[key], target_mean[key], atol=1e-1)
assert not torch.allclose(state.params[key], init_mean[key])

# Test inplace = True
state = transform.init(init_mean)
log_liks = []
for _ in range(n_steps):
state = transform.update(state, batch, inplace=True)
log_liks.append(state.log_likelihood.item())

for key in state.params:
assert torch.allclose(state.params[key], target_mean[key], atol=1e-1)
assert not torch.allclose(state.params[key], init_mean[key])

# Test sample
num_samples = 1000
samples = ekf.dense_fisher.sample(state, (num_samples,))

flat_samples = torch.vmap(lambda s: tree_ravel(s)[0])(samples)
samples_cov = torch.cov(flat_samples.T)

mean_copy = tree_map(lambda x: x.clone(), state.params)
samples_mean = tree_map(lambda x: x.mean(dim=0), samples)

assert torch.allclose(samples_cov, state.cov, atol=1e-1)
for key in samples_mean:
assert torch.allclose(samples_mean[key], state.params[key], atol=1e-1)
assert not torch.allclose(samples_mean[key], mean_copy[key])

0 comments on commit 24d79e8

Please sign in to comment.