From adb6a1087574769f42e1d46c3a2cac97d7c5414e Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 18 May 2026 16:31:24 -0400 Subject: [PATCH 01/18] dev: Error message --- src/dartsort/util/spiketorch.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/dartsort/util/spiketorch.py b/src/dartsort/util/spiketorch.py index bd6bab9d..db8b38fd 100644 --- a/src/dartsort/util/spiketorch.py +++ b/src/dartsort/util/spiketorch.py @@ -1475,7 +1475,11 @@ def get_relative_index(source_channel_index, target_channel_index): """ n_chans, n_source_chans = source_channel_index.shape n_chans_, n_target_chans = target_channel_index.shape - assert n_chans == n_chans_ + if n_chans != n_chans_: + raise ValueError( + f"source/target shapes mismatch: {source_channel_index.shape=} " + f"{target_channel_index.shape=}." + ) relative_index = torch.full_like(target_channel_index, n_source_chans) for c in range(n_chans): row = source_channel_index[c] From e4700142180ac9c646e1a2a1428854d498fc3875 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 18 May 2026 16:31:50 -0400 Subject: [PATCH 02/18] dev: move import --- src/dartsort/clustering/kmeans.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dartsort/clustering/kmeans.py b/src/dartsort/clustering/kmeans.py index ec688c33..b9fcfcce 100644 --- a/src/dartsort/clustering/kmeans.py +++ b/src/dartsort/clustering/kmeans.py @@ -22,7 +22,6 @@ sparse_centroid_distsq, ) from ..util.spiketorch import spawn_torch_rg -from .density import guess_mode logger = get_logger(__name__) @@ -58,6 +57,8 @@ def kmeanspp( closest = torch.cdist(X, X.mean(0, keepdim=True)).argmax() centroid_ixs[0] = closest.item() elif kmeanspp_initial == "mode": + from .density import guess_mode + Xm = X if Xm.shape[1] > mode_dim: q = min(mode_dim + 10, *Xm.shape) From c930c13dcc3067df253d998d3e247cae435db109 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 18 May 2026 16:32:06 -0400 Subject: [PATCH 03/18] main: more docstring --- src/dartsort/main.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/dartsort/main.py b/src/dartsort/main.py index 6888a1fa..43597c5c 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -86,21 +86,25 @@ def dartsort( dredge_motion_est: MotionEstimate | None = None, overwrite=False, ): - """This function runs a spike sorter called dartsort. + """This function runs a spike sorter called *dartsort*. Parameters --------- recording : BaseRecording - A SpikeInterface `BaseRecording` object + A SpikeInterface `BaseRecording` output_dir : str or Path - Folder where outputs are stored + Folder where outputs are stored. See the `work_in_tmpdir` and `tmpdir_parent` + configuration options to store intermediate data in a scratch folder and + then only save the final outputs here. cfg : DARTsortUserConfig or DARTsortInternalConfig or str or Path Your settings. Either create a `DARTsortUserConfig` directly in code, or you can pass a string or Path pointing to a .toml file here. si_motion : spikeinterface.core.Motion, optional - Allows users to pass their own external motion estimate. + Allows users to pass their own external motion estimate. If this is given, + the do_motion_estimation configuration flag is ignored and this object is + used. dredge_motion_est : dredge.MotionEstimate, optional - Allows users to pass their own external motion estimate. + As in `si_motion`. overwrite : bool Ignore and overwrite stored results, if any. Otherwise, dartsort will try to resume from the last step that ran, or if it had finished then From 8eb435db04bc2ae47825bb62b42ad36dbfc6068e Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 18 May 2026 16:46:29 -0400 Subject: [PATCH 04/18] main: rename return type --- docs/main_api.md | 2 +- src/dartsort/__init__.py | 2 +- src/dartsort/main.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/main_api.md b/docs/main_api.md index 2bc88f9a..0ad6b90c 100644 --- a/docs/main_api.md +++ b/docs/main_api.md @@ -17,7 +17,7 @@ For details on parameters you should think about before running the sorter, see show_signature: true separate_signature: true -The return value from the `dartsort()` function is a DARTsortReturn object, which is a dictionary containing spike trains and motion information: +The return value from the `dartsort()` function is a `DARTsortResult` object, which is a dictionary containing spike trains and motion information: ::: dartsort.DARTsortReturn options: diff --git a/src/dartsort/__init__.py b/src/dartsort/__init__.py index d812f8c9..a479adb1 100644 --- a/src/dartsort/__init__.py +++ b/src/dartsort/__init__.py @@ -24,7 +24,7 @@ from .main import ( ObjectiveUpdateTemplateMatchingPeeler, SubtractionPeeler, - DARTsortReturn, + DARTsortResult, check_recording, cluster, dartsort, diff --git a/src/dartsort/main.py b/src/dartsort/main.py index 43597c5c..7503a0a4 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -68,7 +68,7 @@ logger = get_logger(__name__) -class DARTsortReturn(TypedDict): +class DARTsortResult(TypedDict): sorting: DARTsortSorting """Output spike trains.""" motion: MotionInfo @@ -112,7 +112,7 @@ def dartsort( Returns ------- - results : DARTsortReturn + results : DARTsortResult Dictionary of sorting results, with keys: - "sorting": `DARTsortSorting` From c2b99a09487c5f504efb68851f422bc34211826e Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 19 May 2026 11:53:37 -0400 Subject: [PATCH 05/18] dev: resolve_path -> ensure_path --- src/dartsort/__init__.py | 2 +- src/dartsort/evaluate/hybrid_util.py | 10 +++--- src/dartsort/evaluate/simkit.py | 18 ++++++---- src/dartsort/main.py | 16 ++++----- src/dartsort/peel/matching_util/pairwise.py | 4 +-- src/dartsort/peel/reduction_template.py | 4 +-- src/dartsort/templates/postprocess_util.py | 4 +-- src/dartsort/transform/pipeline.py | 6 ++-- .../transform/single_channel_denoiser.py | 4 +-- src/dartsort/util/data_util.py | 35 +++++++++++-------- src/dartsort/util/internal_config.py | 4 +-- src/dartsort/util/main_util.py | 16 ++++----- src/dartsort/util/motion.py | 6 ++-- src/dartsort/util/peel_util.py | 4 +-- src/dartsort/util/py_util.py | 16 ++++++--- src/dartsort/vis/colors.py | 4 +-- src/dartsort/vis/mixture.py | 6 ++-- tests/conftest.py | 12 +++---- tests/test_alignment.py | 2 +- tests/test_matching.py | 2 +- 20 files changed, 97 insertions(+), 78 deletions(-) diff --git a/src/dartsort/__init__.py b/src/dartsort/__init__.py index a479adb1..229269c0 100644 --- a/src/dartsort/__init__.py +++ b/src/dartsort/__init__.py @@ -78,7 +78,7 @@ from .util.motion import MotionInfo, get_motion_info, try_load_motion_info from .util.noise_util import EmbeddedNoise from .util.preprocess_util import preprocess -from .util.py_util import databag, resolve_path +from .util.py_util import databag, ensure_path from .util.waveform_util import full_channel_index, make_channel_index __version__ = importlib.metadata.version("dartsort") diff --git a/src/dartsort/evaluate/hybrid_util.py b/src/dartsort/evaluate/hybrid_util.py index 36831396..ed9206c1 100644 --- a/src/dartsort/evaluate/hybrid_util.py +++ b/src/dartsort/evaluate/hybrid_util.py @@ -23,7 +23,7 @@ from ..util.internal_config import ComputationConfig, unshifted_raw_template_cfg from ..util.logging_util import progbar from ..util.motion import MotionInfo -from ..util.py_util import resolve_path +from ..util.py_util import ensure_path from . import analysis, comparison, simkit logger = getLogger(__name__) @@ -406,7 +406,7 @@ def load_dartsort_step_sortings( use, although its not a guarantee... h5 locking... need to figure it out. """ mtime_dt = mtime_gap_minutes * 60 if mtime_gap_minutes else 0 - sorting_dir = resolve_path(sorting_dir, strict=True) + sorting_dir = ensure_path(sorting_dir, strict=True) if detection_h5_path is None: for dh5n in detection_h5_names: detection_h5_path = cast(Path, sorting_dir / dh5n) @@ -416,12 +416,12 @@ def load_dartsort_step_sortings( age = time.time() - detection_h5_path.stat().st_mtime if age < mtime_dt: continue - h5s = [resolve_path(detection_h5_path)] + h5s = [ensure_path(detection_h5_path)] break else: h5s = [] else: - h5s = [resolve_path(detection_h5_path)] + h5s = [ensure_path(detection_h5_path)] for j in range(1, 100): mh5 = sorting_dir / f"matching{j}.h5" @@ -430,7 +430,7 @@ def load_dartsort_step_sortings( if mtime_dt: if time.time() - mh5.stat().st_mtime < mtime_dt: break - h5s.append(resolve_path(mh5)) + h5s.append(ensure_path(mh5)) # let's check that there is at least something to do... labels_npys = sorting_dir.glob("*_labels.npy") diff --git a/src/dartsort/evaluate/simkit.py b/src/dartsort/evaluate/simkit.py index ec6a1211..36a14f58 100644 --- a/src/dartsort/evaluate/simkit.py +++ b/src/dartsort/evaluate/simkit.py @@ -20,8 +20,8 @@ from ..util.data_util import ( DARTsortSorting, divide_randomly, + ensure_path, extract_random_snips, - resolve_path, ) from ..util.job_util import ensure_computation_config from ..util.logging_util import get_logger, progbar @@ -108,7 +108,7 @@ def generate_simulation( pass if noise_recording_folder is not None: - noise_recording_folder = resolve_path(noise_recording_folder) + noise_recording_folder = ensure_path(noise_recording_folder) else: assert noise_in_memory duration_samples = int(duration_seconds * sampling_frequency) @@ -146,7 +146,7 @@ def generate_simulation( return if folder is not None: - folder = resolve_path(folder) + folder = ensure_path(folder) else: assert no_save @@ -214,7 +214,7 @@ def generate_simulation( def load_simulation(folder): - folder = resolve_path(folder, strict=True) + folder = ensure_path(folder, strict=True) recording_dir = folder / "recording" templates_npz = folder / "templates.npz" sorting_h5 = folder / "dartsort_sorting.h5" @@ -252,7 +252,9 @@ def __init__(self, recording, **simulation_kwargs): **simulation_kwargs, ) ) - self.segment = cast(InjectSpikesPreprocessorSegment, self._recording_segments[0]) + self.segment = cast( + InjectSpikesPreprocessorSegment, self._recording_segments[0] + ) def basic_sorting(self) -> DARTsortSorting: return self.segment.basic_sorting() @@ -406,7 +408,9 @@ def save_features_to_hdf5( # residual snippets if n_residual_snips: - nrs_dset = h5.create_dataset("n_residuals", data=np.zeros((), dtype=np.int64)) + nrs_dset = h5.create_dataset( + "n_residuals", data=np.zeros((), dtype=np.int64) + ) residual = h5.create_dataset( "residual", shape=(n_residual_snips, *self.segment.wf_shape), @@ -486,7 +490,7 @@ def save_simulation( save_collidedness=False, chunk_len_s=0.5, ): - folder = resolve_path(folder) + folder = ensure_path(folder) folder.mkdir(exist_ok=True) recording_dir = folder / "recording" templates_npz = folder / "templates.npz" diff --git a/src/dartsort/main.py b/src/dartsort/main.py index 7503a0a4..852d9dd2 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -62,7 +62,7 @@ from .util.noise_util import SpatialWhitener from .util.peel_util import run_peeler from .util.preprocess_util import preprocess -from .util.py_util import dartcopytree, resolve_path, timer +from .util.py_util import dartcopytree, ensure_path, timer from .util.torch_util import cleanup_and_log_gpu_usage logger = get_logger(__name__) @@ -118,7 +118,7 @@ def dartsort( - "sorting": `DARTsortSorting` - "motion": MotionInfo """ - output_dir = resolve_path(output_dir) + output_dir = ensure_path(output_dir) output_dir.mkdir(exist_ok=True) # convert cfg to internal format and store it for posterity @@ -135,7 +135,7 @@ def dartsort( if cfg.work_in_tmpdir: with TemporaryDirectory(prefix="dartsort", dir=cfg.tmpdir_parent) as work_dir: # copy files and possibly recording to temporary directory - work_dir = resolve_path(work_dir) + work_dir = ensure_path(work_dir) logger.dartsortdebug(f"Working in {work_dir}, outputs to {output_dir}.") recording, work_dir = ds_all_to_workdir( internal_cfg=cfg, @@ -397,7 +397,7 @@ def _dartsort_impl( # finally handle scratch directory and delete intermediate files if requested if work_dir is not None: - orig_h5_path = resolve_path(sorting.parent_h5_path, strict=True) + orig_h5_path = ensure_path(sorting.parent_h5_path, strict=True) final_h5_path = output_dir / orig_h5_path.name assert final_h5_path.exists() sorting.parent_h5_path = final_h5_path @@ -490,7 +490,7 @@ def subtract( hdf5_filename="subtraction.h5", model_subdir="subtraction_models", ) -> DARTsortSorting | None: - output_dir = resolve_path(output_dir) + output_dir = ensure_path(output_dir) computation_cfg = ensure_computation_config(computation_cfg) check_recording(recording) subtraction_peeler = SubtractionPeeler.from_config( @@ -551,7 +551,7 @@ def match( template_denoising_tsvd=None, whitener: SpatialWhitener | None = None, ) -> DARTsortSorting: - output_dir = resolve_path(output_dir) + output_dir = ensure_path(output_dir) model_dir = output_dir / model_subdir computation_cfg = ensure_computation_config(computation_cfg) @@ -639,7 +639,7 @@ def grab( model_subdir="grab_models", computation_cfg: ComputationConfig | None = None, ) -> DARTsortSorting: - output_dir = resolve_path(output_dir) + output_dir = ensure_path(output_dir) grabber = GrabAndFeaturize.from_config( sorting=sorting, recording=recording, @@ -679,7 +679,7 @@ def threshold( model_subdir="threshold_models", computation_cfg: ComputationConfig | None = None, ) -> DARTsortSorting: - output_dir = resolve_path(output_dir) + output_dir = ensure_path(output_dir) computation_cfg = ensure_computation_config(computation_cfg) thresholder = Threshold.from_config( recording=recording, diff --git a/src/dartsort/peel/matching_util/pairwise.py b/src/dartsort/peel/matching_util/pairwise.py index 94f7897a..773b4a1d 100644 --- a/src/dartsort/peel/matching_util/pairwise.py +++ b/src/dartsort/peel/matching_util/pairwise.py @@ -10,7 +10,7 @@ from ...templates.templates import TemplateData from ...util import job_util from ...util.motion import MotionInfo -from ...util.py_util import resolve_path +from ...util.py_util import ensure_path from .matching_base import PconvBase from .pairwise_util import compressed_convolve_to_h5 @@ -149,7 +149,7 @@ def from_template_data( if computation_cfg is None: computation_cfg = job_util.get_global_computation_config() - hdf5_filename = resolve_path(hdf5_filename) + hdf5_filename = ensure_path(hdf5_filename) hdf5_filename.parent.mkdir(exist_ok=True) # TODO: rewrite. diff --git a/src/dartsort/peel/reduction_template.py b/src/dartsort/peel/reduction_template.py index f0368181..1cbc8a72 100644 --- a/src/dartsort/peel/reduction_template.py +++ b/src/dartsort/peel/reduction_template.py @@ -32,7 +32,7 @@ from ..util.logging_util import get_logger from ..util.motion import MotionInfo from ..util.noise_util import SpatialWhitener -from ..util.py_util import resolve_path +from ..util.py_util import ensure_path from ..util.waveform_util import full_channel_index from .grab import GrabAndFeaturize @@ -95,7 +95,7 @@ def _from_config( ignore_cleanup_errors=True, dir=computation_cfg.tmpdir_parent, ) as tdir: - tdir = resolve_path(tdir) + tdir = ensure_path(tdir) h5p = tdir / "tmp.h5" p.load_or_fit_and_save_models(tdir / "models") if template_cfg.denoising_method == "none" and not template_cfg.use_svd: diff --git a/src/dartsort/templates/postprocess_util.py b/src/dartsort/templates/postprocess_util.py index fdab53b7..4e73f3bc 100644 --- a/src/dartsort/templates/postprocess_util.py +++ b/src/dartsort/templates/postprocess_util.py @@ -26,7 +26,7 @@ from ..util.logging_util import get_logger from ..util.motion import MotionInfo from ..util.noise_util import SpatialWhitener -from ..util.py_util import resolve_path +from ..util.py_util import ensure_path from ..util.spiketorch import ptp from . import TemplateData, realign from .templib import fit_tsvd, pca_from_templates, quick_mean_templates @@ -59,7 +59,7 @@ def estimate_template_library( ) -> tuple[DARTsortSorting, TemplateData]: """Postprocess spike train and estimate a TemplateData.""" if template_npz_path is not None: - template_npz_path = resolve_path(template_npz_path) + template_npz_path = ensure_path(template_npz_path) if template_npz_path.exists(): return sorting, TemplateData.from_npz(template_npz_path) diff --git a/src/dartsort/transform/pipeline.py b/src/dartsort/transform/pipeline.py index 750690cd..d28602d0 100644 --- a/src/dartsort/transform/pipeline.py +++ b/src/dartsort/transform/pipeline.py @@ -13,7 +13,7 @@ WaveformConfig, ) from ..util.logging_util import get_logger -from ..util.py_util import resolve_path +from ..util.py_util import ensure_path from ..util.waveform_util import assert_all_finite_in_probe from .transform_base import BaseWaveformFeaturizer, BaseWaveformModule @@ -297,7 +297,7 @@ def fit( del waveforms if hdf5_filename is not None: - hdf5_filename = resolve_path(hdf5_filename, strict=True) + hdf5_filename = ensure_path(hdf5_filename, strict=True) if self.safe: assert torch.is_tensor(features["waveforms"]) @@ -377,7 +377,7 @@ def transform_to_disk( if up_to_index == 0: return - hdf5_filename = resolve_path(hdf5_filename, strict=True) + hdf5_filename = ensure_path(hdf5_filename, strict=True) dev = self.device # use sorting as a way to load all 1d features, which transformers diff --git a/src/dartsort/transform/single_channel_denoiser.py b/src/dartsort/transform/single_channel_denoiser.py index 666f7410..6ee95663 100644 --- a/src/dartsort/transform/single_channel_denoiser.py +++ b/src/dartsort/transform/single_channel_denoiser.py @@ -1,7 +1,7 @@ import torch from torch import nn -from ..util.py_util import resolve_path +from ..util.py_util import ensure_path from ..util.waveform_util import get_channels_in_probe, set_channels_in_probe from .transform_base import BaseWaveformDenoiser @@ -108,7 +108,7 @@ def forward(self, x): return self.out(x) def load(self, pretrained_path=default_pretrained_path): - pretrained_path = resolve_path(pretrained_path) + pretrained_path = ensure_path(pretrained_path) checkpoint = torch.load(pretrained_path, map_location="cpu", weights_only=True) self.load_state_dict(checkpoint) self.eval() diff --git a/src/dartsort/util/data_util.py b/src/dartsort/util/data_util.py index 6c4002d7..9d1b32e5 100644 --- a/src/dartsort/util/data_util.py +++ b/src/dartsort/util/data_util.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from .motion import MotionInfo from .job_util import ensure_computation_config -from .py_util import resolve_path +from .py_util import ensure_path from .waveform_util import make_channel_index logger = get_logger(__name__) @@ -63,7 +63,7 @@ def __init__( """ self.n_spikes = times_samples.shape[0] if parent_h5_path is not None: - parent_h5_path = resolve_path(parent_h5_path) + parent_h5_path = ensure_path(parent_h5_path) self.parent_h5_path = parent_h5_path self.sampling_frequency = float(sampling_frequency) @@ -362,7 +362,7 @@ def from_peeling_hdf5( or multi-channel PCA features. load_all_features : bool """ - h5_path = resolve_path(h5_path, strict=True) + h5_path = ensure_path(h5_path, strict=True) if load_feature_names is None: _lfn = [] else: @@ -435,7 +435,7 @@ def save(self, sorting_npz: str | Path): This is done by saving to .npz, with a pointer (like a relative symlink) to the .h5 file if it exists. """ - sorting_npz = resolve_path(sorting_npz) + sorting_npz = ensure_path(sorting_npz) logger.dartsortdebug(f"Saving {self} to {sorting_npz}.") data = dict( times_samples=self.times_samples, @@ -448,7 +448,7 @@ def save(self, sorting_npz: str | Path): have_hdf5 = self.parent_h5_path is not None if have_hdf5: # path needs to be relative to npz path's parent in case user moves stuff - h5p = resolve_path(self.parent_h5_path, strict=True) + h5p = ensure_path(self.parent_h5_path, strict=True) try: h5p = h5p.relative_to(sorting_npz.parent, walk_up=True) # type: ignore except TypeError: @@ -478,7 +478,7 @@ def load( load_persistent_feature_names=None, ) -> Self: """Load from npz (usually dartsort_sorting.npz).""" - sorting_npz = resolve_path(sorting_npz, strict=True) + sorting_npz = ensure_path(sorting_npz, strict=True) with np.load(sorting_npz) as data: times_samples = data["times_samples"] channels = data["channels"] @@ -499,7 +499,7 @@ def load( parent_h5_path = parent_h5_path.item() assert isinstance(parent_h5_path, str) parent_h5_path = sorting_npz.parent / Path(parent_h5_path) - parent_h5_path = resolve_path(parent_h5_path, strict=True) + parent_h5_path = ensure_path(parent_h5_path, strict=True) if additional_persistent_features: loaded_persistent_features = set( loaded_persistent_features + additional_persistent_features @@ -722,7 +722,7 @@ def slice_feature_by_name( def load(f: str | Path, labels_stem: str | None = None) -> DARTsortSorting: """Load a spike train from h5, npz, or folder.""" - f = resolve_path(f, strict=True) + f = ensure_path(f, strict=True) if f.name.endswith(".h5"): st = DARTsortSorting.from_peeling_hdf5(h5_path=f) @@ -734,9 +734,16 @@ def load(f: str | Path, labels_stem: str | None = None) -> DARTsortSorting: raise ValueError(f"Not sure how to load '{f}'.") if labels_stem: - labels_npy = f.parent / f"{labels_stem}.npy" + print(f"{f=}") + if f.is_dir(): + labels_npy = f / f"{labels_stem}.npy" + else: + labels_npy = f.parent / f"{labels_stem}.npy" + if labels_npy.exists(): + print(f"{labels_npy.name=} {np.load(labels_npy)[:5]=}") st = st.ephemeral_replace(labels=np.load(labels_npy)) + print(f"{st.labels[:5]=}") else: logger.info(f"{labels_npy} did not exist.") @@ -746,7 +753,7 @@ def load(f: str | Path, labels_stem: str | None = None) -> DARTsortSorting: def try_get_model_dir(sorting: DARTsortSorting) -> Path | None: if sorting.parent_h5_path is None: return None - h5_path = resolve_path(sorting.parent_h5_path) + h5_path = ensure_path(sorting.parent_h5_path) model_dir = h5_path.parent / f"{h5_path.stem}_models" if model_dir.exists(): assert model_dir.is_dir() @@ -787,7 +794,7 @@ def _get_featurization_loading_meta(sorting): if sorting.parent_h5_path is None: raise ValueError("Can't load featurization pipeline.") - h5_path = resolve_path(sorting.parent_h5_path) + h5_path = ensure_path(sorting.parent_h5_path) base_dir = h5_path.parent stem = h5_path.stem @@ -965,8 +972,8 @@ def sorting_from_spikeinterface( def filter_link_h5(in_h5_path: str | Path, out_h5_path: str | Path, keep_filter): - in_h5_path = resolve_path(in_h5_path, strict=True) - out_h5_path = resolve_path(out_h5_path) + in_h5_path = ensure_path(in_h5_path, strict=True) + out_h5_path = ensure_path(out_h5_path) assert not out_h5_path.exists() with h5py.File(in_h5_path, "r", locking=False) as h5in: @@ -1537,7 +1544,7 @@ def subsample_waveforms( need_open = h5 is None if need_open and hdf5_filename is not None: - hdf5_filename = resolve_path(hdf5_filename, strict=True) + hdf5_filename = ensure_path(hdf5_filename, strict=True) h5 = h5py.File(hdf5_filename) elif need_open: raise ValueError("Need h5 or hdf5_filename.") diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index ee30f0d9..872670c6 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -8,7 +8,7 @@ import torch from .cli_util import argfield, dataclass_from_toml -from .py_util import cfg_dataclass, resolve_path +from .py_util import cfg_dataclass, ensure_path try: from importlib.resources import files @@ -933,7 +933,7 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: cfg0 = cfg try: - cfg = resolve_path(cfg, strict=True) + cfg = ensure_path(cfg, strict=True) except OSError as e: raise ValueError(f"Configuration file {cfg0} does not exist.") from e diff --git a/src/dartsort/util/main_util.py b/src/dartsort/util/main_util.py index 06d5b068..72f31fe2 100644 --- a/src/dartsort/util/main_util.py +++ b/src/dartsort/util/main_util.py @@ -18,7 +18,7 @@ ) from ..util.logging_util import get_logger from ..util.motion import MotionInfo, try_load_motion_info -from ..util.py_util import dartcopy2, dartcopytree, resolve_path +from ..util.py_util import dartcopy2, dartcopytree, ensure_path logger = get_logger(__name__) @@ -34,11 +34,11 @@ def ds_save_intermediate_sorting( return if output_dir is None: return - output_dir = resolve_path(output_dir, strict=True) + output_dir = ensure_path(output_dir, strict=True) if work_dir is None: store_dir = output_dir else: - store_dir = resolve_path(work_dir, strict=True) + store_dir = ensure_path(work_dir, strict=True) step_npz = store_dir / f"{step_name}.npz" logger.info(f"Saving {step_name} labels to {step_npz}") @@ -63,11 +63,11 @@ def ds_save_intermediate_labels( return if output_dir is None: return - output_dir = resolve_path(output_dir, strict=True) + output_dir = ensure_path(output_dir, strict=True) if work_dir is None: store_dir = output_dir else: - store_dir = resolve_path(work_dir, strict=True) + store_dir = ensure_path(work_dir, strict=True) step_labels_npy = store_dir / f"{step_name}_labels.npy" logger.info(f"Saving {step_name} labels to {step_labels_npy}") @@ -164,7 +164,7 @@ def ds_handle_link_from(cfg: DARTsortInternalConfig, output_dir: Path): if cfg.link_from is None: return - link_from = resolve_path(cfg.link_from, strict=True) + link_from = ensure_path(cfg.link_from, strict=True) assert link_from.is_dir() link_patterns = [] @@ -214,7 +214,7 @@ def ds_save_features( # find h5 and models and copy assert sorting.parent_h5_path is not None - h5_path = resolve_path(sorting.parent_h5_path) + h5_path = ensure_path(sorting.parent_h5_path) assert h5_path.exists() models_path = h5_path.parent / f"{h5_path.stem}_models" @@ -245,7 +245,7 @@ def ds_handle_delete_intermediate_features( # find all non-final h5s, models and delete them assert final_sorting.parent_h5_path is not None - final_h5 = resolve_path(final_sorting.parent_h5_path) + final_h5 = ensure_path(final_sorting.parent_h5_path) assert final_h5.exists() assert final_h5.parent == output_dir diff --git a/src/dartsort/util/motion.py b/src/dartsort/util/motion.py index a80ebc25..e2e55929 100644 --- a/src/dartsort/util/motion.py +++ b/src/dartsort/util/motion.py @@ -25,7 +25,7 @@ ) from .job_util import ensure_computation_config from .logging_util import get_logger -from .py_util import databag, resolve_path +from .py_util import databag, ensure_path from .registration_util import dredge_estimate_motion, dredge_to_si logger = get_logger(__name__) @@ -365,7 +365,7 @@ def pitch_shifts( def try_load( cls, output_directory: Path | str, filename="motion.pkl" ) -> Self | None: - fn = resolve_path(output_directory) / filename + fn = ensure_path(output_directory) / filename if not fn.exists(): return None with open(fn, "rb") as jar: @@ -378,7 +378,7 @@ def save( filename="motion.pkl", overwrite: bool = False, ): - fn = resolve_path(output_directory) / filename + fn = ensure_path(output_directory) / filename if not overwrite and fn.exists(): return v = dict( diff --git a/src/dartsort/util/peel_util.py b/src/dartsort/util/peel_util.py index ef7b1d71..2e6af147 100644 --- a/src/dartsort/util/peel_util.py +++ b/src/dartsort/util/peel_util.py @@ -11,7 +11,7 @@ from .data_util import DARTsortSorting from .internal_config import ComputationConfig, FeaturizationConfig from .job_util import ensure_computation_config -from .py_util import resolve_path, timer +from .py_util import ensure_path, timer def run_peeler( @@ -33,7 +33,7 @@ def run_peeler( shuffle: bool = False, localization_dataset_name="point_source_localizations", ): - output_directory = resolve_path(output_directory) + output_directory = ensure_path(output_directory) output_directory.mkdir(exist_ok=True) model_dir = output_directory / model_subdir output_hdf5_filename = output_directory / hdf5_filename diff --git a/src/dartsort/util/py_util.py b/src/dartsort/util/py_util.py index 1a5deb98..c3eded32 100644 --- a/src/dartsort/util/py_util.py +++ b/src/dartsort/util/py_util.py @@ -1,6 +1,7 @@ import contextlib import dataclasses import os +from os.path import normpath import shutil import signal import subprocess @@ -124,8 +125,12 @@ def __exit__(self, type, value, traceback): # files and paths -def resolve_path( - p: str | Path | Traversable | None, strict=False, mkdir=False, parents=False +def ensure_path( + p: str | Path | Traversable | None, + strict=False, + mkdir=False, + parents=False, + resolve=False, ) -> Path: if p is None: raise ValueError("Can't resolve path None.") @@ -134,7 +139,10 @@ def resolve_path( p = Path(p) p = p.expanduser() p = p.absolute() - p = p.resolve(strict=strict) + if resolve: + p = p.resolve(strict=strict) + elif strict: + assert p.exists() if mkdir: p.mkdir(parents=parents, exist_ok=True) return p @@ -200,7 +208,7 @@ def dartcopytree(icfg, src, dest): def _rsync(src, dest, archive=True, follow_symlinks=False, excludes=None, vp=False): archive_flags = ["-a" + ("vP" if vp else "")] if archive else [] link_flags = ["--no-links", "-L"] if follow_symlinks else [] - exclude_flags = [f'--exclude={ex}' for ex in (excludes or [])] + exclude_flags = [f"--exclude={ex}" for ex in (excludes or [])] cmd = ["rsync", *archive_flags, *link_flags, *exclude_flags, str(src), str(dest)] if vp: logger.info(" ".join(cmd)) diff --git a/src/dartsort/vis/colors.py b/src/dartsort/vis/colors.py index a7929467..c7326477 100644 --- a/src/dartsort/vis/colors.py +++ b/src/dartsort/vis/colors.py @@ -1,6 +1,6 @@ import numpy as np -from ..util.py_util import resolve_path +from ..util.py_util import ensure_path try: from importlib.resources import files @@ -11,7 +11,7 @@ raise ValueError("Need python>=3.10 or pip install importlib_resources.") data_dir = files("dartsort.pretrained") -glasbey1024_npz = resolve_path(data_dir.joinpath("glasbey1024.npz")) +glasbey1024_npz = ensure_path(data_dir.joinpath("glasbey1024.npz")) with np.load(glasbey1024_npz) as npz: glasbey1024 = npz["glasbey1024"] diff --git a/src/dartsort/vis/mixture.py b/src/dartsort/vis/mixture.py index 6bbfaec3..4be659c3 100644 --- a/src/dartsort/vis/mixture.py +++ b/src/dartsort/vis/mixture.py @@ -26,7 +26,7 @@ ) from ..transform import TemporalPCA from ..util import spiketorch -from ..util.data_util import DARTsortSorting, get_tpca, resolve_path +from ..util.data_util import DARTsortSorting, get_tpca, ensure_path from ..util.internal_config import ( ClusteringFeaturesConfig, ComputationConfig, @@ -1301,7 +1301,7 @@ def fit_mixture_and_visualize_all_components( **other_global_params, ): computation_cfg = ensure_computation_config(computation_cfg) - save_folder = resolve_path(save_folder) + save_folder = ensure_path(save_folder) if unit_ids is None: unit_ids = sorting.unit_ids if n_units is not None and n_units < len(unit_ids): @@ -1478,7 +1478,7 @@ def make_mixture_summaries( seed=0, **other_global_params, ): - save_folder = resolve_path(save_folder) + save_folder = ensure_path(save_folder) if unit_ids is None: unit_ids = mix_data.tmm.unit_ids.numpy(force=True).tolist() if n_units is not None and n_units < len(unit_ids): diff --git a/tests/conftest.py b/tests/conftest.py index a7f1b8fa..aa00fa42 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import pytest from dartsort.evaluate import simkit, config_grid from dartsort.util.logging_util import get_logger -from dartsort import resolve_path +from dartsort import ensure_path logger = get_logger(__name__) @@ -37,7 +37,7 @@ def mini_simulations(pytestconfig, tmp_path_factory): for sim_name, kw in sim_settings.items(): cache_key = f"dartsort/{sim_name}" if (p := pytestconfig.cache.get(cache_key, None)) is not None: - p = resolve_path(p) + p = ensure_path(p) if p.exists(): try: sims[sim_name] = simkit.load_simulation(p / "sim") @@ -47,7 +47,7 @@ def mini_simulations(pytestconfig, tmp_path_factory): pass p = tmp_path_factory.mktemp(f"simdata_{sim_name}") - p = resolve_path(p) + p = ensure_path(p) pytestconfig.cache.set(cache_key, str(p)) sims[sim_name] = simkit.generate_simulation(p / "sim", p / "noise", **kw) @@ -73,7 +73,7 @@ def simulations(pytestconfig, tmp_path_factory, mini_simulations): for sim_name, kw in sim_settings.items(): cache_key = f"dartsort/{sim_name}" if (p := pytestconfig.cache.get(cache_key, None)) is not None: - p = resolve_path(p) + p = ensure_path(p) if p.exists(): try: sims[sim_name] = simkit.load_simulation(p / "sim") @@ -81,9 +81,9 @@ def simulations(pytestconfig, tmp_path_factory, mini_simulations): continue except FileNotFoundError: pass - + p = tmp_path_factory.mktemp(f"simdata_{sim_name}") - p = resolve_path(p) + p = ensure_path(p) pytestconfig.cache.set(cache_key, str(p)) sims[sim_name] = simkit.generate_simulation(p / "sim", p / "noise", **kw) diff --git a/tests/test_alignment.py b/tests/test_alignment.py index 8178181b..fd405e80 100644 --- a/tests/test_alignment.py +++ b/tests/test_alignment.py @@ -157,7 +157,7 @@ def test_denoiser_alignment(align_sim, align_templates): for rtd in (False, True) ] with tempfile.TemporaryDirectory() as tdir: - tdir = dartsort.resolve_path(tdir) + tdir = dartsort.ensure_path(tdir) st0, st1 = sts = [ dartsort.DARTsortSorting.from_peeling_hdf5( p.peel(tdir / "hi.h5", overwrite=True) diff --git a/tests/test_matching.py b/tests/test_matching.py index d9dad1ce..67b4c20b 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -46,7 +46,7 @@ def refractory_sim(request, tmp_path_factory): upsampling, scaling, nc = request.param p = tmp_path_factory.mktemp(f"refsim_{upsampling}_{scaling}_{nc}") - p = dartsort.resolve_path(p) + p = dartsort.ensure_path(p) sim = simkit.generate_simulation( p / "sim", p / "noise", From ee390426258a56f68e0e32ca0d09a0a2b5ebf09f Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 19 May 2026 11:55:01 -0400 Subject: [PATCH 06/18] prints --- src/dartsort/util/cli_util.py | 1 - src/dartsort/util/data_util.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/src/dartsort/util/cli_util.py b/src/dartsort/util/cli_util.py index 6bd20e3c..66bffc27 100644 --- a/src/dartsort/util/cli_util.py +++ b/src/dartsort/util/cli_util.py @@ -167,7 +167,6 @@ def dataclass_to_argparse(cls, parser=None, prefix="", skipnames=None): raise ValueError(f"Need type or arg_type for {fld}.") if typing.get_origin(type_) == typing.Annotated: type_, *annots = typing.get_args(type_) - print(f"{annots=}") for annot in annots: if isinstance(annot, Doc): assert not doc diff --git a/src/dartsort/util/data_util.py b/src/dartsort/util/data_util.py index 9d1b32e5..1d215d3d 100644 --- a/src/dartsort/util/data_util.py +++ b/src/dartsort/util/data_util.py @@ -734,16 +734,13 @@ def load(f: str | Path, labels_stem: str | None = None) -> DARTsortSorting: raise ValueError(f"Not sure how to load '{f}'.") if labels_stem: - print(f"{f=}") if f.is_dir(): labels_npy = f / f"{labels_stem}.npy" else: labels_npy = f.parent / f"{labels_stem}.npy" if labels_npy.exists(): - print(f"{labels_npy.name=} {np.load(labels_npy)[:5]=}") st = st.ephemeral_replace(labels=np.load(labels_npy)) - print(f"{st.labels[:5]=}") else: logger.info(f"{labels_npy} did not exist.") From 256cfcdf38c4003a4cb188a7723c1d994e73c569 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 20 May 2026 13:56:04 -0400 Subject: [PATCH 07/18] gmm: nz static; kmeans prop --- src/dartsort/clustering/mixture.py | 71 ++++++++++++++---------------- 1 file changed, 33 insertions(+), 38 deletions(-) diff --git a/src/dartsort/clustering/mixture.py b/src/dartsort/clustering/mixture.py index b5756652..c53ee462 100644 --- a/src/dartsort/clustering/mixture.py +++ b/src/dartsort/clustering/mixture.py @@ -4052,6 +4052,10 @@ def get_truncated_datasets( device=device, rgeom=prgeom[:-1], ) + + gc.collect() + torch.cuda.empty_cache() + assert isinstance(noise, EmbeddedNoise) noise.to(device=device) assert noise.rank == feature_rank @@ -4154,6 +4158,9 @@ def get_truncated_datasets( else: assert False + gc.collect() + torch.cuda.empty_cache() + return neighb_cov, erp, train_data, val_data, full_data, noise, train_ixs, val_ixs @@ -5199,7 +5206,7 @@ def try_kmeans( feature_rank: int, n_iter: int = 100, with_proportions: bool = True, - drop_prop: float = 0.025, + drop_prop: float = 0.0, kmeanspp_initial="random", n_kmeans_tries: int = 25, n_kmeanspp_tries: int = 25, @@ -5253,10 +5260,7 @@ def evaluate_group_demolitions( train_scores: Scores, eval_scores: Scores, ) -> GroupDemolition: - # grab relevant sections of train and eval scores group_ = group.to(train_scores.candidates) - in_group_train = torch.isin(train_scores.candidates, group_).any(dim=1) - (in_group_train,) = in_group_train.nonzero(as_tuple=True) # select which units in the group are candidates for demolition if mean_train_resp is None: @@ -6061,36 +6065,30 @@ def _count_candidates(candidates, batch_candidate_counts, batch_size): if TORCH_IS_OLD: - @torch.jit.script - def _combine_similar_resps(resps: Tensor, keep_mask: Tensor, n_keep: int) -> Tensor: - n_discard = resps.shape[1] - n_keep - assert n_discard + n_keep == resps.shape[1] - discard_mask = torch.logical_not(keep_mask) - discard_ix = discard_mask.nonzero()[:, 0] - keep_ix = keep_mask.nonzero()[:, 0] - assert keep_ix.numel() + discard_ix.numel() == resps.shape[1] - kept_resp = resps[:, keep_ix] - discard_resp = resps[:, discard_ix] - sim = discard_resp.T @ kept_resp - match = sim.argmax(1) - kept_resp[:, match] += discard_resp - return kept_resp + def _nonzero_static(x: Tensor, size: int): + nz = x.nonzero() + assert nz.numel() == size + return nz else: - @torch.jit.script - def _combine_similar_resps(resps: Tensor, keep_mask: Tensor, n_keep: int) -> Tensor: - n_discard = resps.shape[1] - n_keep - assert n_discard + n_keep == resps.shape[1] - discard_mask = torch.logical_not(keep_mask) - discard_ix = discard_mask.nonzero_static(size=n_discard)[:, 0] - keep_ix = keep_mask.nonzero_static(size=n_keep)[:, 0] - assert keep_ix.numel() + discard_ix.numel() == resps.shape[1] - kept_resp = resps[:, keep_ix] - discard_resp = resps[:, discard_ix] - sim = discard_resp.T @ kept_resp - match = sim.argmax(1) - kept_resp[:, match] += discard_resp - return kept_resp + def _nonzero_static(x: Tensor, size: int): + return x.nonzero_static(size=size) + + +@torch.jit.script +def _combine_similar_resps(resps: Tensor, keep_mask: Tensor, n_keep: int) -> Tensor: + n_discard = resps.shape[1] - n_keep + assert n_discard + n_keep == resps.shape[1] + discard_mask = torch.logical_not(keep_mask) + discard_ix = _nonzero_static(discard_mask, size=n_discard)[:, 0] + keep_ix = _nonzero_static(keep_mask, size=n_keep)[:, 0] + assert keep_ix.numel() + discard_ix.numel() == resps.shape[1] + kept_resp = resps[:, keep_ix] + discard_resp = resps[:, discard_ix] + sim = discard_resp.T @ kept_resp + match = sim.argmax(1) + kept_resp[:, match] += discard_resp + return kept_resp def concatenate_scores(scoress: list[Scores]) -> Scores: @@ -6203,10 +6201,7 @@ def mean_responsibilities( rsum_batch.zero_() nc = int(ncand[bix].item()) - if TORCH_IS_OLD: - cii, cjj = (cand[i0:i1] >= 0).nonzero().T - else: - cii, cjj = (cand[i0:i1] >= 0).nonzero_static(size=nc).T + cii, cjj = _nonzero_static(cand[i0:i1] >= 0, size=nc).T c = cand[i0:i1][cii, cjj] r = resp[i0:i1][cii, cjj].double() @@ -6493,8 +6488,8 @@ def _sparsify_candidates( assert cpos.any(dim=1).all() if pnoid and static_size is not None: assert cpos.sum() == static_size - if static_size is not None and not TORCH_IS_OLD: - spike_ixs, candidate_ixs = cpos.nonzero_static(size=static_size).T + if static_size is not None: + spike_ixs, candidate_ixs = _nonzero_static(cpos, size=static_size).T else: spike_ixs, candidate_ixs = cpos.nonzero(as_tuple=True) neighb_ixs = neighborhood_ids[spike_ixs] From 9d1a5ab0512865f088c8efe7a854a4d856eb0f03 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 20 May 2026 13:57:27 -0400 Subject: [PATCH 08/18] dev/vis: cleanups, gmm demolish vis --- src/dartsort/evaluate/hybrid_util.py | 4 +- src/dartsort/transform/reduction.py | 2 +- src/dartsort/util/main_util.py | 9 +- src/dartsort/util/noise_util.py | 7 +- src/dartsort/util/py_util.py | 1 - src/dartsort/util/spiketorch.py | 16 +++- src/dartsort/vis/mixture.py | 133 ++++++++++++++++++++------- 7 files changed, 123 insertions(+), 49 deletions(-) diff --git a/src/dartsort/evaluate/hybrid_util.py b/src/dartsort/evaluate/hybrid_util.py index ed9206c1..6edea65a 100644 --- a/src/dartsort/evaluate/hybrid_util.py +++ b/src/dartsort/evaluate/hybrid_util.py @@ -464,9 +464,9 @@ def load_dartsort_step_sortings( ) if no_npys and st0.labels is not None: if hasattr(st0, "template_inds"): - yield( + yield ( name_formatter(f"{h5.stem}_template"), - st0.ephemeral_replace(labels=st0.template_inds) + st0.ephemeral_replace(labels=st0.template_inds), ) yield name_formatter(h5.stem), st0 continue diff --git a/src/dartsort/transform/reduction.py b/src/dartsort/transform/reduction.py index e94cb9df..b1950532 100644 --- a/src/dartsort/transform/reduction.py +++ b/src/dartsort/transform/reduction.py @@ -139,7 +139,7 @@ def reduction_results( dev = computation_cfg.actual_device() n_jobs, Executor, context, *_ = pool_from_cfg( - computation_cfg, check_local=True, small=True + computation_cfg, check_local=True, small=True, cpu=dev.type == 'cpu' ) with Executor( max_workers=n_jobs, diff --git a/src/dartsort/util/main_util.py b/src/dartsort/util/main_util.py index 72f31fe2..01d112ba 100644 --- a/src/dartsort/util/main_util.py +++ b/src/dartsort/util/main_util.py @@ -164,7 +164,7 @@ def ds_handle_link_from(cfg: DARTsortInternalConfig, output_dir: Path): if cfg.link_from is None: return - link_from = ensure_path(cfg.link_from, strict=True) + link_from = ensure_path(cfg.link_from, strict=True, resolve=True) assert link_from.is_dir() link_patterns = [] @@ -177,7 +177,12 @@ def ds_handle_link_from(cfg: DARTsortInternalConfig, output_dir: Path): link_patterns.extend(["subtraction_models/*denoising_pipeline.pt"]) if link_detection: link_patterns.extend( - ["subtraction.h5", "motion.pkl", "motionthreshold.h5", "subtraction_models"] + [ + "subtraction.h5", + "motion.pkl", + "motionthreshold.h5", + "subtraction_models/featurization_pipeline.pt", + ] ) if link_refined0: link_patterns.extend(["initial*.npy", "refined0*.npy"]) diff --git a/src/dartsort/util/noise_util.py b/src/dartsort/util/noise_util.py index dd38da0a..49a793e4 100644 --- a/src/dartsort/util/noise_util.py +++ b/src/dartsort/util/noise_util.py @@ -976,10 +976,11 @@ def estimate( x_spatial = torch.from_numpy(xx).to(x_spatial) valid = slice(None) init_kw["cov_kind"] = cov_kind.removesuffix("noise") - - cov = spiketorch.nancov(x_spatial[:, valid].double(), force_posdef=True) + cov = torch.cov(x_spatial.T.double()) + cov = spiketorch.enforce_posdef(cov, eps=eps) + else: + cov = spiketorch.nancov(x_spatial[:, valid].double(), force_posdef=True, eps=eps) assert torch.is_tensor(cov) - cov.diagonal().add_(eps) if shrinkage: cov = F.softshrink(cov, shrinkage) diff --git a/src/dartsort/util/py_util.py b/src/dartsort/util/py_util.py index c3eded32..4d41b493 100644 --- a/src/dartsort/util/py_util.py +++ b/src/dartsort/util/py_util.py @@ -1,7 +1,6 @@ import contextlib import dataclasses import os -from os.path import normpath import shutil import signal import subprocess diff --git a/src/dartsort/util/spiketorch.py b/src/dartsort/util/spiketorch.py index db8b38fd..5004175a 100644 --- a/src/dartsort/util/spiketorch.py +++ b/src/dartsort/util/spiketorch.py @@ -854,6 +854,15 @@ def convolve_lowrank( return out +def enforce_posdef(a, eps=0.0): + if eps: + a.diagonal(dim1=-2, dim2=-1).add_(eps) + vals, vecs = torch.linalg.eigh(a) + good = vals > 0 + a = (vecs[:, good] * vals[good]) @ vecs[:, good].T + return a + + def nancov( x, weights=None, @@ -888,14 +897,11 @@ def nancov( denom = nobs - correction denom[denom <= 0] = 1 cov = xtx / denom + cov = torch.asarray(cov) if force_posdef: try: - if eps: - np.fill_diagonal(cov, np.diagonal(cov) + eps) - vals, vecs = torch.linalg.eigh(cov) - good = vals > 0 - cov = (vecs[:, good] * vals[good]) @ vecs[:, good].T + cov = enforce_posdef(cov, eps=eps) except Exception as e: if not cov.isfinite().all(): raise e diff --git a/src/dartsort/vis/mixture.py b/src/dartsort/vis/mixture.py index 4be659c3..244d5141 100644 --- a/src/dartsort/vis/mixture.py +++ b/src/dartsort/vis/mixture.py @@ -17,16 +17,18 @@ StreamingSpikeData, TruncatedMixtureModel, TruncatedSpikeData, + evaluate_group_demolitions, instantiate_and_bootstrap_tmm, labels_from_scores, labels_from_scores_, + mean_responsibilities, run_merge, run_split, try_kmeans, ) from ..transform import TemporalPCA from ..util import spiketorch -from ..util.data_util import DARTsortSorting, get_tpca, ensure_path +from ..util.data_util import DARTsortSorting, ensure_path, get_tpca from ..util.internal_config import ( ClusteringFeaturesConfig, ComputationConfig, @@ -64,6 +66,8 @@ class MixtureVisData: full_scores: Scores eval_scores: Scores eval_labels: torch.Tensor + mean_train_resp: torch.Tensor + mean_eval_resp: torch.Tensor train_times: np.ndarray train_labels: np.ndarray train_ixs: np.ndarray @@ -467,6 +471,7 @@ def draw(self, panel, mix_data: MixtureVisData, unit_id: int): _, neighbors = mix_data.friends(unit_id, count=self.count) neighbors = neighbors[neighbors != unit_id][::-1] colors = np.array(glasbey1024)[neighbors % len(glasbey1024)] + empty = np.array([], dtype=int) axes = panel.subplots( squeeze=False, @@ -486,8 +491,8 @@ def draw(self, panel, mix_data: MixtureVisData, unit_id: int): in_nid = mix_data.full_inunits[nid] elif split == "eval": ssco = mix_data.eval_scores - in_nid = mix_data.eval_inunits[nid] - in_unit_id = mix_data.eval_inunits[unit_id] + in_nid = mix_data.eval_inunits.get(nid, empty) + in_unit_id = mix_data.eval_inunits.get(unit_id, empty) else: assert False @@ -596,46 +601,60 @@ def draw(self, panel, mix_data: MixtureVisData, unit_id: int): ax.set_xlabel("my ll - their ll") +class DemolishView(MixtureComponentPlot): + kind = "block" + width = 2 + height = 0.5 + + def compute(self, mix_data: MixtureVisData, unit_id: int): + _, group, _ = _get_my_merge_group(mix_data, unit_id) + group_res = evaluate_group_demolitions( + mm=mix_data.tmm, + group=group, + mean_train_resp=mix_data.mean_train_resp, + mean_eval_resp=mix_data.mean_eval_resp, + train_scores=mix_data.train_scores, + eval_scores=mix_data.eval_scores, + cur_crit=None, + ) + return group_res + + def draw(self, panel, mix_data: MixtureVisData, unit_id: int): + print(f"{unit_id=}") + demo_res = self.compute(mix_data, unit_id) + + ax = panel.subplots() + ax.axis("off") + + us = ",".join([str(uu.item()) for uu in demo_res.unit_ids.cpu()]) + ims = f"imp={demo_res.improvement}" + if demo_res.demolished is None: + ds = "no" + else: + ds = ",".join([str(uu.item())[:1] for uu in demo_res.demolished.cpu()]) + msg = f"units: {us}\n{ims}\ndemo: {ds}" + print(f"{msg=}") + + ax.text( + 0.5, + 0.5, + msg, + ha="center", + va="center", + fontsize="small", + transform=ax.transAxes, + ) + + class MergeView(MixtureComponentPlot): kind = "block" width = 2 height = 2.5 - def __init__(self): - pass - def compute(self, mix_data: MixtureVisData, unit_id: int): - # -- get actual group used during merge - # start by getting local distance matrix D for neighbors within merge distance - d0 = mix_data.inf_diag_unit_distance_matrix[unit_id] - (neighbors,) = (d0 < mix_data.tmm.p.merge_max_distance).nonzero(as_tuple=True) - me = neighbors.new_full((1,), unit_id) - neighbors = torch.cat([neighbors, me]).sort().values - D = mix_data.inf_diag_unit_distance_matrix[neighbors][:, neighbors].clone() - D.nan_to_num_(posinf=1000.0) - - # find my complete linkage cluster within D - D = D.fill_diagonal_(0.0).numpy(force=True) - if D.shape[0] > 1: - pd = D[np.triu_indices(D.shape[0], k=1)] - Z = linkage(pd, method="complete") - groups = maximal_leaf_groups( - Z, - distances=D, - max_distance=mix_data.tmm.p.merge_max_distance, - max_group_size=mix_data.tmm.p.max_group_size, - ) - groups = [g for g in groups if unit_id in neighbors[list(g)].tolist()] - assert len(groups) == 1 - group_ix = list(groups[0]) - group = neighbors[group_ix] - else: - group_ix = np.arange(neighbors.shape[0]) - group = neighbors - del neighbors + _, group, D = _get_my_merge_group(mix_data, unit_id) # get pair mask - D = D[group_ix][:, group_ix] pair_mask = torch.asarray(D < mix_data.tmm.p.merge_max_distance) # run it @@ -1273,6 +1292,7 @@ def default_mixture_plots(): NeighborMeans(), NeighborDistances(), MergeView(), + DemolishView(), MeanView(), CovarianceView(), SplitView(), @@ -1371,6 +1391,9 @@ def fit_mixture_for_vis( full_proposal_view=True, ) train_labels = labels_from_scores_(train_scores) + mean_train_resp = mean_responsibilities( + scores=train_scores, n_units=mix_data.tmm.n_units + ) full_scores = mix_data.tmm.soft_assign( data=mix_data.full_data, needs_bootstrap=False, @@ -1380,6 +1403,7 @@ def fit_mixture_for_vis( if mix_data.val_data is None: eval_scores = train_scores eval_labels = train_labels + mean_eval_resp = mean_train_resp else: eval_scores = mix_data.tmm.soft_assign( data=mix_data.val_data, @@ -1387,6 +1411,9 @@ def fit_mixture_for_vis( full_proposal_view=True, ) eval_labels = labels_from_scores_(eval_scores) + mean_eval_resp = mean_responsibilities( + scores=eval_scores, n_units=mix_data.tmm.n_units + ) dists = mix_data.tmm.unit_distance_matrix().cpu().clone() dists.diagonal().fill_(torch.inf) @@ -1419,6 +1446,8 @@ def fit_mixture_for_vis( train_scores=train_scores, full_scores=full_scores, eval_scores=eval_scores, + mean_train_resp=mean_train_resp, + mean_eval_resp=mean_eval_resp, train_times=times_s[mix_data.train_ixs], train_ixs=train_ixs, val_ixs=val_ixs, @@ -1611,6 +1640,40 @@ def _summary_job(unit_id): if tmp_out is not None and tmp_out.exists(): tmp_out.unlink() +# -- lib + + +def _get_my_merge_group(mix_data: MixtureVisData, unit_id: int): + # -- get actual group used during merge + # start by getting local distance matrix D for neighbors within merge distance + d0 = mix_data.inf_diag_unit_distance_matrix[unit_id] + (neighbors,) = (d0 < mix_data.tmm.p.merge_max_distance).nonzero(as_tuple=True) + me = neighbors.new_full((1,), unit_id) + neighbors = torch.cat([neighbors, me]).sort().values + D = mix_data.inf_diag_unit_distance_matrix[neighbors][:, neighbors].clone() + D.nan_to_num_(posinf=1000.0) + + # find my complete linkage cluster within D + D = D.fill_diagonal_(0.0).numpy(force=True) + if D.shape[0] > 1: + pd = D[np.triu_indices(D.shape[0], k=1)] + Z = linkage(pd, method="complete") + groups = maximal_leaf_groups( + Z, + distances=D, + max_distance=mix_data.tmm.p.merge_max_distance, + max_group_size=mix_data.tmm.p.max_group_size, + ) + groups = [g for g in groups if unit_id in neighbors[list(g)].tolist()] + assert len(groups) == 1 + group_ix = list(groups[0]) + group = neighbors[group_ix] + else: + group_ix = np.arange(neighbors.shape[0]) + group = neighbors + del neighbors + D = D[group_ix][:, group_ix] + return group_ix, group, D # -- one-offs From 5fac522805dd8b51d8be532f5c7f6279ea65d698 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 20 May 2026 14:18:45 -0400 Subject: [PATCH 09/18] geom: depth_only regular geom --- src/dartsort/util/waveform_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dartsort/util/waveform_util.py b/src/dartsort/util/waveform_util.py index f35ab3d0..f8e55879 100644 --- a/src/dartsort/util/waveform_util.py +++ b/src/dartsort/util/waveform_util.py @@ -316,7 +316,7 @@ def make_filled_channel_index( return channel_index -def make_regular_channel_index(geom, radius, p=2, to_torch=False, depth_only=False): +def make_regular_channel_index(geom, radius, p=2, to_torch=False, depth_only=True): """Channel index for multi-channel models In this channel index, the layout of channels around the max channel is From 7823891073ced4ca0075d7e400ce30e8e7b4dbdc Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 20 May 2026 14:39:26 -0400 Subject: [PATCH 10/18] preprocess: log --- src/dartsort/util/preprocess_util.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/dartsort/util/preprocess_util.py b/src/dartsort/util/preprocess_util.py index 2a908c87..2e18b42f 100644 --- a/src/dartsort/util/preprocess_util.py +++ b/src/dartsort/util/preprocess_util.py @@ -3,6 +3,9 @@ from spikeinterface.core import BaseRecording from .internal_config import PreprocessingStrategy +from .logging_util import get_logger + +logger = get_logger(__name__) preprocessing_strategies = {} @@ -87,4 +90,5 @@ def preprocess( strategy: PreprocessingStrategy = "none", dtype: str = "float32", ) -> BaseRecording: + logger.info("applying preprocessing: %s", strategy) return preprocessing_strategies[strategy](rec, dtype) From c81d94ecde9eeb96434e45b42acef35100195e4a Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 21 May 2026 11:02:52 -0400 Subject: [PATCH 11/18] dev: fix a log0 warning --- src/dartsort/util/data_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dartsort/util/data_util.py b/src/dartsort/util/data_util.py index 1d215d3d..c6475ea9 100644 --- a/src/dartsort/util/data_util.py +++ b/src/dartsort/util/data_util.py @@ -1642,7 +1642,7 @@ def fit_reweighting( if log_voltages: sign = np.sign(v) - v = sign * np.log(np.abs(v)) + v = sign * np.log(np.abs(v) + 1e-5) v = np.nan_to_num(v) sigma = 1.06 * v.std() * np.power(len(v), -0.2) assert np.isfinite(sigma) From 52616ff7af7dae28654919ad2da08c70f3a03e86 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 21 May 2026 11:03:09 -0400 Subject: [PATCH 12/18] vis: figsize in scatters --- src/dartsort/vis/scatterplots.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/dartsort/vis/scatterplots.py b/src/dartsort/vis/scatterplots.py index 061fe838..61116667 100644 --- a/src/dartsort/vis/scatterplots.py +++ b/src/dartsort/vis/scatterplots.py @@ -22,6 +22,7 @@ def scatter_spike_features( figure=None, axes=None, width_ratios=(1, 1, 3), + figsize=(15, 10), semilog_amplitudes=True, show_geom=True, geom_scatter_kw=dict(s=5, marker="s", color="k", lw=0), @@ -57,8 +58,10 @@ def scatter_spike_features( if axes is not None: assert axes.size == 3 figure = axes.flat[0].figure - if figure is None: + if figure is None and len(plt.get_fignums()): figure = plt.gcf() + elif figure is None: + figure = plt.figure(figsize=figsize) if extra_features is None: extra_features = {} else: From c168d9bbb5627e87eda787ca499e7b10ed48d900 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 21 May 2026 12:09:27 -0400 Subject: [PATCH 13/18] loggin: be respectful, no basicConfig --- src/dartsort/util/logging_util.py | 36 ++++++++++++++++--------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/src/dartsort/util/logging_util.py b/src/dartsort/util/logging_util.py index a2a4dbf8..481e3c0e 100644 --- a/src/dartsort/util/logging_util.py +++ b/src/dartsort/util/logging_util.py @@ -5,7 +5,6 @@ INFO, NOTSET, addLevelName, - basicConfig, getLevelNamesMapping, getLogger, getLoggerClass, @@ -52,21 +51,20 @@ def dartsortdebugthunk(self, msg, *args, **kwargs): setLoggerClass(DARTsortLogger) -logger = getLogger(__name__) -assert isinstance(logger, DARTsortLogger) +# shouts out to sinclairtarget.com +package_logger = getLogger(__package__) +assert isinstance(package_logger, DARTsortLogger) # set to environment-defined log level if present if "LOG_LEVEL" in os.environ: - level = os.environ["LOG_LEVEL"] - try: - basicConfig(level=level) - except ValueError: + level = os.environ["LOG_LEVEL"].strip() + if not level.strip("0123456789"): ilevel = int(level) - basicConfig(level=ilevel) else: - ilevel = getLevelNamesMapping()[level] - logger.log(ilevel, f"Log level set to {level} ({ilevel}).") + ilevel = getLevelNamesMapping()[level.upper()] + package_logger.setLevel(ilevel) + package_logger.log(ilevel, f"Log level set to {level} ({ilevel}).") def warn_with_traceback(message, category, filename, lineno, file=None, line=None): @@ -80,8 +78,8 @@ def warn_with_traceback(message, category, filename, lineno, file=None, line=Non # override warnings to show tracebacks when debugging -if logger.isEnabledFor(DARTSORTVERBOSE): - logger.dartsortdebug("Setting warnings.showwarning to print tracebacks.") +if package_logger.isEnabledFor(DARTSORTVERBOSE): + package_logger.dartsortdebug("Setting warnings.showwarning to print tracebacks.") warnings.showwarning = warn_with_traceback # type: ignore @@ -91,9 +89,9 @@ def get_logger(*args, **kwargs) -> DARTsortLogger: return logger -logger.dartsortdebug( - f"Logger is enabled for: DARTSORTDEBUG={logger.isEnabledFor(DARTSORTDEBUG)}, " - f"DARTSORTVERBOSE={logger.isEnabledFor(DARTSORTVERBOSE)}." +package_logger.dartsortdebug( + f"Logger is enabled for: DARTSORTDEBUG={package_logger.isEnabledFor(DARTSORTDEBUG)}, " + f"DARTSORTVERBOSE={package_logger.isEnabledFor(DARTSORTVERBOSE)}." ) @@ -101,7 +99,7 @@ class logress: def __init__( self, iterable, - logger=logger, + logger=package_logger, miniters=100, mininterval=60.0, desc=None, @@ -121,6 +119,7 @@ def __init__( self.mininterval = mininterval self.logger = logger self.unit = unit + self.closed = False try: self.total = len(iterable) self.miniters = min( @@ -136,7 +135,7 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): - pass + self.close() def __iter__(self): it = self.iterable @@ -186,7 +185,10 @@ def write(self, s): self.logger.log(self.level, s) def close(self): + if self.closed: + return self._print(check=False) + self.closed = True def _print(self, t=None, check=True): if t is None: From 60c91b7d50642193a7987d8109dab5818455f507 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 21 May 2026 12:09:39 -0400 Subject: [PATCH 14/18] amend logging --- src/dartsort/util/logging_util.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/dartsort/util/logging_util.py b/src/dartsort/util/logging_util.py index 481e3c0e..f427b6e0 100644 --- a/src/dartsort/util/logging_util.py +++ b/src/dartsort/util/logging_util.py @@ -57,8 +57,13 @@ def dartsortdebugthunk(self, msg, *args, **kwargs): # set to environment-defined log level if present -if "LOG_LEVEL" in os.environ: - level = os.environ["LOG_LEVEL"].strip() +if (level := os.getenv("LOGLEVEL")) is not None: + pass +elif (level := os.getenv("LOG_LEVEL")) is not None: + pass + +if level: + level = level.strip() if not level.strip("0123456789"): ilevel = int(level) else: From 422005d1a429a6bcf4f9ea449a811812ef36b108 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 22 May 2026 11:41:21 -0400 Subject: [PATCH 15/18] jobs: be kind to memory on cluster jobs on 3.12 --- src/dartsort/util/multiprocessing_util.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/dartsort/util/multiprocessing_util.py b/src/dartsort/util/multiprocessing_util.py index 815cbd97..bd94227b 100644 --- a/src/dartsort/util/multiprocessing_util.py +++ b/src/dartsort/util/multiprocessing_util.py @@ -151,6 +151,15 @@ def handle_negative_jobs(n_jobs: int): try: n_cores = os.process_cpu_count() # type: ignore except AttributeError: + try: + my_cores = os.sched_getaffinity(0) + if my_cores: + n_cores = len(my_cores) + else: + n_cores = None + except Exception: + n_cores = None + if n_cores is None: n_cores = multiprocessing.cpu_count() if n_jobs < 0: n_jobs = n_cores + (n_jobs + 1) From 110bbc59561ddc39ac445c0425b1d4065cbaaf12 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 22 May 2026 11:41:38 -0400 Subject: [PATCH 16/18] vis: add a longer isi to defaults --- src/dartsort/vis/unit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/dartsort/vis/unit.py b/src/dartsort/vis/unit.py index a4dc8924..bc82eb37 100644 --- a/src/dartsort/vis/unit.py +++ b/src/dartsort/vis/unit.py @@ -847,6 +847,7 @@ def default_plots(sorting_analysis=None): UnitTextInfo(), ACG(), ISIHistogram(), + ISIHistogram(bin_ms=0.25, max_ms=50.0), XZScatter(), TimeAmpScatter(), RawWaveformPlot(), From ce2ad94b8d487c7cc4dea68e34931e6c4a5597e2 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 22 May 2026 11:41:55 -0400 Subject: [PATCH 17/18] mix: don't crash when split doesn't fill out neighbs --- src/dartsort/clustering/mixture.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/dartsort/clustering/mixture.py b/src/dartsort/clustering/mixture.py index c53ee462..ff310196 100644 --- a/src/dartsort/clustering/mixture.py +++ b/src/dartsort/clustering/mixture.py @@ -1319,10 +1319,13 @@ def erase_candidates(self): self.candidates.fill_(-1) def bootstrap_candidates( - self, distances: Tensor, un_adj_lut: NeighborhoodLUT | None = None + self, + distances: Tensor, + un_adj_lut: NeighborhoodLUT | None = None, + allow_uncovered: bool = False, ) -> NeighborhoodLUT: self.update_adjacency(n_units=distances.shape[0], un_adj_lut=un_adj_lut) - self._fill_missing(distances.shape[0]) + self._fill_missing(distances.shape[0], allow_uncovered=allow_uncovered) # fill in candidates[:, 1:n_candidates] at random obeying un_adj # choosing not to use distances here, since they get used in search sets @@ -1861,7 +1864,7 @@ def update_from_split( # have to do a full bootstrap, bc it's hard to figure out what to do with # spikes whose candidates contain the units that were split. this way, the # lut invariants are maintained, and at least the top labels are the same. - return self.bootstrap_candidates(distances) + return self.bootstrap_candidates(distances, allow_uncovered=True) def full_proposal_view(self, un_adj_lut: NeighborhoodLUT): return FullProposalDataView.from_truncated_spike_data(self, un_adj_lut) From a3cbf284fdc6252774b4b0765730227af85891f0 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Sat, 23 May 2026 18:58:53 -0400 Subject: [PATCH 18/18] loc: add some eps --- src/dartsort/transform/amortized_localization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dartsort/transform/amortized_localization.py b/src/dartsort/transform/amortized_localization.py index 59316f5f..4cba56e2 100644 --- a/src/dartsort/transform/amortized_localization.py +++ b/src/dartsort/transform/amortized_localization.py @@ -192,7 +192,7 @@ def get_alphas(self, obs_amps, pred_amps_alpha1, masks, return_pred=False): a0 = masks * pred_amps_alpha1 numer = a0.mul(obs_amps).sum(dim=1) denom = a0.square().sum(dim=1) - alphas = numer.div_(denom) + alphas = numer.div_(denom + 1e-6) if return_pred: return alphas, alphas.unsqueeze(1) * pred_amps_alpha1 return alphas @@ -203,13 +203,13 @@ def point_source_model(self, z, obs_amps, masks, channels): if self.localization_model == "gaussian": pred_amps_alpha1 = dists.square().mul(-2).exp() else: - pred_amps_alpha1 = 1.0 / dists + pred_amps_alpha1 = 1.0 / (dists + 1e-6) alphas, pred_amps = self.get_alphas( obs_amps, pred_amps_alpha1, masks, return_pred=True ) else: alphas = F.softplus(z[:, 3]) - pred_amps = alphas.unsqueeze(1) / dists + pred_amps = alphas.unsqueeze(1) / (dists + 1e-6) return alphas, pred_amps