CEEDNet: CAUEEG End-to-end Deep neural Network for automatic early detection of dementia based on the CAUEEG dataset.
This repository is the official implementation of "Deep learning-based EEG analysis to classify mild cognitive impairment for early detection of dementia: algorithms and benchmarks" from the CNIR (CAU NeuroImaging Research) team.
Model | #Params | Model size (MiB) | TTA | Throughput (EEG/s) | Test accuracy | Link 1 | Link 2 |
---|---|---|---|---|---|---|---|
k-Nearest Neighbors (k=5) | - | 11848.7 | 52.42 | 36.80% | |||
Random Forests (#trees=2000) | - | 2932.8 | 808.76 | 46.62% | |||
Linear SVM | 0.1M | 0.5 | 10393.40 | 52.33% | |||
Ieracitano-CNN (Ieracitano et al., 2019) | 3.5M | 13.2 | 8172.36 | 54.27% | |||
CEEDNet (1D-VGG-19) | 20.2M | 77.2 | 6593.80 | 64.00% | 1vc80n1f | 1vc80n1f | |
CEEDNet (1D-VGG-19) | 20.2M | 77.2 | ✔ | 931.64 | 67.11% | 1vc80n1f | 1vc80n1f |
CEEDNet (1D-ResNet-18) | 11.4M | 43.6 | ✔ | 1249.90 | 68.75% | 2s1700lg | 2s1700lg |
CEEDNet (1D-ResNet-50) | 26.2M | 100.2 | ✔ | 1013.18 | 67.00% | gvqyvmrj | gvqyvmrj |
CEEDNet (1D-ResNeXt-50) | 25.7M | 98.2 | ✔ | 702.64 | 68.54% | v301o425 | v301o425 |
CEEDNet (2D-VGG-19) | 20.2M | 77.1 | ✔ | 296.84 | 70.18% | lo88puq7 | lo88puq7 |
CEEDNet (2D-ResNet-18) | 11.4M | 43.7 | ✔ | 399.40 | 65.88% | xci5svkl | xci5svkl |
CEEDNet (2D-ResNet-50) | 25.7M | 98.5 | ✔ | 182.20 | 67.21% | syrx7bmk | syrx7bmk |
CEEDNet (2D-ResNeXt-50) | 25.9M | 99.1 | ✔ | 161.25 | 67.91% | 1sl7ipca | 1sl7ipca |
CEEDNet (ViT-B-16) | 90.1M | 343.6 | ✔ | 47.31 | 66.18% | gjkysllw | gjkysllw |
CEEDNet (Ensemble) | 256.7M | 981.1 | ✔ | 22.74 | 74.66% |
Model | #Params | Model size (MiB) | TTA | Throughput (EEG/s) | Test accuracy | Link 1 | Link 2 |
---|---|---|---|---|---|---|---|
K-Nearest Neighbors (K=7) | - | 14015.3 | 41.19 | 51.42% | |||
Random Forests (#trees=2000) | - | 1930.5 | 830.80 | 72.63% | |||
Linear SVM | 0.1M | 0.3 | 10363.76 | 68.00% | |||
Ieracitano-CNN (Ieracitano et al., 2019) | 3.5M | 13.2 | 8293.08 | 65.98% | |||
CEEDNet (1D-VGG-19) | 20.2M | 77.2 | 7660.22 | 72.45% | nemy8ikm | nemy8ikm | |
CEEDNet (1D-VGG-19) | 20.2M | 77.2 | ✔ | 998.54 | 74.28% | nemy8ikm | nemy8ikm |
CEEDNet (1D-ResNet-18) | 11.4M | 43.5 | ✔ | 844.65 | 74.85% | 4439k9pg | 4439k9pg |
CEEDNet (1D-ResNet-50) | 26.3M | 100.7 | ✔ | 837.66 | 76.37% | q1hhkmik | q1hhkmik |
CEEDNet (1D-ResNeXt-50) | 25.7M | 98.2 | ✔ | 800.49 | 77.32% | tp7qn5hd | tp7qn5hd |
CEEDNet (2D-VGG-19) | 20.2M | 77.2 | ✔ | 447.81 | 75.39% | ruqd8r7g | ruqd8r7g |
CEEDNet (2D-ResNet-18) | 11.5M | 43.8 | ✔ | 410.44 | 75.19% | dn10a6bv | dn10a6bv |
CEEDNet (2D-ResNet-50) | 25.7M | 98.5 | ✔ | 187.30 | 74.96% | atbhqdgg | atbhqdgg |
CEEDNet (2D-ResNeXt-50) | 25.9M | 99.1 | ✔ | 201.01 | 75.85% | 0svudowu | 0svudowu |
CEEDNet (ViT-B-16) | 86.9M | 331.6 | ✔ | 63.99 | 72.70% | 1cdws3t5 | 1cdws3t5 |
CEEDNet (Ensemble) | 253.8M | 969.9 | ✔ | 26.40 | 79.16% |
- Installation of Conda (refer to https://www.anaconda.com/products/distribution)
- Nvidia GPU with CUDA support
Note: we tested the code in the following environments.
OS Python PyTorch CUDA Windows 10 3.9.12 1.11.0 11.3 Ubuntu 18.04 3.8.11 1.10.0 11.3 Ubuntu 20.04 3.9.12 1.11.0 11.3
(optional) Create and activate a Conda environment.
conda create -n caueeg python=3.9
conda activate caueeg
Install PyTorch library (refer to https://pytorch.org/get-started/locally/).
conda install pytorch torchvision cudatoolkit=11.3 -c pytorch
Install other necessary libraries.
pip install -r requirements.txt
Preparation of the CAUEEG dataset
❗ Note: The use of the CAUEEG dataset is allowed for only academic and research purposes 👩🎓👨🏼🎓.
- For full access of the CAUEEG dataset, follow the instructions specified in https://github.com/ipis-mjkim/caueeg-dataset.
- Download, unzip, and move the whole dataset files into local/datasets/.
💡 Note: We provide
caueeg-dataset-test-only
at [link 1] or [link 2] to test our research.caueeg-dataset-test-only
has the 'real' test splits of two benchmarks (CAUEEG-Dementia and CAUEEG-Abnormal) but includes the 'fake' train and validation splits.
Train a CEEDNet model on the training set of CAUEEG-Dementia from scratch using the following command:
python run_train.py data=caueeg-dementia model=1D-ResNet-18 train=base_train
Similarly, train a model on the training set of CAUEEG-Abnormal from scratch using:
python run_train.py data=caueeg-abnormal model=1D-ResNet-18 train=base_train
Or, you can use this Jupyter notebook.
If you encounter a GPU memory allocation error or wish to adjust the balance between memory usage and training speed, you can specify the minibatch size by adding the ++model.minibatch=INTEGER_NUMBER
option to the command as shown below:
python run_train.py data=caueeg-dementia model=1D-ResNet-18 train=base_train ++model.minibatch=32
python run_train.py data=caueeg-abnormal model=1D-ResNet-18 train=base_train ++model.minibatch=32
Thanks to Hydra support, the model, hyperparameters, and other training details are easily tuned using or modifying config files.
python run_train.py data=caueeg-dementia model=2D-VGG-19 train=base_train
For speed-up, we recommend using the PyArrow.feather
file format than using directly EDF
, which can be done:
python ./datasets/convert_file_format.py # it takes a few minutes
python run_train.py data=caueeg-dementia model=2D-VGG-19 train=base_train ++data.file_format=feather
Evaluation can be conducted using this Jupyter notebook (or another notebook for caueeg-dataset-test-only
case)
To use the pre-trained model, download the checkpoint file from here, and move it to local/checkpoint directory (e.g., local/checkpoint/1vc80n1f/checkpoint.pt
for the 1D-VGG-19 model on the CAUEEG-Dementia benchmark).
If you found this repository helpful, please cite the paper below.
@article{kim2023deep,
title={Deep learning-based EEG analysis to classify normal, mild cognitive impairment, and dementia: algorithms and dataset},
author={Kim, Min-jae and Youn, Young Chul and Paik, Joonki},
journal={NeuroImage},
pages={120054},
year={2023},
publisher={Elsevier}
}