Skip to content

Commit

Permalink
Tidy up, general checks.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Nov 14, 2024
1 parent 02ec55a commit dbcf818
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 59 deletions.
135 changes: 79 additions & 56 deletions src/spikeinterface/working/load_kilosort_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,25 @@

from scipy import stats

# TODO: spike_times -> spike_indexes
# TODO: spike_times -> spike_indices
"""
Notes
-----
- not everything is used for current purposes
- things might be useful in future for making a sorting analyzer - compute template amplitude as average of spike amplitude.
"""

########################################################################################################################
# Get Spike Data
########################################################################################################################


def compute_spike_amplitude_and_depth(
sorter_output: str | Path,
localised_spikes_only,
exclude_noise,
gain: float | None = None,
localised_spikes_channel_cutoff: int = None, # TODO
localised_spikes_channel_cutoff: int = None,
) -> tuple[np.ndarray, ...]:
"""
Compute the amplitude and depth of all detected spikes from the kilosort output.
Expand All @@ -46,8 +50,8 @@ def compute_spike_amplitude_and_depth(
Returns
-------
spike_indexes : np.ndarray
(num_spikes,) array of spike indexes.
spike_indices : np.ndarray
(num_spikes,) array of spike indices.
spike_amplitudes : np.ndarray
(num_spikes,) array of corresponding spike amplitudes.
spike_depths : np.ndarray
Expand All @@ -66,7 +70,7 @@ def compute_spike_amplitude_and_depth(
if isinstance(sorter_output, str):
sorter_output = Path(sorter_output)

params = _load_ks_dir(sorter_output, load_pcs=True, exclude_noise=exclude_noise)
params = load_ks_dir(sorter_output, load_pcs=True, exclude_noise=exclude_noise)

if localised_spikes_only:
localised_templates = []
Expand All @@ -81,10 +85,52 @@ def compute_spike_amplitude_and_depth(

localised_template_by_spike = np.isin(params["spike_templates"], localised_templates)

_strip_spikes(params, localised_template_by_spike)
params["spike_templates"] = params["spike_templates"][localised_template_by_spike]
params["spike_indices"] = params["spike_indices"][localised_template_by_spike]
params["spike_clusters"] = params["spike_clusters"][localised_template_by_spike]
params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][localised_template_by_spike]
params["pc_features"] = params["pc_features"][localised_template_by_spike]

spike_locations, spike_max_sites = _get_locations_from_pc_features(params)

# Amplitude is calculated for each spike as the template amplitude
# multiplied by the `template_scaling_amplitudes`.
template_amplitudes_unscaled, *_ = get_unwhite_template_info(
params["templates"],
params["whitening_matrix_inv"],
params["channel_positions"],
)
spike_amplitudes = template_amplitudes_unscaled[params["spike_templates"]] * params["temp_scaling_amplitudes"]

if gain is not None:
spike_amplitudes *= gain

compute_template_amplitudes_from_spikes(params["templates"], params["spike_templates"], spike_amplitudes)

if localised_spikes_only:
# Interpolate the channel ids to location.
# Remove spikes > 5 um from average position
# Above we already removed non-localized templates, but that on its own is insufficient.
# Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient
# TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere
# 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold.
# 3) just use depth. Probably go for that. check with others.
spike_depths = spike_locations[:, 1]
b = stats.linregress(spike_depths, spike_max_sites).slope
i = np.abs(spike_max_sites - b * spike_depths) <= 5

params["spike_indices"] = params["spike_indices"][i]
spike_amplitudes = spike_amplitudes[i]
spike_locations = spike_locations[i, :]
spike_max_sites = spike_max_sites[i]

return params["spike_indices"], spike_amplitudes, spike_locations, spike_max_sites


def _get_locations_from_pc_features(params):
""" """
# Compute spike depths
pc_features = params["pc_features"][:, 0, :] # Do this compute
pc_features = params["pc_features"][:, 0, :]
pc_features[pc_features < 0] = 0

# Some spikes do not load at all onto the first PC. To avoid biasing the
Expand All @@ -109,58 +155,28 @@ def compute_spike_amplitude_and_depth(
"to extend this code section to handle more components."
)

# Get the channel indexes corresponding to the 32 channels from the PC.
# Get the channel indices corresponding to the 32 channels from the PC.
spike_features_indices = params["pc_features_indices"][params["spike_templates"], :]

# Compute the spike locations as the center of mass of the PC scores
spike_feature_coords = params["channel_positions"][spike_features_indices, :]
norm_weights = pc_features / np.sum(pc_features, axis=1)[:, np.newaxis] # TOOD: see why they use square
norm_weights = (
pc_features / np.sum(pc_features, axis=1)[:, np.newaxis]
) # TOOD: discuss use of square. Probbaly do not use to keep in line with COM in SI.
spike_locations = spike_feature_coords * norm_weights[:, :, np.newaxis]
spike_locations = np.sum(spike_locations, axis=1)

# TODO: now max site per spike is computed from PCs, not as the channel max site as previous
spike_sites = spike_features_indices[np.arange(spike_features_indices.shape[0]), np.argmax(norm_weights, axis=1)]
spike_max_sites = spike_features_indices[
np.arange(spike_features_indices.shape[0]), np.argmax(norm_weights, axis=1)
]

# Amplitude is calculated for each spike as the template amplitude
# multiplied by the `template_scaling_amplitudes`.
template_amplitudes_unscaled, *_ = get_unwhite_template_info(
params["templates"],
params["whitening_matrix_inv"],
params["channel_positions"],
)
spike_amplitudes = template_amplitudes_unscaled[params["spike_templates"]] * params["temp_scaling_amplitudes"]

if gain is not None:
spike_amplitudes *= gain

if localised_spikes_only:
# Interpolate the channel ids to location.
# Remove spikes > 5 um from average position
# Above we already removed non-localized templates, but that on its own is insufficient.
# Note for IMEC probe adding a constant term kills the regression making the regressors rank deficient
# TODO: a couple of approaches. 1) do everything in 3D, draw a sphere around prediction, take spikes only within the sphere
# 2) do separate for x, y. But resolution will be much lower, making things noisier, also harder to determine threshold.
# 3) just use depth. Probably go for that. check with others.
spike_depths = spike_locations[:, 1]
b = stats.linregress(spike_depths, spike_sites).slope
i = np.abs(spike_sites - b * spike_depths) <= 5 # TODO: need to expose this
return spike_locations, spike_max_sites

params["spike_indexes"] = params["spike_indexes"][i]
spike_amplitudes = spike_amplitudes[i]
spike_locations = spike_locations[i, :]

return params["spike_indexes"], spike_amplitudes, spike_locations, spike_sites


def _strip_spikes_in_place(params, indices):
""" """
params["spike_templates"] = params["spike_templates"][
indices
] # TODO: make an function for this. because we do this a lot
params["spike_indexes"] = params["spike_indexes"][indices]
params["spike_clusters"] = params["spike_clusters"][indices]
params["temp_scaling_amplitudes"] = params["temp_scaling_amplitudes"][indices]
params["pc_features"] = params["pc_features"][indices] # TODO: be conciststetn! change indees to indices
########################################################################################################################
# Get Template Data
########################################################################################################################


def get_unwhite_template_info(
Expand Down Expand Up @@ -213,7 +229,7 @@ def get_unwhite_template_info(

template_amplitudes_unscaled = np.max(template_amplitudes_per_channel, axis=1)

# Zero any small channel amplitudes
# Zero any small channel amplitudes TODO: removed this.
# threshold_values = 0.3 * template_amplitudes_unscaled TODO: remove this to be more general. Agree?
# template_amplitudes_per_channel[template_amplitudes_per_channel < threshold_values[:, np.newaxis]] = 0

Expand Down Expand Up @@ -253,9 +269,11 @@ def get_unwhite_template_info(
)


def compute_template_amplitudes_from_spikes():
# Take the average of all spike amplitudes to get actual template amplitudes
# (since tempScalingAmps are equal mean for all templates)
def compute_template_amplitudes_from_spikes(templates, spike_templates, spike_amplitudes):
"""
Take the average of all spike amplitudes to get actual template amplitudes
(since tempScalingAmps are equal mean for all templates)
"""
num_indices = templates.shape[0]
sum_per_index = np.zeros(num_indices, dtype=np.float64)
np.add.at(sum_per_index, spike_templates, spike_amplitudes)
Expand All @@ -264,7 +282,12 @@ def compute_template_amplitudes_from_spikes():
return template_amplitudes


def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool = False) -> dict:
########################################################################################################################
# Load Parameters from KS Directory
########################################################################################################################


def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool = False) -> dict:
"""
Loads the output of Kilosort into a `params` dict.
Expand Down Expand Up @@ -300,7 +323,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool

params = read_python(sorter_output / "params.py")

spike_indexes = np.load(sorter_output / "spike_times.npy")
spike_indices = np.load(sorter_output / "spike_times.npy")
spike_templates = np.load(sorter_output / "spike_templates.npy")

if (clusters_path := sorter_output / "spike_clusters.csv").is_dir():
Expand Down Expand Up @@ -328,7 +351,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
noise_cluster_ids = cluster_ids[cluster_groups == 0]
not_noise_clusters_by_spike = ~np.isin(spike_clusters.ravel(), noise_cluster_ids)

spike_indexes = spike_indexes[not_noise_clusters_by_spike]
spike_indices = spike_indices[not_noise_clusters_by_spike]
spike_templates = spike_templates[not_noise_clusters_by_spike]
temp_scaling_amplitudes = temp_scaling_amplitudes[not_noise_clusters_by_spike]

Expand All @@ -343,7 +366,7 @@ def _load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
cluster_groups = 3 * np.ones(cluster_ids.size)

new_params = {
"spike_indexes": spike_indexes.squeeze(),
"spike_indices": spike_indices.squeeze(),
"spike_templates": spike_templates.squeeze(),
"spike_clusters": spike_clusters.squeeze(),
"pc_features": pc_features,
Expand Down
27 changes: 24 additions & 3 deletions src/spikeinterface/working/plot_kilosort_drift_map.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from pathlib import Path
from spikeinterface.widgets.base import BaseWidget, to_attr
import matplotlib.axis
import scipy.signal
from spikeinterface.core import read_python

# from spikeinterface.core import read_python
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from scipy import stats
import load_kilosort_utils

from spikeinterface.widgets.base import BaseWidget, to_attr


class KilosortDriftMapWidget(BaseWidget):
"""
Expand Down Expand Up @@ -399,5 +401,24 @@ def _filter_large_amplitude_spikes(
return spike_times, spike_amplitudes, spike_depths


KilosortDriftMapWidget(r"D:\data\New folder\CA_528_1\imec0_ks2")
KilosortDriftMapWidget(
"/Users/joeziminski/data/bombcelll/sorter_output",
only_include_large_amplitude_spikes=False,
localised_spikes_only=True,
)
plt.show()

"""
sorter_output: str | Path,
only_include_large_amplitude_spikes: bool = True,
decimate: None | int = None,
add_histogram_plot: bool = False,
add_histogram_peaks_and_boundaries: bool = True,
add_drift_events: bool = True,
weight_histogram_by_amplitude: bool = False,
localised_spikes_only: bool = False,
exclude_noise: bool = False,
gain: float | None = None,
large_amplitude_only_segment_size: float = 800.0,
localised_spikes_channel_cutoff: int = 20,
"""

0 comments on commit dbcf818

Please sign in to comment.