Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ Figure 7 can be reproduce using:

Figure 8 can be reproduce using:

* TODO Chris
* TODO Chris
* `cd notebooks/real_data_figure` : move to the folder where the environment is defined
* `uv run sort_all_real_data.py` : Sort three datasets using four sorters
* `uv run generate_curation_data.py` : Use UnitRefine, Bombcell and SLAy to curate the sorting output
* `uv run make_drift_plots.py` : Make the drift and probe plots
* `figure.tpy` : Generate the plot


Supplementary figure can be reproduce using::
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
106 changes: 106 additions & 0 deletions notebooks/real_data_figure/figure.typ
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#set page("us-letter")
#set text(size: 9pt, font: "New Computer Modern")

#table(
columns: (1fr, 0.087fr, 0.55fr),
// Text fits content, SVGs split remaining space
align: center,
gutter: 0pt,
stroke: none,

box(width: 100%)[
#table(
columns: (1.3fr, 1fr, 1fr, 1fr, 1fr, 1fr, 1.2fr, 1fr, 1fr),
inset: 4pt,
stroke: 0.5pt + gray,
fill: (x, y) => {
if x == 2 or x == 3 {
green.lighten(85%)
} else if x == 4 or x == 5 or x == 6 {
orange.lighten(85%)
} else if x == 7 or x == 8 {
red.lighten(85%)
} else if x == 0 or x == 1 {
gray.lighten(80%)
}
},
[Sorter], [Tot units], [BC good], [UR sua], [BC mua], [UR mua], [SLAy merges], [BC noise], [UR noise],
)
],
align()[],
align()[],

// --- Row 1 ---

box(width: 100%)[
*Lebedeva et. al., Chronic, NP2.0, 38 mins*
#table(
columns: (1.3fr, 1fr, 1fr, 1fr, 1fr, 1fr, 1.2fr, 1fr, 1fr),
inset: 4pt,
stroke: 0.5pt + gray,
fill: (x, y) => {
if x == 2 or x == 3 {
green.lighten(85%)
} else if x == 4 or x == 5 or x == 6 {
orange.lighten(85%)
} else if x == 7 or x == 8 {
red.lighten(85%)
}
},
[KS4], [808], [246], [288], [489], [312], [5], [73], [208],
[Lupin], [683], [216], [259], [460], [338], [2], [7], [86],
[TCD2], [617], [119], [246], [491], [369], [13], [7], [2],
[SC2], [270], [102], [133], [166], [137], [1], [2], [0],
)
],
align()[#image("drift_maps_and_probes/ucl_probe.png", height: 12.6%)],
image("drift_maps_and_probes/ucl_drift.svg"),
// --- Row 2 ---
box(width: 100%)[
*IBL, Acute, NP1.0, 67 mins*
#table(
columns: (1.3fr, 1fr, 1fr, 1fr, 1fr, 1fr, 1.2fr, 1fr, 1fr),
inset: 4pt,
stroke: 0.5pt + gray,
fill: (x, y) => {
if x == 2 or x == 3 {
green.lighten(85%)
} else if x == 4 or x == 5 or x == 6 {
orange.lighten(85%)
} else if x == 7 or x == 8 {
red.lighten(85%)
}
},
[KS4], [1050], [210], [459], [673], [354], [24], [167], [237],
[Lupin], [864], [209], [379], [601], [278], [6], [54], [207],
[TDC2], [954], [124], [417], [778], [504], [33], [52], [33],
[SC2], [458], [97], [170], [333], [271], [0], [28], [17],
)
],
align()[#image("drift_maps_and_probes/IBL_probe.png", height: 11.9%)],
image("drift_maps_and_probes/IBL_drift.svg"),
// --- Row 3 ---
box(width: 100%)[
*Duszkiewicz et. al., Chronic, CN 156H5, 211 mins*
#table(
columns: (1.3fr, 1fr, 1fr, 1fr, 1fr, 1fr, 1.2fr, 1fr, 1fr),
inset: 4pt,
stroke: 0.5pt + gray,
fill: (x, y) => {
if x == 2 or x == 3 {
green.lighten(85%)
} else if x == 4 or x == 5 or x == 6 {
orange.lighten(85%)
} else if x == 7 or x == 8 {
red.lighten(85%)
}
},
[KS4], [174], [41], [71], [98], [68], [2], [35], [35],
[Lupin], [162], [56], [96], [103], [63], [4], [3], [3],
[TDC2], [191], [11], [60], [180], [128], [9], [0], [3],
[SC2], [58], [4], [9], [53], [47], [0], [1], [2],
)
],
align()[#image("drift_maps_and_probes/Duszkiewicz_probe.png", height: 12.3%)],
image("drift_maps_and_probes/Duszkiewicz_drift.svg"),
)
73 changes: 73 additions & 0 deletions notebooks/real_data_figure/generate_curation_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Generates curation results from computed analyzers.

Once you have the analyzers, run this code by `cd`ing into the `real_data_figure`
folder, then running
>>> uv run generate_curation_data.py

"""

import spikeinterface.full as si
import numpy as np
import pandas as pd
from pathlib import Path

repo_folder = Path("/home/nolanlab/fromgit/sorting_components_benchmark_paper/")
real_data_figure_folder = repo_folder / "notebooks/real_data_figure"
analyzers_folder = real_data_figure_folder / "analyzers"

dataset_protocols = {
'IBL': ['kilosort4_motion_correction', 'lupin_motion_correction', 'tridesclous2_motion_correction','spykingcircus2_motion_correction'],
'ucl': ['kilosort4_no_motion_correction', 'lupin_no_motion_correction', 'tridesclous2_no_motion_correction','spykingcircus2_no_motion_correction'],
'Duszkiewicz': ['kilosort4_no_motion_correction', 'lupin_no_motion_correction', 'tridesclous2_no_motion_correction','spykingcircus2_no_motion_correction'],
}

bombcell_labels = ['good', 'mua', 'noise', 'non_soma_good', 'non_soma_mua']
unitrefine_labels = ['sua', 'mua', 'noise']
merge_presets = ['slay']

for dataset_name, protocols in dataset_protocols.items():
bombcell_results = []
unitrefine_results = []
all_protocols_data = []
for protocol in protocols:

analyzer_path = analyzers_folder / f"{dataset_name}_{protocol}_analyzer"
if analyzer_path.is_dir():
analyzer = si.load_sorting_analyzer(analyzer_path)
else:
analyzer = si.load_sorting_analyzer(str(analyzer_path) + '.zarr')

bombcell_unit_label = si.bombcell_label_units(analyzer, split_non_somatic_good_mua=True)['bombcell_label'].values
bombcell_results = {label: np.sum(bombcell_unit_label == label) for label in bombcell_labels}

# You need to donwload the UnitRefine models `noise_neural_classifier_lightweight` and `sua_mua_classifier_lightweight` from
# https://huggingface.co/AnoushkaJain3
unitrefine_unit_label = si.unitrefine_label_units(analyzer, noise_neural_classifier='/home/nolanlab/Downloads/noise_neural_classifier_lightweight', sua_mua_classifier='/home/nolanlab/Downloads/sua_mua_classifier_lightweight')
unitrefine_results = {label: np.sum(unitrefine_unit_label['unitrefine_label'] == label) for label in unitrefine_labels}

merge_results = {merge_preset: len(si.compute_merge_unit_groups(analyzer, preset=merge_preset)) for merge_preset in merge_presets}

protocol_data = [
protocol,
analyzer.get_num_units(),
bombcell_results['good'] + bombcell_results['non_soma_good'],
unitrefine_results['sua'],
bombcell_results['mua'] + bombcell_results['non_soma_mua'],
unitrefine_results['mua'],
merge_results['slay'],
bombcell_results['noise'],
unitrefine_results['noise'],
]

all_protocols_data.append(protocol_data)

results = pd.DataFrame(all_protocols_data, columns=["sorter", "total units", "bombcell good", "unitrefine sua", "bombcell mua", "unitrefine mua", "# slay merges", "bombcell noise", "unitrefine noise"], index=None)

results.to_csv(real_data_figure_folder / f"curation_results/{dataset_name}_results.csv", index=False)

# render for typst rendering
for row in results.iterrows():
for cell in row[1]:
print(f"[{cell}], ", end="")
print("")
111 changes: 111 additions & 0 deletions notebooks/real_data_figure/make_drift_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import spikeinterface.full as si

from pathlib import Path

repo_folder = Path("/home/nolanlab/fromgit/sorting_components_benchmark_paper/")
real_data_figure_folder = repo_folder / "notebooks/real_data_figure"
analyzers_folder = real_data_figure_folder / "analyzers"
drift_maps_folder = real_data_figure_folder / "drift_maps_and_probes"

bombcell_labels = ['good', 'mua', 'noise', 'non_soma_good', 'non_soma_mua']

protocol = 'no_motion_correction'

FONT_SIZE = 18

plotting_settings = {
'ucl': {
'protocol': 'no_motion_correction',
'vmin': -600,
'scatter_decimate': 20,
'cbar_ticks': [-600,-500,-400,-300,-200,-100,0],
'cbar_ticklabels': ['600','','400','','200','','0'],
'yticklabels': ['','2.9', '', '3.1', '', '3.3', '', '3.5'],
'xticks_s': [0,600,1200,1800],
},
'IBL': {
'protocol': 'motion_correction',
'vmin': -457.829994,
'scatter_decimate': 20,
'cbar_ticks': [-600,-500,-400,-300,-200,-100,0],
'cbar_ticklabels': ['600','','400','','200','','0'],
'yticklabels': ['', '0', '', '1', '', '2', '', '3', ''],
'xticks_s': [0,900,1800,2700,3600],
},
'Duszkiewicz': {
'protocol': 'no_motion_correction',
'vmin': -380,
'scatter_decimate': 5,
'cbar_ticks': [-400,-300,-200,-100,0],
'cbar_ticklabels': [400,300,200,100,0],
'yticklabels': ['', '', '0.2', '', '0.4', '', '0.6', '', '0.8'],
'xticks_s': [0,3000,6000,9000,12000],
}
}

for dataset_name, dataset_settings in plotting_settings.items():

protocol = dataset_settings['protocol']

analyzer_path = analyzers_folder / f'{dataset_name}_kilosort4_{protocol}_analyzer'
if analyzer_path.is_dir():
analyzer = si.load_sorting_analyzer(analyzer_path)
else:
analyzer = si.load_sorting_analyzer(str(analyzer_path) + '.zarr')

print(analyzer.get_total_duration())

bombcell_unit_labels = si.bombcell_label_units(analyzer, split_non_somatic_good_mua=True)['bombcell_label'].values
good_units = analyzer.unit_ids[bombcell_unit_labels == 'good']
analyzer_good = analyzer.select_units(good_units)

cmap_name = 'inferno'

fig = si.plot_drift_raster_map(
sorting_analyzer=analyzer_good,
cmap=cmap_name,
alpha=0.10,
scatter_decimate=dataset_settings['scatter_decimate'],
figsize=(8,4.5)
)
# 1. Define your parameters
vmin = dataset_settings['vmin']
vmax = 0

# 2. Create the Normalization and Mappable objects
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
sm = cm.ScalarMappable(cmap=plt.get_cmap(cmap_name), norm=norm)
sm.set_array([]) # Required for the colorbar to initialize correctly

# 3. Access your existing figure/axes and add the colorbar
# Assuming 'fig' is your figure object
ax = fig.figure.get_axes()[0]
cbar = fig.figure.colorbar(sm, ax=ax)

# 2. Find the scatter plot and rasterize it
# In Matplotlib, scatter plots are usually 'PathCollection' objects
for artist in ax.get_children():
if isinstance(artist, plt.matplotlib.collections.PathCollection):
artist.set_rasterized(True)

# 4. (Optional) Add a label
cbar_ticks = dataset_settings['cbar_ticks']
#cbar_ticks = [-80,-60,-40,-20,0]
cbar.set_label('Abs peak amplitude [uV]', fontsize=FONT_SIZE)
cbar.set_ticklabels(dataset_settings['cbar_ticklabels'])
cbar.ax.tick_params(labelsize=FONT_SIZE) # Font size for colorbar ticks

ax.set_ylabel('Depth [mm]', fontsize=FONT_SIZE)
ax.set_yticklabels(dataset_settings['yticklabels'], fontsize=FONT_SIZE)

xticks_s = dataset_settings['xticks_s']
ax.set_xticks(xticks_s)
ax.set_xticklabels([int(xtick/60) for xtick in xticks_s], fontsize=FONT_SIZE)
ax.set_xlabel('Time [min]', fontsize=FONT_SIZE)

ax.set_title(label=None)

fig.figure.savefig(drift_maps_folder / f'{dataset_name}_drift.svg', bbox_inches='tight')
24 changes: 24 additions & 0 deletions notebooks/real_data_figure/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[project]
name = "real-data-figure"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"h5py>=3.16.0",
"hdbscan>=0.8.42",
"ipykernel>=7.2.0",
"kilosort==4.1.2",
"matplotlib>=3.10.9",
"numba>=0.65.1",
"pandas>=3.0.2",
"scikit-learn==1.6",
"scipy>=1.17.1",
"skops>=0.14.0",
"spikeinterface>=0.104.3",
]

[dependency-groups]
dev = [
"ruff>=0.15.12",
]
Loading