From 152ee9bfb8e36ec66fdbd3caf65209af6c614398 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Thu, 16 Nov 2023 22:41:42 +0000 Subject: [PATCH] Update utils with new functions for preprocessing. --- spikewrap/__init__.py | 24 ++++++ spikewrap/utils/utils.py | 174 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 195 insertions(+), 3 deletions(-) create mode 100644 spikewrap/__init__.py diff --git a/spikewrap/__init__.py b/spikewrap/__init__.py new file mode 100644 index 0000000..781354f --- /dev/null +++ b/spikewrap/__init__.py @@ -0,0 +1,24 @@ +from importlib.metadata import PackageNotFoundError, version + +try: + __version__ = version("swc_epys") +except PackageNotFoundError: + # package is not installed + pass + +from .pipeline.full_pipeline import run_full_pipeline +from .pipeline.preprocess import _preprocess_and_save_all_runs +from .pipeline.sort import run_sorting +from .pipeline.postprocess import run_postprocess + +from .utils.checks import check_environment +from .utils.slurm import run_interactive_slurm + +__all__ = [ + "run_full_pipeline", + "_preprocess_and_save_all_runs", + "run_sorting", + "run_postprocess", + "check_environment", + "run_interactive_slurm", +] diff --git a/spikewrap/utils/utils.py b/spikewrap/utils/utils.py index 4ec7e80..8b6bac7 100644 --- a/spikewrap/utils/utils.py +++ b/spikewrap/utils/utils.py @@ -1,14 +1,18 @@ from __future__ import annotations -import copy -import os from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Literal, Tuple, Union import numpy as np import yaml +if TYPE_CHECKING: + from spikeinterface.core import BaseRecording + + from ..data_classes.preprocessing import PreprocessingData + from ..data_classes.sorting import SortingData + def update(dict_, ses_name, run_name, value): try: @@ -16,6 +20,7 @@ def update(dict_, ses_name, run_name, value): except KeyError: dict_[ses_name] = {run_name: value} + def message_user(message: str) -> None: """ Method to interact with user. @@ -26,3 +31,166 @@ def message_user(message: str) -> None: Message to print. """ print(f"\n{message}") + + +def cast_pp_steps_values( + pp_steps: Dict, list_or_tuple: Literal["list", "tuple"] +) -> None: + """ + The settings in the pp_steps dictionary that defines the options + for preprocessing should be stored in Tuple as they are not to + be edited. However, when dumping Tuple to .yaml, there are tags + displayed on the .yaml file which are very ugly. + + These are not shown when storing list, so this function serves + to convert Tuple and List values in the preprocessing dict when + loading / saving the preprocessing dict to .yaml files. This + function converts `pp_steps` in place. + + Parameters + ---------- + pp_steps : Dict + The dictionary indicating the preprocessing steps to perform. + + list_or_tuple : Literal["list", "tuple"] + The direction to convert (i.e. if "tuple", will convert to Tuple). + """ + assert list_or_tuple in ["list", "tuple"], "Must cast to `list` or `tuple`." + func = tuple if list_or_tuple == "tuple" else list + + for key in pp_steps.keys(): + pp_steps[key] = func(pp_steps[key]) + + +def canonical_names(name: str) -> str: + """ + Store the canonical names e.g. filenames, tags + that are used throughout the project. + + Parameters + ---------- + name : str + short-hand name of the full name of interest. + + Returns + ------- + filenames[name] : str + The full name of interest e.g. filename. + + """ + filenames = { + "preprocessed_yaml": "preprocessing_info.yaml", + "sorting_yaml": "sorting_info.yaml", + } + return filenames[name] + + +def get_dict_value_from_step_num( + data: Union[PreprocessingData, SortingData], step_num: str +) -> Tuple[BaseRecording, str]: + """ + Get the value of the PreprocessingData dict from the preprocessing step number. + + PreprocessingData contain keys indicating the preprocessing steps, + starting with the preprocessing step number. + e.g. 0-raw, 1-raw-bandpass_filter, 2-raw_bandpass_filter-common_average + + Return the value of the dict (spikeinterface recording object) + from the dict using only the step number. + + Parameters + ---------- + data : Union[PreprocessingData, SortingData] + spikewrap PreprocessingData class holding filepath information. + + step_num : str + The preprocessing step number to get the value (i.e. recording object) + from. + + Returns + ------- + dict_value : BaseRecording + The SpikeInterface recording stored in the dict at the + given preprocessing step number. + + pp_key : str + The key of the preprocessing dict at the given + step number. + """ + if step_num == "last": + pp_key_nums = get_keys_first_char(data, as_int=True) + + # Complete overkill as a check but this is critical. + step_num = str(int(np.max(pp_key_nums))) + assert ( + int(step_num) == len(data.keys()) - 1 + ), "the last key has been taken incorrectly" + + select_step_pp_key = [key for key in data.keys() if key.split("-")[0] == step_num] + + assert len(select_step_pp_key) == 1, "pp_key must always have unique first char" + + pp_key: str = select_step_pp_key[0] + dict_value = data[pp_key] + + return dict_value, pp_key + + +def spikewrap_version(): + """ + If the package is installd with `pip install -e .` then + .__version__ will not work. + """ + try: + import spikewrap + + spikewrap_version = spikewrap.__version__ + except AttributeError: + spikewrap_version = "not found." + + return spikewrap_version + + +def get_keys_first_char( + data: Union[PreprocessingData, SortingData], as_int: bool = False +) -> Union[List[str], List[int]]: + """ + Get the first character of all keys in a dictionary. Expected + that the first characters are integers (as str type). + + Parameters + ---------- + data : Union[PreprocessingData, SortingData] + spikewrap PreprocessingData class holding filepath information. + + as_int : bool + If True, the first character of the keys are cast to + integer type. + + Returns + ------- + list_of_numbers : Union[List[str], List[int]] + A list of numbers of string or integer type, that are + the first numbers of the Preprocessing / Sorting Data + .data dictionary keys. + """ + list_of_numbers = [ + int(key.split("-")[0]) if as_int else key.split("-")[0] for key in data.keys() + ] + return list_of_numbers + + +def get_formatted_datetime() -> str: + return datetime.now().strftime("%Y-%m-%d_%H%M%S") + + +def dump_dict_to_yaml(filepath: Union[Path, str], dict_: Dict) -> None: + """ + Save a dictionary to Yaml file. Note that keys are + not sorted and will be saved in the dictionary order. + """ + with open( + filepath, + "w", + ) as file_to_save: + yaml.dump(dict_, file_to_save, sort_keys=False)