diff --git a/causallib/__init__.py b/causallib/__init__.py index e94731c0..f8c6ac7f 100644 --- a/causallib/__init__.py +++ b/causallib/__init__.py @@ -1 +1 @@ -__version__ = "0.9.4" +__version__ = "0.9.5" diff --git a/causallib/contrib/sklearn_scorer_wrapper/__init__.py b/causallib/contrib/sklearn_scorer_wrapper/__init__.py new file mode 100644 index 00000000..f7b14608 --- /dev/null +++ b/causallib/contrib/sklearn_scorer_wrapper/__init__.py @@ -0,0 +1 @@ +from .sklearn_scorer_wrapper import SKLearnScorerWrapper diff --git a/causallib/contrib/sklearn_scorer_wrapper/sklearn_scorer_wrapper.py b/causallib/contrib/sklearn_scorer_wrapper/sklearn_scorer_wrapper.py new file mode 100644 index 00000000..3bf25210 --- /dev/null +++ b/causallib/contrib/sklearn_scorer_wrapper/sklearn_scorer_wrapper.py @@ -0,0 +1,28 @@ +from causallib.metrics.scorers import PropensityScorerBase + + +class SKLearnScorerWrapper(PropensityScorerBase): + def __init__(self, score_func, sign=None, **kwargs): + super().__init__( + score_func=score_func, + sign=1, # This keeps original scorer sign + **kwargs + ) + + def _score(self, estimator, X, a, y=None, sample_weight=None, **kwargs): + learner = self._extract_sklearn_estimator(estimator) + score = self._score_func(learner, X, a, sample_weight=sample_weight) + return score + + @staticmethod + def _extract_sklearn_estimator(estimator): + if hasattr(estimator, "best_estimator_"): + # Causallib's wrapper around GridSearchCV + return estimator.best_estimator_.learner + if hasattr(estimator, "learner"): + return estimator.learner + raise AttributeError( + f"Could not extract an sklearn estimator from {estimator}," + f"which has the following attributes:\n" + f"{list(estimator.__dict__.keys())}" + ) \ No newline at end of file diff --git a/causallib/contrib/tests/test_sklearn_scorer_wrapper.py b/causallib/contrib/tests/test_sklearn_scorer_wrapper.py new file mode 100644 index 00000000..da12e056 --- /dev/null +++ b/causallib/contrib/tests/test_sklearn_scorer_wrapper.py @@ -0,0 +1,75 @@ +import unittest + +import pandas as pd + +from sklearn.linear_model import LogisticRegression +from sklearn.datasets import make_classification +from sklearn.utils import Bunch +from sklearn.metrics import get_scorer + +from causallib.estimation import IPW +from causallib.model_selection import GridSearchCV + +from causallib.contrib.sklearn_scorer_wrapper import SKLearnScorerWrapper + + +class TestSKLearnScorerWrapper(unittest.TestCase): + @classmethod + def setUpClass(cls): + N = 500 + X, a = make_classification( + n_samples=N, + n_features=5, + n_informative=5, + n_redundant=0, + random_state=42, + ) + X = pd.DataFrame(X) + a = pd.Series(a) + cls.data = Bunch(X=X, a=a, y=a) + + learner = LogisticRegression() + ipw = IPW(learner) + ipw.fit(X, a) + # cls.learner = learner + cls.estimator = ipw + + def test_agreement_with_sklearn(self): + scorer_names = [ + "accuracy", + "average_precision", + "neg_brier_score", + "f1", + "neg_log_loss", + "precision", + "recall", + "roc_auc", + ] + for scorer_name in scorer_names: + with self.subTest(f"Test scorer {scorer_name}"): + scorer = get_scorer(scorer_name) + score = scorer(self.estimator.learner, self.data.X, self.data.a) + + causallib_adapted_scorer = SKLearnScorerWrapper(scorer) + causallib_score = causallib_adapted_scorer( + self.estimator, self.data.X, self.data.a, self.data.y + ) + + self.assertAlmostEqual(causallib_score, score) + + def test_hyperparameter_search_model(self): + scorer = SKLearnScorerWrapper(get_scorer("roc_auc")) + param_grid = dict( + clip_min=[0.2, 0.3], + learner__C=[0.1, 1], + ) + model = GridSearchCV( + self.estimator, + param_grid=param_grid, + scoring=scorer, + cv=3, + ) + model.fit(self.data.X, self.data.a, self.data.y) + + score = scorer(model, self.data.X, self.data.a, self.data.y) + self.assertGreaterEqual(score, model.best_score_) diff --git a/causallib/datasets/data_loader.py b/causallib/datasets/data_loader.py index 14295565..a12f87c5 100644 --- a/causallib/datasets/data_loader.py +++ b/causallib/datasets/data_loader.py @@ -128,6 +128,7 @@ def load_nhefs_survival(augment=True, onehot=True): nhefs_all = load_nhefs(raw=True)[0] t = (nhefs_all["yrdth"] - 83) * 12 + nhefs_all["modth"] t = t.fillna(120) + t = t.rename("longevity") y = nhefs_all["death"] nhefs = load_nhefs(augment=augment, onehot=onehot, restrict=False) diff --git a/causallib/metrics/__init__.py b/causallib/metrics/__init__.py index 5ac08697..86fe5f3e 100644 --- a/causallib/metrics/__init__.py +++ b/causallib/metrics/__init__.py @@ -2,6 +2,7 @@ from .propensity_metrics import weighted_roc_curve_error, expected_roc_curve_error from .propensity_metrics import ici_error from .weight_metrics import covariate_balancing_error +from .weight_metrics import covariate_imbalance_count_error from .outcome_metrics import balanced_residuals_error from .scorers import get_scorer, get_scorer_names diff --git a/causallib/metrics/outcome_metrics.py b/causallib/metrics/outcome_metrics.py index dc21f484..44128fcb 100644 --- a/causallib/metrics/outcome_metrics.py +++ b/causallib/metrics/outcome_metrics.py @@ -45,6 +45,7 @@ def balanced_residuals_error( y_true, y_pred, a_true, distance_metric=abs_standardized_mean_difference, distance_metric_kwargs=None, + **kwargs, ): """Computes how different is the residuals distribution of the control group from that of the treatment group. diff --git a/causallib/metrics/propensity_metrics.py b/causallib/metrics/propensity_metrics.py index e465d136..898be3f1 100644 --- a/causallib/metrics/propensity_metrics.py +++ b/causallib/metrics/propensity_metrics.py @@ -14,7 +14,7 @@ import statsmodels.api as sm -def weighted_roc_auc_error(y_true, y_pred, sample_weight): +def weighted_roc_auc_error(y_true, y_pred, sample_weight, **kwargs): """ Compute the squared error between the balanced (e.g. IP-weighted) ROC AUC to the diagonal, i.e. AUC=0.5. @@ -64,7 +64,7 @@ def expected_roc_auc_error(y_true, y_pred, **kwargs): return score -def weighted_roc_curve_error(y_true, y_pred, sample_weight, agg=np.max): +def weighted_roc_curve_error(y_true, y_pred, sample_weight, agg=np.max, **kwargs): """Compute the absolute differences between the balanced (e.g. IP-weighted) ROC curve and the diagonal x=y curve. Since difference in curves results in a multiple values (each point along the curve), diff --git a/causallib/metrics/scorers.py b/causallib/metrics/scorers.py index 04596c32..5923e4b3 100644 --- a/causallib/metrics/scorers.py +++ b/causallib/metrics/scorers.py @@ -92,8 +92,13 @@ def _score(self, estimator, X, a, y, sample_weight=None, **kwargs): weight_metrics.covariate_balancing_error, -1, ) +covariate_imbalance_count_error_scorer = WeightScorerBase( + weight_metrics.covariate_imbalance_count_error, -1, +) + _WEIGHT_SCORERS = dict( covariate_balancing_error=covariate_balancing_error_scorer, + covariate_imbalance_count_error=covariate_imbalance_count_error_scorer, ) diff --git a/causallib/metrics/weight_metrics.py b/causallib/metrics/weight_metrics.py index f152c081..e151bce6 100644 --- a/causallib/metrics/weight_metrics.py +++ b/causallib/metrics/weight_metrics.py @@ -101,8 +101,7 @@ def calculate_distribution_distance_for_single_feature( return distribution_distance - -def covariate_balancing_error(X, a, sample_weight, agg=max): +def covariate_balancing_error(X, a, sample_weight, agg=max, **kwargs): """Computes the weighted (i.e. balanced) absolute standardized mean difference of every covariate in X. @@ -120,3 +119,15 @@ def covariate_balancing_error(X, a, sample_weight, agg=max): weighted_asmds = asmds["weighted"] score = agg(weighted_asmds) return score + + +def covariate_imbalance_count_error( + X, a, sample_weight, threshold=0.1, fraction=True +) -> float: + asmds = calculate_covariate_balance(X, a, sample_weight, metric="abs_smd") + weighted_asmds = asmds["weighted"] + is_violating = weighted_asmds > threshold + score = sum(is_violating) + if fraction: + score /= is_violating.shape[0] + return score diff --git a/causallib/tests/test_metrics.py b/causallib/tests/test_metrics.py index c7efe4f2..2a57418f 100644 --- a/causallib/tests/test_metrics.py +++ b/causallib/tests/test_metrics.py @@ -8,6 +8,7 @@ from causallib.metrics import weighted_roc_curve_error, expected_roc_curve_error from causallib.metrics import ici_error from causallib.metrics import covariate_balancing_error +from causallib.metrics import covariate_imbalance_count_error from causallib.metrics import balanced_residuals_error @@ -275,6 +276,28 @@ def test_covariate_balancing(self): expected /= 2 # Two features, the second has 0 ASMD self.assertAlmostEqual(score, expected, places=4) + def test_covariate_imbalance_count(self): + with self.subTest("High violation threshold"): + score = covariate_imbalance_count_error( + self.data["X"], self.data["a"], self.data["w"], + threshold=10, + ) + self.assertEqual(score, 0) + + with self.subTest("Low violation threshold"): + score = covariate_imbalance_count_error( + self.data["X"], self.data["a"], self.data["w"], + threshold=-0.1, fraction=False, + ) + self.assertEqual(score, self.data["X"].shape[1]) + + with self.subTest("Fraction violation threshold"): + score = covariate_imbalance_count_error( + self.data["X"], self.data["a"], self.data["w"], + threshold=0.1, fraction=True, + ) + self.assertEqual(score, 1/2) + class TestOutcomeMetrics(unittest.TestCase): def test_balanced_residuals(self): diff --git a/causallib/tests/test_scorers.py b/causallib/tests/test_scorers.py index 0cab497a..e36cea7c 100644 --- a/causallib/tests/test_scorers.py +++ b/causallib/tests/test_scorers.py @@ -141,6 +141,33 @@ def test_scoring_with_kwargs(self): ) self.assertLess(-score_agg_min, -score_default) # Default is max. Scores are negative metrics values + def test_covariate_imbalance_count_error(self): + X = pd.DataFrame( + { + "imbalanced": [5, 5, 5, 5, 4, 6, 0, 0, 0, 0, -1, 1], + "balanced": [5, 5, 5, 5, 4, 6, 5, 5, 5, 5, 4, 6], + } + ) + a = pd.Series([1] * 6 + [0] * 6) + ipw = IPW(LogisticRegression()) + ipw.fit(X, a) + + with self.subTest("Count score"): + scorer = get_scorer("covariate_imbalance_count_error") + score = scorer(ipw, X, a, y_true=None, fraction=False) + self.assertEqual(score, -1) + + with self.subTest("Fractional score"): + scorer = get_scorer("covariate_imbalance_count_error") + score = scorer(ipw, X, a, y_true=None, fraction=True) + self.assertEqual(score, -0.5) + + with self.subTest("Non-default threshold"): + threshold = 10 # Should result in not violating features + scorer = get_scorer("covariate_imbalance_count_error") + score = scorer(ipw, X, a, y_true=None, threshold=threshold) + self.assertEqual(score, 0) + class TestOutcomeScorer(BaseTestScorer): @classmethod