Skip to content

Commit

Permalink
wip for jax
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Jun 17, 2024
1 parent 2bd0627 commit ba06aab
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 32 deletions.
58 changes: 34 additions & 24 deletions lifelines/fitters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from numpy.linalg import inv, pinv
import numpy as np

from autograd import hessian, value_and_grad, elementwise_grad as egrad, grad
from jax import hessian, value_and_grad, grad, vmap, vjp
from autograd.differential_operators import make_jvp_reversemode
from autograd.misc import flatten
import autograd.numpy as anp
from jax.flatten_util import ravel_pytree as flatten
import jax.numpy as jnp

from scipy.optimize import minimize, root_scalar
from scipy.integrate import trapz
Expand All @@ -39,6 +39,16 @@
]


def egrad(g):
# assumes grad w.r.t argnum=1
def wrapped(params, times):
y, g_vjp = vjp(lambda times: g(params, times), times)
(x_bar,) = g_vjp(jnp.ones_like(y))
return x_bar

return wrapped


class BaseFitter:

weights: np.ndarray
Expand Down Expand Up @@ -384,30 +394,30 @@ def _buffer_bounds(self, bounds: list[tuple[t.Optional[float], t.Optional[float]
yield (lb + self._MIN_PARAMETER_VALUE, ub - self._MIN_PARAMETER_VALUE)

def _cumulative_hazard(self, params, times):
return -anp.log(self._survival_function(params, times))
return -jnp.log(self._survival_function(params, times))

def _hazard(self, *args, **kwargs):
# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
return egrad(self._cumulative_hazard, argnum=1)(*args, **kwargs)
return egrad(self._cumulative_hazard)(*args, **kwargs)

def _density(self, *args, **kwargs):
# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
return egrad(self._cumulative_density, argnum=1)(*args, **kwargs)
return egrad(self._cumulative_density)(*args, **kwargs)

def _survival_function(self, params, times):
return anp.exp(-self._cumulative_hazard(params, times))
return jnp.exp(-self._cumulative_hazard(params, times))

def _cumulative_density(self, params, times):
return 1 - self._survival_function(params, times)

def _log_hazard(self, params, times):
hz = self._hazard(params, times)
hz = anp.clip(hz, 1e-50, np.inf)
return anp.log(hz)
hz = jnp.clip(hz, 1e-50, np.inf)
return jnp.log(hz)

def _log_1m_sf(self, params, times):
# equal to log(cdf), but often easier to express with sf.
return anp.log1p(-self._survival_function(params, times))
return jnp.log1p(-self._survival_function(params, times))

def _negative_log_likelihood_left_censoring(self, params, Ts, E, entry, weights) -> float:
T = Ts[1]
Expand Down Expand Up @@ -449,8 +459,8 @@ def _negative_log_likelihood_interval_censoring(self, params, Ts, E, entry, weig
ll
+ (
censored_weights
* anp.log(
anp.clip(
* jnp.log(
jnp.clip(
self._survival_function(params, censored_starts) - self._survival_function(params, censored_stops),
1e-25,
1 - 1e-25,
Expand Down Expand Up @@ -1391,23 +1401,23 @@ def _check_values_pre_fitting(self, df, T, E, weights, entries):
utils.check_entry_times(T, entries)

def _cumulative_hazard(self, params, T, Xs):
return -anp.log(self._survival_function(params, T, Xs))
return -jnp.log(self._survival_function(params, T, Xs))

def _hazard(self, params, T, Xs):
return egrad(self._cumulative_hazard, argnum=1)(params, T, Xs) # pylint: disable=unexpected-keyword-arg
return egrad(self._cumulative_hazard)(params, T, Xs) # pylint: disable=unexpected-keyword-arg

def _log_hazard(self, params, T, Xs):
# can be overwritten to improve convergence, see example in WeibullAFTFitter
hz = self._hazard(params, T, Xs)
hz = anp.clip(hz, 1e-20, np.inf)
return anp.log(hz)
hz = jnp.clip(hz, 1e-20, np.inf)
return jnp.log(hz)

def _log_1m_sf(self, params, T, Xs):
# equal to log(cdf), but often easier to express with sf.
return anp.log1p(-self._survival_function(params, T, Xs))
return jnp.log1p(-self._survival_function(params, T, Xs))

def _survival_function(self, params, T, Xs):
return anp.clip(anp.exp(-self._cumulative_hazard(params, T, Xs)), 1e-12, 1 - 1e-12)
return jnp.clip(jnp.exp(-self._cumulative_hazard(params, T, Xs)), 1e-12, 1 - 1e-12)

def _log_likelihood_right_censoring(self, params, Ts: tuple, E, W, entries, Xs) -> float:

Expand All @@ -1422,7 +1432,7 @@ def _log_likelihood_right_censoring(self, params, Ts: tuple, E, W, entries, Xs)
ll = ll + (W * E * log_hz).sum()
ll = ll + -(W * cum_hz).sum()
ll = ll + (W[non_zero_entries] * delayed_entries).sum()
ll = ll / anp.sum(W)
ll = ll / jnp.sum(W)
return ll

def _log_likelihood_left_censoring(self, params, Ts, E, W, entries, Xs) -> float:
Expand All @@ -1438,16 +1448,16 @@ def _log_likelihood_left_censoring(self, params, Ts, E, W, entries, Xs) -> float
ll = 0
ll = (W * E * (log_hz - cum_haz - log_1m_sf)).sum() + (W * log_1m_sf).sum()
ll = ll + (W[non_zero_entries] * delayed_entries).sum()
ll = ll / anp.sum(W)
ll = ll / jnp.sum(W)
return ll

def _log_likelihood_interval_censoring(self, params, Ts, E, W, entries, Xs) -> float:

start, stop = Ts
non_zero_entries = entries > 0
observed_deaths = self._log_hazard(params, stop[E], Xs.filter(E)) - self._cumulative_hazard(params, stop[E], Xs.filter(E))
censored_interval_deaths = anp.log(
anp.clip(
censored_interval_deaths = jnp.log(
jnp.clip(
self._survival_function(params, start[~E], Xs.filter(~E))
- self._survival_function(params, stop[~E], Xs.filter(~E)),
1e-25,
Expand All @@ -1460,7 +1470,7 @@ def _log_likelihood_interval_censoring(self, params, Ts, E, W, entries, Xs) -> f
ll = ll + (W[E] * observed_deaths).sum()
ll = ll + (W[~E] * censored_interval_deaths).sum()
ll = ll + (W[non_zero_entries] * delayed_entries).sum()
ll = ll / anp.sum(W)
ll = ll / jnp.sum(W)
return ll

@utils.CensoringType.left_censoring
Expand Down Expand Up @@ -1885,7 +1895,7 @@ def _add_penalty(self, params: dict, neg_ll: float):
params_array = params_array[~self._cols_to_not_penalize]
if (isinstance(self.penalizer, np.ndarray) or self.penalizer > 0) and self.l1_ratio > 0:
penalty = (
self.l1_ratio * (self.penalizer * anp.abs(params_array)).sum()
self.l1_ratio * (self.penalizer * jnp.abs(params_array)).sum()
+ 0.5 * (1.0 - self.l1_ratio) * (self.penalizer * (params_array) ** 2).sum()
)

Expand Down
4 changes: 2 additions & 2 deletions lifelines/fitters/exponential_fitter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
import numpy as np
from autograd import numpy as anp
from jax import numpy as jnp
from lifelines.fitters import KnownModelParametricUnivariateFitter


Expand Down Expand Up @@ -77,4 +77,4 @@ def _cumulative_hazard(self, params, times):

def _log_hazard(self, params, times):
lambda_ = params[0]
return -anp.log(lambda_)
return -jnp.log(lambda_)
2 changes: 1 addition & 1 deletion lifelines/fitters/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def check_assumptions(
plot_n_bootstraps: int = 15,
columns: Optional[List[str]] = None,
raise_on_fail: bool = False,
) -> None:
) -> list:
"""
Use this function to test the proportional hazards assumption. See usage example at
https://lifelines.readthedocs.io/en/latest/jupyter_notebooks/Proportional%20hazard%20assumption.html
Expand Down
4 changes: 2 additions & 2 deletions lifelines/tests/test_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,8 +1149,8 @@ def test_against_reliability_software(self):

class TestExponentialFitter:
def test_fit_computes_correct_lambda_(self):
T = np.array([10, 10, 10, 10], dtype=float)
E = np.array([1, 1, 1, 0], dtype=float)
T = np.array([10, 20, 10, 10, 5, 10], dtype=float)
E = np.array([1, 1, 1, 0, 0, 1], dtype=float)
enf = ExponentialFitter()
enf.fit(T, E)
assert abs(enf.lambda_ - (T.sum() / E.sum())) < 1e-4
Expand Down
5 changes: 2 additions & 3 deletions reqs/base-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
numpy>=1.14.0,<2.0
numpy>=1.14.0
scipy>=1.2.0
pandas>=1.2.0
matplotlib>=3.0
autograd>=1.5
autograd-gamma>=0.3
formulaic>=0.2.2
jax[cpu]>=0.4.0

0 comments on commit ba06aab

Please sign in to comment.