Skip to content
/ SP-CPC Public

Code, logs, and final models for SIGSPATIAL SpatialEpi '22: Spatiotemporal Disease Case Prediction using Contrastive Predictive Coding.

License

Notifications You must be signed in to change notification settings

ajzliu/SP-CPC

Repository files navigation

Model Architecture

Spatial Probabilistic Contrastive Predictive Coding

This repository contains the code to reproduce the results of our paper Spatiotemporal Disease Case Prediction using Contrastive Predictive Coding, presented at the 2022 Spatial Epidemiology workshop at the 30th ACM SIGSPATIAL conference in Seattle, WA from November 1-4, 2022.

Table of Contents

Citation

If you found our paper or code useful, please cite our paper:

ACM Reference Format: Anish Susarla, Austin Liu, Duy Hoang Thai, Minh Tri Le, and Andreas Züfle. 2022. Spatiotemporal Disease Case Prediction using Contrastive Predictive Coding. In The 3rd ACM SIGSPATIAL International Workshop on Spatial Computing for Epidemiology (SpatialEpi ’22) (SpatialEpi ’22), November 1, 2022, Seattle, WA, USA. ACM, New York, NY, USA, 9 pages. https://doi.org/10.1145/3557995.3566122

About The Project

Time series prediction models have played a vital role in guiding effective policymaking and response during the COVID-19 pandemic by predicting future cases and deaths at the country, state, and county levels. However, for emerging diseases, there is not sufficient historic data to fit traditional supervised prediction models. In addition, such models do not consider human mobility between regions. To mitigate the need for supervised models and to include human mobility data in the prediction, we propose Spatial Probabilistic Contrastive Predictive Coding (SP-CPC) which leverages Contrastive Predictive Coding (CPC), an unsupervised time-series representation learning approach. We augment CPC to incorporate a covariate mobility matrix into the loss function, representing the relative number of individuals traveling between each county on a given day. The proposal distribution learned by the algorithm is then sampled by the Metropolis-Hastings algorithm to give a final prediction of the number of COVID-19 cases. We find that the model applied to COVID-19 data can make accurate short-term predictions, more accurate than ARIMA and simple time-series extrapolation methods, one day into the future. However, for longer-term prediction windows of seven or more days into the future, we find that our predictions are not as competitive and require future research.

Datasets

Old versions of the datasets used for training the model in the paper can be found in the data/ directory. Newer versions of the datasets can be found at the links below.

us-counties-2020.csv, us-counties-2021.csv, and us-counties-2022.csv can all be downloaded at nytimes/covid-19-data. daily_county2county_2021_04_15_int.csv can be downloaded at GeoDS/COVID19USFlows, but this is the newest file available. mainland_fips_master.csv can be downloaded at kjhealy/fips-codes

Getting Started

First, download the required dependencies for the project. A minimum of Python 3.7 is required to run this project (due to ordered dictionaries).

First, install the requirements for this project. If you are using your base environment, use pip install -r requirements.txt or if you are using a Conda environment, use conda install --file requirements.txt. Not all requirements may be on the list, however, so follow the appropriate prompts from your terminal.

To obtain the mobility dictionary for any number of counties, first, run gen_all_mobility_mat.py. Then, to obtain a smaller, renormalized sample of the mobility dictionary, update the fips_codes list to include the FIPS codes for what you want to run your code on, and then run renormalize_mob_custom.py. Make sure the list of counties you pass into this custom list is the same as in dataset.counties in config.yaml.

If you are choosing to use comet, make sure to replace the API key in config.yaml and the project_name in main_multicounty.py. Using comet is not necessary, and can be disabled by setting has_comet to false in main_multicounty.py and relevant files. Then, run the command below to begin the training.

python3 main_multicounty.py

Then, run the appropriate file according to the repository structure. When running a file not in the main directory, make sure you change the path of the terminal to the folder the file is located in before running.

Repository Structure

The log files contain the model structures and hyperparameters used for training in that experiment. The final_models folder contains the final trained models.

Path Description
existing/arima_7day.py Trains and tests a baseline ARIMA model predicting the next 7 days of new case data using 21 days of past data.
existing/arima_1day.py Trains and tests a baseline ARIMA model predicting the next 1 day of new case data using 13 days of past data.
existing/constant_interpolation_7day.py Calculates the mean average percentage error between the case value for the next seven days based on the case values seven days prior.
existing/constant_interpolation_1day.py Calculates the mean average percentage error between the case value for the next day based on the case values seven days prior.
logs/nospatial-30counties-dim120.log Log file for training a CPC model with no spatial data on 30 counties with an encoder/autoregressive dimension of 120.
logs/spatial-30counties-dim120.log Log file for training the Spatial CPC model with spatial data on 30 counties with an encoder/autoregressive dimension of 120.
logs/spatial-30counties-dim120-leakyrelu.log Log file for training the Spatial CPC model with spatial data on 30 counties with an encoder/autoregressive dimension of 120 and a LeakyReLU in place of a ReLU in the previous two experiments.
logs/spatial-30counties-dim60.log Log file for training the Spatial CPC model with spatial data on 30 counties with an encoder/autoregressive dimension of 60.
logs/spatial-30counties-dim60-rnn.log Log file for training the Spatial CPC model with spatial data on 30 counties with an encoder/autoregressive dimension of 60 and an RNN in place of a GRU.
logs/spatial-30counties-dim60-linear.log Log file for training the Spatial CPC model with spatial data on 30 counties with an encoder/autoregressive dimension of 60 and a single Linear layer in place of a Conv1D layer in the encoder.
logs/spatial-30counties-dim60-nonoverlapping.log Log file for training the Spatial CPC model with spatial data on 30 counties with an encoder/autoregressive dimension of 60 and nonoverlapping data.
final_models/... Final trained models produced by the experiments with the corresponding log name.
config.yaml Config file for adjusting hyperparameters.
main_multicounty.py Python script to train Spatial CPC. Change the imported dataset at the top of the file to change which dataset the model uses.
README.md README file.
requirements.txt Requirements for running this repository.
test_predict.py Python script to generate predictions using Metropolis-Hastings.
utils/dataset_nonoverlap.py Contains dataset classes for the nonoverlapping (Experiment 2) dataset and mobility data in training Spatial CPC.
utils/dataset_overlap_scpc.py Contains dataset classes for the overlapping (Experiment 1) dataset and mobility data in training Spatial CPC.
utils/dataset_overlap_mh.py Contains dataset classes for the overlapping dataset and mobility data required for evaluating Metropolis-Hastings.
utils/logger.py Contains helper functions for logging during training and validation.
utils/seed.py Contains helper functions for setting the random seeds for all libraries.
utils/train.py Contains helper functions for training the model.
utils/validation.py Contains helper functions for validating the model during training.
model/models.py Contains classes for the PyTorch models described in the paper and the code for Metropolis-Hastings.
mobility/gen_all_mobility_mat.py Python script to obtain the mobility dictionary for all 3145 US FIPS codes.
mobility/renormalize_mob_custom.py Python script to obtain a renormalized mobility dictionary for any set of US FIPS codes.
figs/... Contains figures required for this README.
data/... Contains datasets described in the Datasets section.

Issues and Contact

For any issues pertaining to the code, please use the Issues tab. For questions regarding the research paper, please email either of the first authors of this paper (Anish Susarla or Austin Liu), and we will get back to you to address your questions.

Acknowledgements

We would like to express our deepest gratitude to our research advisors Professor Andreas Züfle at Emory University and Dr. Duy Hoang Thai at George Mason University for their constant support, ideas, and feedback. Many thanks also to the Aspiring Scientists Summer Internship Program at George Mason University which provided us with the opportunity to conduct this research project this summer.

Funding

This material is based upon work supported by the National Science Foundation under Grant No. DEB-2109647 for "Data-Driven Modeling to Improve Understanding of Human Behavior, Mobility, and Disease Spread." Any opinions, findings, conclusions, or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the National Science Foundation. This research was additionally supported by the Aspiring Scientists Summer Internship Program (ASSIP) at George Mason University.

References

Parts of this repository are based on vgaraujov/CPC-NLP-PyTorch.