From 0a71918328cd34ca19099f3a46d8b84f85a0f076 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 16 Nov 2023 22:40:00 +0000 Subject: [PATCH 1/3] Add derivatives and preprocessing path getters to `base.py` --- spikewrap/data_classes/base.py | 39 ++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/spikewrap/data_classes/base.py b/spikewrap/data_classes/base.py index c41b57c..d21e0c0 100644 --- a/spikewrap/data_classes/base.py +++ b/spikewrap/data_classes/base.py @@ -5,6 +5,8 @@ from pathlib import Path from typing import Callable, Dict, List, Literal, Tuple +from ..utils import utils + @dataclass class BaseUserDict(UserDict): @@ -140,6 +142,43 @@ def get_rawdata_ses_path(self, ses_name: str) -> Path: def get_rawdata_run_path(self, ses_name: str, run_name: str) -> Path: return self.get_rawdata_ses_path(ses_name) / "ephys" / run_name + # Derivatives Paths -------------------------------------------------------------- + + def get_derivatives_top_level_path(self) -> Path: + return self.base_path / "derivatives" / "spikewrap" + + def get_derivatives_sub_path(self) -> Path: + return self.get_derivatives_top_level_path() / self.sub_name + + def get_derivatives_ses_path(self, ses_name: str) -> Path: + return self.get_derivatives_sub_path() / ses_name + + def get_derivatives_run_path(self, ses_name: str, run_name: str) -> Path: + return self.get_derivatives_ses_path(ses_name) / run_name + + # Preprocessing Paths -------------------------------------------------------------- + + def get_preprocessing_path(self, ses_name: str, run_name: str) -> Path: + """ + Set the folder tree where preprocessing output will be + saved. This is canonical and should not change. + """ + preprocessed_output_path = ( + self.get_derivatives_run_path(ses_name, run_name) / "preprocessing" + ) + return preprocessed_output_path + + def _get_pp_binary_data_path(self, ses_name: str, run_name: str) -> Path: + return self.get_preprocessing_path(ses_name, run_name) / "si_recording" + + def _get_sync_channel_data_path(self, ses_name: str, run_name: str) -> Path: + return self.get_preprocessing_path(ses_name, run_name) / "sync_channel" + + def get_preprocessing_info_path(self, ses_name: str, run_name: str) -> Path: + return self.get_preprocessing_path(ses_name, run_name) / utils.canonical_names( + "preprocessed_yaml" + ) + @staticmethod def update_two_layer_dict(dict_, ses_name, run_name, value): """ From 2f8798269eb2b0daa35221f86659598a1071f7b8 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 16 Nov 2023 22:31:52 +0000 Subject: [PATCH 2/3] Add preprocessing pipeline and associated configs. --- spikewrap/configs/configs.py | 70 +++ spikewrap/configs/default.yaml | 37 ++ spikewrap/data_classes/preprocessing.py | 110 ++++- .../{load_data.py => example_load_data.py} | 0 spikewrap/examples/example_preprocess.py | 32 ++ spikewrap/pipeline/preprocess.py | 420 ++++++++++++++++++ spikewrap/utils/custom_types.py | 6 + 7 files changed, 674 insertions(+), 1 deletion(-) create mode 100644 spikewrap/configs/configs.py create mode 100644 spikewrap/configs/default.yaml rename spikewrap/examples/{load_data.py => example_load_data.py} (100%) create mode 100644 spikewrap/examples/example_preprocess.py create mode 100644 spikewrap/pipeline/preprocess.py create mode 100644 spikewrap/utils/custom_types.py diff --git a/spikewrap/configs/configs.py b/spikewrap/configs/configs.py new file mode 100644 index 0000000..3b0d82b --- /dev/null +++ b/spikewrap/configs/configs.py @@ -0,0 +1,70 @@ +import glob +import os +from pathlib import Path +from typing import Dict, Tuple + +import yaml + +from ..utils import utils + + +def get_configs(name: str) -> Tuple[Dict, Dict, Dict]: + """ + Loads the config yaml file in the same folder + (spikewrap/configs) containing preprocessing (pp) + and sorter options. + + Once loaded, the list containing preprocesser name + and kwargs is cast to tuple. This keeps the type + checker happy while not requiring a tuple + in the .yaml which require ugly tags. + + Parameters + ---------- + + name: name of the configs to load. Should not include the + .yaml suffix. + + Returns + ------- + + pp_steps : a dictionary containing the preprocessing + step order (keys) and a [pp_name, kwargs] + list containing the spikeinterface preprocessing + step and keyword options. + + sorter_options : a dictionary with sorter name (key) and + a dictionary of kwargs to pass to the + spikeinterface sorter class. + """ + config_dir = Path(os.path.dirname(os.path.realpath(__file__))) + + available_files = glob.glob((config_dir / "*.yaml").as_posix()) + available_files = [Path(path_).stem for path_ in available_files] + + if name not in available_files: # then assume it is a full path + assert Path(name).is_file(), ( + f"{name} is neither the name of an existing " + f"config or valid path to configuration file." + ) + + assert Path(name).suffix in [ + ".yaml", + ".yml", + ], f"{name} is not the path to a .yaml file" + + config_filepath = Path(name) + + else: + config_filepath = config_dir / f"{name}.yaml" + + with open(config_filepath) as file: + config = yaml.full_load(file) + + pp_steps = config["preprocessing"] if "preprocessing" in config else {} + sorter_options = config["sorting"] if "sorting" in config else {} + waveform_options = config["waveforms"] if "waveforms" in config else {} + + utils.cast_pp_steps_values(pp_steps, "tuple") + + return pp_steps, sorter_options, waveform_options diff --git a/spikewrap/configs/default.yaml b/spikewrap/configs/default.yaml new file mode 100644 index 0000000..cef8ae8 --- /dev/null +++ b/spikewrap/configs/default.yaml @@ -0,0 +1,37 @@ +'preprocessing': + '1': + - phase_shift + - {} + '2': + - bandpass_filter + - freq_min: 300 + freq_max: 6000 + '3': + - common_reference + - operator: median + reference: global + +'sorting': + 'kilosort2': + 'car': False # common average referencing + 'freq_min': 150 # highpass filter cutoff, False nor 0 does not work to turn off. (results in KS error) + 'kilosort2_5': + 'car': False + 'freq_min': 150 + 'kilosort3': + 'car': False + 'freq_min': 300 + 'mountainsort5': + 'scheme': '2' + 'filter': False + +'waveforms': + 'ms_before': 2 + 'ms_after': 2 + 'max_spikes_per_unit': 500 + 'return_scaled': True + # Sparsity Options + 'sparse': True + 'peak_sign': "neg" + 'method': "radius" + 'radius_um': 75 diff --git a/spikewrap/data_classes/preprocessing.py b/spikewrap/data_classes/preprocessing.py index 6f84c75..b11f66f 100644 --- a/spikewrap/data_classes/preprocessing.py +++ b/spikewrap/data_classes/preprocessing.py @@ -1,6 +1,11 @@ +import shutil from dataclasses import dataclass from typing import Dict +import spikeinterface + +from spikewrap.utils import utils + from .base import BaseUserDict @@ -12,7 +17,7 @@ class PreprocessingData(BaseUserDict): Details on the preprocessing steps are held in the dictionary keys e.g. e.g. 0-raw, 1-raw-bandpass_filter, 2-raw_bandpass_filter-common_average and recording objects are held in the value. These are generated - by the `pipeline.preprocess.run_preprocessing()` function. + by the `pipeline.preprocess._preprocess_and_save_all_runs()` function. The class manages paths to raw data and preprocessing output, as defines methods to dump key information and the SpikeInterface @@ -44,6 +49,19 @@ def __post_init__(self) -> None: self.update_two_layer_dict(self, ses_name, run_name, {"0-raw": None}) self.update_two_layer_dict(self.sync, ses_name, run_name, None) + def set_pp_steps(self, pp_steps: Dict) -> None: + """ + Set the preprocessing steps (`pp_steps`) attribute + that defines the preprocessing steps and options. + + Parameters + ---------- + pp_steps : Dict + Preprocessing steps to setp as class attribute. These are used + when `pipeline.preprocess._fill_run_data_with_preprocessed_recording()` function is called. + """ + self.pp_steps = pp_steps + def _validate_rawdata_inputs(self) -> None: self._validate_inputs( "rawdata", @@ -52,3 +70,93 @@ def _validate_rawdata_inputs(self) -> None: self.get_rawdata_ses_path, self.get_rawdata_run_path, ) + + # Saving preprocessed data --------------------------------------------------------- + + def save_preprocessed_data( + self, ses_name: str, run_name: str, overwrite: bool = False + ) -> None: + """ + Save the preprocessed output data to binary, as well + as important class attributes to a .yaml file. + + Both are saved in a folder called 'preprocessed' + in derivatives// + + Parameters + ---------- + run_name : str + Run name corresponding to one of `self.preprocessing_run_names`. + + overwrite : bool + If `True`, existing preprocessed output will be overwritten. + By default, SpikeInterface will error if a binary recording file + (`si_recording`) already exists where it is trying to write one. + In this case, an error should be raised before this function + is called. + + """ + if overwrite: + if self.get_preprocessing_path(ses_name, run_name).is_dir(): + shutil.rmtree(self.get_preprocessing_path(ses_name, run_name)) + + self._save_preprocessed_binary(ses_name, run_name) + self._save_sync_channel(ses_name, run_name) + self._save_preprocessing_info(ses_name, run_name) + + def _save_preprocessed_binary(self, ses_name: str, run_name: str) -> None: + """ + Save the fully preprocessed data (i.e. last step in the + preprocessing chain) to binary file. This is required for sorting. + """ + recording, __ = utils.get_dict_value_from_step_num( + self[ses_name][run_name], "last" + ) + recording.save( + folder=self._get_pp_binary_data_path(ses_name, run_name), chunk_memory="10M" + ) + + def _save_sync_channel(self, ses_name: str, run_name: str) -> None: + """ + Save the sync channel separately. In SI, sorting cannot proceed + if the sync channel is loaded to ensure it does not interfere with + sorting. As such, the sync channel is handled separately here. + """ + utils.message_user(f"Saving sync channel for {ses_name} run {run_name}") + + assert ( + self.sync[ses_name][run_name] is not None + ), f"Sync channel on PreprocessData session {ses_name} run {run_name} is None" + + self.sync[ses_name][run_name].save( # type: ignore + folder=self._get_sync_channel_data_path(ses_name, run_name), + chunk_memory="10M", + ) + + def _save_preprocessing_info(self, ses_name: str, run_name: str) -> None: + """ + Save important details of the postprocessing for provenance. + + Importantly, this is used to check that the preprocessing + file used for waveform extraction in `PostprocessData` + matches the preprocessing that was used for sorting. + """ + assert self.pp_steps is not None, "type narrow `pp_steps`." + + utils.cast_pp_steps_values(self.pp_steps, "list") + + preprocessing_info = { + "base_path": self.base_path.as_posix(), + "sub_name": self.sub_name, + "ses_name": ses_name, + "run_name": run_name, + "rawdata_path": self.get_rawdata_run_path(ses_name, run_name).as_posix(), + "pp_steps": self.pp_steps, + "spikeinterface_version": spikeinterface.__version__, + "spikewrap_version": utils.spikewrap_version(), + "datetime_written": utils.get_formatted_datetime(), + } + + utils.dump_dict_to_yaml( + self.get_preprocessing_info_path(ses_name, run_name), preprocessing_info + ) diff --git a/spikewrap/examples/load_data.py b/spikewrap/examples/example_load_data.py similarity index 100% rename from spikewrap/examples/load_data.py rename to spikewrap/examples/example_load_data.py diff --git a/spikewrap/examples/example_preprocess.py b/spikewrap/examples/example_preprocess.py new file mode 100644 index 0000000..1bbae51 --- /dev/null +++ b/spikewrap/examples/example_preprocess.py @@ -0,0 +1,32 @@ +from pathlib import Path + +from spikewrap.pipeline.load_data import load_data +from spikewrap.pipeline.preprocess import run_preprocessing + +base_path = Path( + r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/test_data/steve_multi_run/1119617/time-short-multises" + # r"C:\data\ephys\test_data\steve_multi_run\1119617\time-miniscule-multises" +) + +sub_name = "sub-1119617" +sessions_and_runs = { + "ses-001": [ + "run-001_1119617_LSE1_shank12_g0", + "run-002_made_up_g0", + ], + "ses-002": [ + "run-001_1119617_pretest1_shank12_g0", + ], + "ses-003": [ + "run-002_1119617_pretest1_shank12_g0", + ], +} + +loaded_data = load_data(base_path, sub_name, sessions_and_runs, data_format="spikeglx") + +run_preprocessing( + loaded_data, + pp_steps="default", + handle_existing_data="overwrite", + log=True, +) diff --git a/spikewrap/pipeline/preprocess.py b/spikewrap/pipeline/preprocess.py new file mode 100644 index 0000000..f8784b5 --- /dev/null +++ b/spikewrap/pipeline/preprocess.py @@ -0,0 +1,420 @@ +import json +from typing import Dict, List, Tuple + +import numpy as np +import spikeinterface.preprocessing as spre + +from ..configs import configs +from ..data_classes.preprocessing import PreprocessingData +from ..utils import utils +from ..utils.custom_types import HandleExisting + +# -------------------------------------------------------------------------------------- +# Public Functions +# -------------------------------------------------------------------------------------- + + +def run_preprocessing( + preprocess_data: PreprocessingData, + pp_steps: str, + handle_existing_data: HandleExisting, + log: bool = True, +): + """ + Main entry function to run preprocessing and write to file. Preprocessed + lazy spikeinterface recordings will be added to all sessions / runs in + `preprocess_data` and written to file. + + Parameters + ---------- + + preprocess_data : PreprocessingData + A preprocessing data object that has as attributes the + paths to rawdata. The pp_steps attribute is set on + this class during execution of this function. + + pp_steps: The name of valid preprocessing .yaml file (without the yaml extension). + stored in spikewrap/configs. + + existing_preprocessed_data : custom_types.HandleExisting + Determines how existing preprocessed data (e.g. from a prior pipeline run) + is handled. + "overwrite" : Will overwrite any existing preprocessed data output. + This will delete the 'preprocessed' folder. Therefore, + never save derivative work there. + "skip_if_exists" : will search for existing data and skip preprocesing + if it exists (sorting will run on existing + preprocessed data). + Otherwise, will preprocess and save the current run. + "fail_if_exists" : If existing preprocessed data is found, an error + will be raised. + + slurm_batch : Union[bool, Dict] + see `run_full_pipeline()` for details. + """ + pp_steps_dict, _, _ = configs.get_configs(pp_steps) + + # + _preprocess_and_save_all_runs( + preprocess_data, pp_steps_dict, handle_existing_data, log + ) + + +def fill_all_runs_with_preprocessed_recording( + preprocess_data: PreprocessingData, pp_steps: str +) -> None: + """ + Convenience function to fill all session and run entries in the + `preprocess_data` dictionary with preprocessed spikeinterface + recording objects. + + preprocess_data : PreprocessingData + A preprocessing data object that has as attributes the + paths to rawdata. The pp_steps attribute is set on + this class during execution of this function. + + pp_steps: The name of valid preprocessing .yaml file (without the yaml extension). + stored in spikewrap/configs. + """ + pp_steps_dict, _, _ = configs.get_configs(pp_steps) + + for ses_name, run_name in preprocess_data.flat_sessions_and_runs(): + _fill_run_data_with_preprocessed_recording( + preprocess_data, ses_name, run_name, pp_steps_dict + ) + + +# -------------------------------------------------------------------------------------- +# Private Functions +# -------------------------------------------------------------------------------------- + + +def _preprocess_and_save_all_runs( + preprocess_data: PreprocessingData, + pp_steps_dict: Dict, + handle_existing_data: HandleExisting, + log: bool = True, +) -> None: + """ + Handle the loading of existing preprocessed data. + See `run_preprocessing()` for details. + + This function validates all input arguments and initialises logging. + Then, it will iterate over every run in `preprocess_data` and + check whether preprocessing needs to be run and saved based on the + `handle_existing_data` option. If so, it will fill the relevant run + with the preprocessed spikeinterface recording object and save to disk. + """ + for ses_name, run_name in preprocess_data.flat_sessions_and_runs(): + utils.message_user(f"Preprocessing run {run_name}...") + + to_save, overwrite = _handle_existing_data_options( + preprocess_data, ses_name, run_name, handle_existing_data + ) + + if to_save: + _preprocess_and_save_single_run( + preprocess_data, ses_name, run_name, pp_steps_dict, overwrite + ) + + +def _preprocess_and_save_single_run( + preprocess_data: PreprocessingData, + ses_name: str, + run_name: str, + pp_steps_dict: Dict, + overwrite: bool, +) -> None: + """ + Given a single session and run, fill the entry for this run + on the `preprocess_data` object and write to disk. + """ + _fill_run_data_with_preprocessed_recording( + preprocess_data, + ses_name, + run_name, + pp_steps_dict, + ) + + preprocess_data.save_preprocessed_data(ses_name, run_name, overwrite) + + +def _handle_existing_data_options( + preprocess_data: PreprocessingData, + ses_name: str, + run_name: str, + handle_existing_data: HandleExisting, +) -> Tuple[bool, bool]: + """ + Determine whether preprocesing for this run needs to be performed based + on the `handle_existing_data setting`. If preprocessing does not exist, preprocessing + is always run. Otherwise, if it already exists, the behaviour depends on + the `handle_existing_data` setting. + + Returns + ------- + + to_save : bool + Whether the preprocessing needs to be run and saved. + + to_overwrite : bool + If saving, set the `overwrite` flag to confirm existing data should + be overwritten. + """ + preprocess_path = preprocess_data.get_preprocessing_path(ses_name, run_name) + + if handle_existing_data == "skip_if_exists": + if preprocess_path.is_dir(): + utils.message_user( + f"\nSkipping preprocessing, using file at " + f"{preprocess_path} for sorting.\n" + ) + to_save = False + overwrite = False + else: + utils.message_user( + f"No data found at {preprocess_path}, saving preprocessed data." + ) + to_save = True + overwrite = False + + elif handle_existing_data == "overwrite": + if preprocess_path.is_dir(): + utils.message_user(f"Removing existing file at {preprocess_path}\n") + + utils.message_user(f"Saving preprocessed data to {preprocess_path}") + to_save = True + overwrite = True + + elif handle_existing_data == "fail_if_exists": + if preprocess_path.is_dir(): + raise FileExistsError( + f"Preprocessed binary already exists at " + f"{preprocess_path}. " + f"To overwrite, set 'existing_preprocessed_data' to 'overwrite'" + ) + to_save = True + overwrite = False + + return to_save, overwrite + + +def _fill_run_data_with_preprocessed_recording( + preprocess_data: PreprocessingData, + ses_name: str, + run_name: str, + pp_steps: Dict, +) -> None: + """ + For a particular run, fill the `preprocess_data` object entry with preprocessed + spikeinterface recording objects. For each preprocessing step, a separate + recording object will be stored. The name of the dict entry will be + a concatenenation of all preprocessing steps that were performed. + + e.g. "0-raw", "0-raw_1-phase_shift_2-bandpass_filter" + """ + pp_funcs = _get_pp_funcs() + + checked_pp_steps, pp_step_names = _check_and_sort_pp_steps(pp_steps, pp_funcs) + + preprocess_data.set_pp_steps(pp_steps) + + for step_num, pp_info in checked_pp_steps.items(): + _perform_preprocessing_step( + step_num, + pp_info, + preprocess_data, + ses_name, + run_name, + pp_step_names, + pp_funcs, + ) + + +def _perform_preprocessing_step( + step_num: str, + pp_info: Tuple[str, Dict], + preprocess_data: PreprocessingData, + ses_name: str, + run_name: str, + pp_step_names: List[str], + pp_funcs: Dict, +) -> None: + """ + Given the preprocessing step and preprocess_data UserDict containing + spikeinterface BaseRecordings, apply a preprocessing step to the + last preprocessed recording and save the recording object to PreprocessingData. + For example, if step_num = "3", get the recording of the second + preprocessing step from preprocess_data and apply the 3rd preprocessing step + as specified in pp_info. + + Parameters + ---------- + step_num : str + Preprocessing step to run (e.g. "1", corresponds to the key in pp_dict). + + pp_info : Tuple[str, Dict] + Preprocessing name, preprocessing kwargs) tuple (the value from + the pp_dict). + + preprocess_data : PreprocessingData + spikewrap PreprocessingData class (a UserDict in which key-values are + the preprocessing chain name : spikeinterface recording objects). + + run_name: str + Name of the run to preprocess. This should correspond to a + run_name in `preprocess_data.preprocessing_run_names`. + + pp_step_names : List[str] + Ordered list of preprocessing step names that are being + applied across the entire preprocessing chain. + + pp_funcs : Dict + The canonical SpikeInterface preprocessing functions. The key + are the function name and value the function object. + """ + pp_name, pp_options = pp_info + + utils.message_user( + f"Running preprocessing step: {pp_name} with options {pp_options}" + ) + + last_pp_step_output, __ = utils.get_dict_value_from_step_num( + preprocess_data[ses_name][run_name], step_num=str(int(step_num) - 1) + ) + + new_name = f"{step_num}-" + "-".join(["raw"] + pp_step_names[: int(step_num)]) + + _confidence_check_pp_func_name(pp_name, pp_funcs) + + preprocess_data[ses_name][run_name][new_name] = pp_funcs[pp_name]( + last_pp_step_output, **pp_options + ) + + +# Helpers for preprocessing steps dictionary ------------------------------------------- + + +def _check_and_sort_pp_steps(pp_steps: Dict, pp_funcs: Dict) -> Tuple[Dict, List[str]]: + """ + Sort the preprocessing steps dictionary by order to be run (based on the + keys) and check the dictionary is valid. + + Parameters + ---------- + pp_steps : Dict + A dictionary with keys as numbers indicating the order that + preprocessing steps are run (starting at "1"). The values are a + (preprocessing name, preprocessing kwargs) tuple containing the + spikeinterface preprocessing function name, and kwargs to pass to it. + + pp_funcs : Dict + A dictionary linking preprocessing step names to the underlying + SpikeInterface function objects that conduct the preprocessing. + + Returns + ------- + pp_steps :Dict + The checked preprocessing steps dictionary. + + pp_step_names : List + Preprocessing step names (e.g. "bandpass_filter") in order + that they are to be run. + """ + _validate_pp_steps(pp_steps) + pp_step_names = [item[0] for item in pp_steps.values()] + + # Check the preprocessing function names are valid and print steps used + canonical_step_names = list(pp_funcs.keys()) + + for user_passed_name in pp_step_names: + assert ( + user_passed_name in canonical_step_names + ), f"{user_passed_name} not in allowed names: ({canonical_step_names}" + + utils.message_user( + f"\nThe preprocessing options dictionary is " + f"{json.dumps(pp_steps, indent=4, sort_keys=True)}" + ) + + return pp_steps, pp_step_names + + +def _validate_pp_steps(pp_steps: Dict): + """ + Ensure the pp_steps dictionary of preprocessing steps + has number-order that makes sense. The preprocessing step numbers + should start 1 at, and increase by 1 for each subsequent step. + """ + assert all( + key.isdigit() for key in pp_steps.keys() + ), "pp_steps keys must be integers" + + key_nums = [int(key) for key in pp_steps.keys()] + + assert np.min(key_nums) == 1, "dict keys must start at 1" + + if len(key_nums) > 1: + diffs = np.diff(key_nums) + assert np.unique(diffs).size == 1, "all dict keys must increase in steps of 1" + assert diffs[0] == 1, "all dict keys must increase in steps of 1" + + +def _confidence_check_pp_func_name(pp_name: str, pp_funcs: Dict): + """ + Ensure that the correct preprocessing function is retrieved. This + essentially checks the _get_pp_funcs dictionary is correct. + + TODO + ---- + This should be a standalone test, not incorporated into the package. + """ + func_name_to_class_name = "".join([word.lower() for word in pp_name.split("_")]) + + if pp_name == "silence_periods": + assert pp_funcs[pp_name].__name__ == "SilencedPeriodsRecording" # TODO: open PR + elif isinstance(pp_funcs[pp_name], type): + assert ( + func_name_to_class_name in pp_funcs[pp_name].__name__.lower() + ), "something is wrong in func dict" + + else: + assert pp_funcs[pp_name].__name__ == pp_name + + +def _get_pp_funcs() -> Dict: + """ + Get the spikeinterface preprocessing function + from its name. + + TODO + ----- + It should be possible to generate this on the fly from + SI __init__ rather than hard code like this + """ + pp_funcs = { + "phase_shift": spre.phase_shift, + "bandpass_filter": spre.bandpass_filter, + "common_reference": spre.common_reference, + "blank_saturation": spre.blank_staturation, + "center": spre.center, + "clip": spre.clip, + "correct_lsb": spre.correct_lsb, + "correct_motion": spre.correct_motion, + "depth_order": spre.depth_order, + "filter": spre.filter, + "gaussian_bandpass_filter": spre.gaussian_bandpass_filter, + "highpass_filter": spre.highpass_filter, + "interpolate_bad_channels": spre.interpolate_bad_channels, + "normalize_by_quantile": spre.normalize_by_quantile, + "notch_filter": spre.notch_filter, + "remove_artifacts": spre.remove_artifacts, + # "remove_channels": remove_channels, not sure how to handle at runtime + # "resample": spre.resample, leading to linAlg error + "scale": spre.scale, + "silence_periods": spre.silence_periods, + "whiten": spre.whiten, + "zscore": spre.zscore, + } + + return pp_funcs diff --git a/spikewrap/utils/custom_types.py b/spikewrap/utils/custom_types.py new file mode 100644 index 0000000..964e8b4 --- /dev/null +++ b/spikewrap/utils/custom_types.py @@ -0,0 +1,6 @@ +from typing import Literal, Tuple, Union + +HandleExisting = Literal["overwrite", "skip_if_exists", "fail_if_exists"] +DeleteIntermediate = Tuple[ + Union[Literal["recording.dat"], Literal["temp_wh.dat"], Literal["waveforms"]], ... +] From 89ea9e7cae7339426b7389b43de7b9a540f86483 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 16 Nov 2023 22:41:42 +0000 Subject: [PATCH 3/3] Update utils with new functions for preprocessing. --- spikewrap/utils/utils.py | 174 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 171 insertions(+), 3 deletions(-) diff --git a/spikewrap/utils/utils.py b/spikewrap/utils/utils.py index 4ec7e80..8b6bac7 100644 --- a/spikewrap/utils/utils.py +++ b/spikewrap/utils/utils.py @@ -1,14 +1,18 @@ from __future__ import annotations -import copy -import os from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Literal, Tuple, Union import numpy as np import yaml +if TYPE_CHECKING: + from spikeinterface.core import BaseRecording + + from ..data_classes.preprocessing import PreprocessingData + from ..data_classes.sorting import SortingData + def update(dict_, ses_name, run_name, value): try: @@ -16,6 +20,7 @@ def update(dict_, ses_name, run_name, value): except KeyError: dict_[ses_name] = {run_name: value} + def message_user(message: str) -> None: """ Method to interact with user. @@ -26,3 +31,166 @@ def message_user(message: str) -> None: Message to print. """ print(f"\n{message}") + + +def cast_pp_steps_values( + pp_steps: Dict, list_or_tuple: Literal["list", "tuple"] +) -> None: + """ + The settings in the pp_steps dictionary that defines the options + for preprocessing should be stored in Tuple as they are not to + be edited. However, when dumping Tuple to .yaml, there are tags + displayed on the .yaml file which are very ugly. + + These are not shown when storing list, so this function serves + to convert Tuple and List values in the preprocessing dict when + loading / saving the preprocessing dict to .yaml files. This + function converts `pp_steps` in place. + + Parameters + ---------- + pp_steps : Dict + The dictionary indicating the preprocessing steps to perform. + + list_or_tuple : Literal["list", "tuple"] + The direction to convert (i.e. if "tuple", will convert to Tuple). + """ + assert list_or_tuple in ["list", "tuple"], "Must cast to `list` or `tuple`." + func = tuple if list_or_tuple == "tuple" else list + + for key in pp_steps.keys(): + pp_steps[key] = func(pp_steps[key]) + + +def canonical_names(name: str) -> str: + """ + Store the canonical names e.g. filenames, tags + that are used throughout the project. + + Parameters + ---------- + name : str + short-hand name of the full name of interest. + + Returns + ------- + filenames[name] : str + The full name of interest e.g. filename. + + """ + filenames = { + "preprocessed_yaml": "preprocessing_info.yaml", + "sorting_yaml": "sorting_info.yaml", + } + return filenames[name] + + +def get_dict_value_from_step_num( + data: Union[PreprocessingData, SortingData], step_num: str +) -> Tuple[BaseRecording, str]: + """ + Get the value of the PreprocessingData dict from the preprocessing step number. + + PreprocessingData contain keys indicating the preprocessing steps, + starting with the preprocessing step number. + e.g. 0-raw, 1-raw-bandpass_filter, 2-raw_bandpass_filter-common_average + + Return the value of the dict (spikeinterface recording object) + from the dict using only the step number. + + Parameters + ---------- + data : Union[PreprocessingData, SortingData] + spikewrap PreprocessingData class holding filepath information. + + step_num : str + The preprocessing step number to get the value (i.e. recording object) + from. + + Returns + ------- + dict_value : BaseRecording + The SpikeInterface recording stored in the dict at the + given preprocessing step number. + + pp_key : str + The key of the preprocessing dict at the given + step number. + """ + if step_num == "last": + pp_key_nums = get_keys_first_char(data, as_int=True) + + # Complete overkill as a check but this is critical. + step_num = str(int(np.max(pp_key_nums))) + assert ( + int(step_num) == len(data.keys()) - 1 + ), "the last key has been taken incorrectly" + + select_step_pp_key = [key for key in data.keys() if key.split("-")[0] == step_num] + + assert len(select_step_pp_key) == 1, "pp_key must always have unique first char" + + pp_key: str = select_step_pp_key[0] + dict_value = data[pp_key] + + return dict_value, pp_key + + +def spikewrap_version(): + """ + If the package is installd with `pip install -e .` then + .__version__ will not work. + """ + try: + import spikewrap + + spikewrap_version = spikewrap.__version__ + except AttributeError: + spikewrap_version = "not found." + + return spikewrap_version + + +def get_keys_first_char( + data: Union[PreprocessingData, SortingData], as_int: bool = False +) -> Union[List[str], List[int]]: + """ + Get the first character of all keys in a dictionary. Expected + that the first characters are integers (as str type). + + Parameters + ---------- + data : Union[PreprocessingData, SortingData] + spikewrap PreprocessingData class holding filepath information. + + as_int : bool + If True, the first character of the keys are cast to + integer type. + + Returns + ------- + list_of_numbers : Union[List[str], List[int]] + A list of numbers of string or integer type, that are + the first numbers of the Preprocessing / Sorting Data + .data dictionary keys. + """ + list_of_numbers = [ + int(key.split("-")[0]) if as_int else key.split("-")[0] for key in data.keys() + ] + return list_of_numbers + + +def get_formatted_datetime() -> str: + return datetime.now().strftime("%Y-%m-%d_%H%M%S") + + +def dump_dict_to_yaml(filepath: Union[Path, str], dict_: Dict) -> None: + """ + Save a dictionary to Yaml file. Note that keys are + not sorted and will be saved in the dictionary order. + """ + with open( + filepath, + "w", + ) as file_to_save: + yaml.dump(dict_, file_to_save, sort_keys=False)