Skip to content

Commit

Permalink
Merge pull request #79 from normal-computing/fix-per-samplify
Browse files Browse the repository at this point in the history
Robustify per_samplify
  • Loading branch information
SamDuffield authored Apr 24, 2024
2 parents ecfeff7 + 48b0fe0 commit 21f8c7f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 12 deletions.
18 changes: 10 additions & 8 deletions posteriors/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import Callable, Any, Tuple, Sequence
import operator
from functools import partial
from functools import partial, wraps
import contextlib
import torch
from torch.func import grad, jvp, vjp, functional_call, jacrev, jacfwd
from torch.distributions import Normal
from optree import tree_map, tree_map_, tree_reduce, tree_flatten, tree_leaves
from optree.integration.torch import tree_ravel


from posteriors.types import TensorTree, ForwardFn, Tensor


Expand Down Expand Up @@ -987,9 +986,8 @@ def flexi_tree_map(
def per_samplify(
f: Callable[[TensorTree, TensorTree], Any],
) -> Callable[[TensorTree, TensorTree], Any]:
"""Converts a function that takes params and batch and averages over the batch in
its output into one that provides an output for each batch sample
(i.e. no averaging).
"""Converts a function that takes params and batch into one that provides an output
for each batch sample.
```
output = f(params, batch)
Expand All @@ -999,8 +997,8 @@ def per_samplify(
For more info see [per_sample_grads.html](https://pytorch.org/tutorials/intermediate/per_sample_grads.html)
Args:
f: A function that takes params and batch and averages over the batch in its
output.
f: A function that takes params and batch provides an output with size
independent of batchsize (i.e. averaged).
Returns:
A new function that provides an output for each batch sample.
Expand All @@ -1012,7 +1010,11 @@ def f_per_sample(params, batch):
batch = tree_map(lambda x: x.unsqueeze(0), batch)
return f(params, batch)

return f_per_sample
@wraps(f)
def f_per_sample_ensure_no_kwargs(params, batch):
return f_per_sample(params, batch) # vmap in_dims requires no kwargs

return f_per_sample_ensure_no_kwargs


def is_scalar(x: Any) -> bool:
Expand Down
35 changes: 31 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,11 @@ def func_aux(x):
batch_labels = torch.randint(2, (3,)).unsqueeze(-1)
batch_spec = {"inputs": batch_inputs, "labels": batch_labels}

def log_likelihood_per_sample(params, batch):
def log_likelihood(params, batch):
output = torch.func.functional_call(model, params, batch["inputs"])
return -torch.nn.BCEWithLogitsLoss(reduction="none")(
output, batch["labels"].float()
)
return -torch.nn.BCEWithLogitsLoss()(output, batch["labels"].float())

log_likelihood_per_sample = per_samplify(log_likelihood)

v = tree_map(lambda x: torch.randn_like(x), params)

Expand Down Expand Up @@ -872,6 +872,33 @@ def func(p, b):
assert torch.allclose(ra, expected_a)
assert torch.allclose(rb, expected_b)

# Test model
model = TestModel() # From tests.scenarios
params = dict(model.named_parameters())
batch_inputs = torch.randn(3, 10)
batch_labels = torch.randint(2, (3, 1))
batch_spec = {"inputs": batch_inputs, "labels": batch_labels}

def log_likelihood(params, batch):
output = torch.func.functional_call(model, params, batch["inputs"])
return -torch.nn.BCEWithLogitsLoss()(output, batch["labels"].float())

log_likelihood_per_sample = per_samplify(log_likelihood)

expected = torch.tensor(
[
log_likelihood(
params, {"inputs": inp.unsqueeze(0), "labels": lab.unsqueeze(0)}
)
for inp, lab in zip(batch_inputs, batch_labels)
]
)

eval = log_likelihood_per_sample(params, batch_spec)
eval_p = partial(log_likelihood_per_sample, batch=batch_spec)(params)
assert torch.allclose(expected, eval)
assert torch.allclose(expected, eval_p)


def test_is_scalar():
assert is_scalar(1)
Expand Down

0 comments on commit 21f8c7f

Please sign in to comment.