Tsinghua University, Columbia University
WhyShift
is a python package that provides a benchmark with various specified distribution shift patterns on real-world tabular data. Our testbed highlights the importance of future research that builds an understanding of how distributions differ. For more details, please refer to our paper.
If you find this repository useful in your research, please cite the following paper:
@inproceedings{liu2023need,
title={On the Need for a Language Describing Distribution Shifts: Illustrations on Tabular Datasets},
author={Jiashuo Liu and Tianyu Wang and Peng Cui and Hongseok Namkoong},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
year={2023}
}
- Dataset Access
- Python Package:
whyshift
- Different Distribution Shift Patterns
- Implemented Algorithms
- Algorithm for Identifying Risk Region
- Degradation Decomposition (DISDE)
- License and terms of use
- References
Here we provide the access links for the 5 datasets used in our benchmark.
- The task is to predict whether an individual’s income is above $50,000.
- Access link: https://github.com/socialfoundations/folktables
- Reference: Ding, F., Hardt, M., Miller, J., & Schmidt, L. (2021). Retiring adult: New datasets for fair machine learning. Advances in neural information processing systems, 34, 6478-6490.
- License: MIT License
- The task is to predict whether an individual has public health insurance.
- Access link: https://github.com/socialfoundations/folktables
- Reference: Ding, F., Hardt, M., Miller, J., & Schmidt, L. (2021). Retiring adult: New datasets for fair machine learning. Advances in neural information processing systems, 34, 6478-6490.
- License: MIT License
- The task is to predict whether an individual had the same residential address one year ago.
- Access link: https://github.com/socialfoundations/folktables
- Reference: Ding, F., Hardt, M., Miller, J., & Schmidt, L. (2021). Retiring adult: New datasets for fair machine learning. Advances in neural information processing systems, 34, 6478-6490.
- License: MIT License
- The task is to predict whether the total ride duration time exceeds 30 minutes, based on location and temporal features.
- Access link:
- License: CC BY-SA 4.0
- The task is to predict whether an accident is severe (long delay) or not (short delay) based on weather features and Road condition features.
- Access link: https://www.kaggle.com/datasets/sobhanmoosavi/us-accidents
- License: CC BY-SA 4.0
Here we provide the scripts to get data in our proposed settings.
pip3 install whyshift
get_data(task, state, year, need_preprocess, root_dir)
functiontask
values: 'income', 'pubcov', 'mobility'
- examples:
from whyshift import get_data # for ACS Income X, y, feature_names = get_data("income", "CA", True, './datasets/acs/', 2018) # for ACS Public Coverage X, y, feature_names = get_data("pubcov", "CA", True, './datasets/acs/', 2018) # for ACS Mobility X, y, feature_names = get_data("mobility", "CA", True, './datasets/acs/', 2018)
- support
state
values:- ['AL', 'AK', 'AZ', 'AR', 'CA', 'CO', 'CT', 'DE', 'FL', 'GA', 'HI', 'ID', 'IL', 'IN', 'IA', 'KS', 'KY', 'LA', 'ME', 'MD', 'MA', 'MI', 'MN', 'MS', 'MO', 'MT', 'NE', 'NV', 'NH', 'NJ', 'NM', 'NY', 'NC', 'ND', 'OH', 'OK', 'OR', 'PA', 'RI', 'SC', 'SD', 'TN', 'TX', 'UT', 'VT', 'VA', 'WA', 'WV', 'WI', 'WY', 'PR']
- download data files:
# US Accident: https://www.kaggle.com/datasets/sobhanmoosavi/us-accidents # Taxi https://www.kaggle.com/competitions/nyc-taxi-trip-duration
- put data files in dir
./datasets/
- accident:
./datasets/Accident/US_Accidents_Dec21_updated.csv
- taxi:
./datasets/Taxi/{city}_clean.csv
- accident:
- pass the
path to the data file
ofget_data
function - example:
from whyshift import get_data # for US Accident X, y, _ = get_data("accident", "CA", True, './datasets/Accident/US_Accidents_Dec21_updated.csv') # for Taxi X, y, _ = get_data("taxi", "nyc", True, './datasets/Taxi/train.csv')
- support
state
values:- for US Accident: ['CA', 'TX', 'FL', 'OR', 'MN', 'VA', 'SC', 'NY', 'PA', 'NC', 'TN', 'MI', 'MO']
- for Taxi: ['nyc', 'bog', 'uio', 'mex']
Based on our whyshift
package, one could design various source-target pairs with different distribution shift patterns. Here we list some of them for reference:
#ID | Dataset | Type | #Features | Outcome | Source | #Train Samples | #Test Domains | Dom. Ratio |
---|---|---|---|---|---|---|---|---|
1 | ACS Income | Spatial | 9 | Income≥50k | California | 195,665 | 50 | |
2 | ACS Income | Spatial | 9 | Income≥50k | Connecticut | 19,785 | 50 | |
3 | ACS Income | Spatial | 9 | Income≥50k | Massachusetts | 40,114 | 50 | |
4 | ACS Income | Spatial | 9 | Income≥50k | South Dakota | 4,899 | 50 | |
5 | ACS Mobility | Spatial | 21 | Residential Address | Mississippi | 5,318 | 50 | |
6 | ACS Mobility | Spatial | 21 | Residential Address | New York | 40,463 | 50 | |
7 | ACS Mobility | Spatial | 21 | Residential Address | California | 80,329 | 50 | |
8 | ACS Mobility | Spatial | 21 | Residential Address | Pennsylvania | 23,918 | 50 | |
9 | Taxi | Spatial | 7 | Duration time≥30 min | Bogotá | 3,063 | 3 | |
10 | Taxi | Spatial | 7 | Duration time≥30 min | New York City | 1,458,646 | 3 | |
11 | ACS Pub.Cov | Spatial | 18 | Public Ins. Coverage | Nebraska | 6,332 | 50 | |
12 | ACS Pub.Cov | Spatial | 18 | Public Ins. Coverage | Florida | 71,297 | 50 | |
13 | ACS Pub.Cov | Spatial | 18 | Public Ins. Coverage | Texas | 98,928 | 50 | |
14 | ACS Pub.Cov | Spatial | 18 | Public Ins. Coverage | Indiana | 24,330 | 50 | |
15 | US Accident | Spatial | 47 | Severity of Accident | Texas | 26,664 | 13 | |
16 | US Accident | Spatial | 47 | Severity of Accident | California | 64,909 | 13 | X: 22/31 |
17 | US Accident | Spatial | 47 | Severity of Accident | Florida | 32,278 | 13 | X: 5/7 |
18 | US Accident | Spatial | 47 | Severity of Accident | Minnesota | 8,927 | 13 | X: 8/11 |
19 | ACS Pub.Cov | Temporal | 18 | Public Ins. Coverage | Year 2010 (NY) | 73,208 | 3 | X: 2/2 |
20 | ACS Pub.Cov | Temporal | 18 | Public Ins. Coverage | Year 2010 (CA) | 149,441 | 3 | X: 2/2 |
21 | ACS Income | Synthetic | 9 | Income≥50k | Younger People (80%) | 20,000 | 1 | X: 1/1 |
22 | ACS Income | Synthetic | 9 | Income≥50k | Younger People (90%) | 20,000 | 1 | X: 1/1 |
In our benchmark, each setting has multiple target domains (except the last setting). In our main body, we select only one target domain for each setting. We report the Dom. Ratio
to represent the dominant ratio of
In our whyshift
package, we also implement several algorithms for tabular data classification, including Logistic Regression
, MLP
, SVM
, Random Forest
, XGBoost
, LightGBM
, GBM
, DRO/DORO
, Group DRO
, Simple-Reweighting
, JTT
, Fairness-In/Postprocess
and DWR
methods.
# use the implemented methods
algo = fetch_model(method_name)
Note that the supported method names are:
method_name_list = ['lr','svm','xgb', 'lightgbm', 'rf', 'dwr', 'jtt','suby', 'subg', 'rwy', 'rwg', 'FairPostprocess_exp','FairInprocess_dp', 'FairPostprocess_threshold', 'FairInprocess_eo', 'FairInprocess_error_parity','chi_dro', 'chi_doro','cvar_dro','cvar_doro','group_dro']
In our whyshift
package, we implement the risk region identification algorithm (Algorithm 1 in our paper). The function is risk_region
.
And here is an example to use it:
from whyshift import risk_region
source_model = xgb.XGBClassifier()
target_model = xgb.XGBClassifier()
risk_region('xgb', source_model, target_model, 'income', 'CA', 'PR', "./datasets/acs")
In our whyshift
package, we implement the DIstribution Shift DEcomposition (DISDE
) method to attribute the performance degradation to degradation_decomp
The parameters include:
- source_X, source_y, other_X_raw, other_y_raw: data from the source and target distributions
- best_method: the model to be diagnosed
- data_sum: the sample num of target data
- K: K-fold training-testing
- domain_classifier: when not specified (None), XGBoost will be used
- draw_calibration: whether to draw the calibration curve to check the quality of domain classifier
- save_calibration_png: the path to save the calibration figure
An example to use our WHYSHIFT
package could be found at Example.ipynb
.
Our benchmark is built upon Folktables
. The License of Folktables
is:
Folktables provides code to download data from the American Community Survey (ACS) Public Use Microdata Sample (PUMS) files managed by the US Census Bureau. The data itself is governed by the terms of use provided by the Census Bureau. For more information, see https://www.census.gov/data/developers/about/terms-of-service.html
The Adult reconstruction dataset is a subsample of the IPUMS CPS data available from https://cps.ipums.org/. The data are intended for replication purposes only. Individuals analyzing the data for other purposes must submit a separate data extract request directly via IPUMS CPS. Individuals are not to redistribute the data without permission. Contact [email protected] for redistribution requests.
Besides, for US Accident and Taxi data from kaggle
, individuals should follow the their Licenses, see https://www.kaggle.com/datasets/sobhanmoosavi/us-accidents and https://www.kaggle.com/competitions/nyc-taxi-trip-duration/data.
[1] Ding, F., Hardt, M., Miller, J., & Schmidt, L. (2021). Retiring adult: New datasets for fair machine learning. Advances in neural information processing systems, 34, 6478-6490.
- We modify the
folktables
code to supportyear
before 2014, and involve the revised version in our package. - Part of the algorithm codes are used from the codebase.