diff --git a/spikewrap/configs/configs.py b/spikewrap/configs/configs.py new file mode 100644 index 0000000..3b0d82b --- /dev/null +++ b/spikewrap/configs/configs.py @@ -0,0 +1,70 @@ +import glob +import os +from pathlib import Path +from typing import Dict, Tuple + +import yaml + +from ..utils import utils + + +def get_configs(name: str) -> Tuple[Dict, Dict, Dict]: + """ + Loads the config yaml file in the same folder + (spikewrap/configs) containing preprocessing (pp) + and sorter options. + + Once loaded, the list containing preprocesser name + and kwargs is cast to tuple. This keeps the type + checker happy while not requiring a tuple + in the .yaml which require ugly tags. + + Parameters + ---------- + + name: name of the configs to load. Should not include the + .yaml suffix. + + Returns + ------- + + pp_steps : a dictionary containing the preprocessing + step order (keys) and a [pp_name, kwargs] + list containing the spikeinterface preprocessing + step and keyword options. + + sorter_options : a dictionary with sorter name (key) and + a dictionary of kwargs to pass to the + spikeinterface sorter class. + """ + config_dir = Path(os.path.dirname(os.path.realpath(__file__))) + + available_files = glob.glob((config_dir / "*.yaml").as_posix()) + available_files = [Path(path_).stem for path_ in available_files] + + if name not in available_files: # then assume it is a full path + assert Path(name).is_file(), ( + f"{name} is neither the name of an existing " + f"config or valid path to configuration file." + ) + + assert Path(name).suffix in [ + ".yaml", + ".yml", + ], f"{name} is not the path to a .yaml file" + + config_filepath = Path(name) + + else: + config_filepath = config_dir / f"{name}.yaml" + + with open(config_filepath) as file: + config = yaml.full_load(file) + + pp_steps = config["preprocessing"] if "preprocessing" in config else {} + sorter_options = config["sorting"] if "sorting" in config else {} + waveform_options = config["waveforms"] if "waveforms" in config else {} + + utils.cast_pp_steps_values(pp_steps, "tuple") + + return pp_steps, sorter_options, waveform_options diff --git a/spikewrap/configs/default.yaml b/spikewrap/configs/default.yaml new file mode 100644 index 0000000..cef8ae8 --- /dev/null +++ b/spikewrap/configs/default.yaml @@ -0,0 +1,37 @@ +'preprocessing': + '1': + - phase_shift + - {} + '2': + - bandpass_filter + - freq_min: 300 + freq_max: 6000 + '3': + - common_reference + - operator: median + reference: global + +'sorting': + 'kilosort2': + 'car': False # common average referencing + 'freq_min': 150 # highpass filter cutoff, False nor 0 does not work to turn off. (results in KS error) + 'kilosort2_5': + 'car': False + 'freq_min': 150 + 'kilosort3': + 'car': False + 'freq_min': 300 + 'mountainsort5': + 'scheme': '2' + 'filter': False + +'waveforms': + 'ms_before': 2 + 'ms_after': 2 + 'max_spikes_per_unit': 500 + 'return_scaled': True + # Sparsity Options + 'sparse': True + 'peak_sign': "neg" + 'method': "radius" + 'radius_um': 75 diff --git a/spikewrap/data_classes/preprocessing.py b/spikewrap/data_classes/preprocessing.py index 6f84c75..b11f66f 100644 --- a/spikewrap/data_classes/preprocessing.py +++ b/spikewrap/data_classes/preprocessing.py @@ -1,6 +1,11 @@ +import shutil from dataclasses import dataclass from typing import Dict +import spikeinterface + +from spikewrap.utils import utils + from .base import BaseUserDict @@ -12,7 +17,7 @@ class PreprocessingData(BaseUserDict): Details on the preprocessing steps are held in the dictionary keys e.g. e.g. 0-raw, 1-raw-bandpass_filter, 2-raw_bandpass_filter-common_average and recording objects are held in the value. These are generated - by the `pipeline.preprocess.run_preprocessing()` function. + by the `pipeline.preprocess._preprocess_and_save_all_runs()` function. The class manages paths to raw data and preprocessing output, as defines methods to dump key information and the SpikeInterface @@ -44,6 +49,19 @@ def __post_init__(self) -> None: self.update_two_layer_dict(self, ses_name, run_name, {"0-raw": None}) self.update_two_layer_dict(self.sync, ses_name, run_name, None) + def set_pp_steps(self, pp_steps: Dict) -> None: + """ + Set the preprocessing steps (`pp_steps`) attribute + that defines the preprocessing steps and options. + + Parameters + ---------- + pp_steps : Dict + Preprocessing steps to setp as class attribute. These are used + when `pipeline.preprocess._fill_run_data_with_preprocessed_recording()` function is called. + """ + self.pp_steps = pp_steps + def _validate_rawdata_inputs(self) -> None: self._validate_inputs( "rawdata", @@ -52,3 +70,93 @@ def _validate_rawdata_inputs(self) -> None: self.get_rawdata_ses_path, self.get_rawdata_run_path, ) + + # Saving preprocessed data --------------------------------------------------------- + + def save_preprocessed_data( + self, ses_name: str, run_name: str, overwrite: bool = False + ) -> None: + """ + Save the preprocessed output data to binary, as well + as important class attributes to a .yaml file. + + Both are saved in a folder called 'preprocessed' + in derivatives// + + Parameters + ---------- + run_name : str + Run name corresponding to one of `self.preprocessing_run_names`. + + overwrite : bool + If `True`, existing preprocessed output will be overwritten. + By default, SpikeInterface will error if a binary recording file + (`si_recording`) already exists where it is trying to write one. + In this case, an error should be raised before this function + is called. + + """ + if overwrite: + if self.get_preprocessing_path(ses_name, run_name).is_dir(): + shutil.rmtree(self.get_preprocessing_path(ses_name, run_name)) + + self._save_preprocessed_binary(ses_name, run_name) + self._save_sync_channel(ses_name, run_name) + self._save_preprocessing_info(ses_name, run_name) + + def _save_preprocessed_binary(self, ses_name: str, run_name: str) -> None: + """ + Save the fully preprocessed data (i.e. last step in the + preprocessing chain) to binary file. This is required for sorting. + """ + recording, __ = utils.get_dict_value_from_step_num( + self[ses_name][run_name], "last" + ) + recording.save( + folder=self._get_pp_binary_data_path(ses_name, run_name), chunk_memory="10M" + ) + + def _save_sync_channel(self, ses_name: str, run_name: str) -> None: + """ + Save the sync channel separately. In SI, sorting cannot proceed + if the sync channel is loaded to ensure it does not interfere with + sorting. As such, the sync channel is handled separately here. + """ + utils.message_user(f"Saving sync channel for {ses_name} run {run_name}") + + assert ( + self.sync[ses_name][run_name] is not None + ), f"Sync channel on PreprocessData session {ses_name} run {run_name} is None" + + self.sync[ses_name][run_name].save( # type: ignore + folder=self._get_sync_channel_data_path(ses_name, run_name), + chunk_memory="10M", + ) + + def _save_preprocessing_info(self, ses_name: str, run_name: str) -> None: + """ + Save important details of the postprocessing for provenance. + + Importantly, this is used to check that the preprocessing + file used for waveform extraction in `PostprocessData` + matches the preprocessing that was used for sorting. + """ + assert self.pp_steps is not None, "type narrow `pp_steps`." + + utils.cast_pp_steps_values(self.pp_steps, "list") + + preprocessing_info = { + "base_path": self.base_path.as_posix(), + "sub_name": self.sub_name, + "ses_name": ses_name, + "run_name": run_name, + "rawdata_path": self.get_rawdata_run_path(ses_name, run_name).as_posix(), + "pp_steps": self.pp_steps, + "spikeinterface_version": spikeinterface.__version__, + "spikewrap_version": utils.spikewrap_version(), + "datetime_written": utils.get_formatted_datetime(), + } + + utils.dump_dict_to_yaml( + self.get_preprocessing_info_path(ses_name, run_name), preprocessing_info + ) diff --git a/spikewrap/examples/load_data.py b/spikewrap/examples/example_load_data.py similarity index 100% rename from spikewrap/examples/load_data.py rename to spikewrap/examples/example_load_data.py diff --git a/spikewrap/examples/example_preprocess.py b/spikewrap/examples/example_preprocess.py new file mode 100644 index 0000000..1bbae51 --- /dev/null +++ b/spikewrap/examples/example_preprocess.py @@ -0,0 +1,32 @@ +from pathlib import Path + +from spikewrap.pipeline.load_data import load_data +from spikewrap.pipeline.preprocess import run_preprocessing + +base_path = Path( + r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/test_data/steve_multi_run/1119617/time-short-multises" + # r"C:\data\ephys\test_data\steve_multi_run\1119617\time-miniscule-multises" +) + +sub_name = "sub-1119617" +sessions_and_runs = { + "ses-001": [ + "run-001_1119617_LSE1_shank12_g0", + "run-002_made_up_g0", + ], + "ses-002": [ + "run-001_1119617_pretest1_shank12_g0", + ], + "ses-003": [ + "run-002_1119617_pretest1_shank12_g0", + ], +} + +loaded_data = load_data(base_path, sub_name, sessions_and_runs, data_format="spikeglx") + +run_preprocessing( + loaded_data, + pp_steps="default", + handle_existing_data="overwrite", + log=True, +) diff --git a/spikewrap/pipeline/preprocess.py b/spikewrap/pipeline/preprocess.py new file mode 100644 index 0000000..f8784b5 --- /dev/null +++ b/spikewrap/pipeline/preprocess.py @@ -0,0 +1,420 @@ +import json +from typing import Dict, List, Tuple + +import numpy as np +import spikeinterface.preprocessing as spre + +from ..configs import configs +from ..data_classes.preprocessing import PreprocessingData +from ..utils import utils +from ..utils.custom_types import HandleExisting + +# -------------------------------------------------------------------------------------- +# Public Functions +# -------------------------------------------------------------------------------------- + + +def run_preprocessing( + preprocess_data: PreprocessingData, + pp_steps: str, + handle_existing_data: HandleExisting, + log: bool = True, +): + """ + Main entry function to run preprocessing and write to file. Preprocessed + lazy spikeinterface recordings will be added to all sessions / runs in + `preprocess_data` and written to file. + + Parameters + ---------- + + preprocess_data : PreprocessingData + A preprocessing data object that has as attributes the + paths to rawdata. The pp_steps attribute is set on + this class during execution of this function. + + pp_steps: The name of valid preprocessing .yaml file (without the yaml extension). + stored in spikewrap/configs. + + existing_preprocessed_data : custom_types.HandleExisting + Determines how existing preprocessed data (e.g. from a prior pipeline run) + is handled. + "overwrite" : Will overwrite any existing preprocessed data output. + This will delete the 'preprocessed' folder. Therefore, + never save derivative work there. + "skip_if_exists" : will search for existing data and skip preprocesing + if it exists (sorting will run on existing + preprocessed data). + Otherwise, will preprocess and save the current run. + "fail_if_exists" : If existing preprocessed data is found, an error + will be raised. + + slurm_batch : Union[bool, Dict] + see `run_full_pipeline()` for details. + """ + pp_steps_dict, _, _ = configs.get_configs(pp_steps) + + # + _preprocess_and_save_all_runs( + preprocess_data, pp_steps_dict, handle_existing_data, log + ) + + +def fill_all_runs_with_preprocessed_recording( + preprocess_data: PreprocessingData, pp_steps: str +) -> None: + """ + Convenience function to fill all session and run entries in the + `preprocess_data` dictionary with preprocessed spikeinterface + recording objects. + + preprocess_data : PreprocessingData + A preprocessing data object that has as attributes the + paths to rawdata. The pp_steps attribute is set on + this class during execution of this function. + + pp_steps: The name of valid preprocessing .yaml file (without the yaml extension). + stored in spikewrap/configs. + """ + pp_steps_dict, _, _ = configs.get_configs(pp_steps) + + for ses_name, run_name in preprocess_data.flat_sessions_and_runs(): + _fill_run_data_with_preprocessed_recording( + preprocess_data, ses_name, run_name, pp_steps_dict + ) + + +# -------------------------------------------------------------------------------------- +# Private Functions +# -------------------------------------------------------------------------------------- + + +def _preprocess_and_save_all_runs( + preprocess_data: PreprocessingData, + pp_steps_dict: Dict, + handle_existing_data: HandleExisting, + log: bool = True, +) -> None: + """ + Handle the loading of existing preprocessed data. + See `run_preprocessing()` for details. + + This function validates all input arguments and initialises logging. + Then, it will iterate over every run in `preprocess_data` and + check whether preprocessing needs to be run and saved based on the + `handle_existing_data` option. If so, it will fill the relevant run + with the preprocessed spikeinterface recording object and save to disk. + """ + for ses_name, run_name in preprocess_data.flat_sessions_and_runs(): + utils.message_user(f"Preprocessing run {run_name}...") + + to_save, overwrite = _handle_existing_data_options( + preprocess_data, ses_name, run_name, handle_existing_data + ) + + if to_save: + _preprocess_and_save_single_run( + preprocess_data, ses_name, run_name, pp_steps_dict, overwrite + ) + + +def _preprocess_and_save_single_run( + preprocess_data: PreprocessingData, + ses_name: str, + run_name: str, + pp_steps_dict: Dict, + overwrite: bool, +) -> None: + """ + Given a single session and run, fill the entry for this run + on the `preprocess_data` object and write to disk. + """ + _fill_run_data_with_preprocessed_recording( + preprocess_data, + ses_name, + run_name, + pp_steps_dict, + ) + + preprocess_data.save_preprocessed_data(ses_name, run_name, overwrite) + + +def _handle_existing_data_options( + preprocess_data: PreprocessingData, + ses_name: str, + run_name: str, + handle_existing_data: HandleExisting, +) -> Tuple[bool, bool]: + """ + Determine whether preprocesing for this run needs to be performed based + on the `handle_existing_data setting`. If preprocessing does not exist, preprocessing + is always run. Otherwise, if it already exists, the behaviour depends on + the `handle_existing_data` setting. + + Returns + ------- + + to_save : bool + Whether the preprocessing needs to be run and saved. + + to_overwrite : bool + If saving, set the `overwrite` flag to confirm existing data should + be overwritten. + """ + preprocess_path = preprocess_data.get_preprocessing_path(ses_name, run_name) + + if handle_existing_data == "skip_if_exists": + if preprocess_path.is_dir(): + utils.message_user( + f"\nSkipping preprocessing, using file at " + f"{preprocess_path} for sorting.\n" + ) + to_save = False + overwrite = False + else: + utils.message_user( + f"No data found at {preprocess_path}, saving preprocessed data." + ) + to_save = True + overwrite = False + + elif handle_existing_data == "overwrite": + if preprocess_path.is_dir(): + utils.message_user(f"Removing existing file at {preprocess_path}\n") + + utils.message_user(f"Saving preprocessed data to {preprocess_path}") + to_save = True + overwrite = True + + elif handle_existing_data == "fail_if_exists": + if preprocess_path.is_dir(): + raise FileExistsError( + f"Preprocessed binary already exists at " + f"{preprocess_path}. " + f"To overwrite, set 'existing_preprocessed_data' to 'overwrite'" + ) + to_save = True + overwrite = False + + return to_save, overwrite + + +def _fill_run_data_with_preprocessed_recording( + preprocess_data: PreprocessingData, + ses_name: str, + run_name: str, + pp_steps: Dict, +) -> None: + """ + For a particular run, fill the `preprocess_data` object entry with preprocessed + spikeinterface recording objects. For each preprocessing step, a separate + recording object will be stored. The name of the dict entry will be + a concatenenation of all preprocessing steps that were performed. + + e.g. "0-raw", "0-raw_1-phase_shift_2-bandpass_filter" + """ + pp_funcs = _get_pp_funcs() + + checked_pp_steps, pp_step_names = _check_and_sort_pp_steps(pp_steps, pp_funcs) + + preprocess_data.set_pp_steps(pp_steps) + + for step_num, pp_info in checked_pp_steps.items(): + _perform_preprocessing_step( + step_num, + pp_info, + preprocess_data, + ses_name, + run_name, + pp_step_names, + pp_funcs, + ) + + +def _perform_preprocessing_step( + step_num: str, + pp_info: Tuple[str, Dict], + preprocess_data: PreprocessingData, + ses_name: str, + run_name: str, + pp_step_names: List[str], + pp_funcs: Dict, +) -> None: + """ + Given the preprocessing step and preprocess_data UserDict containing + spikeinterface BaseRecordings, apply a preprocessing step to the + last preprocessed recording and save the recording object to PreprocessingData. + For example, if step_num = "3", get the recording of the second + preprocessing step from preprocess_data and apply the 3rd preprocessing step + as specified in pp_info. + + Parameters + ---------- + step_num : str + Preprocessing step to run (e.g. "1", corresponds to the key in pp_dict). + + pp_info : Tuple[str, Dict] + Preprocessing name, preprocessing kwargs) tuple (the value from + the pp_dict). + + preprocess_data : PreprocessingData + spikewrap PreprocessingData class (a UserDict in which key-values are + the preprocessing chain name : spikeinterface recording objects). + + run_name: str + Name of the run to preprocess. This should correspond to a + run_name in `preprocess_data.preprocessing_run_names`. + + pp_step_names : List[str] + Ordered list of preprocessing step names that are being + applied across the entire preprocessing chain. + + pp_funcs : Dict + The canonical SpikeInterface preprocessing functions. The key + are the function name and value the function object. + """ + pp_name, pp_options = pp_info + + utils.message_user( + f"Running preprocessing step: {pp_name} with options {pp_options}" + ) + + last_pp_step_output, __ = utils.get_dict_value_from_step_num( + preprocess_data[ses_name][run_name], step_num=str(int(step_num) - 1) + ) + + new_name = f"{step_num}-" + "-".join(["raw"] + pp_step_names[: int(step_num)]) + + _confidence_check_pp_func_name(pp_name, pp_funcs) + + preprocess_data[ses_name][run_name][new_name] = pp_funcs[pp_name]( + last_pp_step_output, **pp_options + ) + + +# Helpers for preprocessing steps dictionary ------------------------------------------- + + +def _check_and_sort_pp_steps(pp_steps: Dict, pp_funcs: Dict) -> Tuple[Dict, List[str]]: + """ + Sort the preprocessing steps dictionary by order to be run (based on the + keys) and check the dictionary is valid. + + Parameters + ---------- + pp_steps : Dict + A dictionary with keys as numbers indicating the order that + preprocessing steps are run (starting at "1"). The values are a + (preprocessing name, preprocessing kwargs) tuple containing the + spikeinterface preprocessing function name, and kwargs to pass to it. + + pp_funcs : Dict + A dictionary linking preprocessing step names to the underlying + SpikeInterface function objects that conduct the preprocessing. + + Returns + ------- + pp_steps :Dict + The checked preprocessing steps dictionary. + + pp_step_names : List + Preprocessing step names (e.g. "bandpass_filter") in order + that they are to be run. + """ + _validate_pp_steps(pp_steps) + pp_step_names = [item[0] for item in pp_steps.values()] + + # Check the preprocessing function names are valid and print steps used + canonical_step_names = list(pp_funcs.keys()) + + for user_passed_name in pp_step_names: + assert ( + user_passed_name in canonical_step_names + ), f"{user_passed_name} not in allowed names: ({canonical_step_names}" + + utils.message_user( + f"\nThe preprocessing options dictionary is " + f"{json.dumps(pp_steps, indent=4, sort_keys=True)}" + ) + + return pp_steps, pp_step_names + + +def _validate_pp_steps(pp_steps: Dict): + """ + Ensure the pp_steps dictionary of preprocessing steps + has number-order that makes sense. The preprocessing step numbers + should start 1 at, and increase by 1 for each subsequent step. + """ + assert all( + key.isdigit() for key in pp_steps.keys() + ), "pp_steps keys must be integers" + + key_nums = [int(key) for key in pp_steps.keys()] + + assert np.min(key_nums) == 1, "dict keys must start at 1" + + if len(key_nums) > 1: + diffs = np.diff(key_nums) + assert np.unique(diffs).size == 1, "all dict keys must increase in steps of 1" + assert diffs[0] == 1, "all dict keys must increase in steps of 1" + + +def _confidence_check_pp_func_name(pp_name: str, pp_funcs: Dict): + """ + Ensure that the correct preprocessing function is retrieved. This + essentially checks the _get_pp_funcs dictionary is correct. + + TODO + ---- + This should be a standalone test, not incorporated into the package. + """ + func_name_to_class_name = "".join([word.lower() for word in pp_name.split("_")]) + + if pp_name == "silence_periods": + assert pp_funcs[pp_name].__name__ == "SilencedPeriodsRecording" # TODO: open PR + elif isinstance(pp_funcs[pp_name], type): + assert ( + func_name_to_class_name in pp_funcs[pp_name].__name__.lower() + ), "something is wrong in func dict" + + else: + assert pp_funcs[pp_name].__name__ == pp_name + + +def _get_pp_funcs() -> Dict: + """ + Get the spikeinterface preprocessing function + from its name. + + TODO + ----- + It should be possible to generate this on the fly from + SI __init__ rather than hard code like this + """ + pp_funcs = { + "phase_shift": spre.phase_shift, + "bandpass_filter": spre.bandpass_filter, + "common_reference": spre.common_reference, + "blank_saturation": spre.blank_staturation, + "center": spre.center, + "clip": spre.clip, + "correct_lsb": spre.correct_lsb, + "correct_motion": spre.correct_motion, + "depth_order": spre.depth_order, + "filter": spre.filter, + "gaussian_bandpass_filter": spre.gaussian_bandpass_filter, + "highpass_filter": spre.highpass_filter, + "interpolate_bad_channels": spre.interpolate_bad_channels, + "normalize_by_quantile": spre.normalize_by_quantile, + "notch_filter": spre.notch_filter, + "remove_artifacts": spre.remove_artifacts, + # "remove_channels": remove_channels, not sure how to handle at runtime + # "resample": spre.resample, leading to linAlg error + "scale": spre.scale, + "silence_periods": spre.silence_periods, + "whiten": spre.whiten, + "zscore": spre.zscore, + } + + return pp_funcs diff --git a/spikewrap/utils/custom_types.py b/spikewrap/utils/custom_types.py new file mode 100644 index 0000000..964e8b4 --- /dev/null +++ b/spikewrap/utils/custom_types.py @@ -0,0 +1,6 @@ +from typing import Literal, Tuple, Union + +HandleExisting = Literal["overwrite", "skip_if_exists", "fail_if_exists"] +DeleteIntermediate = Tuple[ + Union[Literal["recording.dat"], Literal["temp_wh.dat"], Literal["waveforms"]], ... +]