Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change TransformState to NamedTuple #106

Merged
merged 4 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Here:
- `build` is a function that loads `config_args` into the `init` and `update` functions
and stores them within the `transform` instance. The `init` and `update`
functions then conform to a preset signature allowing for easy switching between algorithms.
- `state` is a [`dataclass`](https://docs.python.org/3/library/dataclasses.html)
- `state` is a [`NamedTuple`](https://docs.python.org/3/library/typing.html#typing.NamedTuple)
encoding the state of the algorithm, including `params` and `aux` attributes.
- `init` constructs the iteration-varying `state` based on the model parameters `params`.
- `update` updates the `state` based on a new `batch` of data.
Expand Down
23 changes: 23 additions & 0 deletions docs/gotchas.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,29 @@ state2 = transform.update(state, batch, inplace=True)
# state is updated and state2 is a pointer to state
```

When adding a new algorithm, in-place support can be achieved by modifying `TensorTree`s
via the [`flexi_tree_map`](https://normal-computing.github.io/posteriors/api/tree_utils/#posteriors.tree_utils.flexi_tree_map) function:

```python
from posteriors.tree_utils import flexi_tree_map

new_state = flexi_tree_map(lambda x: x + 1, state, inplace=True)
```

As `posteriors` transform states are immutable `NamedTuple`s, in-place modification of
`TensorTree` leaves can be achieved by modifying the data of the tensor directly with [`tree_insert_`](https://normal-computing.github.io/posteriors/api/tree_utils/#posteriors.tree_utils.tree_insert_):

```python
from posteriors.tree_utils import tree_insert_

tree_insert_(state.log_posterior, log_post.detach())
```

However, the `aux` component of the `TransformState` is not guaranteed to be a `TensorTree`,
and so in-place modification of `aux` is not supported. Using `state._replace(aux=aux)`
will return a state with all `TensorTree` pointing to the same memory as input `state`,
but with a new `aux` component (`aux` is not modified in the input `state` object).


## `torch.tensor` with autograd

Expand Down
3 changes: 1 addition & 2 deletions docs/tutorials/lightning_autoencoder.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
import torchopt
from dataclasses import asdict

import posteriors

Expand Down Expand Up @@ -100,7 +99,7 @@ class LitAutoEncoderUQ(L.LightningModule):
# it is independent of forward
self.state = self.transform.update(self.state, batch, inplace=True)
# Logging to TensorBoard (if installed) by default
for k, v in asdict(self.state).items():
for k, v in self.state._asdict().items():
if isinstance(v, float) or (isinstance(v, torch.Tensor) and v.numel() == 1):
self.log(k, v)

Expand Down
3 changes: 1 addition & 2 deletions examples/lightning_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torchvision.transforms import ToTensor
import lightning as L
import torchopt
from dataclasses import asdict

import posteriors

Expand Down Expand Up @@ -54,7 +53,7 @@ def training_step(self, batch, batch_idx):
# it is independent of forward
self.state = self.transform.update(self.state, batch, inplace=True)
# Logging to TensorBoard (if installed) by default
for k, v in asdict(self.state).items():
for k, v in self.state._asdict().items():
if isinstance(v, float) or (isinstance(v, torch.Tensor) and v.numel() == 1):
self.log(k, v)

Expand Down
24 changes: 11 additions & 13 deletions posteriors/ekf/dense_fisher.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Any
from typing import Any, NamedTuple
from functools import partial
import torch
from torch.func import grad_and_value
from dataclasses import dataclass
from optree.integration.torch import tree_ravel

from posteriors.tree_utils import tree_size
from posteriors.tree_utils import tree_size, tree_insert_

from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.types import TensorTree, Transform, LogProbFn
from posteriors.utils import (
per_samplify,
empirical_fisher,
Expand Down Expand Up @@ -67,11 +66,10 @@ def build(
return Transform(init_fn, update_fn)


@dataclass
class EKFDenseState(TransformState):
class EKFDenseState(NamedTuple):
"""State encoding a Normal distribution over parameters.

Args:
Attributes:
params: Mean of the Normal distribution.
cov: Covariance matrix of the
Normal distribution.
Expand All @@ -81,7 +79,7 @@ class EKFDenseState(TransformState):

params: TensorTree
cov: torch.Tensor
log_likelihood: float = 0
log_likelihood: torch.Tensor = torch.tensor([])
aux: Any = None


Expand Down Expand Up @@ -170,11 +168,11 @@ def log_likelihood_reduced(params, batch):
update_mean = mu_unravel_f(update_mean)

if inplace:
state.params = update_mean
state.cov = update_cov
state.log_likelihood = log_liks.mean().detach()
state.aux = aux
return state
tree_insert_(state.params, update_mean)
tree_insert_(state.cov, update_cov)
tree_insert_(state.log_likelihood, log_liks.mean().detach())
return state._replace(aux=aux)

return EKFDenseState(update_mean, update_cov, log_liks.mean().detach(), aux)


Expand Down
20 changes: 9 additions & 11 deletions posteriors/ekf/diag_fisher.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import Any
from typing import Any, NamedTuple
from functools import partial
import torch
from torch.func import jacrev
from optree import tree_map
from dataclasses import dataclass

from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.tree_utils import flexi_tree_map
from posteriors.types import TensorTree, Transform, LogProbFn
from posteriors.tree_utils import flexi_tree_map, tree_insert_
from posteriors.utils import (
diag_normal_sample,
per_samplify,
Expand Down Expand Up @@ -68,11 +67,10 @@ def build(
return Transform(init_fn, update_fn)


@dataclass
class EKFDiagState(TransformState):
class EKFDiagState(NamedTuple):
"""State encoding a diagonal Normal distribution over parameters.

Args:
Attributes:
params: Mean of the Normal distribution.
sd_diag: Square-root diagonal of the covariance matrix of the
Normal distribution.
Expand All @@ -82,7 +80,7 @@ class EKFDiagState(TransformState):

params: TensorTree
sd_diag: TensorTree
log_likelihood: float = 0
log_likelihood: torch.Tensor = torch.tensor([])
aux: Any = None


Expand Down Expand Up @@ -176,9 +174,9 @@ def update(
)

if inplace:
state.log_likelihood = log_liks.mean().detach()
state.aux = aux
return state
tree_insert_(state.log_likelihood, log_liks.mean().detach())
return state._replace(aux=aux)

return EKFDiagState(update_mean, update_sd_diag, log_liks.mean().detach(), aux)


Expand Down
15 changes: 6 additions & 9 deletions posteriors/laplace/dense_fisher.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Any
from dataclasses import dataclass
from typing import Any, NamedTuple
from functools import partial
import torch
from optree import tree_map
from optree.integration.torch import tree_ravel

from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.types import TensorTree, Transform, LogProbFn
from posteriors.tree_utils import tree_size
from posteriors.utils import (
per_samplify,
Expand Down Expand Up @@ -55,12 +54,11 @@ def build(
return Transform(init_fn, update_fn)


@dataclass
class DenseLaplaceState(TransformState):
class DenseLaplaceState(NamedTuple):
"""State encoding a Normal distribution over parameters,
with a dense precision matrix

Args:
Attributes:
params: Mean of the Normal distribution.
prec: Precision matrix of the Normal distribution.
aux: Auxiliary information from the log_posterior call.
Expand Down Expand Up @@ -130,9 +128,8 @@ def update(
)(state.params)

if inplace:
state.prec += fisher
state.aux = aux
return state
state.prec.data += fisher
KaelanDt marked this conversation as resolved.
Show resolved Hide resolved
return state._replace(aux=aux)
else:
return DenseLaplaceState(state.params, state.prec + fisher, aux)

Expand Down
14 changes: 5 additions & 9 deletions posteriors/laplace/dense_ggn.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from functools import partial
from typing import Any
from typing import Any, NamedTuple
import torch
from optree import tree_map
from dataclasses import dataclass
from optree.integration.torch import tree_ravel

from posteriors.types import (
TensorTree,
Transform,
ForwardFn,
OuterLogProbFn,
TransformState,
)
from posteriors.utils import (
tree_size,
Expand Down Expand Up @@ -67,12 +65,11 @@ def build(
return Transform(init_fn, update_fn)


@dataclass
class DenseLaplaceState(TransformState):
class DenseLaplaceState(NamedTuple):
"""State encoding a Normal distribution over parameters,
with a dense precision matrix

Args:
Attributes:
params: Mean of the Normal distribution.
prec: Precision matrix of the Normal distribution.
aux: Auxiliary information from the log_posterior call.
Expand Down Expand Up @@ -145,9 +142,8 @@ def outer_loss(z, batch):
)(state.params)

if inplace:
state.prec += ggn_batch
state.aux = aux
return state
state.prec.data += ggn_batch
return state._replace(aux=aux)
else:
return DenseLaplaceState(state.params, state.prec + ggn_batch, aux)

Expand Down
13 changes: 5 additions & 8 deletions posteriors/laplace/diag_fisher.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from functools import partial
from typing import Any
from typing import Any, NamedTuple
import torch
from torch.func import jacrev
from optree import tree_map
from dataclasses import dataclass

from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.types import TensorTree, Transform, LogProbFn
from posteriors.tree_utils import flexi_tree_map
from posteriors.utils import (
diag_normal_sample,
Expand Down Expand Up @@ -54,11 +53,10 @@ def build(
return Transform(init_fn, update_fn)


@dataclass
class DiagLaplaceState(TransformState):
class DiagLaplaceState(NamedTuple):
"""State encoding a diagonal Normal distribution over parameters.

Args:
Attributes:
params: Mean of the Normal distribution.
prec_diag: Diagonal of the precision matrix of the Normal distribution.
aux: Auxiliary information from the log_posterior call.
Expand Down Expand Up @@ -134,8 +132,7 @@ def update_func(x, y):
)

if inplace:
state.aux = aux
return state
return state._replace(aux=aux)
return DiagLaplaceState(state.params, prec_diag, aux)


Expand Down
12 changes: 4 additions & 8 deletions posteriors/laplace/diag_ggn.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from functools import partial
from typing import Any
from typing import Any, NamedTuple
import torch
from optree import tree_map
from dataclasses import dataclass

from posteriors.types import (
TensorTree,
Transform,
ForwardFn,
OuterLogProbFn,
TransformState,
)
from posteriors.tree_utils import flexi_tree_map
from posteriors.utils import (
Expand Down Expand Up @@ -66,11 +64,10 @@ def build(
return Transform(init_fn, update_fn)


@dataclass
class DiagLaplaceState(TransformState):
class DiagLaplaceState(NamedTuple):
"""State encoding a diagonal Normal distribution over parameters.

Args:
Attributes:
params: Mean of the Normal distribution.
prec_diag: Diagonal of the precision matrix of the Normal distribution.
aux: Auxiliary information from the log_posterior call.
Expand Down Expand Up @@ -149,8 +146,7 @@ def update_func(x, y):
)

if inplace:
state.aux = aux
return state
return state._replace(aux=aux)
return DiagLaplaceState(state.params, prec_diag, aux)


Expand Down
18 changes: 8 additions & 10 deletions posteriors/optim.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Type, Any
from typing import Type, Any, NamedTuple
from functools import partial
import torch
from dataclasses import dataclass

from posteriors.types import TensorTree, Transform, LogProbFn, TransformState
from posteriors.types import TensorTree, Transform, LogProbFn
from posteriors.utils import CatchAuxError
from posteriors.tree_utils import tree_insert_


def build(
Expand Down Expand Up @@ -36,11 +36,10 @@ def build(
return Transform(init_fn, update_fn)


@dataclass
class OptimState(TransformState):
class OptimState(NamedTuple):
"""State of an optimizer from [torch.optim](https://pytorch.org/docs/stable/optim.html).

Args:
Attributes:
params: Parameters to be optimized.
optimizer: torch.optim optimizer instance.
loss: Loss value.
Expand All @@ -49,7 +48,7 @@ class OptimState(TransformState):

params: TensorTree
optimizer: torch.optim.Optimizer
loss: torch.tensor = None
loss: torch.tensor = torch.tensor([])
aux: Any = None


Expand Down Expand Up @@ -104,6 +103,5 @@ def update(
loss, aux = loss_fn(state.params, batch)
loss.backward()
state.optimizer.step()
state.loss = loss
state.aux = aux
return state
tree_insert_(state.loss, loss.detach())
return state._replace(aux=aux)
Loading