Skip to content

Commit

Permalink
update for deprecated in1d
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed Sep 14, 2023
1 parent 2b567cd commit 099fed6
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion phylib/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def _spikes_in_clusters(spike_clusters, clusters):
"""Return the ids of all spikes belonging to the specified clusters."""
if len(spike_clusters) == 0 or len(clusters) == 0:
return np.array([], dtype=int)
return np.nonzero(np.in1d(spike_clusters, clusters))[0]
return np.nonzero(np.isin(spike_clusters, clusters))[0]


def _spikes_per_cluster(spike_clusters, spike_ids=None):
Expand Down
4 changes: 2 additions & 2 deletions phylib/io/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def from_sparse(data, cols, channel_ids):
# NOTE: we ensure here that `col` contains integers.
c = cols.flatten().astype(np.int32)
# Remove columns that do not belong to the specified channels.
c[~np.in1d(c, channel_ids)] = -1
assert np.all(np.in1d(c, np.r_[channel_ids, -1]))
c[~np.isin(c, channel_ids)] = -1
assert np.all(np.isin(c, np.r_[channel_ids, -1]))
# Convert column indices to relative indices given the specified
# channel_ids.
cols_loc = _index_of(c, np.r_[channel_ids, -1]).reshape(cols.shape)
Expand Down
2 changes: 1 addition & 1 deletion phylib/io/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def test_spikes_in_clusters():
assert np.all(spike_clusters[_spikes_in_clusters(spike_clusters, [i])] == i)

clusters = [1, 2, 3]
assert np.all(np.in1d(
assert np.all(np.isin(
spike_clusters[_spikes_in_clusters(spike_clusters, clusters)], clusters))


Expand Down
2 changes: 1 addition & 1 deletion phylib/stats/tests/test_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_sorted_main_channels(masks):
mean_masks = mean(masks)
channels = get_sorted_main_channels(mean_masks,
get_unmasked_channels(mean_masks))
assert np.all(np.in1d(channels, [5, 7]))
assert np.all(np.isin(channels, [5, 7]))


def test_waveform_amplitude(masks, waveforms):
Expand Down

0 comments on commit 099fed6

Please sign in to comment.