Skip to content

This repository is the official implementation of "Deep learning-based EEG analysis to classify mild cognitive impairment for early detection of dementia: algorithms and benchmarks" from the CNIR (CAU NeuroImaging Research) team.

Notifications You must be signed in to change notification settings

ipis-mjkim/caueeg-ceednet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

caueeg-ceednet

CEEDNet: CAUEEG End-to-end Deep neural Network for automatic early detection of dementia based on the CAUEEG dataset.

This repository is the official implementation of "Deep learning-based EEG analysis to classify mild cognitive impairment for early detection of dementia: algorithms and benchmarks" from the CNIR (CAU NeuroImaging Research) team.

graphical-abstract


Model summary

CAUEEG-Dementia dataset

Model #Params Model size (MiB) TTA Throughput (EEG/s) Test accuracy Link 1 Link 2
k-Nearest Neighbors (k=5) - 11848.7 52.42 36.80%
Random Forests (#trees=2000) - 2932.8 808.76 46.62%
Linear SVM 0.1M 0.5 10393.40 52.33%
Ieracitano-CNN (Ieracitano et al., 2019) 3.5M 13.2 8172.36 54.27%
CEEDNet (1D-VGG-19) 20.2M 77.2 6593.80 64.00% 1vc80n1f 1vc80n1f
CEEDNet (1D-VGG-19) 20.2M 77.2 931.64 67.11% 1vc80n1f 1vc80n1f
CEEDNet (1D-ResNet-18) 11.4M 43.6 1249.90 68.75% 2s1700lg 2s1700lg
CEEDNet (1D-ResNet-50) 26.2M 100.2 1013.18 67.00% gvqyvmrj gvqyvmrj
CEEDNet (1D-ResNeXt-50) 25.7M 98.2 702.64 68.54% v301o425 v301o425
CEEDNet (2D-VGG-19) 20.2M 77.1 296.84 70.18% lo88puq7 lo88puq7
CEEDNet (2D-ResNet-18) 11.4M 43.7 399.40 65.88% xci5svkl xci5svkl
CEEDNet (2D-ResNet-50) 25.7M 98.5 182.20 67.21% syrx7bmk syrx7bmk
CEEDNet (2D-ResNeXt-50) 25.9M 99.1 161.25 67.91% 1sl7ipca 1sl7ipca
CEEDNet (ViT-B-16) 90.1M 343.6 47.31 66.18% gjkysllw gjkysllw
CEEDNet (Ensemble) 256.7M 981.1 22.74 74.66%

dementia-roc-curve

dementia-confusion-matrix

dementia-class-wise-metrics

CAUEEG-Abnormal dataset

Model #Params Model size (MiB) TTA Throughput (EEG/s) Test accuracy Link 1 Link 2
K-Nearest Neighbors (K=7) - 14015.3 41.19 51.42%
Random Forests (#trees=2000) - 1930.5 830.80 72.63%
Linear SVM 0.1M 0.3 10363.76 68.00%
Ieracitano-CNN (Ieracitano et al., 2019) 3.5M 13.2 8293.08 65.98%
CEEDNet (1D-VGG-19) 20.2M 77.2 7660.22 72.45% nemy8ikm nemy8ikm
CEEDNet (1D-VGG-19) 20.2M 77.2 998.54 74.28% nemy8ikm nemy8ikm
CEEDNet (1D-ResNet-18) 11.4M 43.5 844.65 74.85% 4439k9pg 4439k9pg
CEEDNet (1D-ResNet-50) 26.3M 100.7 837.66 76.37% q1hhkmik q1hhkmik
CEEDNet (1D-ResNeXt-50) 25.7M 98.2 800.49 77.32% tp7qn5hd tp7qn5hd
CEEDNet (2D-VGG-19) 20.2M 77.2 447.81 75.39% ruqd8r7g ruqd8r7g
CEEDNet (2D-ResNet-18) 11.5M 43.8 410.44 75.19% dn10a6bv dn10a6bv
CEEDNet (2D-ResNet-50) 25.7M 98.5 187.30 74.96% atbhqdgg atbhqdgg
CEEDNet (2D-ResNeXt-50) 25.9M 99.1 201.01 75.85% 0svudowu 0svudowu
CEEDNet (ViT-B-16) 86.9M 331.6 63.99 72.70% 1cdws3t5 1cdws3t5
CEEDNet (Ensemble) 253.8M 969.9 26.40 79.16%

abnormal-roc-curve

abnormal-confusion-matrix

abnormal-class-wise-metrics


Getting started

Requirements

Note: we tested the code in the following environments.

OS Python PyTorch CUDA
Windows 10 3.9.12 1.11.0 11.3
Ubuntu 18.04 3.8.11 1.10.0 11.3
Ubuntu 20.04 3.9.12 1.11.0 11.3

Installation

(optional) Create and activate a Conda environment.

conda create -n caueeg python=3.9
conda activate caueeg

Install PyTorch library (refer to https://pytorch.org/get-started/locally/).

conda install pytorch torchvision cudatoolkit=11.3 -c pytorch

Install other necessary libraries.

pip install -r requirements.txt

Preparation of the CAUEEG dataset

❗ Note: The use of the CAUEEG dataset is allowed for only academic and research purposes 👩‍🎓👨🏼‍🎓.

💡 Note: We provide caueeg-dataset-test-only at [link 1] or [link 2] to test our research. caueeg-dataset-test-only has the 'real' test splits of two benchmarks (CAUEEG-Dementia and CAUEEG-Abnormal) but includes the 'fake' train and validation splits.


Usage

Train

Train a CEEDNet model on the training set of CAUEEG-Dementia from scratch using the following command:

python run_train.py data=caueeg-dementia model=1D-ResNet-18 train=base_train

Similarly, train a model on the training set of CAUEEG-Abnormal from scratch using:

python run_train.py data=caueeg-abnormal model=1D-ResNet-18 train=base_train

Or, you can use this Jupyter notebook.

If you encounter a GPU memory allocation error or wish to adjust the balance between memory usage and training speed, you can specify the minibatch size by adding the ++model.minibatch=INTEGER_NUMBER option to the command as shown below:

python run_train.py data=caueeg-dementia model=1D-ResNet-18 train=base_train ++model.minibatch=32
python run_train.py data=caueeg-abnormal model=1D-ResNet-18 train=base_train ++model.minibatch=32

Thanks to Hydra support, the model, hyperparameters, and other training details are easily tuned using or modifying config files.

python run_train.py data=caueeg-dementia model=2D-VGG-19 train=base_train

For speed-up, we recommend using the PyArrow.feather file format than using directly EDF, which can be done:

python ./datasets/convert_file_format.py  # it takes a few minutes
python run_train.py data=caueeg-dementia model=2D-VGG-19 train=base_train ++data.file_format=feather

Evaluation

Evaluation can be conducted using this Jupyter notebook (or another notebook for caueeg-dataset-test-only case)

To use the pre-trained model, download the checkpoint file from here, and move it to local/checkpoint directory (e.g., local/checkpoint/1vc80n1f/checkpoint.pt for the 1D-VGG-19 model on the CAUEEG-Dementia benchmark).


Citation

If you found this repository helpful, please cite the paper below.

@article{kim2023deep,
  title={Deep learning-based EEG analysis to classify normal, mild cognitive impairment, and dementia: algorithms and dataset},
  author={Kim, Min-jae and Youn, Young Chul and Paik, Joonki},
  journal={NeuroImage},
  pages={120054},
  year={2023},
  publisher={Elsevier}
}

About

This repository is the official implementation of "Deep learning-based EEG analysis to classify mild cognitive impairment for early detection of dementia: algorithms and benchmarks" from the CNIR (CAU NeuroImaging Research) team.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published