From d5b27a95ad7229f65121619b9c6ab434966510f6 Mon Sep 17 00:00:00 2001 From: jcqcai Date: Mon, 27 May 2024 03:52:31 -0700 Subject: [PATCH 1/7] Added dense fisher EKF! --- docs/api/ekf/dense_fisher.md | 7 ++ mkdocs.yml | 1 + posteriors/ekf/__init__.py | 1 + posteriors/ekf/dense_fisher.py | 198 +++++++++++++++++++++++++++++++++ tests/ekf/test_dense_fisher.py | 73 ++++++++++++ 5 files changed, 280 insertions(+) create mode 100644 docs/api/ekf/dense_fisher.md create mode 100644 posteriors/ekf/dense_fisher.py create mode 100644 tests/ekf/test_dense_fisher.py diff --git a/docs/api/ekf/dense_fisher.md b/docs/api/ekf/dense_fisher.md new file mode 100644 index 00000000..4f6d63d2 --- /dev/null +++ b/docs/api/ekf/dense_fisher.md @@ -0,0 +1,7 @@ +--- +title: EKF Dense Fisher +--- + +# EKF Dense Fisher + +::: posteriors.ekf.dense_fisher \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index e28518d5..d2d4e575 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -69,6 +69,7 @@ nav: - api/index.md - EKF: - Diagonal Fisher: api/ekf/diag_fisher.md + - Dense Fisher: api/ekf/dense_fisher.md - Laplace: - Dense Fisher: api/laplace/dense_fisher.md - Dense GGN: api/laplace/dense_ggn.md diff --git a/posteriors/ekf/__init__.py b/posteriors/ekf/__init__.py index 248a49df..89cea23a 100644 --- a/posteriors/ekf/__init__.py +++ b/posteriors/ekf/__init__.py @@ -1 +1,2 @@ from posteriors.ekf import diag_fisher +from posteriors.ekf import dense_fisher diff --git a/posteriors/ekf/dense_fisher.py b/posteriors/ekf/dense_fisher.py new file mode 100644 index 00000000..4360fe26 --- /dev/null +++ b/posteriors/ekf/dense_fisher.py @@ -0,0 +1,198 @@ +from typing import Any +from functools import partial +import torch +from torch.func import jacrev +from optree import tree_map +from dataclasses import dataclass +from optree.integration.torch import tree_ravel + +from posteriors.tree_utils import tree_size + +from posteriors.types import TensorTree, Transform, LogProbFn, TransformState +from posteriors.utils import ( + per_samplify, + empirical_fisher, + is_scalar, + CatchAuxError, +) + + +def build( + log_likelihood: LogProbFn, + lr: float, + transition_cov: torch.Tensor | float = 0.0, + per_sample: bool = False, + init_cov: torch.Tensor | float = 1.0, +) -> Transform: + """Builds a transform to implement an extended Kalman Filter update. + + EKF applies an online update to a Gaussian posterior over the parameters. + + The approximate Bayesian update is based on the linearization + $$ + \\log p(θ | y) ≈ \\log p(θ) + ε g(μ)ᵀ(θ - μ) + \\frac12 ε (θ - μ)^T F(μ) (θ - μ) + $$ + where $μ$ is the mean of the prior distribution, $ε$ is the learning rate + (or equivalently the likelihood inverse temperature), + $g(μ)$ is the gradient of the log likelihood at μ and $F(μ)$ is the + empirical Fisher information matrix at $μ$ for data $y$. + + For more information on extended Kalman filtering as well as an equivalence + to (online) natural gradient descent see [Ollivier, 2019](https://arxiv.org/abs/1703.00209). + + Args: + log_likelihood: Function that takes parameters and input batch and + returns the log-likelihood value as well as auxiliary information, + e.g. from the model call. + lr: Inverse temperature of the update, which behaves like a learning rate. + transition_cov: Covariance of the transition noise, to additively + inflate the covariance before the update. + per_sample: If True, then log_likelihood is assumed to return a vector of + log likelihoods for each sample in the batch. If False, then log_likelihood + is assumed to return a scalar log likelihood for the whole batch, in this + case torch.func.vmap will be called, this is typically slower than + directly writing log_likelihood to be per sample. + init_cov: Initial covariance of the Normal distribution. Can be torch.Tensor or scalar. + + Returns: + EKF transform instance. + """ + init_fn = partial(init, init_cov=init_cov) + update_fn = partial( + update, + log_likelihood=log_likelihood, + lr=lr, + transition_cov=transition_cov, + per_sample=per_sample, + ) + return Transform(init_fn, update_fn) + + +@dataclass +class EKFDenseState(TransformState): + """State encoding a Normal distribution over parameters. + + Args: + params: Mean of the Normal distribution. + cov: Covariance matrix of the + Normal distribution. + log_likelihood: Log likelihood of the data given the parameters. + aux: Auxiliary information from the log_likelihood call. + """ + + params: TensorTree + cov: torch.Tensor + log_likelihood: float = 0 + aux: Any = None + + +def init( + params: TensorTree, + init_cov: TensorTree | float = 1.0, +) -> EKFDenseState: + """Initialise Multivariate Normal distribution over parameters. + + Args: + params: Initial mean of the Normal distribution. + init_cov: Initial covariance matrix + of the Multivariate Normal distribution. Can be tree like params or scalar. + + Returns: + Initial EKFDenseState. + """ + if is_scalar(init_cov): + num_params = tree_size(params) + init_cov = init_cov * torch.eye(num_params, requires_grad=False) + + return EKFDenseState(params, init_cov) + + +def update( + state: EKFDenseState, + batch: Any, + log_likelihood: LogProbFn, + lr: float, + transition_cov: torch.Tensor | float = 0.0, + per_sample: bool = False, + inplace: bool = False, +) -> EKFDenseState: + """Applies an extended Kalman Filter update to the Multivariate Normal distribution. + The approximate Bayesian update is based on the linearization + $$ + \\log p(θ | y) ≈ \\log p(θ) + ε g(μ)ᵀ(θ - μ) + \\frac12 ε (θ - μ)^T F(μ) (θ - μ) + $$ + where $μ$ is the mean of the prior distribution, $ε$ is the learning rate + (or equivalently the likelihood inverse temperature), + $g(μ)$ is the gradient of the log likelihood at μ and $F(μ)$ is the + empirical Fisher information matrix at $μ$ for data $y$. + + Args: + state: Current state. + batch: Input data to log_likelihood. + log_likelihood: Function that takes parameters and input batch and + returns the log-likelihood value as well as auxiliary information, + e.g. from the model call. + lr: Inverse temperature of the update, which behaves like a learning rate. + transition_cov: Covariance of the transition noise, to additively + inflate the covariance before the update. + per_sample: If True, then log_likelihood is assumed to return a vector of + log likelihoods for each sample in the batch. If False, then log_likelihood + is assumed to return a scalar log likelihood for the whole batch, in this + case torch.func.vmap will be called, this is typically slower than + directly writing log_likelihood to be per sample. + inplace: Whether to update the state parameters in-place. + + Returns: + Updated EKFDenseState. + """ + if not per_sample: + log_likelihood = per_samplify(log_likelihood) + + with torch.no_grad(), CatchAuxError(): + log_liks, aux = log_likelihood(state.params, batch) + jac, _ = jacrev(log_likelihood, has_aux=True)(state.params, batch) + grad = tree_map(lambda x: x.mean(0), jac) + fisher, _ = empirical_fisher( + lambda p: log_likelihood(p, batch), has_aux=True, normalize=True + )(state.params) + + predict_cov = state.cov + transition_cov + predict_cov_inv = torch.cholesky_inverse(torch.linalg.cholesky(predict_cov)) + update_cov_inv = predict_cov_inv - lr * fisher + update_cov = torch.cholesky_inverse(torch.linalg.cholesky(update_cov_inv)) + + mu_raveled, mu_unravel_f = tree_ravel(state.params) + update_mean = mu_raveled + lr * update_cov @ tree_ravel(grad)[0] + update_mean = mu_unravel_f(update_mean) + + if inplace: + state.params = update_mean + state.cov = update_cov + state.log_likelihood = log_liks.mean().detach() + state.aux = aux + return state + return EKFDenseState(update_mean, update_cov, log_liks.mean().detach(), aux) + + +def sample( + state: EKFDenseState, sample_shape: torch.Size = torch.Size([]) +) -> TensorTree: + """Single sample from Multivariate Normal distribution over parameters. + + Args: + state: State encoding mean and covariance. + sample_shape: Shape of the desired samples. + + Returns: + Sample(s) from Multivariate Normal distribution. + """ + mean_flat, unravel_func = tree_ravel(state.params) + + samples = torch.distributions.MultivariateNormal( + loc=mean_flat, + covariance_matrix=state.cov, + validate_args=False, + ).sample(sample_shape) + + samples = torch.vmap(unravel_func)(samples) + return samples diff --git a/tests/ekf/test_dense_fisher.py b/tests/ekf/test_dense_fisher.py new file mode 100644 index 00000000..e959625d --- /dev/null +++ b/tests/ekf/test_dense_fisher.py @@ -0,0 +1,73 @@ +import torch +from optree import tree_map +from torch.distributions import MultivariateNormal +from optree.integration.torch import tree_ravel +from posteriors.tree_utils import tree_size +from posteriors import ekf + + +def test_ekf_dense(): + torch.manual_seed(42) + target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)} + num_params = tree_size(target_mean) + A = torch.randn(num_params, num_params) + target_cov = torch.mm(A.t(), A) + + dist = MultivariateNormal(tree_ravel(target_mean)[0], covariance_matrix=target_cov) + + def log_prob(p, b): + return dist.log_prob(tree_ravel(p)[0]).sum(), torch.Tensor([]) + + init_mean = tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean) + batch = torch.arange(3).reshape(-1, 1) + n_steps = 1000 + transform = ekf.dense_fisher.build(log_prob, lr=1e-1) + + # Test inplace = False + state = transform.init(init_mean) + log_liks = [] + for _ in range(n_steps): + state = transform.update(state, batch, inplace=False) + log_liks.append(state.log_likelihood.item()) + + assert log_liks[0] < log_liks[-1] + + for key in state.params: + assert torch.allclose(state.params[key], target_mean[key], atol=1e-1) + assert not torch.allclose(state.params[key], init_mean[key]) + + # Test inplace = True + state = transform.init(init_mean) + log_liks = [] + for _ in range(n_steps): + state = transform.update(state, batch, inplace=True) + log_liks.append(state.log_likelihood.item()) + + for key in state.params: + assert torch.allclose(state.params[key], target_mean[key], atol=1e-1) + assert not torch.allclose(state.params[key], init_mean[key]) + + # Test sample + num_samples = 1000 + samples = ekf.dense_fisher.sample(state, (num_samples,)) + + # Reshaping the samples after raveling them does not return the correct order. This is because, + # on raveling, all the samples under "a" are added to the list left to right, only then followed by + # the samples under parameter "b". Reshaping therefore has whole rows of "a" samples before "b" samples begin. + # Really, we want each 2D "a" sample followed by a 1D "b" sample, each 3-column row containing a sample of each param in this way. + # Therefore, we have to carefully column stack the result of raveling each parameter and reshaping them. + # This is all to aid in computing the sample covariance. + k = list(samples.keys())[0] # get first sample to start column stacking + samples_copy = tree_ravel(samples[k])[0].reshape(num_samples, samples[k].shape[1]) + for k in list(samples.keys())[1:]: + v = tree_ravel(samples[k])[0].reshape(num_samples, samples[k].shape[1]) + samples_copy = torch.column_stack((samples_copy, v)) + samples_cov = torch.cov(samples_copy.T) + + mean_copy = tree_map(lambda x: x.clone(), state.params) + samples_mean = tree_map(lambda x: x.mean(dim=0), samples) + + assert torch.allclose(samples_cov, state.cov, atol=1e-1) + for key in samples_mean: + assert torch.allclose(samples_mean[key], state.params[key], atol=1e-1) + assert not torch.allclose(samples_mean[key], mean_copy[key]) From 8064349a20d90ab0149e3a48bb5071912dcea48c Mon Sep 17 00:00:00 2001 From: jcqcai Date: Mon, 27 May 2024 23:16:34 -0700 Subject: [PATCH 2/7] Changed init covariance matrix to be tensor, not tensortree. --- posteriors/ekf/dense_fisher.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/posteriors/ekf/dense_fisher.py b/posteriors/ekf/dense_fisher.py index 4360fe26..acf68338 100644 --- a/posteriors/ekf/dense_fisher.py +++ b/posteriors/ekf/dense_fisher.py @@ -88,14 +88,14 @@ class EKFDenseState(TransformState): def init( params: TensorTree, - init_cov: TensorTree | float = 1.0, + init_cov: torch.Tensor | float = 1.0, ) -> EKFDenseState: """Initialise Multivariate Normal distribution over parameters. Args: params: Initial mean of the Normal distribution. - init_cov: Initial covariance matrix - of the Multivariate Normal distribution. Can be tree like params or scalar. + init_cov: Initial covariance matrix of the Multivariate Normal distribution. + If it is a float, it is defined as an identity matrix scaled by that float. Returns: Initial EKFDenseState. From e6433ca5a395836ffe89f88cf4144d9bba338159 Mon Sep 17 00:00:00 2001 From: jcqcai Date: Mon, 27 May 2024 23:20:48 -0700 Subject: [PATCH 3/7] Small documentation fix --- posteriors/ekf/dense_fisher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/posteriors/ekf/dense_fisher.py b/posteriors/ekf/dense_fisher.py index acf68338..dfb5202f 100644 --- a/posteriors/ekf/dense_fisher.py +++ b/posteriors/ekf/dense_fisher.py @@ -95,7 +95,7 @@ def init( Args: params: Initial mean of the Normal distribution. init_cov: Initial covariance matrix of the Multivariate Normal distribution. - If it is a float, it is defined as an identity matrix scaled by that float. + If it is a float, it is defined as an identity matrix scaled by that float. Returns: Initial EKFDenseState. From 05af1bdfb4e385ef9a31a880d85a69f204ce4c08 Mon Sep 17 00:00:00 2001 From: John Crossman Date: Fri, 31 May 2024 23:50:55 -0700 Subject: [PATCH 4/7] Update posteriors/ekf/dense_fisher.py Co-authored-by: SamDuffield <34280297+SamDuffield@users.noreply.github.com> --- posteriors/ekf/dense_fisher.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/posteriors/ekf/dense_fisher.py b/posteriors/ekf/dense_fisher.py index dfb5202f..aa62a0ef 100644 --- a/posteriors/ekf/dense_fisher.py +++ b/posteriors/ekf/dense_fisher.py @@ -149,9 +149,13 @@ def update( log_likelihood = per_samplify(log_likelihood) with torch.no_grad(), CatchAuxError(): - log_liks, aux = log_likelihood(state.params, batch) - jac, _ = jacrev(log_likelihood, has_aux=True)(state.params, batch) - grad = tree_map(lambda x: x.mean(0), jac) + def log_likelihood_reduced(params, batch): + per_samp_log_lik, internal_aux = log_likelihood(params, batch) + return per_samp_log_lik.mean(), internal_aux + + grad, (log_liks, aux) = grad_and_value(log_likelihood_reduced, has_aux=True)( + state.params, batch + ) fisher, _ = empirical_fisher( lambda p: log_likelihood(p, batch), has_aux=True, normalize=True )(state.params) From 75c3f13845e09fc90c40a7d0bc1616f35af55c77 Mon Sep 17 00:00:00 2001 From: John Crossman Date: Fri, 31 May 2024 23:51:09 -0700 Subject: [PATCH 5/7] Update tests/ekf/test_dense_fisher.py Co-authored-by: SamDuffield <34280297+SamDuffield@users.noreply.github.com> --- tests/ekf/test_dense_fisher.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/ekf/test_dense_fisher.py b/tests/ekf/test_dense_fisher.py index e959625d..a55e9186 100644 --- a/tests/ekf/test_dense_fisher.py +++ b/tests/ekf/test_dense_fisher.py @@ -51,18 +51,8 @@ def log_prob(p, b): num_samples = 1000 samples = ekf.dense_fisher.sample(state, (num_samples,)) - # Reshaping the samples after raveling them does not return the correct order. This is because, - # on raveling, all the samples under "a" are added to the list left to right, only then followed by - # the samples under parameter "b". Reshaping therefore has whole rows of "a" samples before "b" samples begin. - # Really, we want each 2D "a" sample followed by a 1D "b" sample, each 3-column row containing a sample of each param in this way. - # Therefore, we have to carefully column stack the result of raveling each parameter and reshaping them. - # This is all to aid in computing the sample covariance. - k = list(samples.keys())[0] # get first sample to start column stacking - samples_copy = tree_ravel(samples[k])[0].reshape(num_samples, samples[k].shape[1]) - for k in list(samples.keys())[1:]: - v = tree_ravel(samples[k])[0].reshape(num_samples, samples[k].shape[1]) - samples_copy = torch.column_stack((samples_copy, v)) - samples_cov = torch.cov(samples_copy.T) + flat_samples = torch.vmap(lambda s: tree_ravel(s)[0])(samples) + samples_cov = torch.cov(flat_samples.T) mean_copy = tree_map(lambda x: x.clone(), state.params) samples_mean = tree_map(lambda x: x.mean(dim=0), samples) From 7ade15d636c2742fe765ea6b0ee28beced7360bd Mon Sep 17 00:00:00 2001 From: jcqcai Date: Sat, 1 Jun 2024 00:02:35 -0700 Subject: [PATCH 6/7] Update documentation (add to index and fix alphabetical order) and fix imports --- docs/api/index.md | 4 ++++ mkdocs.yml | 2 +- posteriors/ekf/dense_fisher.py | 4 ++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/api/index.md b/docs/api/index.md index 94b8ece2..f1549b85 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -1,6 +1,10 @@ # API ### Extended Kalman filter (EKF) +- [`ekf.dense_fisher`](ekf/dense_fisher.md) applies an online Bayesian update based +on a Taylor approximation of the log-likelihood. Uses the empirical Fisher +information matrix as a positive-definite alternative to the Hessian. +Natural gradient descent equivalence following [Ollivier, 2019](https://arxiv.org/abs/1703.00209). - [`ekf.diag_fisher`](ekf/diag_fisher.md) applies an online Bayesian update based on a Taylor approximation of the log-likelihood. Uses the diagonal empirical Fisher information matrix as a positive-definite alternative to the Hessian. diff --git a/mkdocs.yml b/mkdocs.yml index d2d4e575..7a5ea641 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -68,8 +68,8 @@ nav: - API: - api/index.md - EKF: - - Diagonal Fisher: api/ekf/diag_fisher.md - Dense Fisher: api/ekf/dense_fisher.md + - Diagonal Fisher: api/ekf/diag_fisher.md - Laplace: - Dense Fisher: api/laplace/dense_fisher.md - Dense GGN: api/laplace/dense_ggn.md diff --git a/posteriors/ekf/dense_fisher.py b/posteriors/ekf/dense_fisher.py index aa62a0ef..a99120b2 100644 --- a/posteriors/ekf/dense_fisher.py +++ b/posteriors/ekf/dense_fisher.py @@ -1,8 +1,7 @@ from typing import Any from functools import partial import torch -from torch.func import jacrev -from optree import tree_map +from torch.func import grad_and_value from dataclasses import dataclass from optree.integration.torch import tree_ravel @@ -149,6 +148,7 @@ def update( log_likelihood = per_samplify(log_likelihood) with torch.no_grad(), CatchAuxError(): + def log_likelihood_reduced(params, batch): per_samp_log_lik, internal_aux = log_likelihood(params, batch) return per_samp_log_lik.mean(), internal_aux From 5e6ca0ba45289c597691e40b28951a29b9ff2e28 Mon Sep 17 00:00:00 2001 From: jcqcai Date: Mon, 3 Jun 2024 01:09:58 -0700 Subject: [PATCH 7/7] Condense index documentation --- docs/api/index.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/api/index.md b/docs/api/index.md index f1549b85..77bcf95a 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -5,10 +5,8 @@ on a Taylor approximation of the log-likelihood. Uses the empirical Fisher information matrix as a positive-definite alternative to the Hessian. Natural gradient descent equivalence following [Ollivier, 2019](https://arxiv.org/abs/1703.00209). -- [`ekf.diag_fisher`](ekf/diag_fisher.md) applies an online Bayesian update based -on a Taylor approximation of the log-likelihood. Uses the diagonal empirical Fisher -information matrix as a positive-definite alternative to the Hessian. -Natural gradient descent equivalence following [Ollivier, 2019](https://arxiv.org/abs/1703.00209). +- [`ekf.diag_fisher`](ekf/diag_fisher.md) same as `ekf.dense_fisher` but +uses the diagonal of the empirical Fisher information matrix instead. ### Laplace approximation - [`laplace.dense_fisher`](laplace/dense_fisher.md) calculates the empirical Fisher