Skip to content

Latest commit

 

History

History
192 lines (142 loc) · 14.2 KB

TableShift_README.md

File metadata and controls

192 lines (142 loc) · 14.2 KB

status status status

tableshift logo

TableShift

TableShift is a benchmarking library for machine learning with tabular data under distribution shift.

You can read more about TableShift at tableshift.org or read the full paper (published in NeurIPS 2023 Datasets & Benchmarks Track) on arxiv. If you use the benchmark in your research, please cite the paper:

@article{gardner2023tableshift,
  title={Benchmarking Distribution Shift in Tabular Data with TableShift},
  author={Gardner, Josh and Popovic, Zoran and Schmidt, Ludwig},
  journal={Advances in Neural Information Processing Systems},
  year={2023}
}

If you find an issue, please file a GitHub issue.

Quickstart

Environment setup: We recommend the use of docker with TableShift. Our dataset construction and model pipelines have a diverse set of dependencies that included non-Python files required to make some libraries work. As a result, we recommend you use the provided Docker image for using the benchmark, and suggest forking this Docker image for your own development.

# fetch the docker image
docker pull ghcr.io/jpgard/tableshift:latest

# run it to test your setup; this automatically launches examples/run_expt.py
docker run ghcr.io/jpgard/tableshift:latest --model xgb

# optionally, use the container interactively
docker run -it --entrypoint=/bin/bash ghcr.io/jpgard/tableshift:latest

Conda: We recommend using Docker with TableShift when running training or using any of the pretrained modeling code, as the libraries used for training contain a complex and subtle set of dependencies that can be difficult to configure outside Docker. However, Conda might provide a more lightweight environment for basic development and exploration with TableShift, so we describe how to set up Conda here.

To create a conda environment, simply clone this repo, enter the root directory, and run the following commands to create and test a local execution environment:

# set up the environment
conda env create -f environment.yml
conda activate tableshift
# test the install by running the training script
python examples/run_expt.py

The final line above will print some detailed logging output as the script executes. When you see training completed! test accuracy: 0.6221 your environment is ready to go! (Accuracy may vary slightly due to randomness.)

Accessing datasets: If you simply want to load and use a standard version of one of the public TableShift datasets, it's as simple as:

from tableshift import get_dataset

dataset_name = "diabetes_readmission"
dset = get_dataset(dataset_name)

The full list of identifiers for all available datasets is below; simply swap any of these for dataset_name to access the relevant data.

If you would like to use a dataset without a domain split, replace get_dataset() with get_iid_dataset().

The call to get_dataset() returns a TabularDataset that you can use to easily load tabular data in several formats, including Pandas DataFrame and PyTorch DataLoaders:

# Fetch a pandas DataFrame of the training set
X_tr, y_tr, _, _ = dset.get_pandas("train")

# Fetch and use a pytorch DataLoader
train_loader = dset.get_dataloader("train", batch_size=1024)

for X, y, _, _ in train_loader:
    ...

For all TableShift datasets, the following splits are available: train, validation, id_test, ood_validation, ood_test.

For IID datasets (those without a domain split) these splits are available: train, validation, test.

There is a complete example of a training script in examples/run_expt.py.

Benchmark Dataset Availability

tl;dr: if you want to get started exploring ASAP, use datasets marked as " public" below.

All of the datasets used in the TableShift benchmark are either publicly available or provide open credentialized access. The datasets with open credentialized access require signing a data use agreement; as a result, some datasets must be manually fetched and stored locally. TableShift makes this process as simple as possible.

A list of datasets, their names in TableShift, and the corresponding access levels are below. The string identifier is the value that should be passed as the experiment parameter to get_dataset() or the --experiment flag of run_expt.py and other training scripts.

Dataset String Identifier Availability Source
Voting anes Public Credentialized Access (source) American National Election Studies (ANES)
ASSISTments assistments Public Kaggle
Childhood Lead nhanes_lead Public National Health and Nutrition Examination Survey (NHANES)
College Scorecard college_scorecard Public College Scorecard
Diabetes brfss_diabetes Public Behavioral Risk Factor Surveillance System (BRFSS)
Food Stamps acsfoodstamps Public American Community Survey (via folktables
HELOC heloc Public Credentialized Access (source) FICO
Hospital Readmission diabetes_readmission Public UCI
Hypertension brfss_blood_pressure Public Behavioral Risk Factor Surveillance System (BRFSS)
ICU Length of Stay mimic_extract_los_3 Public Credentialized Access (source) MIMIC-iii via MIMIC-Extract
ICU Mortality mimic_extract_mort_hosp Public Credentialized Access (source) MIMIC-iii via MIMIC-Extract
Income acsincome Public American Community Survey (via folktables
Public Health Insurance acspubcov Public American Community Survey (via folktables
Sepsis physionet Public Physionet
Unemployment acsunemployment Public American Community Survey (via folktables

Note that details on the data source, which files to load, and the feature codings are provided in the TableShift source code for each dataset and data source (see data_sources.py and the tableshift.datasets module).

For additional, non-benchmark datasets (possibly with only IID splits, not a distribution shift), see tableshift.configs.non_benchmark.configs.py

Dataset Details

More information about the tasks, datasets, splitting variables, data sources, and motivation are available in the TableShift paper; we provide a summary below.

Task Target Shift Domain Baseline Total Observations
ASSISTments Next Answer Correct School -34.5% 2,667,776
College Scorecard Low Degree Completion Rate Institution Type -11.2% 124,699
ICU Mortality ICU patient expires in hospital during current visit Insurance Type -6.3% 23,944
Hospital Readmission 30-day readmission of diabetic hospital patients Admission source -5.9% 99,493
Diabetes Diabetes diagnosis Race -4.5% 1,444,176
ICU Length of Stay Length of stay >= 3 hrs in ICU Insurance Type -3.4% 23,944
Voting Voted in U.S. presidential election Geographic Region -2.6% 8280
Food Stamps Food stamp recipiency in past year for households with child Geographic Region -2.4% 840,582
Unemployment Unemployment for non-social security-eligible adults Education Level -1.3% 1,795,434
Income Income >= 56k for employed adults Geographic Region -1.3% 1,664,500
HELOC Repayment of Home Equity Line of Credit loan Est. third-party risk level -22.6% 10,459
Public Health Insurance Coverage of non-Medicare eligible low-income individuals Disability Status -14.5% 5,916,565
Sepsis Sepsis onset within next 6hrs for hospital patients Length of Stay -6.0% 1,552,210
Childhood Lead Blood lead levels above CDC Blood Level Reference Value Poverty level -5.1% 27,499
Hypertension Hypertension diagnosis for high-risk age (50+) BMI Category -4.4% 846,761

A Self-Contained Training Example

A sample training script is located at examples/run_expt.py. However, training a scikit-learn model is as simple as:

from tableshift import get_dataset
from sklearn.ensemble import GradientBoostingClassifier

dset = get_dataset("diabetes_readmission")
X_train, y_train, _, _ = dset.get_pandas("train")

# Train
estimator = GradientBoostingClassifier()
trained_estimator = estimator.fit(X_train, y_train)

# Test
for split in ('id_test', 'ood_test'):
    X, y, _, _ = dset.get_pandas(split)
    preds = estimator.predict(X)
    acc = (preds == y).mean()
    print(f'accuracy on split {split} is: {acc:.3f}')

The code should output the following:

accuracy on split id_test is: 0.655
accuracy on split ood_test is: 0.619

Now, please close that domain gap!

Non-benchmark datasets

We also have several tabular datasets available in TableShift which are not part of the official TableShift benchmark, but which still may be useful for tabular data research. We are continuously adding datasets to the package. These datasets support all of the same functionality provided for the TableShift benchmark datasets, but we did not include these as an official part of the TableShift benchmark -- they are not an official part of the TableShift package and are mostly intended for convenience and for our own internal use.

For a list of the non-benchmark datasets, see the file tableshift.configs.non_benchmark_configs.py.