-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #96 from jcqcai/main
Dense Fisher EKF
- Loading branch information
Showing
6 changed files
with
278 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
--- | ||
title: EKF Dense Fisher | ||
--- | ||
|
||
# EKF Dense Fisher | ||
|
||
::: posteriors.ekf.dense_fisher |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from posteriors.ekf import diag_fisher | ||
from posteriors.ekf import dense_fisher |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
from typing import Any | ||
from functools import partial | ||
import torch | ||
from torch.func import grad_and_value | ||
from dataclasses import dataclass | ||
from optree.integration.torch import tree_ravel | ||
|
||
from posteriors.tree_utils import tree_size | ||
|
||
from posteriors.types import TensorTree, Transform, LogProbFn, TransformState | ||
from posteriors.utils import ( | ||
per_samplify, | ||
empirical_fisher, | ||
is_scalar, | ||
CatchAuxError, | ||
) | ||
|
||
|
||
def build( | ||
log_likelihood: LogProbFn, | ||
lr: float, | ||
transition_cov: torch.Tensor | float = 0.0, | ||
per_sample: bool = False, | ||
init_cov: torch.Tensor | float = 1.0, | ||
) -> Transform: | ||
"""Builds a transform to implement an extended Kalman Filter update. | ||
EKF applies an online update to a Gaussian posterior over the parameters. | ||
The approximate Bayesian update is based on the linearization | ||
$$ | ||
\\log p(θ | y) ≈ \\log p(θ) + ε g(μ)ᵀ(θ - μ) + \\frac12 ε (θ - μ)^T F(μ) (θ - μ) | ||
$$ | ||
where $μ$ is the mean of the prior distribution, $ε$ is the learning rate | ||
(or equivalently the likelihood inverse temperature), | ||
$g(μ)$ is the gradient of the log likelihood at μ and $F(μ)$ is the | ||
empirical Fisher information matrix at $μ$ for data $y$. | ||
For more information on extended Kalman filtering as well as an equivalence | ||
to (online) natural gradient descent see [Ollivier, 2019](https://arxiv.org/abs/1703.00209). | ||
Args: | ||
log_likelihood: Function that takes parameters and input batch and | ||
returns the log-likelihood value as well as auxiliary information, | ||
e.g. from the model call. | ||
lr: Inverse temperature of the update, which behaves like a learning rate. | ||
transition_cov: Covariance of the transition noise, to additively | ||
inflate the covariance before the update. | ||
per_sample: If True, then log_likelihood is assumed to return a vector of | ||
log likelihoods for each sample in the batch. If False, then log_likelihood | ||
is assumed to return a scalar log likelihood for the whole batch, in this | ||
case torch.func.vmap will be called, this is typically slower than | ||
directly writing log_likelihood to be per sample. | ||
init_cov: Initial covariance of the Normal distribution. Can be torch.Tensor or scalar. | ||
Returns: | ||
EKF transform instance. | ||
""" | ||
init_fn = partial(init, init_cov=init_cov) | ||
update_fn = partial( | ||
update, | ||
log_likelihood=log_likelihood, | ||
lr=lr, | ||
transition_cov=transition_cov, | ||
per_sample=per_sample, | ||
) | ||
return Transform(init_fn, update_fn) | ||
|
||
|
||
@dataclass | ||
class EKFDenseState(TransformState): | ||
"""State encoding a Normal distribution over parameters. | ||
Args: | ||
params: Mean of the Normal distribution. | ||
cov: Covariance matrix of the | ||
Normal distribution. | ||
log_likelihood: Log likelihood of the data given the parameters. | ||
aux: Auxiliary information from the log_likelihood call. | ||
""" | ||
|
||
params: TensorTree | ||
cov: torch.Tensor | ||
log_likelihood: float = 0 | ||
aux: Any = None | ||
|
||
|
||
def init( | ||
params: TensorTree, | ||
init_cov: torch.Tensor | float = 1.0, | ||
) -> EKFDenseState: | ||
"""Initialise Multivariate Normal distribution over parameters. | ||
Args: | ||
params: Initial mean of the Normal distribution. | ||
init_cov: Initial covariance matrix of the Multivariate Normal distribution. | ||
If it is a float, it is defined as an identity matrix scaled by that float. | ||
Returns: | ||
Initial EKFDenseState. | ||
""" | ||
if is_scalar(init_cov): | ||
num_params = tree_size(params) | ||
init_cov = init_cov * torch.eye(num_params, requires_grad=False) | ||
|
||
return EKFDenseState(params, init_cov) | ||
|
||
|
||
def update( | ||
state: EKFDenseState, | ||
batch: Any, | ||
log_likelihood: LogProbFn, | ||
lr: float, | ||
transition_cov: torch.Tensor | float = 0.0, | ||
per_sample: bool = False, | ||
inplace: bool = False, | ||
) -> EKFDenseState: | ||
"""Applies an extended Kalman Filter update to the Multivariate Normal distribution. | ||
The approximate Bayesian update is based on the linearization | ||
$$ | ||
\\log p(θ | y) ≈ \\log p(θ) + ε g(μ)ᵀ(θ - μ) + \\frac12 ε (θ - μ)^T F(μ) (θ - μ) | ||
$$ | ||
where $μ$ is the mean of the prior distribution, $ε$ is the learning rate | ||
(or equivalently the likelihood inverse temperature), | ||
$g(μ)$ is the gradient of the log likelihood at μ and $F(μ)$ is the | ||
empirical Fisher information matrix at $μ$ for data $y$. | ||
Args: | ||
state: Current state. | ||
batch: Input data to log_likelihood. | ||
log_likelihood: Function that takes parameters and input batch and | ||
returns the log-likelihood value as well as auxiliary information, | ||
e.g. from the model call. | ||
lr: Inverse temperature of the update, which behaves like a learning rate. | ||
transition_cov: Covariance of the transition noise, to additively | ||
inflate the covariance before the update. | ||
per_sample: If True, then log_likelihood is assumed to return a vector of | ||
log likelihoods for each sample in the batch. If False, then log_likelihood | ||
is assumed to return a scalar log likelihood for the whole batch, in this | ||
case torch.func.vmap will be called, this is typically slower than | ||
directly writing log_likelihood to be per sample. | ||
inplace: Whether to update the state parameters in-place. | ||
Returns: | ||
Updated EKFDenseState. | ||
""" | ||
if not per_sample: | ||
log_likelihood = per_samplify(log_likelihood) | ||
|
||
with torch.no_grad(), CatchAuxError(): | ||
|
||
def log_likelihood_reduced(params, batch): | ||
per_samp_log_lik, internal_aux = log_likelihood(params, batch) | ||
return per_samp_log_lik.mean(), internal_aux | ||
|
||
grad, (log_liks, aux) = grad_and_value(log_likelihood_reduced, has_aux=True)( | ||
state.params, batch | ||
) | ||
fisher, _ = empirical_fisher( | ||
lambda p: log_likelihood(p, batch), has_aux=True, normalize=True | ||
)(state.params) | ||
|
||
predict_cov = state.cov + transition_cov | ||
predict_cov_inv = torch.cholesky_inverse(torch.linalg.cholesky(predict_cov)) | ||
update_cov_inv = predict_cov_inv - lr * fisher | ||
update_cov = torch.cholesky_inverse(torch.linalg.cholesky(update_cov_inv)) | ||
|
||
mu_raveled, mu_unravel_f = tree_ravel(state.params) | ||
update_mean = mu_raveled + lr * update_cov @ tree_ravel(grad)[0] | ||
update_mean = mu_unravel_f(update_mean) | ||
|
||
if inplace: | ||
state.params = update_mean | ||
state.cov = update_cov | ||
state.log_likelihood = log_liks.mean().detach() | ||
state.aux = aux | ||
return state | ||
return EKFDenseState(update_mean, update_cov, log_liks.mean().detach(), aux) | ||
|
||
|
||
def sample( | ||
state: EKFDenseState, sample_shape: torch.Size = torch.Size([]) | ||
) -> TensorTree: | ||
"""Single sample from Multivariate Normal distribution over parameters. | ||
Args: | ||
state: State encoding mean and covariance. | ||
sample_shape: Shape of the desired samples. | ||
Returns: | ||
Sample(s) from Multivariate Normal distribution. | ||
""" | ||
mean_flat, unravel_func = tree_ravel(state.params) | ||
|
||
samples = torch.distributions.MultivariateNormal( | ||
loc=mean_flat, | ||
covariance_matrix=state.cov, | ||
validate_args=False, | ||
).sample(sample_shape) | ||
|
||
samples = torch.vmap(unravel_func)(samples) | ||
return samples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import torch | ||
from optree import tree_map | ||
from torch.distributions import MultivariateNormal | ||
from optree.integration.torch import tree_ravel | ||
from posteriors.tree_utils import tree_size | ||
from posteriors import ekf | ||
|
||
|
||
def test_ekf_dense(): | ||
torch.manual_seed(42) | ||
target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)} | ||
num_params = tree_size(target_mean) | ||
A = torch.randn(num_params, num_params) | ||
target_cov = torch.mm(A.t(), A) | ||
|
||
dist = MultivariateNormal(tree_ravel(target_mean)[0], covariance_matrix=target_cov) | ||
|
||
def log_prob(p, b): | ||
return dist.log_prob(tree_ravel(p)[0]).sum(), torch.Tensor([]) | ||
|
||
init_mean = tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean) | ||
batch = torch.arange(3).reshape(-1, 1) | ||
n_steps = 1000 | ||
transform = ekf.dense_fisher.build(log_prob, lr=1e-1) | ||
|
||
# Test inplace = False | ||
state = transform.init(init_mean) | ||
log_liks = [] | ||
for _ in range(n_steps): | ||
state = transform.update(state, batch, inplace=False) | ||
log_liks.append(state.log_likelihood.item()) | ||
|
||
assert log_liks[0] < log_liks[-1] | ||
|
||
for key in state.params: | ||
assert torch.allclose(state.params[key], target_mean[key], atol=1e-1) | ||
assert not torch.allclose(state.params[key], init_mean[key]) | ||
|
||
# Test inplace = True | ||
state = transform.init(init_mean) | ||
log_liks = [] | ||
for _ in range(n_steps): | ||
state = transform.update(state, batch, inplace=True) | ||
log_liks.append(state.log_likelihood.item()) | ||
|
||
for key in state.params: | ||
assert torch.allclose(state.params[key], target_mean[key], atol=1e-1) | ||
assert not torch.allclose(state.params[key], init_mean[key]) | ||
|
||
# Test sample | ||
num_samples = 1000 | ||
samples = ekf.dense_fisher.sample(state, (num_samples,)) | ||
|
||
flat_samples = torch.vmap(lambda s: tree_ravel(s)[0])(samples) | ||
samples_cov = torch.cov(flat_samples.T) | ||
|
||
mean_copy = tree_map(lambda x: x.clone(), state.params) | ||
samples_mean = tree_map(lambda x: x.mean(dim=0), samples) | ||
|
||
assert torch.allclose(samples_cov, state.cov, atol=1e-1) | ||
for key in samples_mean: | ||
assert torch.allclose(samples_mean[key], state.params[key], atol=1e-1) | ||
assert not torch.allclose(samples_mean[key], mean_copy[key]) |