Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Review preprocessing. #135

Open
wants to merge 3 commits into
base: reviewed_code
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions spikewrap/configs/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import glob
import os
from pathlib import Path
from typing import Dict, Tuple

import yaml

from ..utils import utils


def get_configs(name: str) -> Tuple[Dict, Dict, Dict]:
"""
Loads the config yaml file in the same folder
(spikewrap/configs) containing preprocessing (pp)
and sorter options.

Once loaded, the list containing preprocesser name
and kwargs is cast to tuple. This keeps the type
checker happy while not requiring a tuple
in the .yaml which require ugly tags.

Parameters
----------

name: name of the configs to load. Should not include the
.yaml suffix.

Returns
-------

pp_steps : a dictionary containing the preprocessing
step order (keys) and a [pp_name, kwargs]
list containing the spikeinterface preprocessing
step and keyword options.

sorter_options : a dictionary with sorter name (key) and
a dictionary of kwargs to pass to the
spikeinterface sorter class.
"""
config_dir = Path(os.path.dirname(os.path.realpath(__file__)))

available_files = glob.glob((config_dir / "*.yaml").as_posix())
available_files = [Path(path_).stem for path_ in available_files]

if name not in available_files: # then assume it is a full path
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change this, see comment below on setting a config directory.

assert Path(name).is_file(), (
f"{name} is neither the name of an existing "
f"config or valid path to configuration file."
)

assert Path(name).suffix in [
".yaml",
".yml",
], f"{name} is not the path to a .yaml file"

config_filepath = Path(name)

else:
config_filepath = config_dir / f"{name}.yaml"

with open(config_filepath) as file:
config = yaml.full_load(file)

pp_steps = config["preprocessing"] if "preprocessing" in config else {}
sorter_options = config["sorting"] if "sorting" in config else {}
waveform_options = config["waveforms"] if "waveforms" in config else {}

utils.cast_pp_steps_values(pp_steps, "tuple")

return pp_steps, sorter_options, waveform_options
37 changes: 37 additions & 0 deletions spikewrap/configs/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
'preprocessing':
'1':
- phase_shift
- {}
'2':
- bandpass_filter
- freq_min: 300
freq_max: 6000
'3':
- common_reference
- operator: median
reference: global

'sorting':
'kilosort2':
'car': False # common average referencing
'freq_min': 150 # highpass filter cutoff, False nor 0 does not work to turn off. (results in KS error)
'kilosort2_5':
'car': False
'freq_min': 150
'kilosort3':
'car': False
'freq_min': 300
'mountainsort5':
'scheme': '2'
'filter': False

'waveforms':
'ms_before': 2
'ms_after': 2
'max_spikes_per_unit': 500
'return_scaled': True
# Sparsity Options
'sparse': True
'peak_sign': "neg"
'method': "radius"
'radius_um': 75
39 changes: 39 additions & 0 deletions spikewrap/data_classes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from pathlib import Path
from typing import Callable, Dict, List, Literal, Tuple

from ..utils import utils


@dataclass
class BaseUserDict(UserDict):
Expand Down Expand Up @@ -140,6 +142,43 @@ def get_rawdata_ses_path(self, ses_name: str) -> Path:
def get_rawdata_run_path(self, ses_name: str, run_name: str) -> Path:
return self.get_rawdata_ses_path(ses_name) / "ephys" / run_name

# Derivatives Paths --------------------------------------------------------------

def get_derivatives_top_level_path(self) -> Path:
return self.base_path / "derivatives" / "spikewrap"

def get_derivatives_sub_path(self) -> Path:
return self.get_derivatives_top_level_path() / self.sub_name

def get_derivatives_ses_path(self, ses_name: str) -> Path:
return self.get_derivatives_sub_path() / ses_name

def get_derivatives_run_path(self, ses_name: str, run_name: str) -> Path:
return self.get_derivatives_ses_path(ses_name) / run_name

# Preprocessing Paths --------------------------------------------------------------

def get_preprocessing_path(self, ses_name: str, run_name: str) -> Path:
"""
Set the folder tree where preprocessing output will be
saved. This is canonical and should not change.
"""
preprocessed_output_path = (
self.get_derivatives_run_path(ses_name, run_name) / "preprocessing"
)
return preprocessed_output_path

def _get_pp_binary_data_path(self, ses_name: str, run_name: str) -> Path:
return self.get_preprocessing_path(ses_name, run_name) / "si_recording"

def _get_sync_channel_data_path(self, ses_name: str, run_name: str) -> Path:
return self.get_preprocessing_path(ses_name, run_name) / "sync_channel"

def get_preprocessing_info_path(self, ses_name: str, run_name: str) -> Path:
return self.get_preprocessing_path(ses_name, run_name) / utils.canonical_names(
"preprocessed_yaml"
)

@staticmethod
def update_two_layer_dict(dict_, ses_name, run_name, value):
"""
Expand Down
110 changes: 109 additions & 1 deletion spikewrap/data_classes/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import shutil
from dataclasses import dataclass
from typing import Dict

import spikeinterface

from spikewrap.utils import utils

from .base import BaseUserDict


Expand All @@ -12,7 +17,7 @@ class PreprocessingData(BaseUserDict):
Details on the preprocessing steps are held in the dictionary keys e.g.
e.g. 0-raw, 1-raw-bandpass_filter, 2-raw_bandpass_filter-common_average
and recording objects are held in the value. These are generated
by the `pipeline.preprocess.run_preprocessing()` function.
by the `pipeline.preprocess._preprocess_and_save_all_runs()` function.

The class manages paths to raw data and preprocessing output,
as defines methods to dump key information and the SpikeInterface
Expand Down Expand Up @@ -44,6 +49,19 @@ def __post_init__(self) -> None:
self.update_two_layer_dict(self, ses_name, run_name, {"0-raw": None})
self.update_two_layer_dict(self.sync, ses_name, run_name, None)

def set_pp_steps(self, pp_steps: Dict) -> None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a docstring is not even required for this function.

"""
Set the preprocessing steps (`pp_steps`) attribute
that defines the preprocessing steps and options.

Parameters
----------
pp_steps : Dict
Preprocessing steps to setp as class attribute. These are used
when `pipeline.preprocess._fill_run_data_with_preprocessed_recording()` function is called.
"""
self.pp_steps = pp_steps

def _validate_rawdata_inputs(self) -> None:
self._validate_inputs(
"rawdata",
Expand All @@ -52,3 +70,93 @@ def _validate_rawdata_inputs(self) -> None:
self.get_rawdata_ses_path,
self.get_rawdata_run_path,
)

# Saving preprocessed data ---------------------------------------------------------

def save_preprocessed_data(
self, ses_name: str, run_name: str, overwrite: bool = False
) -> None:
"""
Save the preprocessed output data to binary, as well
as important class attributes to a .yaml file.

Both are saved in a folder called 'preprocessed'
in derivatives/<sub_name>/<pp_run_name>

Parameters
----------
run_name : str
Run name corresponding to one of `self.preprocessing_run_names`.

overwrite : bool
If `True`, existing preprocessed output will be overwritten.
By default, SpikeInterface will error if a binary recording file
(`si_recording`) already exists where it is trying to write one.
In this case, an error should be raised before this function
is called.

"""
if overwrite:
if self.get_preprocessing_path(ses_name, run_name).is_dir():
shutil.rmtree(self.get_preprocessing_path(ses_name, run_name))

self._save_preprocessed_binary(ses_name, run_name)
self._save_sync_channel(ses_name, run_name)
self._save_preprocessing_info(ses_name, run_name)

def _save_preprocessed_binary(self, ses_name: str, run_name: str) -> None:
"""
Save the fully preprocessed data (i.e. last step in the
preprocessing chain) to binary file. This is required for sorting.
"""
recording, __ = utils.get_dict_value_from_step_num(
self[ses_name][run_name], "last"
)
recording.save(
folder=self._get_pp_binary_data_path(ses_name, run_name), chunk_memory="10M"
)

def _save_sync_channel(self, ses_name: str, run_name: str) -> None:
"""
Save the sync channel separately. In SI, sorting cannot proceed
if the sync channel is loaded to ensure it does not interfere with
sorting. As such, the sync channel is handled separately here.
"""
utils.message_user(f"Saving sync channel for {ses_name} run {run_name}")

assert (
self.sync[ses_name][run_name] is not None
), f"Sync channel on PreprocessData session {ses_name} run {run_name} is None"

self.sync[ses_name][run_name].save( # type: ignore
folder=self._get_sync_channel_data_path(ses_name, run_name),
chunk_memory="10M",
)

def _save_preprocessing_info(self, ses_name: str, run_name: str) -> None:
"""
Save important details of the postprocessing for provenance.

Importantly, this is used to check that the preprocessing
file used for waveform extraction in `PostprocessData`
matches the preprocessing that was used for sorting.
"""
assert self.pp_steps is not None, "type narrow `pp_steps`."

utils.cast_pp_steps_values(self.pp_steps, "list")

preprocessing_info = {
"base_path": self.base_path.as_posix(),
"sub_name": self.sub_name,
"ses_name": ses_name,
"run_name": run_name,
"rawdata_path": self.get_rawdata_run_path(ses_name, run_name).as_posix(),
"pp_steps": self.pp_steps,
"spikeinterface_version": spikeinterface.__version__,
"spikewrap_version": utils.spikewrap_version(),
"datetime_written": utils.get_formatted_datetime(),
}

utils.dump_dict_to_yaml(
self.get_preprocessing_info_path(ses_name, run_name), preprocessing_info
)
32 changes: 32 additions & 0 deletions spikewrap/examples/example_preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from pathlib import Path

from spikewrap.pipeline.load_data import load_data
from spikewrap.pipeline.preprocess import run_preprocessing

base_path = Path(
r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/test_data/steve_multi_run/1119617/time-short-multises"
# r"C:\data\ephys\test_data\steve_multi_run\1119617\time-miniscule-multises"
)

sub_name = "sub-1119617"
sessions_and_runs = {
"ses-001": [
"run-001_1119617_LSE1_shank12_g0",
"run-002_made_up_g0",
],
"ses-002": [
"run-001_1119617_pretest1_shank12_g0",
],
"ses-003": [
"run-002_1119617_pretest1_shank12_g0",
],
}

loaded_data = load_data(base_path, sub_name, sessions_and_runs, data_format="spikeglx")

run_preprocessing(
loaded_data,
pp_steps="default",
handle_existing_data="overwrite",
log=True,
)
Loading