From 099fed6fb8e14e4dcfe0a13981d10ad063b59b99 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Thu, 14 Sep 2023 15:48:20 -0400 Subject: [PATCH] update for deprecated in1d --- phylib/io/array.py | 2 +- phylib/io/model.py | 4 ++-- phylib/io/tests/test_array.py | 2 +- phylib/stats/tests/test_clusters.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/phylib/io/array.py b/phylib/io/array.py index 4ad9829..8c3fb94 100644 --- a/phylib/io/array.py +++ b/phylib/io/array.py @@ -328,7 +328,7 @@ def _spikes_in_clusters(spike_clusters, clusters): """Return the ids of all spikes belonging to the specified clusters.""" if len(spike_clusters) == 0 or len(clusters) == 0: return np.array([], dtype=int) - return np.nonzero(np.in1d(spike_clusters, clusters))[0] + return np.nonzero(np.isin(spike_clusters, clusters))[0] def _spikes_per_cluster(spike_clusters, spike_ids=None): diff --git a/phylib/io/model.py b/phylib/io/model.py index 778244b..260ea85 100644 --- a/phylib/io/model.py +++ b/phylib/io/model.py @@ -90,8 +90,8 @@ def from_sparse(data, cols, channel_ids): # NOTE: we ensure here that `col` contains integers. c = cols.flatten().astype(np.int32) # Remove columns that do not belong to the specified channels. - c[~np.in1d(c, channel_ids)] = -1 - assert np.all(np.in1d(c, np.r_[channel_ids, -1])) + c[~np.isin(c, channel_ids)] = -1 + assert np.all(np.isin(c, np.r_[channel_ids, -1])) # Convert column indices to relative indices given the specified # channel_ids. cols_loc = _index_of(c, np.r_[channel_ids, -1]).reshape(cols.shape) diff --git a/phylib/io/tests/test_array.py b/phylib/io/tests/test_array.py index 5d50370..594e9fd 100644 --- a/phylib/io/tests/test_array.py +++ b/phylib/io/tests/test_array.py @@ -288,7 +288,7 @@ def test_spikes_in_clusters(): assert np.all(spike_clusters[_spikes_in_clusters(spike_clusters, [i])] == i) clusters = [1, 2, 3] - assert np.all(np.in1d( + assert np.all(np.isin( spike_clusters[_spikes_in_clusters(spike_clusters, clusters)], clusters)) diff --git a/phylib/stats/tests/test_clusters.py b/phylib/stats/tests/test_clusters.py index f54f903..5ed097a 100644 --- a/phylib/stats/tests/test_clusters.py +++ b/phylib/stats/tests/test_clusters.py @@ -103,7 +103,7 @@ def test_sorted_main_channels(masks): mean_masks = mean(masks) channels = get_sorted_main_channels(mean_masks, get_unmasked_channels(mean_masks)) - assert np.all(np.in1d(channels, [5, 7])) + assert np.all(np.isin(channels, [5, 7])) def test_waveform_amplitude(masks, waveforms):