Skip to content

Commit

Permalink
use expm1
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Oct 11, 2023
1 parent a2df40c commit 8524286
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/stream_ml/pytorch/_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def forward(self, data: Data[Array], /) -> Array:
ln_weight = ( # (N, 1)
ln_weights[:, i]
if name != BACKGROUND_KEY
else self.xp.log(1 - self.xp.sum(self.xp.exp(ln_weights), 1))
else self.xp.log(
-self.xp.expm1(self.xp.special.logsumexp(ln_weights, 1))
)
)[:, None]
wgt_is[i] = counter

Expand All @@ -170,7 +172,7 @@ def forward(self, data: Data[Array], /) -> Array:
# Ensure that the background weight is 1 - sum(weights)
if self._includes_bkg:
out[:, wgt_is[-1]] = self.xp.log(
1 - self.xp.sum(self.xp.exp(out[:, wgt_is[:-1]]), 1)
-self.xp.expm1(self.xp.special.logsumexp(out[:, wgt_is[:-1]], 1))
)

return out

0 comments on commit 8524286

Please sign in to comment.