This repository contains the codebase for the foundational EEG model developed during the Google Summer of Code (GSoC) 2024, with the Department of Biomedical Informatics at Emory University.
The primary objective of this project was to enhance the NeuroGPT codebase, a framework for processing and analyzing EEG data. Key enhancements include:
- Data Loaders and Model Architecture: Adaptations for handling EEG data, especially in
.edf
and.fif
formats. - SafeTensors Support: Added support for loading models in SafeTensors format, particularly useful in distributed environments.
- DistilledGPT Architecture: Integration of a lightweight variant of GPT-2 for reduced computational load and faster training.
- DeepLIFT Implementation: Added to improve model interpretability by identifying which EEG channels contributed most to the model's decisions.
- Early Stopping Callback: Implemented to prevent overfitting by halting training when performance improvements stagnate.
- LOOCV and Kfold Cross Validaiton: Implemented for proper and flexible cross validation of the model.
To set up the project locally, follow these steps:
-
Clone the Repository:
git clone https://github.com/zhreyu/eeg-foundation-model.git cd eeg-foundation-model
-
Install Dependencies:
pip install -r requirements.txt
-
Run the Model:
torchrun --nproc_per_node=<num_gpu> src/train_gpt.py --training-type=epoch --num-epochs=25 --log-every-n-steps=10000 --per-device-training-batch-size=8 --per-device-validation-batch-size=8 --num-workers=0 --num_chunks=2 --chunk_len=500 --chunk_ovlp=100 --num-hidden-layers=6 --num-encoder-layers=6 --training-style='CSM_causal' --embedding-dim=1024 --data=EEG --num_channels=20 --train-data-path='' --fp16=True --sub_list='sub_list.csv' --architecture=DistilledGPT2 --no_evaluation=True
--data
: Type of data to use. Options areEEG
(default) orEEG2
.--architecture
: Select the architecture to use. Options includeGPT
,PretrainedGPT2
, andDistilledGPT2
.--kfold
: Set tofalse
for regular training,0
for LOOCV, or any integer for k-fold cross-validation.--training-type
: Specify the training type. Options aresteps
(default) orepoch
.- If
--training-type=epoch
, use the--num-epochs
argument to specify the number of epochs.
- If
--num-epochs
: Number of epochs to train when--training-type=epoch
is selected (default is3
).--sublist
: Path to the CSV file containing the subject list for training.--no_evaluation
: Set toTrue
to disable evaluation during training (default isFalse
).--early-stopping-patience
: Number of evaluations with no improvement before stopping training (default is5
).--early-stopping-threshold
: Minimum change in monitored quantity to qualify as improvement (default is0.00
).
This project builds upon the NeuroGPT codebase developed by Wenhui Cui, Woojae Jeong, Philipp Thölke, Takfarinas Medani, Karim Jerbi, Anand A. Joshi, and Richard M. Leahy. Their contributions have been instrumental to the success of this project.
For more details, refer to their arXiv paper.
I’ve had a great time during these 2-3 months working on this project. The experience has been incredibly rewarding, and I’ve gained a lot of valuable knowledge and skills. Special thanks to my mentor, Dr. Mahmoud Zeydabadinezhad, for his constant support and guidance throughout GSoC 2024.
This project is licensed under the BSD 3-Clause "New" or "Revised" License. See the LICENSE
file for more details.