From 48b0fe0739b426363f525562e49719ad5178a4f0 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Wed, 24 Apr 2024 15:19:16 +0100 Subject: [PATCH] Robustify per_samplify --- posteriors/utils.py | 18 ++++++++++-------- tests/test_utils.py | 35 +++++++++++++++++++++++++++++++---- 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/posteriors/utils.py b/posteriors/utils.py index 4ff4a1b7..7ae769a0 100644 --- a/posteriors/utils.py +++ b/posteriors/utils.py @@ -1,6 +1,6 @@ from typing import Callable, Any, Tuple, Sequence import operator -from functools import partial +from functools import partial, wraps import contextlib import torch from torch.func import grad, jvp, vjp, functional_call, jacrev, jacfwd @@ -8,7 +8,6 @@ from optree import tree_map, tree_map_, tree_reduce, tree_flatten, tree_leaves from optree.integration.torch import tree_ravel - from posteriors.types import TensorTree, ForwardFn, Tensor @@ -987,9 +986,8 @@ def flexi_tree_map( def per_samplify( f: Callable[[TensorTree, TensorTree], Any], ) -> Callable[[TensorTree, TensorTree], Any]: - """Converts a function that takes params and batch and averages over the batch in - its output into one that provides an output for each batch sample - (i.e. no averaging). + """Converts a function that takes params and batch into one that provides an output + for each batch sample. ``` output = f(params, batch) @@ -999,8 +997,8 @@ def per_samplify( For more info see [per_sample_grads.html](https://pytorch.org/tutorials/intermediate/per_sample_grads.html) Args: - f: A function that takes params and batch and averages over the batch in its - output. + f: A function that takes params and batch provides an output with size + independent of batchsize (i.e. averaged). Returns: A new function that provides an output for each batch sample. @@ -1012,7 +1010,11 @@ def f_per_sample(params, batch): batch = tree_map(lambda x: x.unsqueeze(0), batch) return f(params, batch) - return f_per_sample + @wraps(f) + def f_per_sample_ensure_no_kwargs(params, batch): + return f_per_sample(params, batch) # vmap in_dims requires no kwargs + + return f_per_sample_ensure_no_kwargs def is_scalar(x: Any) -> bool: diff --git a/tests/test_utils.py b/tests/test_utils.py index 645a9004..063cc065 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -209,11 +209,11 @@ def func_aux(x): batch_labels = torch.randint(2, (3,)).unsqueeze(-1) batch_spec = {"inputs": batch_inputs, "labels": batch_labels} - def log_likelihood_per_sample(params, batch): + def log_likelihood(params, batch): output = torch.func.functional_call(model, params, batch["inputs"]) - return -torch.nn.BCEWithLogitsLoss(reduction="none")( - output, batch["labels"].float() - ) + return -torch.nn.BCEWithLogitsLoss()(output, batch["labels"].float()) + + log_likelihood_per_sample = per_samplify(log_likelihood) v = tree_map(lambda x: torch.randn_like(x), params) @@ -872,6 +872,33 @@ def func(p, b): assert torch.allclose(ra, expected_a) assert torch.allclose(rb, expected_b) + # Test model + model = TestModel() # From tests.scenarios + params = dict(model.named_parameters()) + batch_inputs = torch.randn(3, 10) + batch_labels = torch.randint(2, (3, 1)) + batch_spec = {"inputs": batch_inputs, "labels": batch_labels} + + def log_likelihood(params, batch): + output = torch.func.functional_call(model, params, batch["inputs"]) + return -torch.nn.BCEWithLogitsLoss()(output, batch["labels"].float()) + + log_likelihood_per_sample = per_samplify(log_likelihood) + + expected = torch.tensor( + [ + log_likelihood( + params, {"inputs": inp.unsqueeze(0), "labels": lab.unsqueeze(0)} + ) + for inp, lab in zip(batch_inputs, batch_labels) + ] + ) + + eval = log_likelihood_per_sample(params, batch_spec) + eval_p = partial(log_likelihood_per_sample, batch=batch_spec)(params) + assert torch.allclose(expected, eval) + assert torch.allclose(expected, eval_p) + def test_is_scalar(): assert is_scalar(1)