Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add where to flow #126

Merged
merged 1 commit into from
Aug 4, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions src/stream_ml/pytorch/builtin/compat/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch as xp

from stream_ml.core.params.scaler import scale_params
from stream_ml.core.builtin._utils import WhereRequiredError

from stream_ml.pytorch._base import ModelBase
from stream_ml.pytorch.utils import names_intersect
Expand All @@ -30,7 +30,13 @@ class FlowModel(ModelBase):
with_grad: bool = True

def ln_likelihood(
self, mpars: Params[Array], /, data: Data[Array], **kwargs: Array
self,
mpars: Params[Array],
/,
data: Data[Array],
*,
where: Data[Array] | None = None,
**kwargs: Array,
) -> Array:
"""Log-likelihood of the array.

Expand All @@ -43,26 +49,47 @@ def ln_likelihood(
data : Data[Array]
Data (phi1, phi2).

where : Data[Array[(N,), bool]] | None, optional keyword-only
Where to evaluate the log-likelihood. If not provided, then the
log-likelihood is evaluated at all data points. ``where`` must
contain the fields in ``coord_names``. Each field must be a boolean
array of the same length as `data`. `True` indicates that the data
point is available, and `False` indicates that the data point is not
available.

**kwargs : Array
Additional arguments.

Returns
-------
Array
"""
# TODO: support `where` argument.
# 'where' is used to indicate which data points are available. If
# 'where' is not provided, then all data points are assumed to be
# available.
where_: Array # (N, F)
if where is not None:
where_ = where[self.coord_names].array
elif self.require_where:
raise WhereRequiredError
else:
where_ = self.xp.ones((len(data), self.ndim), dtype=bool)
idx = where_.all(axis=1)
# TODO: allow for missing data in only some of the dimensions

data = self.data_scaler.transform(
data, names=names_intersect(data, self.data_scaler), xp=self.xp
)
mpars = scale_params(self, mpars)

out = self.xp.zeros(len(data), dtype=data.dtype)
with nullcontext() if self.with_grad else xp.no_grad():
return self.jacobian_logdet + self.net.log_prob(
inputs=data[self.coord_names].array,
context=data[self.indep_coord_names].array
out[idx] = self.jacobian_logdet + self.net.log_prob(
inputs=data[self.coord_names].array[idx],
context=data[self.indep_coord_names].array[idx]
if self.indep_coord_names is not None
else None,
)
return out

def forward(self, data: Data[Array]) -> Array:
"""Forward pass.
Expand Down