Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/main_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ For details on parameters you should think about before running the sorter, see
show_signature: true
separate_signature: true

The return value from the `dartsort()` function is a DARTsortReturn object, which is a dictionary containing spike trains and motion information:
The return value from the `dartsort()` function is a `DARTsortResult` object, which is a dictionary containing spike trains and motion information:

::: dartsort.DARTsortReturn
options:
Expand Down
4 changes: 2 additions & 2 deletions src/dartsort/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .main import (
ObjectiveUpdateTemplateMatchingPeeler,
SubtractionPeeler,
DARTsortReturn,
DARTsortResult,
check_recording,
cluster,
dartsort,
Expand Down Expand Up @@ -78,7 +78,7 @@
from .util.motion import MotionInfo, get_motion_info, try_load_motion_info
from .util.noise_util import EmbeddedNoise
from .util.preprocess_util import preprocess
from .util.py_util import databag, resolve_path
from .util.py_util import databag, ensure_path
from .util.waveform_util import full_channel_index, make_channel_index

__version__ = importlib.metadata.version("dartsort")
3 changes: 2 additions & 1 deletion src/dartsort/clustering/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
sparse_centroid_distsq,
)
from ..util.spiketorch import spawn_torch_rg
from .density import guess_mode

logger = get_logger(__name__)

Expand Down Expand Up @@ -58,6 +57,8 @@ def kmeanspp(
closest = torch.cdist(X, X.mean(0, keepdim=True)).argmax()
centroid_ixs[0] = closest.item()
elif kmeanspp_initial == "mode":
from .density import guess_mode

Xm = X
if Xm.shape[1] > mode_dim:
q = min(mode_dim + 10, *Xm.shape)
Expand Down
80 changes: 39 additions & 41 deletions src/dartsort/clustering/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,10 +1319,13 @@ def erase_candidates(self):
self.candidates.fill_(-1)

def bootstrap_candidates(
self, distances: Tensor, un_adj_lut: NeighborhoodLUT | None = None
self,
distances: Tensor,
un_adj_lut: NeighborhoodLUT | None = None,
allow_uncovered: bool = False,
) -> NeighborhoodLUT:
self.update_adjacency(n_units=distances.shape[0], un_adj_lut=un_adj_lut)
self._fill_missing(distances.shape[0])
self._fill_missing(distances.shape[0], allow_uncovered=allow_uncovered)

# fill in candidates[:, 1:n_candidates] at random obeying un_adj
# choosing not to use distances here, since they get used in search sets
Expand Down Expand Up @@ -1861,7 +1864,7 @@ def update_from_split(
# have to do a full bootstrap, bc it's hard to figure out what to do with
# spikes whose candidates contain the units that were split. this way, the
# lut invariants are maintained, and at least the top labels are the same.
return self.bootstrap_candidates(distances)
return self.bootstrap_candidates(distances, allow_uncovered=True)

def full_proposal_view(self, un_adj_lut: NeighborhoodLUT):
return FullProposalDataView.from_truncated_spike_data(self, un_adj_lut)
Expand Down Expand Up @@ -4052,6 +4055,10 @@ def get_truncated_datasets(
device=device,
rgeom=prgeom[:-1],
)

gc.collect()
torch.cuda.empty_cache()

assert isinstance(noise, EmbeddedNoise)
noise.to(device=device)
assert noise.rank == feature_rank
Expand Down Expand Up @@ -4154,6 +4161,9 @@ def get_truncated_datasets(
else:
assert False

gc.collect()
torch.cuda.empty_cache()

return neighb_cov, erp, train_data, val_data, full_data, noise, train_ixs, val_ixs


Expand Down Expand Up @@ -5199,7 +5209,7 @@ def try_kmeans(
feature_rank: int,
n_iter: int = 100,
with_proportions: bool = True,
drop_prop: float = 0.025,
drop_prop: float = 0.0,
kmeanspp_initial="random",
n_kmeans_tries: int = 25,
n_kmeanspp_tries: int = 25,
Expand Down Expand Up @@ -5253,10 +5263,7 @@ def evaluate_group_demolitions(
train_scores: Scores,
eval_scores: Scores,
) -> GroupDemolition:
# grab relevant sections of train and eval scores
group_ = group.to(train_scores.candidates)
in_group_train = torch.isin(train_scores.candidates, group_).any(dim=1)
(in_group_train,) = in_group_train.nonzero(as_tuple=True)

# select which units in the group are candidates for demolition
if mean_train_resp is None:
Expand Down Expand Up @@ -6061,36 +6068,30 @@ def _count_candidates(candidates, batch_candidate_counts, batch_size):

if TORCH_IS_OLD:

@torch.jit.script
def _combine_similar_resps(resps: Tensor, keep_mask: Tensor, n_keep: int) -> Tensor:
n_discard = resps.shape[1] - n_keep
assert n_discard + n_keep == resps.shape[1]
discard_mask = torch.logical_not(keep_mask)
discard_ix = discard_mask.nonzero()[:, 0]
keep_ix = keep_mask.nonzero()[:, 0]
assert keep_ix.numel() + discard_ix.numel() == resps.shape[1]
kept_resp = resps[:, keep_ix]
discard_resp = resps[:, discard_ix]
sim = discard_resp.T @ kept_resp
match = sim.argmax(1)
kept_resp[:, match] += discard_resp
return kept_resp
def _nonzero_static(x: Tensor, size: int):
nz = x.nonzero()
assert nz.numel() == size
return nz
else:

@torch.jit.script
def _combine_similar_resps(resps: Tensor, keep_mask: Tensor, n_keep: int) -> Tensor:
n_discard = resps.shape[1] - n_keep
assert n_discard + n_keep == resps.shape[1]
discard_mask = torch.logical_not(keep_mask)
discard_ix = discard_mask.nonzero_static(size=n_discard)[:, 0]
keep_ix = keep_mask.nonzero_static(size=n_keep)[:, 0]
assert keep_ix.numel() + discard_ix.numel() == resps.shape[1]
kept_resp = resps[:, keep_ix]
discard_resp = resps[:, discard_ix]
sim = discard_resp.T @ kept_resp
match = sim.argmax(1)
kept_resp[:, match] += discard_resp
return kept_resp
def _nonzero_static(x: Tensor, size: int):
return x.nonzero_static(size=size)


@torch.jit.script
def _combine_similar_resps(resps: Tensor, keep_mask: Tensor, n_keep: int) -> Tensor:
n_discard = resps.shape[1] - n_keep
assert n_discard + n_keep == resps.shape[1]
discard_mask = torch.logical_not(keep_mask)
discard_ix = _nonzero_static(discard_mask, size=n_discard)[:, 0]
keep_ix = _nonzero_static(keep_mask, size=n_keep)[:, 0]
assert keep_ix.numel() + discard_ix.numel() == resps.shape[1]
kept_resp = resps[:, keep_ix]
discard_resp = resps[:, discard_ix]
sim = discard_resp.T @ kept_resp
match = sim.argmax(1)
kept_resp[:, match] += discard_resp
return kept_resp


def concatenate_scores(scoress: list[Scores]) -> Scores:
Expand Down Expand Up @@ -6203,10 +6204,7 @@ def mean_responsibilities(
rsum_batch.zero_()

nc = int(ncand[bix].item())
if TORCH_IS_OLD:
cii, cjj = (cand[i0:i1] >= 0).nonzero().T
else:
cii, cjj = (cand[i0:i1] >= 0).nonzero_static(size=nc).T
cii, cjj = _nonzero_static(cand[i0:i1] >= 0, size=nc).T

c = cand[i0:i1][cii, cjj]
r = resp[i0:i1][cii, cjj].double()
Expand Down Expand Up @@ -6493,8 +6491,8 @@ def _sparsify_candidates(
assert cpos.any(dim=1).all()
if pnoid and static_size is not None:
assert cpos.sum() == static_size
if static_size is not None and not TORCH_IS_OLD:
spike_ixs, candidate_ixs = cpos.nonzero_static(size=static_size).T
if static_size is not None:
spike_ixs, candidate_ixs = _nonzero_static(cpos, size=static_size).T
else:
spike_ixs, candidate_ixs = cpos.nonzero(as_tuple=True)
neighb_ixs = neighborhood_ids[spike_ixs]
Expand Down
14 changes: 7 additions & 7 deletions src/dartsort/evaluate/hybrid_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..util.internal_config import ComputationConfig, unshifted_raw_template_cfg
from ..util.logging_util import progbar
from ..util.motion import MotionInfo
from ..util.py_util import resolve_path
from ..util.py_util import ensure_path
from . import analysis, comparison, simkit

logger = getLogger(__name__)
Expand Down Expand Up @@ -406,7 +406,7 @@ def load_dartsort_step_sortings(
use, although its not a guarantee... h5 locking... need to figure it out.
"""
mtime_dt = mtime_gap_minutes * 60 if mtime_gap_minutes else 0
sorting_dir = resolve_path(sorting_dir, strict=True)
sorting_dir = ensure_path(sorting_dir, strict=True)
if detection_h5_path is None:
for dh5n in detection_h5_names:
detection_h5_path = cast(Path, sorting_dir / dh5n)
Expand All @@ -416,12 +416,12 @@ def load_dartsort_step_sortings(
age = time.time() - detection_h5_path.stat().st_mtime
if age < mtime_dt:
continue
h5s = [resolve_path(detection_h5_path)]
h5s = [ensure_path(detection_h5_path)]
break
else:
h5s = []
else:
h5s = [resolve_path(detection_h5_path)]
h5s = [ensure_path(detection_h5_path)]

for j in range(1, 100):
mh5 = sorting_dir / f"matching{j}.h5"
Expand All @@ -430,7 +430,7 @@ def load_dartsort_step_sortings(
if mtime_dt:
if time.time() - mh5.stat().st_mtime < mtime_dt:
break
h5s.append(resolve_path(mh5))
h5s.append(ensure_path(mh5))

# let's check that there is at least something to do...
labels_npys = sorting_dir.glob("*_labels.npy")
Expand Down Expand Up @@ -464,9 +464,9 @@ def load_dartsort_step_sortings(
)
if no_npys and st0.labels is not None:
if hasattr(st0, "template_inds"):
yield(
yield (
name_formatter(f"{h5.stem}_template"),
st0.ephemeral_replace(labels=st0.template_inds)
st0.ephemeral_replace(labels=st0.template_inds),
)
yield name_formatter(h5.stem), st0
continue
Expand Down
18 changes: 11 additions & 7 deletions src/dartsort/evaluate/simkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from ..util.data_util import (
DARTsortSorting,
divide_randomly,
ensure_path,
extract_random_snips,
resolve_path,
)
from ..util.job_util import ensure_computation_config
from ..util.logging_util import get_logger, progbar
Expand Down Expand Up @@ -108,7 +108,7 @@ def generate_simulation(
pass

if noise_recording_folder is not None:
noise_recording_folder = resolve_path(noise_recording_folder)
noise_recording_folder = ensure_path(noise_recording_folder)
else:
assert noise_in_memory
duration_samples = int(duration_seconds * sampling_frequency)
Expand Down Expand Up @@ -146,7 +146,7 @@ def generate_simulation(
return

if folder is not None:
folder = resolve_path(folder)
folder = ensure_path(folder)
else:
assert no_save

Expand Down Expand Up @@ -214,7 +214,7 @@ def generate_simulation(


def load_simulation(folder):
folder = resolve_path(folder, strict=True)
folder = ensure_path(folder, strict=True)
recording_dir = folder / "recording"
templates_npz = folder / "templates.npz"
sorting_h5 = folder / "dartsort_sorting.h5"
Expand Down Expand Up @@ -252,7 +252,9 @@ def __init__(self, recording, **simulation_kwargs):
**simulation_kwargs,
)
)
self.segment = cast(InjectSpikesPreprocessorSegment, self._recording_segments[0])
self.segment = cast(
InjectSpikesPreprocessorSegment, self._recording_segments[0]
)

def basic_sorting(self) -> DARTsortSorting:
return self.segment.basic_sorting()
Expand Down Expand Up @@ -406,7 +408,9 @@ def save_features_to_hdf5(

# residual snippets
if n_residual_snips:
nrs_dset = h5.create_dataset("n_residuals", data=np.zeros((), dtype=np.int64))
nrs_dset = h5.create_dataset(
"n_residuals", data=np.zeros((), dtype=np.int64)
)
residual = h5.create_dataset(
"residual",
shape=(n_residual_snips, *self.segment.wf_shape),
Expand Down Expand Up @@ -486,7 +490,7 @@ def save_simulation(
save_collidedness=False,
chunk_len_s=0.5,
):
folder = resolve_path(folder)
folder = ensure_path(folder)
folder.mkdir(exist_ok=True)
recording_dir = folder / "recording"
templates_npz = folder / "templates.npz"
Expand Down
Loading
Loading