Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for conjugate gradient solver through callable matrix-vector products #73

Merged
merged 13 commits into from
Apr 18, 2024
1 change: 1 addition & 0 deletions posteriors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from posteriors.utils import hvp
from posteriors.utils import fvp
from posteriors.utils import empirical_fisher
from posteriors.utils import cg
from posteriors.utils import diag_normal_log_prob
from posteriors.utils import diag_normal_sample
from posteriors.utils import tree_size
Expand Down
134 changes: 133 additions & 1 deletion posteriors/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Callable, Any, Tuple, Sequence
import operator
from functools import partial
import contextlib
import torch
from torch.func import grad, jvp, vjp, functional_call, jacrev
from torch.distributions import Normal
from optree import tree_map, tree_map_, tree_reduce, tree_flatten
from optree import tree_map, tree_map_, tree_reduce, tree_flatten, tree_leaves
from optree.integration.torch import tree_ravel


from posteriors.types import TensorTree, ForwardFn, Tensor


Expand Down Expand Up @@ -228,6 +230,136 @@ def fisher(*args, **kwargs):
return fisher


def _vdot_real_part(x: Tensor, y: Tensor) -> float:
"""Vector dot-product guaranteed to have a real valued result despite
possibly complex input. Thus neglects the real-imaginary cross-terms.

Args:
x: First tensor in the dot product.
y: Second tensor in the dot product.

Returns:
The result vector dot-product, a real float
"""
# all our uses of vdot() in CG are for computing an operator of the form
# z^H M z
# where M is positive definite and Hermitian, so the result is
# real valued:
# https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Definitions_for_complex_matrices
real_part = torch.vdot(x.real.flatten(), y.real.flatten())
if torch.is_complex(x) or torch.is_complex(y):
imag_part = torch.vdot(x.imag.flatten(), y.imag.flatten())
return real_part + imag_part
return real_part


def _vdot_real_tree(x, y) -> TensorTree:
return sum(tree_leaves(tree_map(_vdot_real_part, x, y)))


def _mul(scalar, tree) -> TensorTree:
return tree_map(partial(operator.mul, scalar), tree)


_add = partial(tree_map, operator.add)
_sub = partial(tree_map, operator.sub)


def _identity(x):
return x


def cg(
A: Callable,
b: TensorTree,
x0: TensorTree = None,
*,
maxiter: int = None,
damping: float = 0.0,
tol: float = 1e-5,
atol: float = 0.0,
M: Callable = _identity,
) -> Tuple[TensorTree, Any]:
"""Use Conjugate Gradient iteration to solve ``Ax = b``.
``A`` is supplied as a function instead of a matrix.

Adapted from [`jax.scipy.sparse.linalg.cg`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.sparse.linalg.cg.html).

Args:
A: Callable that calculates the linear map (matrix-vector
product) ``Ax`` when called like ``A(x)``. ``A`` must represent
a hermitian, positive definite matrix, and must return array(s) with the
same structure and shape as its argument.
b: Right hand side of the linear system representing a single vector.
x0: Starting guess for the solution. Must have the same structure as ``b``.
maxiter: Maximum number of iterations. Iteration will stop after maxiter
steps even if the specified tolerance has not been achieved.
damping: damping term for the mvp function. Acts as regularization.
tol: Tolerance for convergence.
atol: Tolerance for convergence. ``norm(residual) <= max(tol*norm(b), atol)``.
The behaviour will differ from SciPy unless you explicitly pass
``atol`` to SciPy's ``cg``.
M: Preconditioner for A.
See [the preconditioned CG method.](https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method)

Returns:
x : The converged solution. Has the same structure as ``b``.
info : Placeholder for convergence information.
"""
if x0 is None:
x0 = tree_map(torch.zeros_like, b)

if maxiter is None:
maxiter = 10 * tree_size(b) # copied from scipy

tol *= torch.tensor([1.0])
atol *= torch.tensor([1.0])

# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
bs = _vdot_real_tree(b, b)
atol2 = torch.maximum(torch.square(tol) * bs, torch.square(atol))

def A_damped(p):
return _add(A(p), _mul(damping, p))

def cond_fun(value):
_, r, gamma, _, k = value
rs = gamma.real if M is _identity else _vdot_real_tree(r, r)
return (rs > atol2) & (k < maxiter)

def body_fun(value):
x, r, gamma, p, k = value
Ap = A_damped(p)
alpha = gamma / _vdot_real_tree(p, Ap)
x_ = _add(x, _mul(alpha, p))
r_ = _sub(r, _mul(alpha, Ap))
z_ = M(r_)
gamma_ = _vdot_real_tree(r_, z_)
beta_ = gamma_ / gamma
p_ = _add(z_, _mul(beta_, p))
return x_, r_, gamma_, p_, k + 1

r0 = _sub(b, A_damped(x0))
p0 = z0 = r0
gamma0 = _vdot_real_tree(r0, z0)
initial_value = (x0, r0, gamma0, p0, 0)

value = initial_value

while cond_fun(value):
value = body_fun(value)

x_final, r, gamma, _, k = value
# compute the final error and whether it has converged.
rs = gamma if M is _identity else _vdot_real_tree(r, r)
converged = rs <= atol2

# additional info output structure
info = {"error": rs, "converged": converged, "niter": k}

return x_final, info


def diag_normal_log_prob(
x: TensorTree,
mean: float | TensorTree = 0.0,
Expand Down
55 changes: 55 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
hvp,
fvp,
empirical_fisher,
cg,
diag_normal_log_prob,
diag_normal_sample,
tree_size,
Expand Down Expand Up @@ -260,6 +261,60 @@ def f_aux(params, batch):
assert torch.allclose(fvp_result, fisher_fvp, rtol=1e-5)


def test_cg():
# simple function with tensor parameters
def func(x):
return torch.stack([(x**5).sum(), (x**3).sum()])

def partial_fvp(v):
return fvp(func, (x,), (v,), normalize=False)[1]

x = torch.arange(1.0, 6.0)
v = torch.ones_like(x)

jac = torch.func.jacrev(func)(x)
fisher = jac.T @ jac
damping = 100

sol = torch.linalg.solve(fisher + damping * torch.eye(fisher.shape[0]), v)
sol_cg, _ = cg(partial_fvp, v, x0=None, damping=damping, maxiter=10000, tol=1e-10)
assert torch.allclose(sol, sol_cg, rtol=1e-3)

# simple complex number example
A = torch.tensor([[0, -1j], [1j, 0]])

def mvp(x):
return A @ x

b = torch.randn(2, dtype=torch.cfloat)

sol = torch.linalg.solve(A, b)
sol_cg, _ = cg(mvp, b, x0=None, tol=1e-10)

assert torch.allclose(sol, sol_cg, rtol=1e-1)

# function with parameters as a TensorTree
model = TestModel()

func_model = model_to_function(model)
f_per_sample = torch.vmap(func_model, in_dims=(None, 0))

xs = torch.randn(100, 10)

def partial_fvp(v):
return fvp(lambda p: func_model(p, xs), (params,), (v,), normalize=False)[1]

params = dict(model.named_parameters())
fisher = empirical_fisher(lambda p: f_per_sample(p, xs), normalize=False)(params)
damping = 0

v, _ = tree_ravel(params)
sol = torch.linalg.solve(fisher + damping * torch.eye(fisher.shape[0]), v)
sol_cg, _ = cg(partial_fvp, params, x0=None, damping=damping, tol=1e-10)

assert torch.allclose(sol, tree_ravel(sol_cg)[0], rtol=1e-3)


def test_diag_normal_log_prob():
# Test tree mean and tree sd
mean = {"a": torch.tensor([1.0, 2.0]), "b": torch.tensor([3.0, 4.0])}
Expand Down