From 78212a539c63647745e52e4c3fab1392e83d1475 Mon Sep 17 00:00:00 2001 From: kai juarez jimenez Date: Fri, 6 Dec 2024 14:47:34 -0800 Subject: [PATCH] meow pls lemme in its my chbmit tutorial --- .DS_Store | Bin 0 -> 10244 bytes examples/snnCHBMIT.ipynb | 1453 ++++++++++++++++++++++++++++++++++++++ snntorch/.DS_Store | Bin 0 -> 6148 bytes 3 files changed, 1453 insertions(+) create mode 100644 .DS_Store create mode 100644 examples/snnCHBMIT.ipynb create mode 100644 snntorch/.DS_Store diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..12942d778b9eda8d1153ea3d993b6d053c0b0884 GIT binary patch literal 10244 zcmeHM&2Jk;6n~RAWF7azN!_$i6|Lk8hmeLylvW&YoP^?1NED|fp&w>#Z=5CTU9-D( z5~3iV;ldw)ICJB`g@1s5fCDEaj@%H}p4#7=U8noe_Jl+PGt%yxo&CKxZ+p zM77_3i%2IT4HwJe8Qh*z_&r~lG8VPykOKNd0lDN+3wI}jRt<^)#eiZ!F`yVw44eiA z@Xh9;6+)@*iUGxdV&H@U&JQ{+mX$~jLn%`SZgL9%nM1QIsAC@>b@WJ9A~_7DRE5oO z_aGuw5p6LI4mcaqhq;5N1+mtdr82fR2R`n`9)Bv;^e$u@PPZvd0$bO?j@- zrLN>Dd~z^)?sc9H!9kbolr`onZY?jZ9b!~Hsgu!=oW(V!f#co{cpHEZo+lTc_%G^H zFq?4;1!~X$dLDEvc{0NzI!t-WO22@b2@|rt8|xQQc_Xv&{WQv9#XSLUMc6>9~iq*G2H4ZjUte^K=o@?9du5SgF>uf*p!F(++yESLWMdt&szSiSc z{yo$4AsfY%&epc!d4|(5tA`J5qiNb}PSp(px64SjhB2*sCR>tj)O$wDaysSr{tgK) zSJp?BH^NGh%3+15T)91zb*(gNilo?dT*u7k%asqK77fwAOm`3kpCi)0rSIqm`iXv_ zU+GVJEXKvOD2r?2hFB7}#T{{1JQ5Aj6MZod`=N!g(cHhD>`M&soSFMU-E(Z$;jG1a z1^NK7S4RY{AsV-zAAy{k6Ol3re{1l16;WHK)nr_8o=KoJ`Y6icoY4@8DLmF_0$%`r z36^X0K6+r^IHNi#Eo6m+);Z2RGM~C&(nKB$pvR}inM}x;BB47abpy^3k=8*v3(0Kp zRV-Uffv;7xuE5*Xm+%(XDam~M1#+|{YhCo{phq8hl)o4Q>2aoS$!Oe89vUX}!d&0N zmm>8w Detecting Seizure Activity in EEG Using Spiking Neural Networks with the CHB-MIT Dataset \n", + "\n", + "##### This tutorial provides a step-by-step demonstration of using Spiking Neural Networks (SNNs) for seizure detection in EEG data. The dataset, preprocessing steps, SNN model design, and evaluation metrics are discussed comprehensively." + ] + }, + { + "cell_type": "markdown", + "id": "fe8481aa-bef8-4803-b7eb-026a3d17b9ff", + "metadata": {}, + "source": [ + "### Author: Kai Juarez-Jimenez" + ] + }, + { + "cell_type": "markdown", + "id": "53556c39-7612-4fb6-b6f5-a97d482aa24d", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "--- \n" + ] + }, + { + "cell_type": "markdown", + "id": "c02340af-23b7-429b-90f4-548d921c1593", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "## I. Introduction \n", + "### Objective \n", + "#### Importance of Seizure Detection in EEG Data: Seizure detection plays a vital role in advancing research and innovation aimed at developing solutions for epilepsy, such as wearable devices or neuromorphic chips. It also supports the accurate diagnosis of epilepsy and aids in identifying the precise location of seizure activity, which is critical for effective treatment planning [2].\n", + "\n", + "#### How Do SNNs Address This?: Spiking Neural Networks (SNNs) leverage their ability to process temporal information and replicate firing patterns similar to actual neurons. This capability enables them to effectively identify changes in brain activity associated with seizures, providing a more biologically inspired approach to analysis.\n", + "\n", + "#### Tutorial's Aim: This tutorial demonstrates the end-to-end process of:\n", + "#### 1. Preprocessing EEG data from the CHB-MIT dataset.\n", + "#### 2. Training an SNN model tailored to detect seizures.\n", + "#### 3. Evaluating model performance using metrics like AUC-ROC and confusion matrices.\n" + ] + }, + { + "cell_type": "markdown", + "id": "9adf8102-b6cb-4e96-b5fb-2213f287c3a6", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "431cf8a7-299c-4218-a2c3-c735a4b5a880", + "metadata": {}, + "source": [ + "## II. Data and Preprocessing \n", + "### Dataset Overview \n", + "#### About the CHB-MIT Dataset:The CHB-MIT data comprises EEG recordings from 23 patients who experienced seizures, with a total of 24 cases, including one patient with two recordings taken 1.5 years apart. It contains 969 hours of scalp EEG data capturing 173 seizure events. The patients range in age from 3 to 22 years old [1].\n", + "\n", + "#### Channels: Each EEG recording contains data from multiple channels (23 channels by default in CHB-MIT), representing different electrodes placed on the scalp. Each channel records the electrical activity at a specific location on the brain [1].\n", + "\n", + "#### Data: The raw EEG signals are stored as continuous time-series data, often in digital formats such as .edf, a standard for storing biomedical signals. \n", + "\n", + "#### Frequency: The CHB-MIT dataset has a fixed sampling rate of 256 Hz, meaning that 256 data points are recorded per second for each channel. This high sampling rate is sufficient to capture fast-changing brain activity, including seizure-related patterns [1]. " + ] + }, + { + "cell_type": "markdown", + "id": "8ea7e6ba-8153-4f04-98e0-ffa7b822d506", + "metadata": {}, + "source": [ + "
\n", + " Before starting, download the dataset, which contains recordings across 23 channels. Note that the full dataset is approximately 43 GB in size, so ensure you have sufficient storage available. Dataset download available at: https://physionet.org/content/chbmit/1.0.0/#files-panel\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "5905555c-b461-40de-9a3e-f16b1a344e3d", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "95db441a-08b5-4fe4-8b01-f44316dd7f54", + "metadata": {}, + "source": [ + "### Preprocessing \n", + "#### Before training the model, the raw EEG data requires preprocessing to extract meaningful features.\n", + "#### EEG Segmentations: To analyze EEG recordings effectively, we segment the data into smaller time windows. This step ensures localized analysis of the signal and facilitates applying frequency-domain transformations. Segmenting the data also simplifies computational processing, making it feasible to apply machine learning techniques.\n", + "#### Fourier Transformation: Fourier Transformation is applied to each segment to convert the signal from the time domain to the frequency domain.\n" + ] + }, + { + "cell_type": "markdown", + "id": "1fabf1df-3eb8-42ef-993c-0aff8d1fa864", + "metadata": {}, + "source": [ + "
\n", + " Why Fourier Transformation? \n", + " It converts the raw EEG data into the frequency domain, indentifying specific frequency bands (gamma 30-50 Hz) that significantly increase in power during a seizure, which can be used as a key indicator for seizure detection [4].\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "id": "1d2833c1-c766-4ce6-acb8-c9e8e6cc9e8e", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "66e2cc75-99fe-48cb-907b-1d921ff36b27", + "metadata": {}, + "source": [ + "#### This code processes EEG signals stored in EDF files to extract meaningful features that can be used for seizure detection. It focuses on the gamma-band power (30–50 Hz), which is a known indicator of seizure activity. The extracted features are saved incrementally to a CSV file for easy access and analysis." + ] + }, + { + "cell_type": "code", + "execution_count": 122, + "id": "32373b8e-b4f6-4632-a2b7-8fa4b71bc523", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "resuming from existing file: preprocessed_features.csv\n", + "shape of the dataset features: (686, 24)\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import os\n", + "from scipy.signal import stft\n", + "import pyedflib\n", + "import glob\n", + "\n", + "def segment_and_transform(signal, fs=256, window_size=256):\n", + " \"\"\"segmenting signal into smaller windows and apply fourier transformation.\"\"\"\n", + " segments = [signal[i:i+window_size] for i in range(0, len(signal), window_size) if i+window_size <= len(signal)]\n", + " features = []\n", + " for segment in segments:\n", + " f, _, Zxx = stft(segment, fs=fs, nperseg=window_size)\n", + " gamma_band = (f >= 30) & (f <= 50)\n", + " gamma_power = np.abs(Zxx[gamma_band]).mean()\n", + " features.append(gamma_power)\n", + " return np.array(features).mean() # aggregate features\n", + "\n", + "def preprocess_eeg(filepath, fs=256, window_size=256):\n", + " \"\"\"preprocess all channels in a file.\"\"\"\n", + " try:\n", + " print(f\"Processing: {filepath}\")\n", + " with pyedflib.EdfReader(filepath) as f:\n", + " n = f.signals_in_file\n", + " signals = np.zeros((n, f.getNSamples()[0]))\n", + " for i in range(n):\n", + " signals[i, :] = f.readSignal(i)\n", + " features = [segment_and_transform(channel, fs, window_size) for channel in signals]\n", + " # ensure exactly 23 features\n", + " if len(features) > 23:\n", + " features = features[:23] # truncate to 23\n", + " elif len(features) < 23:\n", + " features += [0] * (23 - len(features)) # pad with zeros\n", + " return np.array(features)\n", + " except Exception as e:\n", + " print(f\"error processing file {filepath}: {e}\")\n", + " return None\n", + "\n", + "\n", + "# dir containing the dataset\n", + "data_dir = './CHBMIT'\n", + "file_paths = glob.glob(f'{data_dir}/**/*.edf', recursive=True)\n", + "\n", + "# csv file for saving features\n", + "output_file = 'preprocessed_features.csv'\n", + "\n", + "# initialize or load existing data\n", + "if os.path.exists(output_file):\n", + " print(f\"resuming from existing file: {output_file}\")\n", + " processed_files = [os.path.abspath(p) for p in pd.read_csv(output_file)['file_path'].tolist()]\n", + " all_features = pd.read_csv(output_file)\n", + "else:\n", + " processed_files = []\n", + " all_features = pd.DataFrame(columns=['file_path'] + [f'feature_{i+1}' for i in range(23)])\n", + "\n", + "# preprocess and save incrementally\n", + "for file_path in file_paths:\n", + " if os.path.abspath(file_path) in processed_files:\n", + " #print(f\"skipping already processed file: {file_path}\")\n", + " continue\n", + " print(f\"processing file: {file_path}\")\n", + " features = preprocess_eeg(file_path) # preprocess file\n", + " if features is None or len(features) == 0:\n", + " print(f\"skipping file due to invalid or empty features: {file_path}\")\n", + " continue\n", + " feature_row = [os.path.abspath(file_path)] + features.tolist()\n", + " new_row = pd.DataFrame([feature_row], columns=all_features.columns)\n", + " all_features = pd.concat([all_features, new_row], ignore_index=True) # append new row\n", + " all_features.to_csv(output_file, index=False) # save incrementally\n", + " print(f\"saved features for file: {file_path}\")\n", + "\n", + "# (verify) the saved features\n", + "print(f\"shape of the dataset features: {all_features.shape}\")\n", + "\n", + "# preprocessed_data = pd.read_csv('preprocessed_features.csv')\n", + "# print(preprocessed_data.head())\n" + ] + }, + { + "cell_type": "markdown", + "id": "bdd29c51-63f0-4175-8a53-d66abe02f3e2", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "f20feaf3-4a32-4987-b844-d4f34826826a", + "metadata": {}, + "source": [ + "### Load and Inspect the Preprocessed Features \n", + "#### The preprocessed features stored in the preprocessed_features.csv file contain critical information extracted from EEG signals, such as gamma-band power for each channel. Loading this file allows us to verify the data structure and ensure it aligns with our modeling requirements." + ] + }, + { + "cell_type": "markdown", + "id": "2f0780cf-91df-4edc-8b7b-72c86cb8ddec", + "metadata": {}, + "source": [ + "
\n", + " Why Is This Important?
\n", + " Data Integrity: Verifying the shape and structure ensures that the preprocessing pipeline worked as intended and the dataset contains the expected number of features for each EEG file.
\n", + " Preparation for Merging: This inspection is a crucial step before merging the features with labels, ensuring compatibility in downstream tasks like training and evaluation.\n", + "
\n" + ] + }, + { + "cell_type": "code", + "execution_count": 123, + "id": "3ed13674-2a87-413e-b652-b99762b1e978", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dataset shape: (686, 24)\n", + " file_path feature_1 feature_2 feature_3 \\\n", + "0 /Users/kjuarezj/CHBMIT/chb11/chb11_53.edf 0.335455 0.425711 0.319864 \n", + "1 /Users/kjuarezj/CHBMIT/chb11/chb11_92.edf 1.194302 1.326720 1.355358 \n", + "2 /Users/kjuarezj/CHBMIT/chb11/chb11_82.edf 0.948020 1.374275 1.305503 \n", + "3 /Users/kjuarezj/CHBMIT/chb11/chb11_55.edf 0.822861 1.004402 0.914919 \n", + "4 /Users/kjuarezj/CHBMIT/chb11/chb11_54.edf 0.302837 0.320896 0.251632 \n", + "\n", + " feature_4 feature_5 feature_6 feature_7 feature_8 feature_9 ... \\\n", + "0 0.467975 0.001104 0.390983 0.483227 0.351886 0.382416 ... \n", + "1 0.914141 0.001104 0.915612 0.531348 0.400397 0.460694 ... \n", + "2 0.501360 0.001104 0.826069 0.363019 0.249405 0.309681 ... \n", + "3 0.607224 0.001104 0.780986 0.471013 0.413295 0.492717 ... \n", + "4 0.326476 0.001104 0.345955 0.339471 0.260942 0.309337 ... \n", + "\n", + " feature_14 feature_15 feature_16 feature_17 feature_18 feature_19 \\\n", + "0 0.381853 0.433515 0.342898 0.463714 0.001104 0.364284 \n", + "1 1.029096 0.891290 0.831539 0.754172 0.001104 1.385414 \n", + "2 0.796286 0.440978 0.340114 0.417588 0.001104 1.015485 \n", + "3 0.680256 0.477175 0.366936 0.459407 0.001104 0.852010 \n", + "4 0.335248 0.324554 0.254700 0.329429 0.001104 0.326655 \n", + "\n", + " feature_20 feature_21 feature_22 feature_23 \n", + "0 0.379613 0.347219 0.379839 0.001104 \n", + "1 1.541060 1.255841 0.745652 0.001104 \n", + "2 1.353998 1.263604 0.454354 0.001104 \n", + "3 0.869236 0.742546 0.505805 0.001104 \n", + "4 0.304063 0.262912 0.301555 0.001104 \n", + "\n", + "[5 rows x 24 columns]\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "# load preprocessed features\n", + "data_file = 'preprocessed_features.csv'\n", + "data = pd.read_csv(data_file)\n", + "print(f\"dataset shape: {data.shape}\") # should be (686, 24) rows, col\n", + "print(data.head()) # inspect the first few rows\n" + ] + }, + { + "cell_type": "markdown", + "id": "12a1662e-67f6-48a5-8515-088449e48e14", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "9e1d547c-3648-484b-bbad-cf081ce5d9b5", + "metadata": {}, + "source": [ + "### Data Splitting \n", + "#### This code prepares the dataset for machine learning by organizing EEG recordings into two labeled groups: \"seizure\" (1) and \"non-seizure\" (0). It ensures a balanced split of the data into training (80%) and testing (20%) sets while maintaining the proportion of seizure and non-seizure recordings. Finally, it saves the training and testing sets to separate CSV files for use in machine learning or analysis." + ] + }, + { + "cell_type": "markdown", + "id": "62768999-c2fc-4d75-9ad4-e3649632ff44", + "metadata": {}, + "source": [ + "
\n", + " Debugging: \n", + " Print statements are invaluable for ensuring your dataset is clean and correctly processed! <3 Verify the total file counts, inspect the splits, and ensure there are no missing or mislabeled files. Additionally, printing the first few rows of the labeled DataFrame can help confirm that the labeling process was executed as expected\n", + "
\n" + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "id": "5f4a3de8-57ba-46ee-82b8-f6d9c56d21e6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "total files: 686\n", + "seizure files: 141\n", + "non-seizure files: 545\n", + "total labeled files: 686\n", + "first __ rows of the labeled data:\n", + " file_path label\n", + "0 ./CHBMIT/chb01/chb01_01.edf 0\n", + "1 ./CHBMIT/chb01/chb01_02.edf 0\n", + "2 ./CHBMIT/chb01/chb01_03.edf 1\n", + "3 ./CHBMIT/chb01/chb01_04.edf 1\n", + "4 ./CHBMIT/chb01/chb01_05.edf 0\n", + "labeled files saved to: /Users/kjuarezj/labeled_files.csv\n", + "\n", + "splitting the dataset into training and testing sets...\n", + "training set size: 548, Testing set size: 138\n", + "training labels: 113 seizures, 435 non-seizures\n", + "testing labels: 28 seizures, 110 non-seizures\n", + "training set saved to: /Users/kjuarezj/train_files.csv\n", + "testing set saved to: /Users/kjuarezj/test_files.csv\n" + ] + } + ], + "source": [ + "import os\n", + "import pandas as pd\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "# load seizure files with clean handling\n", + "seizure_files_path = \"./CHBMIT/RECORDS-WITH-SEIZURES\" # replace with actual path\n", + "with open(seizure_files_path, \"r\") as f:\n", + " seizure_files = [line.strip() for line in f if line.strip() and line.strip().endswith(\".edf\")]\n", + "\n", + "# add base path to seizure files\n", + "seizure_files = {os.path.join(\"./CHBMIT\", file) for file in seizure_files}\n", + "\n", + "# load all files with clean handling\n", + "all_files_path = \"./CHBMIT/RECORDS\" # replace with actual path\n", + "with open(all_files_path, \"r\") as f:\n", + " all_files = [line.strip() for line in f if line.strip() and line.strip().endswith(\".edf\")]\n", + "\n", + "# add base path to all files\n", + "all_files = {os.path.join(\"./CHBMIT\", file) for file in all_files}\n", + "\n", + "# debug step: ensuring all files and seizure files are clean\n", + "print(f\"total files: {len(all_files)}\") #686\n", + "print(f\"seizure files: {len(seizure_files)}\") #141\n", + "\n", + "# find non-seizure files\n", + "non_seizure_files = all_files - seizure_files\n", + "print(f\"non-seizure files: {len(non_seizure_files)}\") #545\n", + "\n", + "# create dataframes\n", + "seizure_df = pd.DataFrame({\"file_path\": list(seizure_files), \"label\": 1})\n", + "non_seizure_df = pd.DataFrame({\"file_path\": list(non_seizure_files), \"label\": 0})\n", + "\n", + "# combine into a single DataFrame\n", + "all_labels_df = pd.concat([seizure_df, non_seizure_df], ignore_index=True)\n", + "\n", + "\n", + "print(f\"total labeled files: {len(all_labels_df)}\") # should match total files in RECORDS (686)\n", + "\n", + "# sort the DataFrame for consistency\n", + "all_labels_df = all_labels_df.sort_values(\"file_path\").reset_index(drop=True)\n", + "\n", + "# display the first __ rows of the labeled data\n", + "print(\"first __ rows of the labeled data:\")\n", + "print(all_labels_df.head(5))\n", + "\n", + "# save the dataframe to a csv file\n", + "output_file = \"labeled_files.csv\"\n", + "all_labels_df.to_csv(output_file, index=False)\n", + "print(f\"labeled files saved to: {os.path.abspath(output_file)}\")\n", + "\n", + "# split the dataset into training and testing sets\n", + "print(\"\\nsplitting the dataset into training and testing sets...\")\n", + "\n", + "# extract file paths and labels\n", + "file_paths = all_labels_df[\"file_path\"].values\n", + "labels = all_labels_df[\"label\"].values\n", + "\n", + "# split into training and testing sets\n", + "train_files, test_files, train_labels, test_labels = train_test_split(\n", + " file_paths, labels, test_size=0.2, random_state=42, stratify=labels\n", + ")\n", + "\n", + "# debuging core: verifying the splits\n", + "print(f\"training set size: {len(train_files)}, Testing set size: {len(test_files)}\")\n", + "print(f\"training labels: {sum(train_labels)} seizures, {len(train_labels) - sum(train_labels)} non-seizures\")\n", + "print(f\"testing labels: {sum(test_labels)} seizures, {len(test_labels) - sum(test_labels)} non-seizures\")\n", + "\n", + "# save the splits to csv files for future reference (optional tbh) \n", + "train_df = pd.DataFrame({\"file_path\": train_files, \"label\": train_labels})\n", + "test_df = pd.DataFrame({\"file_path\": test_files, \"label\": test_labels})\n", + "\n", + "train_output_file = \"train_files.csv\"\n", + "test_output_file = \"test_files.csv\"\n", + "\n", + "train_df.to_csv(train_output_file, index=False)\n", + "test_df.to_csv(test_output_file, index=False)\n", + "\n", + "print(f\"training set saved to: {os.path.abspath(train_output_file)}\")\n", + "print(f\"testing set saved to: {os.path.abspath(test_output_file)}\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "e2de998a-afa8-4af7-bf4a-a1d6b0b3cbc6", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "6baf36f9-3c7c-4027-8072-f1c50575b97d", + "metadata": {}, + "source": [ + "### Merging the Datasets \n", + "#### The datasets must be successfully merged, ensuring file paths align correctly and that features are combined with their corresponding labels. This process creates a feature matrix X (features) and a label array y (labels), which are critical for model training." + ] + }, + { + "cell_type": "code", + "execution_count": 105, + "id": "4d0795bd-0d65-4353-99ed-b2d23503c492", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "First few rows of merged data:\n", + " file_path label feature_1 feature_2 \\\n", + "0 /Users/kjuarezj/CHBMIT/chb01/chb01_01.edf 0 1.904987 1.282136 \n", + "1 /Users/kjuarezj/CHBMIT/chb01/chb01_02.edf 0 1.023876 0.692827 \n", + "2 /Users/kjuarezj/CHBMIT/chb01/chb01_03.edf 1 0.461945 0.314066 \n", + "3 /Users/kjuarezj/CHBMIT/chb01/chb01_04.edf 1 0.547646 0.392939 \n", + "4 /Users/kjuarezj/CHBMIT/chb01/chb01_05.edf 0 1.508249 1.026032 \n", + "\n", + " feature_3 feature_4 feature_5 feature_6 feature_7 feature_8 ... \\\n", + "0 1.160174 0.660278 2.245718 1.012637 0.605064 0.827183 ... \n", + "1 0.688489 0.656426 1.179537 0.777356 0.439819 0.825834 ... \n", + "2 0.315250 0.430209 0.490375 0.572592 0.380799 0.585074 ... \n", + "3 0.360169 0.444842 0.581529 0.576234 0.382254 0.619699 ... \n", + "4 0.933145 0.472396 1.744471 0.703685 0.443001 0.561593 ... \n", + "\n", + " feature_14 feature_15 feature_16 feature_17 feature_18 feature_19 \\\n", + "0 1.329508 1.028440 0.808805 0.405869 0.359414 1.160210 \n", + "1 0.940571 0.561562 0.856556 0.394640 0.372359 0.688610 \n", + "2 0.666185 0.384251 0.678264 0.393871 0.386834 0.315372 \n", + "3 0.773006 0.473018 0.693563 0.344362 0.350670 0.360295 \n", + "4 1.123658 0.941383 0.678876 0.344669 0.349095 0.933210 \n", + "\n", + " feature_20 feature_21 feature_22 feature_23 \n", + "0 1.398481 1.323586 0.990157 1.028440 \n", + "1 1.058208 1.282028 0.619977 0.561562 \n", + "2 0.632324 0.908260 0.372807 0.384251 \n", + "3 0.724498 0.998886 0.493445 0.473018 \n", + "4 1.038105 1.031051 0.878390 0.941383 \n", + "\n", + "[5 rows x 25 columns]\n", + "\n", + "Merged data saved to 'merged_data.csv'\n", + "\n", + "X shape: (686, 23)\n", + "y shape: (686,)\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "import os\n", + "\n", + "# oad both datasets\n", + "labeled_data = pd.read_csv(\"labeled_files.csv\")\n", + "features_data = pd.read_csv(\"preprocessed_features.csv\")\n", + "\n", + "\n", + "# step 2: convert relative paths to absolute paths in labeled data\n", + "labeled_data[\"file_path\"] = labeled_data[\"file_path\"].apply(os.path.abspath)\n", + "\n", + "\n", + "\n", + "# step 4: merge the datasets\n", + "merged_data = pd.merge(labeled_data, features_data, on=\"file_path\", how=\"inner\")\n", + "\n", + "# debugging merged data\n", + "print(\"\\nFirst few rows of merged data:\")\n", + "print(merged_data.head())\n", + "\n", + "\n", + "# step 5: save the merged data\n", + "merged_data.to_csv(\"merged_data.csv\", index=False)\n", + "print(\"\\nMerged data saved to 'merged_data.csv'\")\n", + "\n", + "# step 6: extract features and labels\n", + "X = merged_data.iloc[:, 2:25].values # features (23 columns)\n", + "y = merged_data[\"label\"].values # labels\n", + "\n", + "# verify extracted data\n", + "print(f\"\\nX shape: {X.shape}\") # should be (number of merged rows, 23)\n", + "print(f\"y shape: {y.shape}\") # should be (number of merged rows,)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "c087707b-89ce-4b4e-9b56-41a1161782f1", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "4d292d10-ea26-458a-9d2c-9836bcae04a7", + "metadata": {}, + "source": [ + "### Splitting the Data \n", + "#### To evaluate the performance of our model effectively, we must split the dataset into training and testing sets. The training set is used to train the model, while the testing set serves as unseen data to validate the model's ability to generalize." + ] + }, + { + "cell_type": "markdown", + "id": "2c7d5f39-36a4-4bb1-9ab1-9042149e272f", + "metadata": {}, + "source": [ + "
\n", + " Why Split the Data? \n", + " To evaluate the performance of our model effectively, we must split the dataset into training and testing sets. The training set is used to train the model, while the testing set serves as unseen data to validate the model's ability to generalize.\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "id": "ca7622c4-54cb-4c22-a23d-309e3ed3b532", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "X shape: (686, 23)\n", + "y shape: (686,)\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "# load merged data\n", + "merged_data = pd.read_csv(\"merged_data.csv\") # replace with actual merged file path\n", + "\n", + "# extract features and labels\n", + "X = merged_data.iloc[:, 2:].values # features (columns 2–24)\n", + "y = merged_data[\"label\"].values # labels\n", + "\n", + "# verify shapes\n", + "print(f\"X shape: {X.shape}\") # should be (686, 23)\n", + "print(f\"y shape: {y.shape}\") # should be (686,)\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "4f555100-4623-40e4-9e86-2f701cfb9cc6", + "metadata": {}, + "source": [ + "
Debugging Tip: Always verify the shapes of your feature matrix (X) and label vector (y) after splitting to ensure the data is correctly prepared for training and evaluation.
" + ] + }, + { + "cell_type": "markdown", + "id": "8bcff7b8-9286-4ada-bb9a-912911b12735", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "79f555a7-1f59-4a21-8b22-8eb71e8c306f", + "metadata": {}, + "source": [ + "### Train-Test Split \n", + "#### Purpose of the Train-Test Split: The train-test split is crucial for evaluating a model's performance. By dividing the data:\n", + "#### Training Set: Used to train the model, enabling it to learn patterns from the EEG data.\n", + "#### Testing Set: Held out to validate the model's ability to generalize to unseen data, ensuring it isn't overfitting." + ] + }, + { + "cell_type": "code", + "execution_count": 107, + "id": "728b3f1c-3bd8-42f8-bcbf-ce1143c8057b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "training set size: 548\n", + "testing set size: 138\n", + "class distribution in training set: {0: 435, 1: 113}\n", + "class distribution in testing set: {0: 110, 1: 28}\n" + ] + } + ], + "source": [ + "# split into training and testing sets\n", + "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)\n", + "\n", + "# print splits for verification\n", + "print(f\"training set size: {len(X_train)}\")\n", + "print(f\"testing set size: {len(X_test)}\")\n", + "print(f\"class distribution in training set: {pd.Series(y_train).value_counts().to_dict()}\")\n", + "print(f\"class distribution in testing set: {pd.Series(y_test).value_counts().to_dict()}\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "a61543b7-3ba9-4e69-b135-9c4f6e4da4c7", + "metadata": {}, + "source": [ + "
Stratification: \n", + "Stratifying the split ensures the same proportion of \"seizure\" (1) and \"non-seizure\" (0) labels in both training and testing sets. This step is essential because imbalanced datasets can lead to biased models.
" + ] + }, + { + "cell_type": "markdown", + "id": "6699bb03-339f-4fbf-8032-2c2a95ae70bc", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "9627d5dd-5adc-4a4e-9a92-cbc27425fe82", + "metadata": {}, + "source": [ + "### Convert Data into PyTorch Tensors \n", + "#### Purpose: To train and evaluate the Spiking Neural Network (SNN) in PyTorch, the data must be converted into tensors. PyTorch tensors are the primary data structure used in PyTorch for computations on GPUs. The DataLoader is used to efficiently manage the data during training and testing [6]." + ] + }, + { + "cell_type": "code", + "execution_count": 124, + "id": "79d5ff37-2278-493f-b409-36f78b79324a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "X_train_tensor shape: torch.Size([548, 23]), y_train_tensor shape: torch.Size([548])\n", + "X_test_tensor shape: torch.Size([138, 23]), y_test_tensor shape: torch.Size([138])\n", + "Number of training batches: 18\n", + "Number of testing batches: 5\n", + "DataLoaders created successfully.\n" + ] + } + ], + "source": [ + "import torch\n", + "from torch.utils.data import DataLoader, TensorDataset\n", + "\n", + "# convert to tensors\n", + "X_train_tensor = torch.tensor(X_train, dtype=torch.float32)\n", + "y_train_tensor = torch.tensor(y_train, dtype=torch.long)\n", + "X_test_tensor = torch.tensor(X_test, dtype=torch.float32)\n", + "y_test_tensor = torch.tensor(y_test, dtype=torch.long)\n", + "\n", + "# create dataLoaders\n", + "train_loader = DataLoader(TensorDataset(X_train_tensor, y_train_tensor), batch_size=32, shuffle=True)\n", + "test_loader = DataLoader(TensorDataset(X_test_tensor, y_test_tensor), batch_size=32, shuffle=False)\n", + "\n", + "# check tensor shapes\n", + "print(f\"X_train_tensor shape: {X_train_tensor.shape}, y_train_tensor shape: {y_train_tensor.shape}\")\n", + "print(f\"X_test_tensor shape: {X_test_tensor.shape}, y_test_tensor shape: {y_test_tensor.shape}\")\n", + "\n", + "# check dataLoader length\n", + "print(f\"Number of training batches: {len(train_loader)}\")\n", + "print(f\"Number of testing batches: {len(test_loader)}\")\n", + "\n", + "print(\"DataLoaders created successfully.\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "51950149-bb95-4ad9-be1a-698263f4cb13", + "metadata": {}, + "source": [ + "
Why Use DataLoaders? DataLoaders simplify data handling during training by automating the batching and shuffling process. They ensure optimal GPU/CPU utilization by loading data efficiently in the background, allowing seamless training even with large datasets.
\n" + ] + }, + { + "cell_type": "markdown", + "id": "92941ca2-1d7a-48d5-9042-c020b3d2f269", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "4fd34dba-831d-43ac-9e4b-03667e7d4653", + "metadata": {}, + "source": [ + "## III. Model Architecture and Training " + ] + }, + { + "cell_type": "markdown", + "id": "c4713bab-b52b-420b-9c2e-befff9feed8f", + "metadata": {}, + "source": [ + "### Nework Design \n", + "#### Input Layer: 23 Neuron corresponding to the 23 preprocessed EEG channels. Responsible for feeding the preprocessed features into the network for analysis. \n", + "#### Hidden Layer: 1000 Leaky Integrate-and-Fire (LIF) Neurons designed to capture temporal dependencies and patterns in EEG data over time. Utilizes spiking dynamics for event-based processing, emulating biological neurons.\n", + "#### Output Layer: 2 Neurons designed for binary classification—detecting whether a seizure event is present (1) or absent (0).\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "7b099529-1890-411f-9541-4812907331b8", + "metadata": {}, + "source": [ + "### Leaky Integrate-and-Fire \n", + "#### The LIF model governs the behavior of the hidden layer, dynamically updating each neuron's membrane potential based on incoming input [5]. A spike is generated whenever the membrane potential exceeds a threshold, after which the potential resets.\n", + "\n", + "$$\n", + "\\tau_m \\frac{dV}{dt} = - (V - V_{rest}) + R_m I(t)\n", + "$$\n", + "where:\n", + "\n", + "- $\\tau_m$ is the membrane time constant, controlling the rate of decay of the membrane potential.\n", + "- $V$ is the membrane potential, representing the current state of the neuron.\n", + "- $V_{rest}$ is the resting membrane potential, the baseline state of the neuron.\n", + "- $R_m$ is the membrane resistance, determining how much the membrane potential changes in response to an input current.\n", + "- $I(t)$ is the input current, derived from the preceding layer or external input.\n" + ] + }, + { + "cell_type": "markdown", + "id": "77ca4035-758f-4a93-b512-489156cd1914", + "metadata": {}, + "source": [ + "#### The model is a SNN designed for binary classification of EEG data. It leverages the bio inspiration of spiking neurons to process temporal dependencies in the data. The architecture consists of input, hidden, and output layers, with (LIF) neurons modeling the hidden and output dynamics." + ] + }, + { + "cell_type": "code", + "execution_count": 168, + "id": "48b15dc9-ecb5-4f21-bcaf-692e62337321", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SNNModel(\n", + " (fc1): Linear(in_features=23, out_features=1000, bias=True)\n", + " (lif1): Leaky()\n", + " (fc2): Linear(in_features=1000, out_features=2, bias=True)\n", + " (lif2): Leaky()\n", + ")\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import snntorch as snn\n", + "\n", + "class SNNModel(nn.Module):\n", + " def __init__(self, input_size=23, hidden_size=1000, output_size=2, beta=0.9, num_steps=25):\n", + " super(SNNModel, self).__init__()\n", + " self.fc1 = nn.Linear(input_size, hidden_size)\n", + " self.lif1 = snn.Leaky(beta=beta)\n", + " self.fc2 = nn.Linear(hidden_size, output_size)\n", + " self.lif2 = snn.Leaky(beta=beta)\n", + " self.num_steps = num_steps\n", + "\n", + " def forward(self, x):\n", + " # initialize states for LIF neurons\n", + " mem1 = self.lif1.init_leaky()\n", + " mem2 = self.lif2.init_leaky()\n", + "\n", + " # record the output of the final layer across time steps\n", + " spk2_rec = []\n", + " mem2_rec = []\n", + "\n", + " for step in range(self.num_steps):\n", + " cur1 = self.fc1(x)\n", + " spk1, mem1 = self.lif1(cur1, mem1)\n", + " cur2 = self.fc2(spk1)\n", + " spk2, mem2 = self.lif2(cur2, mem2)\n", + " spk2_rec.append(spk2)\n", + " mem2_rec.append(mem2)\n", + "\n", + " # return spikes and membrane potentials across time steps\n", + " return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)\n", + "\n", + "# initialize the model\n", + "model = SNNModel()\n", + "print(model)\n" + ] + }, + { + "cell_type": "markdown", + "id": "14cbc10e-fe76-4680-ac02-717f7af1bf85", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "66f953b7-862a-4ccb-a927-eaa497c1635a", + "metadata": {}, + "source": [ + "
\n", + " Why LIF? \n", + " The use of LIF neurons allows the model to leverage temporal dynamics inherent in EEG signals[5].\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "78e5513c-8c16-4da1-a259-fb77d6d5bb3f", + "metadata": {}, + "source": [ + "
\n", + " Why Spiking Neural Networks? \n", + " It efficientaly processes temporal and event based data like EEG signals. Spike based computation mimics brain activity [7]. \n", + "
\n", + "\n", + "
\n", + " Why This Architexture? \n", + " This architecture is optimized for capturing the temporal dependencies inherent in EEG signals. By leveraging the event-based processing capabilities of Spiking Neural Networks (SNNs), the design provides a biologically inspired approach to seizure detection [8]. The LIF neurons in the hidden layer excel at detecting temporal patterns, making the network particularly effective for binary classification tasks like seizure detection.\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "id": "a79d89a2-7cdc-4260-860b-529722ca8144", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "246e64f1-99e4-4dd3-bd39-0b78fcdadbee", + "metadata": {}, + "source": [ + "### Define Training Parameters \n", + "#### Optimizer and Loss Function: To train the Spiking Neural Network (SNN), we use the following:\n", + "#### Optimizer: Adam, which adapts the learning rate for each parameter to optimize convergence [3].\n", + "#### Learning Rate: Set to a small value (0.001) for gradual learning, ensuring the model captures nuanced patterns in EEG data.\n", + "#### Loss Function: CrossEntropyLoss, a standard choice for binary classification, computes the difference between predicted probabilities and true labels [3].\n" + ] + }, + { + "cell_type": "code", + "execution_count": 180, + "id": "d3ed2a81-9fa9-4eb5-9753-2df8bd4d176c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Optimizer initialized with the following parameters:\n", + "Learning Rate: 0.001\n", + "Optimizer and loss function initialized and validated successfully.\n" + ] + } + ], + "source": [ + "from torch.optim import Adam\n", + "\n", + "# optimizer and loss function\n", + "optimizer = Adam(model.parameters(), lr=0.001)\n", + "loss_fn = nn.CrossEntropyLoss()\n", + "\n", + "# debug: check optimizer parameters\n", + "print(\"Optimizer initialized with the following parameters:\")\n", + "for param_group in optimizer.param_groups:\n", + " print(f\"Learning Rate: {param_group['lr']}\")\n", + "\n", + "\n", + "print(\"Optimizer and loss function initialized and validated successfully.\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "19ad5d45-532f-4bf6-bc9c-d375c5e2c903", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "9db95d26-1da8-4af2-b9e4-f5f1be1e5d8b", + "metadata": {}, + "source": [ + "### Training the Model \n", + "#### This section focuses on training the SNN using the training dataset. The training process leverages the cross-entropy loss function, a standard loss function for classification tasks, and gradient-based backpropagation with surrogate gradients. Surrogate gradients enable the training of spiking neurons by approximating gradients for non-differentiable spiking activation functions [3]." + ] + }, + { + "cell_type": "code", + "execution_count": 185, + "id": "5d4e26cd-1abb-4ef7-a01d-d7016bef848b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 1: 100%|████████████████████████████████| 18/18 [00:00<00:00, 73.06it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1, Loss: 10.321447372436523\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 2: 100%|████████████████████████████████| 18/18 [00:00<00:00, 79.01it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2, Loss: 9.645870526631674\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 3: 100%|████████████████████████████████| 18/18 [00:00<00:00, 81.47it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3, Loss: 9.829082594977486\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 4: 100%|████████████████████████████████| 18/18 [00:00<00:00, 83.35it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4, Loss: 10.852201249864367\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 5: 100%|████████████████████████████████| 18/18 [00:00<00:00, 83.50it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5, Loss: 11.142020596398247\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 6: 100%|████████████████████████████████| 18/18 [00:00<00:00, 69.05it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6, Loss: 10.889189614189995\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 7: 100%|████████████████████████████████| 18/18 [00:00<00:00, 81.75it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7, Loss: 11.645541508992514\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 8: 100%|████████████████████████████████| 18/18 [00:00<00:00, 84.28it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8, Loss: 10.80745177798801\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 9: 100%|████████████████████████████████| 18/18 [00:00<00:00, 83.86it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9, Loss: 9.705048031277126\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 10: 100%|███████████████████████████████| 18/18 [00:00<00:00, 81.38it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10, Loss: 9.561542722913954\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 11: 100%|███████████████████████████████| 18/18 [00:00<00:00, 81.78it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11, Loss: 8.944989654752943\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 12: 100%|███████████████████████████████| 18/18 [00:00<00:00, 84.16it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12, Loss: 9.114972670873007\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 13: 100%|███████████████████████████████| 18/18 [00:00<00:00, 83.45it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13, Loss: 8.85685912768046\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 14: 100%|███████████████████████████████| 18/18 [00:00<00:00, 82.87it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14, Loss: 9.71735077434116\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 15: 100%|███████████████████████████████| 18/18 [00:00<00:00, 84.23it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15, Loss: 9.51416187816196\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Testing: 100%|█████████████████████████████████| 5/5 [00:00<00:00, 141.23it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test Loss: 14.3072, Test Accuracy: 80.43\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import Adam\n", + "from tqdm import tqdm\n", + "import snntorch as snn\n", + "from sklearn.metrics import roc_auc_score, confusion_matrix, classification_report\n", + "from torch.nn.functional import softmax\n", + "\n", + "# parameters\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "num_steps = 25 # Number of simulation steps\n", + "\n", + "# assuming model, train_loader, test_loader are already defined\n", + "loss_fn = nn.CrossEntropyLoss()\n", + "optimizer = Adam(model.parameters(), lr=0.001)\n", + "\n", + "# training Loop\n", + "num_epochs = 15\n", + "for epoch in range(num_epochs):\n", + " model.train()\n", + " total_loss = 0\n", + " for batch_X, batch_y in tqdm(train_loader, desc=f\"Epoch {epoch + 1}\"):\n", + " batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + "\n", + " # forward pass\n", + " spk_rec, mem_rec = model(batch_X)\n", + "\n", + " # loss calculation\n", + " loss_val = torch.zeros((1), device=device)\n", + " for step in range(num_steps):\n", + " loss_val += loss_fn(mem_rec[step], batch_y)\n", + "\n", + " # backward pass\n", + " loss_val.backward()\n", + " optimizer.step()\n", + "\n", + " total_loss += loss_val.item()\n", + "\n", + " print(f\"Epoch {epoch + 1}, Loss: {total_loss / len(train_loader)}\")\n", + "\n", + "# testing Loop with AUC-ROC\n", + "model.eval()\n", + "test_loss = 0\n", + "correct = 0\n", + "total = 0\n", + "true_labels = []\n", + "predicted_probs = []\n", + "\n", + "with torch.no_grad():\n", + " for batch_X, batch_y in tqdm(test_loader, desc=\"Testing\"):\n", + " batch_X, batch_y = batch_X.to(device), batch_y.to(device)\n", + "\n", + " spk_rec, mem_rec = model(batch_X)\n", + "\n", + " # Summed loss over time\n", + " loss_val = torch.zeros((1), device=device)\n", + " for step in range(num_steps):\n", + " loss_val += loss_fn(mem_rec[step], batch_y)\n", + "\n", + " test_loss += loss_val.item()\n", + "\n", + " # Collect probabilities for AUC-ROC\n", + " probabilities = softmax(spk_rec.sum(dim=0), dim=1) # Sum over time and apply softmax\n", + " predicted_probs.extend(probabilities[:, 1].cpu().numpy()) # Probability for positive class\n", + " true_labels.extend(batch_y.cpu().numpy())\n", + "\n", + " # Accuracy Calculation\n", + " _, predicted = spk_rec.sum(dim=0).max(1)\n", + " correct += (predicted == batch_y).sum().item()\n", + " total += batch_y.size(0)\n", + "\n", + "print(f\"Test Loss: {test_loss / len(test_loader):.4f}, Test Accuracy: {test_accuracy:.2f}\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "57002045-8c5f-45ae-8c76-54d66533b547", + "metadata": {}, + "source": [ + "
\n", + " Why This Approach? \n", + " The gradient-based backpropagation with surrogate gradients allows the model to learn temporal patterns in EEG data. The spiking neuron architecture, with time-stepped simulation, ensures temporal dependencies are captured effectively.\n", + "\n", + " Temporal Learning: The spiking dynamics allow the network to analyze EEG signals in a biologically plausible way.
\n", + " Loss Summation: Accumulating the loss over num_steps ensures that the model learns from patterns across the entire time window.\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "id": "fb201057-47ae-4a37-9683-5ad62c402749", + "metadata": {}, + "source": [ + "## VI. Results and Evaluation \n", + "### AUC-ROC and Confusion Matrix \n", + "#### AUC-ROC: evaluates the model's ability to distinguish between classes across all decision thresholds. It is particularly valuable for seizure detection as EEG signals are inherently noisy and seizures may activate multiple regions simultaneously. AUC-ROC provides a balanced view of sensitivity (true positive rate) and specificity (true negative rate) over varying thresholds [9].\n", + "\n", + "#### Confusion Matrix: The confusion matrix visually represents:\n", + "#### True Positives (TP): Correctly detected seizures.\n", + "#### True Negatives (TN): Correctly identified non-seizure instances.\n", + "#### False Positives (FP): Non-seizures incorrectly classified as seizures.\n", + "#### False Negatives (FN): Missed seizure detections.\n" + ] + }, + { + "cell_type": "markdown", + "id": "c89ff5af-f14c-43c6-84ca-0d466782c4be", + "metadata": {}, + "source": [ + "
Why AUC-ROC? AUC-ROC (Area Under the Receiver Operating Characteristic Curve) is a robust metric for assessing the model’s performance across all thresholds, mitigating biases introduced by noisy EEG signals. It offers insights into trade-offs between sensitivity and specificity, particularly crucial in medical applications like seizure detection [9].
\n" + ] + }, + { + "cell_type": "code", + "execution_count": 189, + "id": "1e74fed4-41b5-429f-96e0-391a07b70190", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " AUC-ROC: 0.67\n", + "\n", + "confusion matrix:\n", + "[[109 1]\n", + " [ 24 4]]\n", + "\n", + "classification report:\n", + " precision recall f1-score support\n", + "\n", + " non-Seizure 0.82 0.99 0.90 110\n", + " seizure 0.80 0.14 0.24 28\n", + "\n", + " accuracy 0.82 138\n", + " macro avg 0.81 0.57 0.57 138\n", + "weighted avg 0.82 0.82 0.76 138\n", + "\n" + ] + } + ], + "source": [ + "# compute AUC-ROC\n", + "auc_score = roc_auc_score(true_labels, predicted_probs)\n", + "test_accuracy = 100.0 * correct / total\n", + "\n", + "# confusion Matrix\n", + "predicted_labels = [1 if prob > 0.5 else 0 for prob in predicted_probs]\n", + "conf_matrix = confusion_matrix(true_labels, predicted_labels)\n", + "report = classification_report(true_labels, predicted_labels, target_names=[\"non-Seizure\", \"seizure\"])\n", + "\n", + "print(f\" AUC-ROC: {auc_score:.2f}\")\n", + "print(\"\\nconfusion matrix:\")\n", + "print(conf_matrix)\n", + "print(\"\\nclassification report:\")\n", + "print(report)" + ] + }, + { + "cell_type": "markdown", + "id": "007f17de-a7ac-48eb-9217-d8c1a2ab63f4", + "metadata": {}, + "source": [ + "### Analysis \n", + "#### AUC-ROC: The value of 0.67 suggests moderate model performance, with room for improvement in distinguishing between seizure and non-seizure classes.\n", + "#### Confusion Matrix: The model effectively classifies non-seizures but struggles with sensitivity for seizures, as reflected in the high false negative rate (24 FN).\n", + "\n", + "#### Classification Report: Precision for seizures is relatively high (.80), but recall (0.14) indicates the model is missing a significant number of seizures. Balancing sensitivity and specificity will be critical in subsequent improvements.\n" + ] + }, + { + "cell_type": "markdown", + "id": "03921d78-be40-414a-9c75-cea4a48907e5", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "6a22998b-b824-4859-878c-39b3d14bf69d", + "metadata": {}, + "source": [ + "## V. Conclusion and Future Work \n", + "### Performance of the SNN model: \n", + "#### The Spiking Neural Network (SNN) model achieved a test accuracy of 80.43%, with an AUC-ROC score of .66, demonstrating its potential for automated seizure detection. While the model effectively identified non-seizure instances, challenges in sensitivity highlight areas for improvement in seizure detection.\n", + "### Challenges Encountered \n", + "\n", + "#### Noise in EEG Data: Background noise and artifacts in the EEG recordings introduced additional complexity, necessitating robust preprocessing techniques.\n", + "\n", + "#### Class Imbalance: The dataset's significant imbalance between seizure and non-seizure instances posed challenges for achieving high sensitivity.\n", + "\n", + "#### Biological Complexity: Seizure patterns are non-linear and often asynchronous, which can be challenging for binary classification.\n" + ] + }, + { + "cell_type": "markdown", + "id": "2c95843e-5a69-4dba-871e-e796b541322e", + "metadata": {}, + "source": [ + "## VI. References \n", + "#### [1] Guttag, John. \"CHB-MIT Scalp EEG Database\" (version 1.0.0). PhysioNet (2010), https://doi.org/10.13026/C2K01R. \n", + "\n", + "#### [2] Shoeibi A, Khodatars M, Ghassemi N, Jafari M, Moridian P, Alizadehsani R, Panahiazar M, Khozeimeh F, Zare A, Hosseini-Nejad H, Khosravi A, Atiya AF, Aminshahidi D, Hussain S, Rouhani M, Nahavandi S, Acharya UR. Epileptic Seizures Detection Using Deep Learning Techniques: A Review. Int J Environ Res Public Health. 2021 May 27;18(11):5780. doi: 10.3390/ijerph18115780. PMID: 34072232; PMCID: PMC8199071.\n", + "#### [3] Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. “Training Spiking Neural Networks Using Lessons From Deep Learning”. arXiv preprint arXiv:2109.12894, September 2021.\n", + "#### [4] Wang, Z., Mengoni, P. Seizure classification with selected frequency bands and EEG montages: a Natural Language Processing approach. Brain Inf. 9, 11 (2022). https://doi.org/10.1186/s40708-022-00159-3\n", + "#### [5] Lu, Sijia, and Feng Xu. \"Linear Leaky-Integrate-and-Fire Neuron Model-Based Spiking Neural Networks and Its Mapping Relationship to Deep Neural Networks.\" Frontiers in Neuroscience, vol. 16, 2022, https://doi.org/10.3389/fnins.2022.857513. \n", + "#### [6] PyTorch Documentation. “Tensors.” PyTorch. https://pytorch.org/docs/stable/tensors.html\n", + "#### [7] Yamazaki K, Vo-Ho VK, Bulsara D, Le N. Spiking Neural Networks and Their Applications: A Review. Brain Sci. 2022 Jun 30;12(7):863. doi: 10.3390/brainsci12070863. PMID: 35884670; PMCID: PMC9313413.\n", + "#### [8] Hussein R, Palangi H, Ward RK, Wang ZJ. Optimized deep neural network architecture for robust detection of epileptic seizures using EEG signals. Clin Neurophysiol. 2019 Jan;130(1):25-37. doi: 10.1016/j.clinph.2018.10.010. Epub 2018 Nov 15. PMID: 30472579.\n", + "#### [9] Dastgoshadeh M, Rabiei Z. Detection of epileptic seizures through EEG signals using entropy features and ensemble learning. Front Hum Neurosci. 2023 Feb 1;16:1084061. doi: 10.3389/fnhum.2022.1084061. PMID: 36875740; PMCID: PMC9976189." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/snntorch/.DS_Store b/snntorch/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..88e31ebf29f51e1e77ebf6c3b31c3158a2d459b8 GIT binary patch literal 6148 zcmeHK%}xR_82pMPZa@r2FDCm0B)-8)@Zi}8P(UQal_mT|ZhQ9~d2K~YM7ko9pb(oaR5iitER|^Zo(|C1d(31lxA^pAQDJL` zI-m~xH3#_aw(UDRlevWN@9sJ;#-qF#fx+eB5~cTH11*tbsB#&~PA0OHRm6Ua*a_Hk zFgM^9XhnQ;9`QM*t-CWuMFrQCXiVm%q8edDV|fIZc^0**_ej-2o@a0iXt;jShh?3z zF^!ZhdBBY3$hh{S97-NJLEl~j-#TDEeJ7jH!%*3l6?8MqZX#2ZdQa$Ns5%K$3p}v` zV?E?$s0*`gk#*%}xIa5XFGCfx@fp+oDW1g|1zJ||VpE(U&L|Tm> zI)FW!k+dCZtq!OI>cE!+Tpv;tV&<`QXqOID_6PuMqT2@A{Ift!>M`?JI>Z%(aZ;d@ z8h^wvPLBT6^D>X6LnjyG4Tbzt9tBloq-{r~Fl`F~%e z-_!wh;9oglnuD7`A5-#sYh`lWYg3dv6b5;vLn}e$w`1GjR(ytH1J_bEfSJeAA$kz{ ON5I;kl{)aN4txUA?#Olk literal 0 HcmV?d00001