❗ This repository is heavily based on the TableShift GitHub repository. Our code is built on top of TableShift and code from Ricardo Sandoval and Hardt & Kim (2023).
This is code to reproduce experiments in the paper:
Vivian Y. Nastl and Moritz Hardt. "Do causal predictors generalize better to new domains?", 2024.
Simply clone the repo, enter the root directory and create a local conda environment.
git clone https://github.com/socialfoundations/causal-features.git
# set up the environment
conda env create -f environment.yml
Run the following commands to test the local execution environment:
conda env create -f environment.yml
conda activate tableshift
# test the install by running the training script
python examples/run_expt.py
The final line above will print some detailed logging output as the script executes. When you see training completed! test accuracy: 0.6221
your environment is ready to go! (Accuracy may vary slightly due to randomness.)
The training script we run is located at experiments_causal/run_experiment.py
.
It takes the following arguments:
experiment
(experiment to run)model
(model to use)cache_dir
(directory to cache raw data files to)save_dir
(directory to save result files to)
The full list of model names is given below. For more details on each algorithm, see TableShift.
Model | Name in TableShift |
---|---|
XGBoost | xgb |
LightGBM | lightgbm |
SAINT | saint |
NODE | node |
Group DRO | group_dro |
MLP | mlp |
Tabular ResNet | resnet |
Adversarial Label DRO | aldro |
CORAL | deepcoral |
MMD | mmd |
DRO | dro |
DANN | dann |
TabTransformer | tabtransformer |
MixUp | mixup |
Label Group DRO | label_group_dro |
IRM | irm |
VREX | vrex |
FT-Transformer | ft_transformer |
IB-IRM | ib_irm |
CausIRL CORAL | causirl_coral |
CausIRL MMD | causirl_mmd |
AND-Mask | and_mask |
All experiments were run as jobs submitted to a centralized cluster, running the open-source HTCondor scheduler.
The relevant script launching the jobs is located at experiments_causal/launch_experiments.py
.
We provide the raw results of our experiments in the folder experiments_causal/results/
. They contain a single json
file for each task, feature selection and trained model.
Use the following Python scripts:
- Main result:
- Figure in introduction:
experiments_causal/plot_paper_introduction_figure.py
- Figures in section "Empirical results":
experiments_causal/plot_paper_figures.py
- Figure in introduction:
- Appendix:
- Main results:
experiments_causal/plot_paper_appendix_figures.py
,experiments_causal/plot_paper_appendix_figures_extra.py
,experiments_causal/plot_paper_appendix_figures_extra2.py
- Anti-causal features:
experiments_causal/plot_paper_appendix_figures.py
- Causal machine learning:
experiments_causal/plot_add_on_causalml.py
- Causal discovery:
experiments_causal/plot_add_on_causal_discovery.py
- Random subsets:
experiments_causal/plot_add_on_random_subsets.py
- Ablation study:
experiments_causal/plot_experiment_ablation.py
- Empirical results across machine learning models:
experiments_causal/plot_add_on_models.py
- Synthetic experiments:
experiments_causal/synthetic_experiments.ipynb
- Main results:
The datasets in our paper are either publicly available, or provide open credentialized access.
The datasets with open credentialized access require signing a data use agreement. For the tasks ICU Mortality
and ICU Length of Stay
, it is required to complete training CITI Data or Specimens Only Research, as they contain sensitive personal information.
Hence, these datasets must be manually fetched and stored locally.
A list of datasets, their names in our code, and the corresponding access levels are below. The string identifier is the value that should be passed as the experiment
parameter to the --experiment
flag of experiments_causal/run_experiment.py
.
The causal, arguably causal, and anti-causal feature sets are obtained by appending _causal
, _arguablycausal
and _anticausal
to the string identifier. Combined causal and anti-causal features have the appendix _causal_anticausal
. If they exist, one obtains the estimated parents from causal discovery algorithms by appending the abbreviation of the algorithms in lower letters. For example, acsincome_pc
. Random subsets are indexed from 0 to 500, and callable via the appendix _random_test_{index}
.
Tasks | String Identifier | Availability | Source | Preprocessing |
---|---|---|---|---|
Voting | anes |
Public Credentialized Access (source) | American National Election Studies (ANES) | TableShift |
ASSISTments | assistments |
Public | Kaggle | TableShift |
Childhood Lead | nhanes_lead |
Public | National Health and Nutrition Examination Survey (NHANES) | TableShift |
College Scorecard | college_scorecard |
Public | College Scorecard | TableShift |
Diabetes | brfss_diabetes |
Public | Behavioral Risk Factor Surveillance System (BRFSS) | TableShift |
Food Stamps | acsfoodstamps |
Public | American Community Survey (via folktables) | |
Hospital Readmission | diabetes_readmission |
Public | UCI | TableShift |
Hypertension | brfss_blood_pressure |
Public | Behavioral Risk Factor Surveillance System (BRFSS) | TableShift |
ICU Length of Stay | mimic_extract_los_3 |
Public Credentialized Access (source) | MIMIC-iii via MIMIC-Extract | TableShift |
ICU Mortality | mimic_extract_mort_hosp |
Public Credentialized Access (source) | MIMIC-iii via MIMIC-Extract | TableShift |
Income | acsincome |
Public | American Community Survey (via folktables) | TableShift |
Public Health Insurance | acspubcov |
Public | American Community Survey (via folktables) | TableShift |
Sepsis | physionet |
Public | Physionet | TableShift |
Unemployment | acsunemployment |
Public | American Community Survey (via folktables) | TableShift |
Utilization | meps |
Public (source) | Medical expenditure panel survey | Hardt & Kim (2023) |
Poverty | sipp |
Public (source, source) | Survey of income and program participation | Hardt & Kim (2023) |
TableShift includes the preprocessing of the data files in their implementation. For the tasks Utilization
and Poverty
, follow the instructions provided by Hardt & Kim (2023) in backward_predictor/README.md
.
We list in the following which files/folders we changed for our experiments:
- created folder
experiments_causal
with python scripts to run experiments, launch experiments on a cluster, and plot figures for the paper - created folder
backward_prediction
with preprocessing files adapted from Hardt & Kim (2023) withbackward_predictor/sipp/data/data_cleaning.ipynb
© Ricardo Sandoval, 2024 - added tasks
meps
andsipp
, as well as feature selections of all tasks in their respective Python scripts in the foldertableshift/datasets
- added data source for
meps
andsipp
intableshift/core/data_source.py
- added tasks
meps
andsipp
, as well as feature selections of all tasks intableshift/core/tasks.py
- added configurations for tasks and their feature selections in
tableshift/configs/non_benchmark_configs.py
- added models
ib_erm
,ib_irm
,causirl_coral
,causirl_mmd
andand_mask
intableshift/models
, adapted from Gulrajani & Lopez-Paz (2021) - added configurations for hyperparameters of added models in
tableshift/configs/hparams.py
- added computation of balanced accuracy in
tableshift/models/torchutils.py
and adaptedtableshift/models/compat.py
accordingly - minor fixes in
tableshift/core/features.py
,tableshift/core/tabular_dataset.py
andtableshift/models/training.py
- added the packages
paretoset==1.2.3
andseaborn==0.13.0
inrequirements.txt
This repository contains code and supplementary materials for the following preprint:
@misc{nastl2024predictors,
title={Do causal predictors generalize better to new domains?},
author={Vivian Y. Nastl and Moritz Hardt},
year={2024},
eprint={2402.09891},
archivePrefix={arXiv},
primaryClass={cs.LG}
}