Skip to content

Commit

Permalink
Removed positive constraint on 'cholesky' diagonal, renamed to L_fact…
Browse files Browse the repository at this point in the history
…or, updated documentation and variable and function names accordingly.
  • Loading branch information
jcqcai committed Oct 26, 2024
1 parent f9282ac commit ed3c48b
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 93 deletions.
4 changes: 2 additions & 2 deletions posteriors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from posteriors.utils import diag_normal_sample
from posteriors.utils import per_samplify
from posteriors.utils import is_scalar
from posteriors.utils import cholesky_from_log_flat
from posteriors.utils import cholesky_to_log_flat
from posteriors.utils import L_from_flat
from posteriors.utils import L_to_flat

from posteriors.tree_utils import tree_size
from posteriors.tree_utils import tree_extract
Expand Down
49 changes: 16 additions & 33 deletions posteriors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,51 +892,34 @@ def is_scalar(x: Any) -> bool:
return isinstance(x, (int, float)) or (torch.is_tensor(x) and x.numel() == 1)


def cholesky_from_log_flat(
log_chol_flat: torch.Tensor, num_params: int
) -> torch.Tensor:
"""Returns Cholesky matrix from a flat representation of its nonzero elements.
The input is assumed to have taken the log of the diagonal elements of the original Cholesky matrix.
def L_from_flat(L_flat: torch.Tensor, num_params: int) -> torch.Tensor:
"""Returns lower triangular matrix from a flat representation of its nonzero elements.
Args:
log_chol_flat: Flat representation of nonzero Cholesky matrix elements.
The corresponding diagonal elements are the logs.
num_params: Width of the desired Cholesky matrix.
L_flat: Flat representation of nonzero lower triangular matrix elements.
num_params: Width of the desired lower triangular matrix.
Returns:
Lower triangular Cholesky matrix.
Lower triangular matrix.
"""

tril_indices = torch.tril_indices(num_params, num_params)
chol_exp = torch.zeros((num_params, num_params), device=log_chol_flat.device)
chol_exp[tril_indices[0], tril_indices[1]] = log_chol_flat
diag_indices = torch.arange(num_params)
# Exponentiate diagonal elements to ensure positivity.
chol_exp[diag_indices, diag_indices] = torch.exp(
chol_exp[diag_indices, diag_indices]
)
return chol_exp
L = torch.zeros((num_params, num_params), device=L_flat.device)
L[tril_indices[0], tril_indices[1]] = L_flat
return L


def cholesky_to_log_flat(chol: torch.Tensor) -> torch.Tensor:
"""Returns flat representation of the nonzero Cholesky matrix elements.
The logarithm of the diagonal of the input is taken.
def L_to_flat(L: torch.Tensor) -> torch.Tensor:
"""Returns flat representation of the nonzero elements of a lower triangular matrix.
Args:
chol: Lower triangular Cholesky matrix.
L: Lower triangular matrix.
Returns:
Flat representation of the nonzero Cholesky matrix elements
with the logs of the diagonal values.
Flat representation of the nonzero lower triangular matrix elements.
"""
num_params = chol.shape[0]
tril_indices = torch.tril_indices(num_params, num_params)
chol_flat = chol[tril_indices[0], tril_indices[1]].clone()

# The positions of the diagonal elements in L_params
diag_positions = torch.cumsum(torch.arange(1, num_params + 1), dim=0) - 1
chol_flat[diag_positions] = torch.log(
chol[torch.arange(num_params), torch.arange(num_params)]
)

return chol_flat
num_params = L.shape[0]
tril_indices = torch.tril_indices(num_params, num_params)
L_flat = L[tril_indices[0], tril_indices[1]].clone()
return L_flat
56 changes: 27 additions & 29 deletions posteriors/vi/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from posteriors.utils import (
is_scalar,
CatchAuxError,
cholesky_from_log_flat,
cholesky_to_log_flat,
L_from_flat,
L_to_flat,
)


Expand Down Expand Up @@ -67,9 +67,8 @@ class VIDenseState(NamedTuple):
Attributes:
params: Mean of the variational distribution.
cov: Covariance matrix of the variational distribution.
log_chol: Flat representation of the nonzero values of the Cholesky
of the covariance matrix of the variational distribution. The
log of the diagonal is taken.
L_factor: Flat representation of the nonzero values of the lower
triangular matrix $L$ satisfying $LL^T$ = cov.
opt_state: TorchOpt state storing optimizer data for updating the
variational parameters.
nelbo: Negative evidence lower bound (lower is better).
Expand All @@ -78,7 +77,7 @@ class VIDenseState(NamedTuple):

params: TensorTree
cov: torch.Tensor
log_chol: torch.Tensor
L_factor: torch.Tensor
opt_state: torchopt.typing.OptState
nelbo: torch.tensor = torch.tensor([])
aux: Any = None
Expand Down Expand Up @@ -117,11 +116,11 @@ def init(
if is_scalar(init_cov):
init_cov = init_cov * torch.eye(num_params, requires_grad=True)

init_chol = torch.linalg.cholesky(init_cov)
init_log_chol = cholesky_to_log_flat(init_chol)
init_L = torch.linalg.cholesky(init_cov)
init_L = L_to_flat(init_L)

opt_state = optimizer.init([params, init_log_chol])
return VIDenseState(params, init_cov, init_log_chol, opt_state)
opt_state = optimizer.init([params, init_L])
return VIDenseState(params, init_cov, init_L, opt_state)


def update(
Expand Down Expand Up @@ -153,37 +152,37 @@ def update(
Updated DenseVIState.
"""

def nelbo_log_chol(m, chol):
return nelbo(m, chol, batch, log_posterior, temperature, n_samples, stl)
def nelbo_L_factor(m, L_flat):
return nelbo(m, L_flat, batch, log_posterior, temperature, n_samples, stl)

with torch.no_grad(), CatchAuxError():
nelbo_grads, (nelbo_val, aux) = grad_and_value(
nelbo_log_chol, argnums=(0, 1), has_aux=True
)(state.params, state.log_chol)
nelbo_L_factor, argnums=(0, 1), has_aux=True
)(state.params, state.L_factor)

updates, opt_state = optimizer.update(
nelbo_grads,
state.opt_state,
params=[state.params, state.log_chol],
params=[state.params, state.L_factor],
inplace=inplace,
)
mean, log_chol = torchopt.apply_updates(
(state.params, state.log_chol), updates, inplace=inplace
mean, L_factor = torchopt.apply_updates(
(state.params, state.L_factor), updates, inplace=inplace
)
chol = cholesky_from_log_flat(log_chol, state.cov.shape[0])
cov = chol @ chol.T
L = L_from_flat(L_factor, state.cov.shape[0])
cov = L @ L.T

if inplace:
tree_insert_(state.nelbo, nelbo_val.detach())
tree_insert_(state.cov, cov)
return state._replace(aux=aux)

return VIDenseState(mean, cov, log_chol, opt_state, nelbo_val.detach(), aux)
return VIDenseState(mean, cov, L_factor, opt_state, nelbo_val.detach(), aux)


def nelbo(
mean: dict,
log_chol: torch.Tensor,
L_flat: torch.Tensor,
batch: Any,
log_posterior: LogProbFn,
temperature: float = 1.0,
Expand Down Expand Up @@ -211,9 +210,9 @@ def nelbo(
Args:
mean: Mean of the variational distribution.
log_chol: Flat representation of the nonzero values of the Cholesky
of the covariance matrix of the variational distribution. The
log of the diagonal is taken.
L_factor: Flat representation of the nonzero values of the lower
triangular matrix $L$ satisfying $LL^T$ = cov, where cov
is the covariance matrix of the variational distribution.
batch: Input data to log_posterior.
log_posterior: Function that takes parameters and input batch and
returns the log posterior (which can be unnormalised).
Expand All @@ -228,8 +227,8 @@ def nelbo(

mean_flat, unravel_func = tree_ravel(mean)
num_params = mean_flat.shape[0]
chol = cholesky_from_log_flat(log_chol, num_params)
cov = chol @ chol.T
L = L_from_flat(L_flat, num_params)
cov = L @ L.T
dist = torch.distributions.MultivariateNormal(
loc=mean_flat,
covariance_matrix=cov,
Expand All @@ -241,9 +240,8 @@ def nelbo(

if stl:
mean_flat.detach()
chol = log_chol.detach()
chol = cholesky_from_log_flat(chol, num_params)
cov = chol @ chol.T
L = L_from_flat(L_flat.detach(), num_params)
cov = L @ L.T
# Redefine distribution to sample from after stl
dist = torch.distributions.MultivariateNormal(
loc=mean_flat,
Expand Down
34 changes: 14 additions & 20 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
diag_normal_sample,
per_samplify,
is_scalar,
cholesky_from_log_flat,
cholesky_to_log_flat,
L_from_flat,
L_to_flat,
)
from posteriors.utils import NO_AUX_ERROR_MSG, NON_TENSOR_AUX_ERROR_MSG
from tests.scenarios import TestModel, TestLanguageModel
Expand Down Expand Up @@ -807,34 +807,28 @@ def test_is_scalar():
assert not is_scalar(torch.ones(2))


def test_cholesky_from_log_flat():
exp_0_0 = torch.exp(torch.tensor(1.0)).item()
exp_1_1 = torch.exp(torch.tensor(2.2)).item()
exp_2_2 = torch.exp(torch.tensor(5.5)).item()
def test_L_from_flat():
expected_L = torch.Tensor(
[
[exp_0_0, 0.0, 0.0],
[-4.1, exp_1_1, 0.0],
[-1.7, 4.4, exp_2_2],
[1.0, 0.0, 0.0],
[-4.1, 2.2, 0.0],
[-1.7, 4.4, -5.5],
]
)
L_flat = torch.tensor([1.0, -4.1, 2.2, -1.7, 4.4, 5.5])
L_flat = torch.tensor([1.0, -4.1, 2.2, -1.7, 4.4, -5.5])
num_params = expected_L.shape[0]
L = cholesky_from_log_flat(L_flat, num_params)
L = L_from_flat(L_flat, num_params)
assert torch.allclose(expected_L, L)


def test_cholesky_to_log_flat():
exp_0_0 = torch.exp(torch.tensor(1.0)).item()
exp_1_1 = torch.exp(torch.tensor(2.2)).item()
exp_2_2 = torch.exp(torch.tensor(5.5)).item()
def test_L_to_flat():
L = torch.Tensor(
[
[exp_0_0, 0.0, 0.0],
[-4.1, exp_1_1, 0.0],
[-1.7, 4.4, exp_2_2],
[1.0, 0.0, 0.0],
[-4.1, 2.2, 0.0],
[-1.7, 4.4, -5.5],
]
)
expected_L_flat = torch.tensor([1.0, -4.1, 2.2, -1.7, 4.4, 5.5])
L_flat = cholesky_to_log_flat(L)
expected_L_flat = torch.tensor([1.0, -4.1, 2.2, -1.7, 4.4, -5.5])
L_flat = L_to_flat(L)
assert torch.allclose(expected_L_flat, L_flat)
18 changes: 9 additions & 9 deletions tests/vi/test_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

from posteriors import vi
from posteriors.tree_utils import tree_size
from posteriors.utils import cholesky_to_log_flat
from posteriors.utils import L_to_flat


def test_nelbo():
target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)}
num_params = tree_size(target_mean)
L = torch.abs(torch.randn(num_params, num_params))
L = torch.randn(num_params, num_params)
L = torch.tril(L)
target_cov = L @ L.T

Expand All @@ -26,18 +26,18 @@ def log_prob(p, b):
batch = torch.arange(10).reshape(-1, 1)
target_nelbo_100, _ = vi.dense.nelbo(
target_mean,
cholesky_to_log_flat(L),
L_to_flat(L),
batch,
log_prob,
n_samples=100,
)
assert torch.isclose(target_nelbo_100, torch.tensor(0.0), atol=1e-6)

bad_mean = tree_map(lambda x: torch.zeros_like(x), target_mean)
bad_chol = torch.tril(torch.eye(num_params))
bad_L = torch.tril(torch.eye(num_params))

bad_nelbo_100, _ = vi.dense.nelbo(
bad_mean, cholesky_to_log_flat(bad_chol), batch, log_prob, n_samples=100
bad_mean, L_to_flat(bad_L), batch, log_prob, n_samples=100
)
assert bad_nelbo_100 > target_nelbo_100

Expand All @@ -46,7 +46,7 @@ def _test_vi_dense(optimizer_cls, stl):
torch.manual_seed(43)
target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)}
num_params = tree_size(target_mean)
L = torch.abs(torch.randn(num_params, num_params, requires_grad=True))
L = torch.randn(num_params, num_params, requires_grad=True)
L = torch.tril(L)
target_cov = L @ L.T

Expand All @@ -65,21 +65,21 @@ def log_prob(p, b):

state = vi.dense.init(init_mean, optimizer)

init_log_chol = state.log_chol
init_L_factor = state.L_factor

batch = torch.arange(3).reshape(-1, 1)

nelbo_init, _ = vi.dense.nelbo(
state.params,
init_log_chol,
init_L_factor,
batch,
log_prob,
n_samples=n_vi_samps_large,
)

nelbo_target, _ = vi.dense.nelbo(
target_mean,
cholesky_to_log_flat(L),
L_to_flat(L),
batch,
log_prob,
n_samples=n_vi_samps_large,
Expand Down

0 comments on commit ed3c48b

Please sign in to comment.