Skip to content

Commit

Permalink
Merge pull request #73 from normal-computing/add-cg-fvp
Browse files Browse the repository at this point in the history
Add support for conjugate gradient solver through callable matrix-vector products
  • Loading branch information
SamDuffield authored Apr 18, 2024
2 parents cef163d + d40775a commit 8e66a50
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 1 deletion.
1 change: 1 addition & 0 deletions posteriors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from posteriors.utils import hvp
from posteriors.utils import fvp
from posteriors.utils import empirical_fisher
from posteriors.utils import cg
from posteriors.utils import diag_normal_log_prob
from posteriors.utils import diag_normal_sample
from posteriors.utils import tree_size
Expand Down
134 changes: 133 additions & 1 deletion posteriors/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Callable, Any, Tuple, Sequence
import operator
from functools import partial
import contextlib
import torch
from torch.func import grad, jvp, vjp, functional_call, jacrev
from torch.distributions import Normal
from optree import tree_map, tree_map_, tree_reduce, tree_flatten
from optree import tree_map, tree_map_, tree_reduce, tree_flatten, tree_leaves
from optree.integration.torch import tree_ravel


from posteriors.types import TensorTree, ForwardFn, Tensor


Expand Down Expand Up @@ -228,6 +230,136 @@ def fisher(*args, **kwargs):
return fisher


def _vdot_real_part(x: Tensor, y: Tensor) -> float:
"""Vector dot-product guaranteed to have a real valued result despite
possibly complex input. Thus neglects the real-imaginary cross-terms.
Args:
x: First tensor in the dot product.
y: Second tensor in the dot product.
Returns:
The result vector dot-product, a real float
"""
# all our uses of vdot() in CG are for computing an operator of the form
# z^H M z
# where M is positive definite and Hermitian, so the result is
# real valued:
# https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Definitions_for_complex_matrices
real_part = torch.vdot(x.real.flatten(), y.real.flatten())
if torch.is_complex(x) or torch.is_complex(y):
imag_part = torch.vdot(x.imag.flatten(), y.imag.flatten())
return real_part + imag_part
return real_part


def _vdot_real_tree(x, y) -> TensorTree:
return sum(tree_leaves(tree_map(_vdot_real_part, x, y)))


def _mul(scalar, tree) -> TensorTree:
return tree_map(partial(operator.mul, scalar), tree)


_add = partial(tree_map, operator.add)
_sub = partial(tree_map, operator.sub)


def _identity(x):
return x


def cg(
A: Callable,
b: TensorTree,
x0: TensorTree = None,
*,
maxiter: int = None,
damping: float = 0.0,
tol: float = 1e-5,
atol: float = 0.0,
M: Callable = _identity,
) -> Tuple[TensorTree, Any]:
"""Use Conjugate Gradient iteration to solve ``Ax = b``.
``A`` is supplied as a function instead of a matrix.
Adapted from [`jax.scipy.sparse.linalg.cg`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.sparse.linalg.cg.html).
Args:
A: Callable that calculates the linear map (matrix-vector
product) ``Ax`` when called like ``A(x)``. ``A`` must represent
a hermitian, positive definite matrix, and must return array(s) with the
same structure and shape as its argument.
b: Right hand side of the linear system representing a single vector.
x0: Starting guess for the solution. Must have the same structure as ``b``.
maxiter: Maximum number of iterations. Iteration will stop after maxiter
steps even if the specified tolerance has not been achieved.
damping: damping term for the mvp function. Acts as regularization.
tol: Tolerance for convergence.
atol: Tolerance for convergence. ``norm(residual) <= max(tol*norm(b), atol)``.
The behaviour will differ from SciPy unless you explicitly pass
``atol`` to SciPy's ``cg``.
M: Preconditioner for A.
See [the preconditioned CG method.](https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method)
Returns:
x : The converged solution. Has the same structure as ``b``.
info : Placeholder for convergence information.
"""
if x0 is None:
x0 = tree_map(torch.zeros_like, b)

if maxiter is None:
maxiter = 10 * tree_size(b) # copied from scipy

tol *= torch.tensor([1.0])
atol *= torch.tensor([1.0])

# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
bs = _vdot_real_tree(b, b)
atol2 = torch.maximum(torch.square(tol) * bs, torch.square(atol))

def A_damped(p):
return _add(A(p), _mul(damping, p))

def cond_fun(value):
_, r, gamma, _, k = value
rs = gamma.real if M is _identity else _vdot_real_tree(r, r)
return (rs > atol2) & (k < maxiter)

def body_fun(value):
x, r, gamma, p, k = value
Ap = A_damped(p)
alpha = gamma / _vdot_real_tree(p, Ap)
x_ = _add(x, _mul(alpha, p))
r_ = _sub(r, _mul(alpha, Ap))
z_ = M(r_)
gamma_ = _vdot_real_tree(r_, z_)
beta_ = gamma_ / gamma
p_ = _add(z_, _mul(beta_, p))
return x_, r_, gamma_, p_, k + 1

r0 = _sub(b, A_damped(x0))
p0 = z0 = r0
gamma0 = _vdot_real_tree(r0, z0)
initial_value = (x0, r0, gamma0, p0, 0)

value = initial_value

while cond_fun(value):
value = body_fun(value)

x_final, r, gamma, _, k = value
# compute the final error and whether it has converged.
rs = gamma if M is _identity else _vdot_real_tree(r, r)
converged = rs <= atol2

# additional info output structure
info = {"error": rs, "converged": converged, "niter": k}

return x_final, info


def diag_normal_log_prob(
x: TensorTree,
mean: float | TensorTree = 0.0,
Expand Down
55 changes: 55 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
hvp,
fvp,
empirical_fisher,
cg,
diag_normal_log_prob,
diag_normal_sample,
tree_size,
Expand Down Expand Up @@ -260,6 +261,60 @@ def f_aux(params, batch):
assert torch.allclose(fvp_result, fisher_fvp, rtol=1e-5)


def test_cg():
# simple function with tensor parameters
def func(x):
return torch.stack([(x**5).sum(), (x**3).sum()])

def partial_fvp(v):
return fvp(func, (x,), (v,), normalize=False)[1]

x = torch.arange(1.0, 6.0)
v = torch.ones_like(x)

jac = torch.func.jacrev(func)(x)
fisher = jac.T @ jac
damping = 100

sol = torch.linalg.solve(fisher + damping * torch.eye(fisher.shape[0]), v)
sol_cg, _ = cg(partial_fvp, v, x0=None, damping=damping, maxiter=10000, tol=1e-10)
assert torch.allclose(sol, sol_cg, rtol=1e-3)

# simple complex number example
A = torch.tensor([[0, -1j], [1j, 0]])

def mvp(x):
return A @ x

b = torch.randn(2, dtype=torch.cfloat)

sol = torch.linalg.solve(A, b)
sol_cg, _ = cg(mvp, b, x0=None, tol=1e-10)

assert torch.allclose(sol, sol_cg, rtol=1e-1)

# function with parameters as a TensorTree
model = TestModel()

func_model = model_to_function(model)
f_per_sample = torch.vmap(func_model, in_dims=(None, 0))

xs = torch.randn(100, 10)

def partial_fvp(v):
return fvp(lambda p: func_model(p, xs), (params,), (v,), normalize=False)[1]

params = dict(model.named_parameters())
fisher = empirical_fisher(lambda p: f_per_sample(p, xs), normalize=False)(params)
damping = 0

v, _ = tree_ravel(params)
sol = torch.linalg.solve(fisher + damping * torch.eye(fisher.shape[0]), v)
sol_cg, _ = cg(partial_fvp, params, x0=None, damping=damping, tol=1e-10)

assert torch.allclose(sol, tree_ravel(sol_cg)[0], rtol=1e-3)


def test_diag_normal_log_prob():
# Test tree mean and tree sd
mean = {"a": torch.tensor([1.0, 2.0]), "b": torch.tensor([3.0, 4.0])}
Expand Down

0 comments on commit 8e66a50

Please sign in to comment.