Skip to content

Commit

Permalink
Add preprocessing pipeline and associated configs.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Nov 17, 2023
1 parent 2b8b0c6 commit b4c21d1
Show file tree
Hide file tree
Showing 7 changed files with 674 additions and 1 deletion.
70 changes: 70 additions & 0 deletions spikewrap/configs/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import glob
import os
from pathlib import Path
from typing import Dict, Tuple

import yaml

from ..utils import utils


def get_configs(name: str) -> Tuple[Dict, Dict, Dict]:
"""
Loads the config yaml file in the same folder
(spikewrap/configs) containing preprocessing (pp)
and sorter options.
Once loaded, the list containing preprocesser name
and kwargs is cast to tuple. This keeps the type
checker happy while not requiring a tuple
in the .yaml which require ugly tags.
Parameters
----------
name: name of the configs to load. Should not include the
.yaml suffix.
Returns
-------
pp_steps : a dictionary containing the preprocessing
step order (keys) and a [pp_name, kwargs]
list containing the spikeinterface preprocessing
step and keyword options.
sorter_options : a dictionary with sorter name (key) and
a dictionary of kwargs to pass to the
spikeinterface sorter class.
"""
config_dir = Path(os.path.dirname(os.path.realpath(__file__)))

available_files = glob.glob((config_dir / "*.yaml").as_posix())
available_files = [Path(path_).stem for path_ in available_files]

if name not in available_files: # then assume it is a full path
assert Path(name).is_file(), (
f"{name} is neither the name of an existing "
f"config or valid path to configuration file."
)

assert Path(name).suffix in [
".yaml",
".yml",
], f"{name} is not the path to a .yaml file"

config_filepath = Path(name)

else:
config_filepath = config_dir / f"{name}.yaml"

with open(config_filepath) as file:
config = yaml.full_load(file)

pp_steps = config["preprocessing"] if "preprocessing" in config else {}
sorter_options = config["sorting"] if "sorting" in config else {}
waveform_options = config["waveforms"] if "waveforms" in config else {}

utils.cast_pp_steps_values(pp_steps, "tuple")

return pp_steps, sorter_options, waveform_options
37 changes: 37 additions & 0 deletions spikewrap/configs/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
'preprocessing':
'1':
- phase_shift
- {}
'2':
- bandpass_filter
- freq_min: 300
freq_max: 6000
'3':
- common_reference
- operator: median
reference: global

'sorting':
'kilosort2':
'car': False # common average referencing
'freq_min': 150 # highpass filter cutoff, False nor 0 does not work to turn off. (results in KS error)
'kilosort2_5':
'car': False
'freq_min': 150
'kilosort3':
'car': False
'freq_min': 300
'mountainsort5':
'scheme': '2'
'filter': False

'waveforms':
'ms_before': 2
'ms_after': 2
'max_spikes_per_unit': 500
'return_scaled': True
# Sparsity Options
'sparse': True
'peak_sign': "neg"
'method': "radius"
'radius_um': 75
110 changes: 109 additions & 1 deletion spikewrap/data_classes/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import shutil
from dataclasses import dataclass
from typing import Dict

import spikeinterface

from spikewrap.utils import utils

from .base import BaseUserDict


Expand All @@ -12,7 +17,7 @@ class PreprocessingData(BaseUserDict):
Details on the preprocessing steps are held in the dictionary keys e.g.
e.g. 0-raw, 1-raw-bandpass_filter, 2-raw_bandpass_filter-common_average
and recording objects are held in the value. These are generated
by the `pipeline.preprocess.run_preprocessing()` function.
by the `pipeline.preprocess._preprocess_and_save_all_runs()` function.
The class manages paths to raw data and preprocessing output,
as defines methods to dump key information and the SpikeInterface
Expand Down Expand Up @@ -44,6 +49,19 @@ def __post_init__(self) -> None:
self.update_two_layer_dict(self, ses_name, run_name, {"0-raw": None})
self.update_two_layer_dict(self.sync, ses_name, run_name, None)

def set_pp_steps(self, pp_steps: Dict) -> None:
"""
Set the preprocessing steps (`pp_steps`) attribute
that defines the preprocessing steps and options.
Parameters
----------
pp_steps : Dict
Preprocessing steps to setp as class attribute. These are used
when `pipeline.preprocess._fill_run_data_with_preprocessed_recording()` function is called.
"""
self.pp_steps = pp_steps

def _validate_rawdata_inputs(self) -> None:
self._validate_inputs(
"rawdata",
Expand All @@ -52,3 +70,93 @@ def _validate_rawdata_inputs(self) -> None:
self.get_rawdata_ses_path,
self.get_rawdata_run_path,
)

# Saving preprocessed data ---------------------------------------------------------

def save_preprocessed_data(
self, ses_name: str, run_name: str, overwrite: bool = False
) -> None:
"""
Save the preprocessed output data to binary, as well
as important class attributes to a .yaml file.
Both are saved in a folder called 'preprocessed'
in derivatives/<sub_name>/<pp_run_name>
Parameters
----------
run_name : str
Run name corresponding to one of `self.preprocessing_run_names`.
overwrite : bool
If `True`, existing preprocessed output will be overwritten.
By default, SpikeInterface will error if a binary recording file
(`si_recording`) already exists where it is trying to write one.
In this case, an error should be raised before this function
is called.
"""
if overwrite:
if self.get_preprocessing_path(ses_name, run_name).is_dir():
shutil.rmtree(self.get_preprocessing_path(ses_name, run_name))

self._save_preprocessed_binary(ses_name, run_name)
self._save_sync_channel(ses_name, run_name)
self._save_preprocessing_info(ses_name, run_name)

def _save_preprocessed_binary(self, ses_name: str, run_name: str) -> None:
"""
Save the fully preprocessed data (i.e. last step in the
preprocessing chain) to binary file. This is required for sorting.
"""
recording, __ = utils.get_dict_value_from_step_num(
self[ses_name][run_name], "last"
)
recording.save(
folder=self._get_pp_binary_data_path(ses_name, run_name), chunk_memory="10M"
)

def _save_sync_channel(self, ses_name: str, run_name: str) -> None:
"""
Save the sync channel separately. In SI, sorting cannot proceed
if the sync channel is loaded to ensure it does not interfere with
sorting. As such, the sync channel is handled separately here.
"""
utils.message_user(f"Saving sync channel for {ses_name} run {run_name}")

assert (
self.sync[ses_name][run_name] is not None
), f"Sync channel on PreprocessData session {ses_name} run {run_name} is None"

self.sync[ses_name][run_name].save( # type: ignore
folder=self._get_sync_channel_data_path(ses_name, run_name),
chunk_memory="10M",
)

def _save_preprocessing_info(self, ses_name: str, run_name: str) -> None:
"""
Save important details of the postprocessing for provenance.
Importantly, this is used to check that the preprocessing
file used for waveform extraction in `PostprocessData`
matches the preprocessing that was used for sorting.
"""
assert self.pp_steps is not None, "type narrow `pp_steps`."

utils.cast_pp_steps_values(self.pp_steps, "list")

preprocessing_info = {
"base_path": self.base_path.as_posix(),
"sub_name": self.sub_name,
"ses_name": ses_name,
"run_name": run_name,
"rawdata_path": self.get_rawdata_run_path(ses_name, run_name).as_posix(),
"pp_steps": self.pp_steps,
"spikeinterface_version": spikeinterface.__version__,
"spikewrap_version": utils.spikewrap_version(),
"datetime_written": utils.get_formatted_datetime(),
}

utils.dump_dict_to_yaml(
self.get_preprocessing_info_path(ses_name, run_name), preprocessing_info
)
File renamed without changes.
32 changes: 32 additions & 0 deletions spikewrap/examples/example_preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from pathlib import Path

from spikewrap.pipeline.load_data import load_data
from spikewrap.pipeline.preprocess import run_preprocessing

base_path = Path(
r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/test_data/steve_multi_run/1119617/time-short-multises"
# r"C:\data\ephys\test_data\steve_multi_run\1119617\time-miniscule-multises"
)

sub_name = "sub-1119617"
sessions_and_runs = {
"ses-001": [
"run-001_1119617_LSE1_shank12_g0",
"run-002_made_up_g0",
],
"ses-002": [
"run-001_1119617_pretest1_shank12_g0",
],
"ses-003": [
"run-002_1119617_pretest1_shank12_g0",
],
}

loaded_data = load_data(base_path, sub_name, sessions_and_runs, data_format="spikeglx")

run_preprocessing(
loaded_data,
pp_steps="default",
handle_existing_data="overwrite",
log=True,
)
Loading

0 comments on commit b4c21d1

Please sign in to comment.