Skip to content

Commit

Permalink
Imbalanced covariates metrics and sklearn scorer wrapper (#59)
Browse files Browse the repository at this point in the history
* Add arbitrary `kwargs` to metrics to more easily align signatures

Signed-off-by: Ehud-Karavani <[email protected]>

* Add count/fraction of imbalanced covariates metric+scorer

Signed-off-by: Ehud-Karavani <[email protected]>

* Add scikit-learn scorer wrapper for propensity models

Signed-off-by: Ehud-Karavani <[email protected]>

* Add name to time-variable (pd.Series) in NHEFS survival data

Signed-off-by: Ehud-Karavani <[email protected]>

* Bump version: 0.9.5

Signed-off-by: Ehud-Karavani <[email protected]>

---------

Signed-off-by: Ehud-Karavani <[email protected]>
  • Loading branch information
ehudkr authored Jun 22, 2023
1 parent 9aea4dc commit a494221
Show file tree
Hide file tree
Showing 12 changed files with 178 additions and 5 deletions.
2 changes: 1 addition & 1 deletion causallib/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.9.4"
__version__ = "0.9.5"
1 change: 1 addition & 0 deletions causallib/contrib/sklearn_scorer_wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .sklearn_scorer_wrapper import SKLearnScorerWrapper
28 changes: 28 additions & 0 deletions causallib/contrib/sklearn_scorer_wrapper/sklearn_scorer_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from causallib.metrics.scorers import PropensityScorerBase


class SKLearnScorerWrapper(PropensityScorerBase):
def __init__(self, score_func, sign=None, **kwargs):
super().__init__(
score_func=score_func,
sign=1, # This keeps original scorer sign
**kwargs
)

def _score(self, estimator, X, a, y=None, sample_weight=None, **kwargs):
learner = self._extract_sklearn_estimator(estimator)
score = self._score_func(learner, X, a, sample_weight=sample_weight)
return score

@staticmethod
def _extract_sklearn_estimator(estimator):
if hasattr(estimator, "best_estimator_"):
# Causallib's wrapper around GridSearchCV
return estimator.best_estimator_.learner
if hasattr(estimator, "learner"):
return estimator.learner
raise AttributeError(
f"Could not extract an sklearn estimator from {estimator},"
f"which has the following attributes:\n"
f"{list(estimator.__dict__.keys())}"
)
75 changes: 75 additions & 0 deletions causallib/contrib/tests/test_sklearn_scorer_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import unittest

import pandas as pd

from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_classification
from sklearn.utils import Bunch
from sklearn.metrics import get_scorer

from causallib.estimation import IPW
from causallib.model_selection import GridSearchCV

from causallib.contrib.sklearn_scorer_wrapper import SKLearnScorerWrapper


class TestSKLearnScorerWrapper(unittest.TestCase):
@classmethod
def setUpClass(cls):
N = 500
X, a = make_classification(
n_samples=N,
n_features=5,
n_informative=5,
n_redundant=0,
random_state=42,
)
X = pd.DataFrame(X)
a = pd.Series(a)
cls.data = Bunch(X=X, a=a, y=a)

learner = LogisticRegression()
ipw = IPW(learner)
ipw.fit(X, a)
# cls.learner = learner
cls.estimator = ipw

def test_agreement_with_sklearn(self):
scorer_names = [
"accuracy",
"average_precision",
"neg_brier_score",
"f1",
"neg_log_loss",
"precision",
"recall",
"roc_auc",
]
for scorer_name in scorer_names:
with self.subTest(f"Test scorer {scorer_name}"):
scorer = get_scorer(scorer_name)
score = scorer(self.estimator.learner, self.data.X, self.data.a)

causallib_adapted_scorer = SKLearnScorerWrapper(scorer)
causallib_score = causallib_adapted_scorer(
self.estimator, self.data.X, self.data.a, self.data.y
)

self.assertAlmostEqual(causallib_score, score)

def test_hyperparameter_search_model(self):
scorer = SKLearnScorerWrapper(get_scorer("roc_auc"))
param_grid = dict(
clip_min=[0.2, 0.3],
learner__C=[0.1, 1],
)
model = GridSearchCV(
self.estimator,
param_grid=param_grid,
scoring=scorer,
cv=3,
)
model.fit(self.data.X, self.data.a, self.data.y)

score = scorer(model, self.data.X, self.data.a, self.data.y)
self.assertGreaterEqual(score, model.best_score_)
1 change: 1 addition & 0 deletions causallib/datasets/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def load_nhefs_survival(augment=True, onehot=True):
nhefs_all = load_nhefs(raw=True)[0]
t = (nhefs_all["yrdth"] - 83) * 12 + nhefs_all["modth"]
t = t.fillna(120)
t = t.rename("longevity")
y = nhefs_all["death"]

nhefs = load_nhefs(augment=augment, onehot=onehot, restrict=False)
Expand Down
1 change: 1 addition & 0 deletions causallib/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .propensity_metrics import weighted_roc_curve_error, expected_roc_curve_error
from .propensity_metrics import ici_error
from .weight_metrics import covariate_balancing_error
from .weight_metrics import covariate_imbalance_count_error
from .outcome_metrics import balanced_residuals_error

from .scorers import get_scorer, get_scorer_names
1 change: 1 addition & 0 deletions causallib/metrics/outcome_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def balanced_residuals_error(
y_true, y_pred, a_true,
distance_metric=abs_standardized_mean_difference,
distance_metric_kwargs=None,
**kwargs,
):
"""Computes how different is the residuals distribution of the control group
from that of the treatment group.
Expand Down
4 changes: 2 additions & 2 deletions causallib/metrics/propensity_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import statsmodels.api as sm


def weighted_roc_auc_error(y_true, y_pred, sample_weight):
def weighted_roc_auc_error(y_true, y_pred, sample_weight, **kwargs):
"""
Compute the squared error between the balanced (e.g. IP-weighted) ROC AUC
to the diagonal, i.e. AUC=0.5.
Expand Down Expand Up @@ -64,7 +64,7 @@ def expected_roc_auc_error(y_true, y_pred, **kwargs):
return score


def weighted_roc_curve_error(y_true, y_pred, sample_weight, agg=np.max):
def weighted_roc_curve_error(y_true, y_pred, sample_weight, agg=np.max, **kwargs):
"""Compute the absolute differences between the balanced (e.g. IP-weighted) ROC curve
and the diagonal x=y curve.
Since difference in curves results in a multiple values (each point along the curve),
Expand Down
5 changes: 5 additions & 0 deletions causallib/metrics/scorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,13 @@ def _score(self, estimator, X, a, y, sample_weight=None, **kwargs):
weight_metrics.covariate_balancing_error, -1,
)

covariate_imbalance_count_error_scorer = WeightScorerBase(
weight_metrics.covariate_imbalance_count_error, -1,
)

_WEIGHT_SCORERS = dict(
covariate_balancing_error=covariate_balancing_error_scorer,
covariate_imbalance_count_error=covariate_imbalance_count_error_scorer,
)


Expand Down
15 changes: 13 additions & 2 deletions causallib/metrics/weight_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ def calculate_distribution_distance_for_single_feature(
return distribution_distance



def covariate_balancing_error(X, a, sample_weight, agg=max):
def covariate_balancing_error(X, a, sample_weight, agg=max, **kwargs):
"""Computes the weighted (i.e. balanced) absolute standardized mean difference
of every covariate in X.
Expand All @@ -120,3 +119,15 @@ def covariate_balancing_error(X, a, sample_weight, agg=max):
weighted_asmds = asmds["weighted"]
score = agg(weighted_asmds)
return score


def covariate_imbalance_count_error(
X, a, sample_weight, threshold=0.1, fraction=True
) -> float:
asmds = calculate_covariate_balance(X, a, sample_weight, metric="abs_smd")
weighted_asmds = asmds["weighted"]
is_violating = weighted_asmds > threshold
score = sum(is_violating)
if fraction:
score /= is_violating.shape[0]
return score
23 changes: 23 additions & 0 deletions causallib/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from causallib.metrics import weighted_roc_curve_error, expected_roc_curve_error
from causallib.metrics import ici_error
from causallib.metrics import covariate_balancing_error
from causallib.metrics import covariate_imbalance_count_error
from causallib.metrics import balanced_residuals_error


Expand Down Expand Up @@ -275,6 +276,28 @@ def test_covariate_balancing(self):
expected /= 2 # Two features, the second has 0 ASMD
self.assertAlmostEqual(score, expected, places=4)

def test_covariate_imbalance_count(self):
with self.subTest("High violation threshold"):
score = covariate_imbalance_count_error(
self.data["X"], self.data["a"], self.data["w"],
threshold=10,
)
self.assertEqual(score, 0)

with self.subTest("Low violation threshold"):
score = covariate_imbalance_count_error(
self.data["X"], self.data["a"], self.data["w"],
threshold=-0.1, fraction=False,
)
self.assertEqual(score, self.data["X"].shape[1])

with self.subTest("Fraction violation threshold"):
score = covariate_imbalance_count_error(
self.data["X"], self.data["a"], self.data["w"],
threshold=0.1, fraction=True,
)
self.assertEqual(score, 1/2)


class TestOutcomeMetrics(unittest.TestCase):
def test_balanced_residuals(self):
Expand Down
27 changes: 27 additions & 0 deletions causallib/tests/test_scorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,33 @@ def test_scoring_with_kwargs(self):
)
self.assertLess(-score_agg_min, -score_default) # Default is max. Scores are negative metrics values

def test_covariate_imbalance_count_error(self):
X = pd.DataFrame(
{
"imbalanced": [5, 5, 5, 5, 4, 6, 0, 0, 0, 0, -1, 1],
"balanced": [5, 5, 5, 5, 4, 6, 5, 5, 5, 5, 4, 6],
}
)
a = pd.Series([1] * 6 + [0] * 6)
ipw = IPW(LogisticRegression())
ipw.fit(X, a)

with self.subTest("Count score"):
scorer = get_scorer("covariate_imbalance_count_error")
score = scorer(ipw, X, a, y_true=None, fraction=False)
self.assertEqual(score, -1)

with self.subTest("Fractional score"):
scorer = get_scorer("covariate_imbalance_count_error")
score = scorer(ipw, X, a, y_true=None, fraction=True)
self.assertEqual(score, -0.5)

with self.subTest("Non-default threshold"):
threshold = 10 # Should result in not violating features
scorer = get_scorer("covariate_imbalance_count_error")
score = scorer(ipw, X, a, y_true=None, threshold=threshold)
self.assertEqual(score, 0)


class TestOutcomeScorer(BaseTestScorer):
@classmethod
Expand Down

0 comments on commit a494221

Please sign in to comment.