From ba06aaba98fe9457d7ca57572de852b72ccad788 Mon Sep 17 00:00:00 2001 From: CamDavidsonPilon Date: Sun, 16 Jun 2024 21:01:20 -0400 Subject: [PATCH] wip for jax --- lifelines/fitters/__init__.py | 58 +++++++++++++++---------- lifelines/fitters/exponential_fitter.py | 4 +- lifelines/fitters/mixins.py | 2 +- lifelines/tests/test_estimation.py | 4 +- reqs/base-requirements.txt | 5 +-- 5 files changed, 41 insertions(+), 32 deletions(-) diff --git a/lifelines/fitters/__init__.py b/lifelines/fitters/__init__.py index 964101ec2..00b5e8551 100644 --- a/lifelines/fitters/__init__.py +++ b/lifelines/fitters/__init__.py @@ -12,10 +12,10 @@ from numpy.linalg import inv, pinv import numpy as np -from autograd import hessian, value_and_grad, elementwise_grad as egrad, grad +from jax import hessian, value_and_grad, grad, vmap, vjp from autograd.differential_operators import make_jvp_reversemode -from autograd.misc import flatten -import autograd.numpy as anp +from jax.flatten_util import ravel_pytree as flatten +import jax.numpy as jnp from scipy.optimize import minimize, root_scalar from scipy.integrate import trapz @@ -39,6 +39,16 @@ ] +def egrad(g): + # assumes grad w.r.t argnum=1 + def wrapped(params, times): + y, g_vjp = vjp(lambda times: g(params, times), times) + (x_bar,) = g_vjp(jnp.ones_like(y)) + return x_bar + + return wrapped + + class BaseFitter: weights: np.ndarray @@ -384,30 +394,30 @@ def _buffer_bounds(self, bounds: list[tuple[t.Optional[float], t.Optional[float] yield (lb + self._MIN_PARAMETER_VALUE, ub - self._MIN_PARAMETER_VALUE) def _cumulative_hazard(self, params, times): - return -anp.log(self._survival_function(params, times)) + return -jnp.log(self._survival_function(params, times)) def _hazard(self, *args, **kwargs): # pylint: disable=no-value-for-parameter,unexpected-keyword-arg - return egrad(self._cumulative_hazard, argnum=1)(*args, **kwargs) + return egrad(self._cumulative_hazard)(*args, **kwargs) def _density(self, *args, **kwargs): # pylint: disable=no-value-for-parameter,unexpected-keyword-arg - return egrad(self._cumulative_density, argnum=1)(*args, **kwargs) + return egrad(self._cumulative_density)(*args, **kwargs) def _survival_function(self, params, times): - return anp.exp(-self._cumulative_hazard(params, times)) + return jnp.exp(-self._cumulative_hazard(params, times)) def _cumulative_density(self, params, times): return 1 - self._survival_function(params, times) def _log_hazard(self, params, times): hz = self._hazard(params, times) - hz = anp.clip(hz, 1e-50, np.inf) - return anp.log(hz) + hz = jnp.clip(hz, 1e-50, np.inf) + return jnp.log(hz) def _log_1m_sf(self, params, times): # equal to log(cdf), but often easier to express with sf. - return anp.log1p(-self._survival_function(params, times)) + return jnp.log1p(-self._survival_function(params, times)) def _negative_log_likelihood_left_censoring(self, params, Ts, E, entry, weights) -> float: T = Ts[1] @@ -449,8 +459,8 @@ def _negative_log_likelihood_interval_censoring(self, params, Ts, E, entry, weig ll + ( censored_weights - * anp.log( - anp.clip( + * jnp.log( + jnp.clip( self._survival_function(params, censored_starts) - self._survival_function(params, censored_stops), 1e-25, 1 - 1e-25, @@ -1391,23 +1401,23 @@ def _check_values_pre_fitting(self, df, T, E, weights, entries): utils.check_entry_times(T, entries) def _cumulative_hazard(self, params, T, Xs): - return -anp.log(self._survival_function(params, T, Xs)) + return -jnp.log(self._survival_function(params, T, Xs)) def _hazard(self, params, T, Xs): - return egrad(self._cumulative_hazard, argnum=1)(params, T, Xs) # pylint: disable=unexpected-keyword-arg + return egrad(self._cumulative_hazard)(params, T, Xs) # pylint: disable=unexpected-keyword-arg def _log_hazard(self, params, T, Xs): # can be overwritten to improve convergence, see example in WeibullAFTFitter hz = self._hazard(params, T, Xs) - hz = anp.clip(hz, 1e-20, np.inf) - return anp.log(hz) + hz = jnp.clip(hz, 1e-20, np.inf) + return jnp.log(hz) def _log_1m_sf(self, params, T, Xs): # equal to log(cdf), but often easier to express with sf. - return anp.log1p(-self._survival_function(params, T, Xs)) + return jnp.log1p(-self._survival_function(params, T, Xs)) def _survival_function(self, params, T, Xs): - return anp.clip(anp.exp(-self._cumulative_hazard(params, T, Xs)), 1e-12, 1 - 1e-12) + return jnp.clip(jnp.exp(-self._cumulative_hazard(params, T, Xs)), 1e-12, 1 - 1e-12) def _log_likelihood_right_censoring(self, params, Ts: tuple, E, W, entries, Xs) -> float: @@ -1422,7 +1432,7 @@ def _log_likelihood_right_censoring(self, params, Ts: tuple, E, W, entries, Xs) ll = ll + (W * E * log_hz).sum() ll = ll + -(W * cum_hz).sum() ll = ll + (W[non_zero_entries] * delayed_entries).sum() - ll = ll / anp.sum(W) + ll = ll / jnp.sum(W) return ll def _log_likelihood_left_censoring(self, params, Ts, E, W, entries, Xs) -> float: @@ -1438,7 +1448,7 @@ def _log_likelihood_left_censoring(self, params, Ts, E, W, entries, Xs) -> float ll = 0 ll = (W * E * (log_hz - cum_haz - log_1m_sf)).sum() + (W * log_1m_sf).sum() ll = ll + (W[non_zero_entries] * delayed_entries).sum() - ll = ll / anp.sum(W) + ll = ll / jnp.sum(W) return ll def _log_likelihood_interval_censoring(self, params, Ts, E, W, entries, Xs) -> float: @@ -1446,8 +1456,8 @@ def _log_likelihood_interval_censoring(self, params, Ts, E, W, entries, Xs) -> f start, stop = Ts non_zero_entries = entries > 0 observed_deaths = self._log_hazard(params, stop[E], Xs.filter(E)) - self._cumulative_hazard(params, stop[E], Xs.filter(E)) - censored_interval_deaths = anp.log( - anp.clip( + censored_interval_deaths = jnp.log( + jnp.clip( self._survival_function(params, start[~E], Xs.filter(~E)) - self._survival_function(params, stop[~E], Xs.filter(~E)), 1e-25, @@ -1460,7 +1470,7 @@ def _log_likelihood_interval_censoring(self, params, Ts, E, W, entries, Xs) -> f ll = ll + (W[E] * observed_deaths).sum() ll = ll + (W[~E] * censored_interval_deaths).sum() ll = ll + (W[non_zero_entries] * delayed_entries).sum() - ll = ll / anp.sum(W) + ll = ll / jnp.sum(W) return ll @utils.CensoringType.left_censoring @@ -1885,7 +1895,7 @@ def _add_penalty(self, params: dict, neg_ll: float): params_array = params_array[~self._cols_to_not_penalize] if (isinstance(self.penalizer, np.ndarray) or self.penalizer > 0) and self.l1_ratio > 0: penalty = ( - self.l1_ratio * (self.penalizer * anp.abs(params_array)).sum() + self.l1_ratio * (self.penalizer * jnp.abs(params_array)).sum() + 0.5 * (1.0 - self.l1_ratio) * (self.penalizer * (params_array) ** 2).sum() ) diff --git a/lifelines/fitters/exponential_fitter.py b/lifelines/fitters/exponential_fitter.py index 0f001edef..e06c24d6b 100644 --- a/lifelines/fitters/exponential_fitter.py +++ b/lifelines/fitters/exponential_fitter.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- import numpy as np -from autograd import numpy as anp +from jax import numpy as jnp from lifelines.fitters import KnownModelParametricUnivariateFitter @@ -77,4 +77,4 @@ def _cumulative_hazard(self, params, times): def _log_hazard(self, params, times): lambda_ = params[0] - return -anp.log(lambda_) + return -jnp.log(lambda_) diff --git a/lifelines/fitters/mixins.py b/lifelines/fitters/mixins.py index caa35628c..eb27de608 100644 --- a/lifelines/fitters/mixins.py +++ b/lifelines/fitters/mixins.py @@ -30,7 +30,7 @@ def check_assumptions( plot_n_bootstraps: int = 15, columns: Optional[List[str]] = None, raise_on_fail: bool = False, - ) -> None: + ) -> list: """ Use this function to test the proportional hazards assumption. See usage example at https://lifelines.readthedocs.io/en/latest/jupyter_notebooks/Proportional%20hazard%20assumption.html diff --git a/lifelines/tests/test_estimation.py b/lifelines/tests/test_estimation.py index b7c9c2969..35450493c 100644 --- a/lifelines/tests/test_estimation.py +++ b/lifelines/tests/test_estimation.py @@ -1149,8 +1149,8 @@ def test_against_reliability_software(self): class TestExponentialFitter: def test_fit_computes_correct_lambda_(self): - T = np.array([10, 10, 10, 10], dtype=float) - E = np.array([1, 1, 1, 0], dtype=float) + T = np.array([10, 20, 10, 10, 5, 10], dtype=float) + E = np.array([1, 1, 1, 0, 0, 1], dtype=float) enf = ExponentialFitter() enf.fit(T, E) assert abs(enf.lambda_ - (T.sum() / E.sum())) < 1e-4 diff --git a/reqs/base-requirements.txt b/reqs/base-requirements.txt index 816cf6453..102f45235 100644 --- a/reqs/base-requirements.txt +++ b/reqs/base-requirements.txt @@ -1,7 +1,6 @@ -numpy>=1.14.0,<2.0 +numpy>=1.14.0 scipy>=1.2.0 pandas>=1.2.0 matplotlib>=3.0 -autograd>=1.5 -autograd-gamma>=0.3 formulaic>=0.2.2 +jax[cpu]>=0.4.0