Skip to content

Commit

Permalink
revert PR for new fix, fix strata labeller to handle np types
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Oct 29, 2024
1 parent c9d6738 commit 1e0a3d5
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
12 changes: 10 additions & 2 deletions lifelines/fitters/coxph_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2991,7 +2991,11 @@ def __init__(self, strata, strata_values, n_baseline_knots=1, knots=None, *args,

@staticmethod
def _strata_labeler(stratum, i):
return "s%s_phi%d_" % (stratum, i)
try:
return "s%s_phi%d_" % (tuple(str(s) for s in stratum), i)
except:
# singleton
return "s%s_phi%d_" % (stratum, i)

@property
def _fitted_parameter_names(self):
Expand Down Expand Up @@ -3112,7 +3116,11 @@ def __init__(self, strata, strata_values, breakpoints, *args, **kwargs):

@staticmethod
def _strata_labeler(stratum, i):
return "s%s_lambda%d_" % (stratum, i)
try:
return "s%s_lambda%d_" % (tuple(str(s) for s in stratum), i)
except:
# singleton
return "s%s_lambda%d_" % (stratum, i)

@property
def _fitted_parameter_names(self):
Expand Down
8 changes: 4 additions & 4 deletions lifelines/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ def print_summary(self, decimals=2, style=None, **kwargs):
"""
if style is not None:
self._print_specific_style(style)
self._print_specific_style(style, decimals=decimals, **kwargs)
else:
try:
from IPython.display import display

display(self)
except ImportError:
self._ascii_print()
self._ascii_print(decimals=decimals, **kwargs)

def _html_print(self, decimals=2, **kwargs):
print(self.to_html(decimals, **kwargs))
Expand Down Expand Up @@ -835,7 +835,7 @@ def multivariate_logrank_test(
assert abs(Z_j.sum()) < 10e-8, "Sum is not zero." # this should move to a test eventually.

# compute covariance matrix
factor = (((n_i - d_i) / (n_i - 1)).replace([np.inf, np.nan], 1)) * d_i / n_i ** 2
factor = (((n_i - d_i) / (n_i - 1)).replace([np.inf, np.nan], 1)) * d_i / n_i**2
n_ij["_"] = n_i.values
V_ = (n_ij.mul(w_i, axis=0)).mul(np.sqrt(factor), axis="index").fillna(0) # weighted V_
V = -np.dot(V_.T, V_)
Expand Down Expand Up @@ -923,7 +923,7 @@ def proportional_hazard_test(
def compute_statistic(times, resids, n_deaths):
demeaned_times = times - times.mean()
T = (demeaned_times.values[:, None] * resids.values).sum(0) ** 2 / (
n_deaths * (fitted_cox_model.standard_errors_ ** 2) * (demeaned_times ** 2).sum()
n_deaths * (fitted_cox_model.standard_errors_**2) * (demeaned_times**2).sum()
)
return T

Expand Down

0 comments on commit 1e0a3d5

Please sign in to comment.