Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nonprojective dependency entropy #103

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
62 changes: 32 additions & 30 deletions torch_struct/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ def log_prob(self, value):

d = value.dim()
batch_dims = range(d - len(self.event_shape))
v = self._struct().score(
self.log_potentials,
value.type_as(self.log_potentials),
batch_dims=batch_dims,
)
v = self._struct().score(self.log_potentials, value.type_as(self.log_potentials), batch_dims=batch_dims,)

return v - self.partition

Expand All @@ -91,9 +87,7 @@ def cross_entropy(self, other):
cross entropy (*batch_shape*)
"""

return self._struct(CrossEntropySemiring).sum(
[self.log_potentials, other.log_potentials], self.lengths
)
return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths)

def kl(self, other):
"""
Expand All @@ -105,9 +99,7 @@ def kl(self, other):
Returns:
cross entropy (*batch_shape*)
"""
return self._struct(KLDivergenceSemiring).sum(
[self.log_potentials, other.log_potentials], self.lengths
)
return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths)

@lazy_property
def max(self):
Expand Down Expand Up @@ -140,9 +132,7 @@ def kmax(self, k):
kmax (*k x batch_shape*)
"""
with torch.enable_grad():
return self._struct(KMaxSemiring(k)).sum(
self.log_potentials, self.lengths, _raw=True
)
return self._struct(KMaxSemiring(k)).sum(self.log_potentials, self.lengths, _raw=True)

def topk(self, k):
r"""
Expand All @@ -155,9 +145,7 @@ def topk(self, k):
kmax (*k x batch_shape x event_shape*)
"""
with torch.enable_grad():
return self._struct(KMaxSemiring(k)).marginals(
self.log_potentials, self.lengths, _raw=True
)
return self._struct(KMaxSemiring(k)).marginals(self.log_potentials, self.lengths, _raw=True)

@lazy_property
def mode(self):
Expand Down Expand Up @@ -186,9 +174,7 @@ def count(self):

def gumbel_crf(self, temperature=1.0):
with torch.enable_grad():
st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(
self.log_potentials, self.lengths
)
st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(self.log_potentials, self.lengths)
return st_gumbel

# @constraints.dependent_property
Expand Down Expand Up @@ -219,9 +205,7 @@ def sample(self, sample_shape=torch.Size()):
samples = []
for k in range(nsamples):
if k % 10 == 0:
sample = self._struct(MultiSampledSemiring).marginals(
self.log_potentials, lengths=self.lengths
)
sample = self._struct(MultiSampledSemiring).marginals(self.log_potentials, lengths=self.lengths)
sample = sample.detach()
tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % 10) + 1)
samples.append(tmp_sample)
Expand Down Expand Up @@ -301,9 +285,7 @@ def __init__(self, log_potentials, local=False, lengths=None, max_gap=None):
super().__init__(log_potentials, lengths)

def _struct(self, sr=None):
return self.struct(
sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap
)
return self.struct(sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap)


class HMM(StructDistribution):
Expand Down Expand Up @@ -440,9 +422,7 @@ def __init__(self, log_potentials, lengths=None):
event_shape = log_potentials[0].shape[1:]
self.log_potentials = log_potentials
self.lengths = lengths
super(StructDistribution, self).__init__(
batch_shape=batch_shape, event_shape=event_shape
)
super(StructDistribution, self).__init__(batch_shape=batch_shape, event_shape=event_shape)


class NonProjectiveDependencyCRF(StructDistribution):
Expand Down Expand Up @@ -504,4 +484,26 @@ def argmax(self):

@lazy_property
def entropy(self):
pass
r"""
Compute entropy efficiently using arc-factorization property.

Algorithm derivation:
..math::
{{
\begin{align}
H[p] &= E_{p(T)}[-\log p(T)]\\
&= -E_{p(T)}\big[ \log [\frac{1}{Z} \prod\limits_{(i,j) \in T} \exp\{\phi_{i,j}\}] \big]\\
&= -E_{p(T)}\big[ \sum\limits_{(i,j) \in T} \phi_{i,j} - \log Z \big]\\
&= \log Z -E_{p(T)}\big[\sum\limits_{(i,j) \in A} 1\{(i,j) \in T\} \phi_{i,j}\big]\\
&= \log Z - \sum\limits_{(i,j) \in A} p\big((i,j) \in T\big) \phi_{i,j}
\end{align}
}}

Returns:
entropy (*batch_shape)
"""
logZ = self.partition
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchiu points out that Instead of grouping this with non-proj we should just have it be the default entropy function for all of the models to materialize the marginals. (nothing non-projective specific).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's a general property of all exponential family / sum-product models. See also this Twitter discussion https://twitter.com/RanZmigrod/status/1300832956701970434?s=20

p = self.marginals
phi = self.log_potentials
H = logZ - (p * phi).reshape(phi.shape[0], -1).sum(-1)
return H