From 709f37371b3f47f412edd0f26f88e8e0f37683f6 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 13 May 2026 18:12:52 +0200 Subject: [PATCH 1/4] implement cache spike vector for phy sorting reader --- .../extractors/phykilosortextractors.py | 57 ++++++++++++++----- 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 0e5dd2694d..804825d359 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -14,6 +14,7 @@ SortingAnalyzer, ) from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.base import minimum_spike_dtype from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations from probeinterface import read_prb, Probe @@ -72,7 +73,7 @@ def __init__( raise ImportError(self.installation_mesg) phy_folder = Path(folder_path) - spike_times = np.load(phy_folder / "spike_times.npy").astype(int) + spike_times = np.load(phy_folder / "spike_times.npy").astype("int64") if (phy_folder / "spike_clusters.npy").is_file(): spike_clusters = np.load(phy_folder / "spike_clusters.npy") @@ -83,8 +84,8 @@ def __init__( spike_times = np.atleast_1d(spike_times.squeeze()) spike_clusters = np.atleast_1d(spike_clusters.squeeze()) - clust_id = np.unique(spike_clusters) - unique_unit_ids = [int(c) for c in clust_id] + unique_unit_ids = np.unique(spike_clusters).astype("int64") + params = read_python(str(phy_folder / "params.py")) sampling_frequency = params["sample_rate"] @@ -151,10 +152,15 @@ def __init__( cluster_info = cluster_info.query(f"cluster_id in {unique_unit_ids}") # update spike clusters and times values - bad_clusters = [clust for clust in clust_id if clust not in cluster_info["cluster_id"].values] - spike_clusters_clean_idxs = ~np.isin(spike_clusters, bad_clusters) - spike_clusters_clean = spike_clusters[spike_clusters_clean_idxs] - spike_times_clean = spike_times[spike_clusters_clean_idxs] + bad_clusters = [clust for clust in unique_unit_ids if clust not in cluster_info["cluster_id"].values] + if len(bad_clusters) > 0: + # if no bad cluster we avoid this data reduction wich cost a lot for long dataset + spike_clusters_clean_idxs = ~np.isin(spike_clusters, bad_clusters) + spike_clusters_clean = spike_clusters[spike_clusters_clean_idxs] + spike_times_clean = spike_times[spike_clusters_clean_idxs] + else: + spike_clusters_clean = spike_clusters + spike_times_clean = spike_times if "si_unit_id" in cluster_info.columns: unit_ids = cluster_info["si_unit_id"].values @@ -180,7 +186,7 @@ def __init__( idx = np.searchsorted(from_values, spike_clusters_clean, sorter=sort_idx) spike_clusters_new = unit_ids[sort_idx][idx] - unit_ids = unit_ids.astype(int) + unit_ids = unit_ids.astype("int64") spike_clusters_clean = spike_clusters_new del cluster_info["si_unit_id"] else: @@ -223,21 +229,46 @@ def __init__( self.annotate(phy_folder=str(phy_folder.resolve())) self.add_sorting_segment(PhySortingSegment(spike_times_clean, spike_clusters_clean)) + + def _compute_and_cache_spike_vector(self) -> None: + # make the spike_vector fast using the internal spike_times/spike_clusters + # with a small mapping id to index + # the order for 2 units with the same sample_index is not garanty here but should be OK + + unit_ids = self.unit_ids + + # mapping unit_id to unit_index + mapping = -np.ones(np.max(unit_ids) +1, dtype="int64") + for unit_ind, unit_id in enumerate(unit_ids): + mapping[unit_id] = unit_ind + + spike_times = self.segments[0]._all_spike_times + spike_clusters = self.segments[0]._all_clusters + n = spike_times.size + spikes = np.zeros(n, dtype=minimum_spike_dtype) + spikes["sample_index"] = spike_times + spikes["unit_index"] = mapping[spike_clusters] + # This is useless because phy is always one segment + # spikes["segment_index"] = 0 + + self._cached_spike_vector = spikes + self._cached_spike_vector_segment_slices = np.zeros((1, 2), dtype="int64") + self._cached_spike_vector_segment_slices[0, 1] = n class PhySortingSegment(BaseSortingSegment): - def __init__(self, all_spikes, all_clusters): + def __init__(self, all_spike_times, all_clusters): BaseSortingSegment.__init__(self) - self._all_spikes = all_spikes + self._all_spike_times = all_spike_times self._all_clusters = all_clusters def get_unit_spike_train(self, unit_id, start_frame, end_frame): - start = 0 if start_frame is None else np.searchsorted(self._all_spikes, start_frame, side="left") + start = 0 if start_frame is None else np.searchsorted(self._all_spike_times, start_frame, side="left") end = ( - len(self._all_spikes) if end_frame is None else np.searchsorted(self._all_spikes, end_frame, side="left") + len(self._all_spike_times) if end_frame is None else np.searchsorted(self._all_spike_times, end_frame, side="left") ) # Exclude end frame - spike_times = self._all_spikes[start:end][self._all_clusters[start:end] == unit_id] + spike_times = self._all_spike_times[start:end][self._all_clusters[start:end] == unit_id] return np.atleast_1d(spike_times.copy().squeeze()) From 834ac84d7f5fa608b567baa64389354b5635156e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 May 2026 16:19:03 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../extractors/phykilosortextractors.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 804825d359..3b06525311 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -229,19 +229,19 @@ def __init__( self.annotate(phy_folder=str(phy_folder.resolve())) self.add_sorting_segment(PhySortingSegment(spike_times_clean, spike_clusters_clean)) - + def _compute_and_cache_spike_vector(self) -> None: # make the spike_vector fast using the internal spike_times/spike_clusters # with a small mapping id to index # the order for 2 units with the same sample_index is not garanty here but should be OK - + unit_ids = self.unit_ids # mapping unit_id to unit_index - mapping = -np.ones(np.max(unit_ids) +1, dtype="int64") + mapping = -np.ones(np.max(unit_ids) + 1, dtype="int64") for unit_ind, unit_id in enumerate(unit_ids): mapping[unit_id] = unit_ind - + spike_times = self.segments[0]._all_spike_times spike_clusters = self.segments[0]._all_clusters n = spike_times.size @@ -265,7 +265,9 @@ def __init__(self, all_spike_times, all_clusters): def get_unit_spike_train(self, unit_id, start_frame, end_frame): start = 0 if start_frame is None else np.searchsorted(self._all_spike_times, start_frame, side="left") end = ( - len(self._all_spike_times) if end_frame is None else np.searchsorted(self._all_spike_times, end_frame, side="left") + len(self._all_spike_times) + if end_frame is None + else np.searchsorted(self._all_spike_times, end_frame, side="left") ) # Exclude end frame spike_times = self._all_spike_times[start:end][self._all_clusters[start:end] == unit_id] From 124893af3303489304002d9db3898987a772f2ee Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 14 May 2026 12:45:54 +0200 Subject: [PATCH 3/4] fic: kilosort tests --- src/spikeinterface/extractors/phykilosortextractors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 3b06525311..e65e430a79 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -149,7 +149,7 @@ def __init__( del cluster_info["id"] if remove_empty_units: - cluster_info = cluster_info.query(f"cluster_id in {unique_unit_ids}") + cluster_info = cluster_info.query(f"cluster_id in {list(unique_unit_ids)}") # update spike clusters and times values bad_clusters = [clust for clust in unique_unit_ids if clust not in cluster_info["cluster_id"].values] From 60d032c2b69e165cb9aba46ab1b0ff42104e7931 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 14 May 2026 14:31:45 +0200 Subject: [PATCH 4/4] fix attempt 2 --- src/spikeinterface/extractors/phykilosortextractors.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index e65e430a79..2b43aff8f9 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -149,7 +149,8 @@ def __init__( del cluster_info["id"] if remove_empty_units: - cluster_info = cluster_info.query(f"cluster_id in {list(unique_unit_ids)}") + unique_unit_ids_list = [int(clust) for clust in unique_unit_ids] + cluster_info = cluster_info.query(f"cluster_id in {unique_unit_ids_list}") # update spike clusters and times values bad_clusters = [clust for clust in unique_unit_ids if clust not in cluster_info["cluster_id"].values]