Skip to content

Commit

Permalink
add docs for distributions base class and refactor for readability
Browse files Browse the repository at this point in the history
  • Loading branch information
gorold committed Jul 9, 2024
1 parent c449083 commit 07656c4
Showing 1 changed file with 64 additions and 17 deletions.
81 changes: 64 additions & 17 deletions src/uni2ts/distribution/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import abc
from collections.abc import Callable
from collections.abc import Callable, Sequence
from typing import Any, Optional

import torch
Expand All @@ -32,13 +32,15 @@
def tree_map_multi(
func: Callable, tree: PyTree[Any, "T"], *other: PyTree[Any, "T"]
) -> PyTree[Any, "T"]:
"""Tree map with function requiring multiple inputs, where other inputs are from a PyTree too."""
leaves, treespec = tree_flatten(tree)
other_leaves = [tree_flatten(o)[0] for o in other]
return_leaves = [func(*leaf) for leaf in zip(leaves, *other_leaves)]
return tree_unflatten(return_leaves, treespec)


def convert_to_module(tree: PyTree[nn.Module, "T"]) -> PyTree[nn.Module, "T"]:
"""Convert a simple container PyTree into an nn.Module PyTree"""
if isinstance(tree, dict):
return nn.ModuleDict(
{key: convert_to_module(child) for key, child in tree.items()}
Expand All @@ -49,6 +51,7 @@ def convert_to_module(tree: PyTree[nn.Module, "T"]) -> PyTree[nn.Module, "T"]:


def convert_to_container(tree: PyTree[nn.Module, "T"]) -> PyTree[nn.Module, "T"]:
"""Convert an nn.Module PyTree into a simple container PyTree"""
if isinstance(tree, nn.ModuleDict):
return {key: convert_to_container(child) for key, child in tree.items()}
if isinstance(tree, nn.ModuleList):
Expand All @@ -57,6 +60,10 @@ def convert_to_container(tree: PyTree[nn.Module, "T"]) -> PyTree[nn.Module, "T"]


class DistrParamProj(nn.Module):
"""
Projection layer from representations to distribution parameters.
"""

def __init__(
self,
in_features: int,
Expand All @@ -66,26 +73,41 @@ def __init__(
proj_layer: Callable[..., nn.Module] = MultiOutSizeLinear,
**kwargs: Any,
):
"""
:param in_features: size of representation
:param out_features: size multiplier of distribution parameters
:param args_dim: dimensionality of distribution parameters
:param domain_map: mapping for distribution parameters
:param proj_layer: projection layer
:param kwargs: additional kwargs for proj_layer
"""
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.args_dim = args_dim
self.domain_map = domain_map
self.proj = convert_to_module(
tree_map(
lambda dim: (
proj_layer(in_features, dim * out_features, **kwargs)
if isinstance(out_features, int)
else proj_layer(
in_features,
tuple(dim * of for of in out_features),
dim=dim,
**kwargs,
)
),
args_dim,

if isinstance(out_features, int):

def proj(dim):
proj_layer(in_features, dim * out_features, **kwargs)

elif isinstance(out_features, Sequence):

def proj(dim):
return proj_layer(
in_features,
tuple(dim * of for of in out_features),
dim=dim,
**kwargs,
)

else:
raise ValueError(
f"out_features must be int or sequence of ints, got invalid type: {type(out_features)}"
)
)

self.proj = convert_to_module(tree_map(proj, args_dim))
self.out_size = (
out_features if isinstance(out_features, int) else max(out_features)
)
Expand Down Expand Up @@ -160,11 +182,27 @@ def _distribution(

@property
@abc.abstractmethod
def args_dim(self) -> PyTree[int, "T"]: ...
def args_dim(self) -> PyTree[int, "T"]:
"""
Returns the dimensionality of the distribution parameters in the form of a pytree.
For simple distributions, this will be a simple dictionary:
e.g. for a univariate normal distribution, the args_dim should return {"loc": 1, "scale": 1}.
For more complex distributions, this could be an arbitrarily complex pytree.
:return: pytree of integers representing the dimensionality of the distribution parameters
"""
...

@property
@abc.abstractmethod
def domain_map(self) -> PyTree[Callable[[torch.Tensor], torch.Tensor], "T"]: ...
def domain_map(self) -> PyTree[Callable[[torch.Tensor], torch.Tensor], "T"]:
"""
Returns a pytree of callables that maps the unconstrained distribution parameters
to the range required by their distributions.
:return: callables in the same PyTree format as args_dim
"""
...

def get_param_proj(
self,
Expand All @@ -173,6 +211,15 @@ def get_param_proj(
proj_layer: Callable[..., nn.Module] = MultiOutSizeLinear,
**kwargs: Any,
) -> nn.Module:
"""
Get a projection layer mapping representations to distribution parameters.
:param in_features: input feature dimension
:param out_features: size multiplier of distribution parameters
:param proj_layer: projection layer
:param kwargs: additional kwargs for proj_layer
:return: distribution parameter projection layer
"""
return DistrParamProj(
in_features=in_features,
out_features=out_features,
Expand Down

0 comments on commit 07656c4

Please sign in to comment.