diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index c9754f9..249ea65 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -56,7 +56,7 @@ def get_unit_data(self, unit_id, segment_index=0): return spike_times, spike_data, np.array([1]), np.array([ymin, ymax]), ymin, ymax, inds # avoid clear outliers in the plot and histogram by using percentiles - ymin, ymax = np.percentile(spike_data, [self.settings['display_low_percentiles'], self.settings['display_high_percentiles']]) + ymin, ymax = np.percentile(spike_data[~np.isnan(spike_data)], [self.settings['display_low_percentiles'], self.settings['display_high_percentiles']]) min_bin_size = np.min(np.diff(np.unique(spike_data))) bins = np.linspace(ymin, ymax, self.settings['num_bins']) # if bins are too small, adjust the number of bins to ensure a minimum bin size and avoid jumps in the histogram @@ -329,8 +329,8 @@ def _qt_refresh(self, set_scatter_range=False): # set x range to time range of the current segment for scatter, and max count for histogram # set y range to min and max of visible spike amplitudes if len(ymins) > 0 and (set_scatter_range or not self._first_refresh_done): - ymin = np.min(ymins) - ymax = np.max(ymaxs) + ymin = np.nanmin(ymins) + ymax = np.nanmax(ymaxs) t_start, t_stop = self.controller.get_t_start_t_stop() self.viewBox.setXRange(t_start, t_stop, padding = 0.0) self.viewBox.setYRange(ymin, ymax, padding = 0.0) diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index 6f3f60c..342368a 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -1,4 +1,5 @@ import time +from copy import deepcopy import numpy as np @@ -10,9 +11,13 @@ from spikeinterface import compute_sparsity from spikeinterface.core import get_template_extremum_channel, BaseEvent from spikeinterface.core.sorting_tools import spike_vector_to_indices -from spikeinterface.curation import validate_curation_dict +from spikeinterface.core.core_tools import check_json +from spikeinterface.curation import validate_curation_dict, apply_curation from spikeinterface.curation.curation_model import Curation from spikeinterface.widgets.utils import make_units_table_from_analyzer +from spikeinterface.widgets.utils import make_units_table_from_analyzer + +from .utils_global import add_new_unit_ids_to_curation_dict from .curation_tools import add_merge, default_label_definitions, empty_curation_data from .event_tools import parse_events @@ -25,7 +30,9 @@ _default_main_settings = dict( max_visible_units=10, color_mode='color_by_unit', - use_times=False + use_times=False, + merge_new_id_strategy = 'take_first', + split_new_id_strategy = 'append', ) from spikeinterface.widgets.sorting_summary import _default_displayed_unit_properties @@ -40,6 +47,7 @@ def __init__( verbose=False, save_on_compute=False, curation=False, + iterative_curation=False, curation_data=None, label_definitions=None, with_traces=True, @@ -60,6 +68,10 @@ def __init__( self.backend = backend self.disable_save_settings_button = disable_save_settings_button self.current_curation_saved = True + self.applied_curations = [] + + if extra_unit_properties is None: + self.extra_unit_properties_names = [] self.external_data = external_data if self.backend == "qt": @@ -72,19 +84,43 @@ def __init__( self.with_traces = with_traces - self.analyzer = analyzer - assert self.analyzer.get_extension("random_spikes") is not None - - self.return_in_uV = self.analyzer.return_in_uV self.save_on_compute = save_on_compute self.verbose = verbose - t0 = time.perf_counter() + self.original_analyzer = None self.main_settings = _default_main_settings.copy() if user_main_settings is not None: self.main_settings.update(user_main_settings) + self.set_analyzer_info(analyzer) + self.units_table = make_units_table_from_analyzer(self.analyzer, extra_properties=extra_unit_properties) + + self.set_curation_info(curation, iterative_curation, curation_data, label_definitions, curation_callback, curation_callback_kwargs) + + # parse events + self.events = None + if events is not None: + self.events = parse_events(events, self, verbose=verbose) + if len(self.events) == 0: + self.events = None + + if displayed_unit_properties is None: + displayed_unit_properties = list(_default_displayed_unit_properties) + if extra_unit_properties is not None: + self.extra_unit_properties_names = list(extra_unit_properties.keys()) + displayed_unit_properties += self.extra_unit_properties_names + displayed_unit_properties = [v for v in displayed_unit_properties if v in self.units_table.columns] + self.displayed_unit_properties = displayed_unit_properties + + def set_analyzer_info(self, analyzer): + + self.analyzer = analyzer + assert self.analyzer.get_extension("random_spikes") is not None + + self.return_in_uV = self.analyzer.return_in_uV + t0 = time.perf_counter() + self.num_channels = self.analyzer.get_num_channels() # this now private and should be access using function self._visible_unit_ids = [self.unit_ids[0]] @@ -98,7 +134,7 @@ def __init__( self.analyzer_sparsity = self.analyzer.sparsity # Mandatory extensions: computation forced - if verbose: + if self.verbose: print('\tLoading templates') temp_ext = self.analyzer.get_extension("templates") if temp_ext is None: @@ -112,7 +148,7 @@ def __init__( else: self.templates_std = None - if verbose: + if self.verbose: print('\tLoading unit_locations') ext = analyzer.get_extension('unit_locations') if ext is None: @@ -122,7 +158,7 @@ def __init__( self.unit_positions = ext.get_data()[:, :2] # Optional extensions : can be None or skipped - if verbose: + if self.verbose: print('\tLoading noise_levels') ext = analyzer.get_extension('noise_levels') if ext is None and self.has_extension('recording'): @@ -130,12 +166,12 @@ def __init__( ext = analyzer.compute_one_extension('noise_levels') self.noise_levels = ext.get_data() if ext is not None else None - if "quality_metrics" in skip_extensions: + if "quality_metrics" in self.skip_extensions: if self.verbose: print('\tSkipping quality_metrics') self.metrics = None else: - if verbose: + if self.verbose: print('\tLoading quality_metrics') qm_ext = analyzer.get_extension('quality_metrics') if qm_ext is not None: @@ -143,12 +179,12 @@ def __init__( else: self.metrics = None - if "spike_amplitudes" in skip_extensions: + if "spike_amplitudes" in self.skip_extensions: if self.verbose: print('\tSkipping spike_amplitudes') self.spike_amplitudes = None else: - if verbose: + if self.verbose: print('\tLoading spike_amplitudes') sa_ext = analyzer.get_extension('spike_amplitudes') if sa_ext is not None: @@ -156,12 +192,12 @@ def __init__( else: self.spike_amplitudes = None - if "amplitude_scalings" in skip_extensions: + if "amplitude_scalings" in self.skip_extensions: if self.verbose: print('\tSkipping amplitude_scalings') self.amplitude_scalings = None else: - if verbose: + if self.verbose: print('\tLoading amplitude_scalings') sa_ext = analyzer.get_extension('amplitude_scalings') if sa_ext is not None: @@ -169,12 +205,12 @@ def __init__( else: self.amplitude_scalings = None - if "spike_locations" in skip_extensions: + if "spike_locations" in self.skip_extensions: if self.verbose: print('\tSkipping spike_locations') self.spike_depths = None else: - if verbose: + if self.verbose: print('\tLoading spike_locations') sl_ext = analyzer.get_extension('spike_locations') if sl_ext is not None: @@ -182,13 +218,13 @@ def __init__( else: self.spike_depths = None - if "correlograms" in skip_extensions: + if "correlograms" in self.skip_extensions: if self.verbose: print('\tSkipping correlograms') self.correlograms = None self.correlograms_bins = None else: - if verbose: + if self.verbose: print('\tLoading correlograms') ccg_ext = analyzer.get_extension('correlograms') if ccg_ext is not None: @@ -196,13 +232,13 @@ def __init__( else: self.correlograms, self.correlograms_bins = None, None - if "isi_histograms" in skip_extensions: + if "isi_histograms" in self.skip_extensions: if self.verbose: print('\tSkipping isi_histograms') self.isi_histograms = None self.isi_bins = None else: - if verbose: + if self.verbose: print('\tLoading isi_histograms') isi_ext = analyzer.get_extension('isi_histograms') if isi_ext is not None: @@ -211,11 +247,11 @@ def __init__( self.isi_histograms, self.isi_bins = None, None self._similarity_by_method = {} - if "template_similarity" in skip_extensions: + if "template_similarity" in self.skip_extensions: if self.verbose: print('\tSkipping template_similarity') else: - if verbose: + if self.verbose: print('\tLoading template_similarity') ts_ext = analyzer.get_extension('template_similarity') if ts_ext is not None: @@ -228,12 +264,12 @@ def __init__( ts_ext = analyzer.compute_one_extension('template_similarity', method=method, save=save_on_compute) self._similarity_by_method[method] = ts_ext.get_data() - if "waveforms" in skip_extensions: + if "waveforms" in self.skip_extensions: if self.verbose: print('\tSkipping waveforms') self.waveforms_ext = None else: - if verbose: + if self.verbose: print('\tLoading waveforms') wf_ext = analyzer.get_extension('waveforms') if wf_ext is not None: @@ -241,12 +277,12 @@ def __init__( else: self.waveforms_ext = None self._pc_projections = None - if "principal_components" in skip_extensions: + if "principal_components" in self.skip_extensions: if self.verbose: print('\tSkipping principal_components') self.pc_ext = None else: - if verbose: + if self.verbose: print('\tLoading principal_components') pc_ext = analyzer.get_extension('principal_components') self.pc_ext = pc_ext @@ -262,15 +298,8 @@ def __init__( self.num_segments = self.analyzer.get_num_segments() self.sampling_frequency = self.analyzer.sampling_frequency - # parse events - self.events = None - if events is not None: - self.events = parse_events(events, self, verbose=verbose) - if len(self.events) == 0: - self.events = None - t1 = time.perf_counter() - if verbose: + if self.verbose: print('Loading extensions took', t1 - t0) t0 = time.perf_counter() @@ -332,7 +361,7 @@ def __init__( self._spike_index_by_units[unit_id] = np.concatenate(inds) t1 = time.perf_counter() - if verbose: + if self.verbose: print('Gathering all spikes took', t1 - t0) self._spike_visible_indices = np.array([], dtype='int64') @@ -341,22 +370,20 @@ def __init__( self._traces_cached = {} - self.units_table = make_units_table_from_analyzer(analyzer, extra_properties=extra_unit_properties) - - if displayed_unit_properties is None: - displayed_unit_properties = list(_default_displayed_unit_properties) - if extra_unit_properties is not None: - displayed_unit_properties += list(extra_unit_properties.keys()) - displayed_unit_properties = [v for v in displayed_unit_properties if v in self.units_table.columns] - self.displayed_unit_properties = displayed_unit_properties - # set default time info self.update_time_info() + + def set_curation_info(self, curation, iterative_curation, curation_data, label_definitions, curation_callback, curation_callback_kwargs): + self.iterative_curation = iterative_curation + if self.iterative_curation: + curation = True self.curation = curation self.curation_callback = curation_callback self.curation_callback_kwargs = curation_callback_kwargs + self._potential_merges = None + # TODO: Reload the dictionary if it already exists if self.curation: # rules: # * if user sends curation_data, then it is used @@ -375,6 +402,24 @@ def __init__( except Exception as e: raise ValueError(f"Invalid curation data.\nError: {e}") + if curation_data.get("merges") is None: + curation_data["merges"] = [] + else: + # here we reset the merges for better formatting (str) + existing_merges = curation_data["merges"] + new_merges = [] + for m in existing_merges: + if "unit_ids" not in m: + continue + if len(m["unit_ids"]) < 2: + continue + new_merges = add_merge(new_merges, m["unit_ids"]) + curation_data["merges"] = new_merges + if curation_data.get("splits") is None: + curation_data["splits"] = [] + if curation_data.get("removed") is None: + curation_data["removed"] = [] + elif self.analyzer.format == "binary_folder": json_file = self.analyzer.folder / "spikeinterface_gui" / "curation_data.json" if json_file.exists(): @@ -390,26 +435,23 @@ def __init__( if curation_data is None: curation_data = deepcopy(empty_curation_data) curation_data["unit_ids"] = self.unit_ids.tolist() + curation_data["label_definitions"] = default_label_definitions.copy() - if "label_definitions" not in curation_data: + self.curation_data = curation_data + + if "label_definitions" not in self.curation_data: if label_definitions is not None: - curation_data["label_definitions"] = label_definitions - else: - curation_data["label_definitions"] = default_label_definitions.copy() + self.curation_data["label_definitions"] = label_definitions - # This will enable the default shortcuts if has default quality labels self.has_default_quality_labels = False - if "quality" in curation_data["label_definitions"]: - curation_dict_quality_labels = curation_data["label_definitions"]["quality"]["label_options"] + if "quality" in self.curation_data["label_definitions"]: + curation_dict_quality_labels = self.curation_data["label_definitions"]["quality"]["label_options"] default_quality_labels = default_label_definitions["quality"]["label_options"] if set(curation_dict_quality_labels) == set(default_quality_labels): if self.verbose: print('Curation quality labels are the default ones') self.has_default_quality_labels = True - curation_data = Curation(**curation_data).model_dump() - self.curation_data = curation_data - def check_is_view_possible(self, view_name): from .viewlist import get_all_possible_views possible_class_views = get_all_possible_views() @@ -548,15 +590,22 @@ def get_information_txt(self): return txt - def refresh_colors(self): + def refresh_colors(self, existing_colors=None): if self.backend == "qt": self._cached_qcolors = {} elif self.backend == "panel": pass if self.main_settings['color_mode'] == 'color_by_unit': - self.colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar', - shuffle=True, seed=42) + unit_colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar', + shuffle=True, seed=42) + if existing_colors is None: + self.colors = unit_colors + else: + for unit_id, unit_color in unit_colors.items(): + if unit_id not in self.colors.keys(): + self.colors[unit_id] = unit_color + elif self.main_settings['color_mode'] == 'color_only_visible': unit_colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar', shuffle=True, seed=42) @@ -861,9 +910,6 @@ def compute_isi_histograms(self, window_ms, bin_ms): self.isi_histograms, self.isi_bins = ext.get_data() return self.isi_histograms, self.isi_bins - def get_units_table(self): - return self.units_table - def compute_auto_merge(self, **params): from spikeinterface.curation import compute_merge_unit_groups @@ -880,14 +926,58 @@ def compute_auto_merge(self, **params): def curation_can_be_saved(self): return self.analyzer.format != "memory" - def construct_final_curation(self): + def construct_final_curation(self, with_explicit_new_unit_ids=False): d = dict() d["format_version"] = "2" d["unit_ids"] = self.unit_ids.tolist() d.update(self.curation_data.copy()) + if with_explicit_new_unit_ids: + split_new_id_strategy = self.main_settings.get('split_new_id_strategy') + merge_new_id_strategy = self.main_settings.get('merge_new_id_strategy') + d = add_new_unit_ids_to_curation_dict(d, self.analyzer.sorting, split_new_id_strategy=split_new_id_strategy, merge_new_id_strategy=merge_new_id_strategy) + model = Curation(**d) return model + def apply_curation(self): + + if self.original_analyzer is None: + self.original_analyzer = deepcopy(self.analyzer) + self.original_analyzer.extensions = {} + + curation = self.construct_final_curation(with_explicit_new_unit_ids=True) + curated_analyzer = apply_curation(self.analyzer, curation) + + self.applied_curations.append(curation) + self.remove_curation(curated_analyzer) + + self.set_analyzer_info(curated_analyzer) + + # for now, don't show externally provided properties after curation + self.displayed_unit_properties = [displayed_property for displayed_property in self.displayed_unit_properties if displayed_property not in self.extra_unit_properties_names] + self.units_table = make_units_table_from_analyzer(self.analyzer) + self.refresh_colors(existing_colors=self.colors) + + for view in self.views: + view.reinitialize() + + def remove_curation(self, curated_analyzer): + """Removes curation from the controller, retaining quality labels.""" + + curation_data = deepcopy(empty_curation_data) + # retain label definitions and 'quality' label + label_definitioins = self.curation_data.get("label_definitions", None) + curation_data["label_definitions"] = label_definitioins + + if (quality_labels := curated_analyzer.get_sorting_property('quality')) is not None: + manual_labels = [] + for unit_id, quality_label in zip(curated_analyzer.unit_ids, quality_labels): + manual_labels.append({'unit_id': unit_id, 'labels': {'quality': [quality_label]}}) + + curation_data['manual_labels'] = manual_labels + + self.curation_data = curation_data + def set_curation_data(self, curation_data): print("Setting curation data") new_curation_data = empty_curation_data.copy() diff --git a/spikeinterface_gui/correlogramview.py b/spikeinterface_gui/correlogramview.py index 9ca6fa6..8e2d585 100644 --- a/spikeinterface_gui/correlogramview.py +++ b/spikeinterface_gui/correlogramview.py @@ -48,6 +48,10 @@ def _qt_make_layout(self): self.grid = pg.GraphicsLayoutWidget() self.layout.addWidget(self.grid) + def _reinitialize(self): + self.ccg, self.bins = self.controller.get_correlograms() + self.figure_cache = {} + self._refresh() def _qt_refresh(self): import pyqtgraph as pg diff --git a/spikeinterface_gui/curationview.py b/spikeinterface_gui/curationview.py index 21e66c7..2e5980f 100644 --- a/spikeinterface_gui/curationview.py +++ b/spikeinterface_gui/curationview.py @@ -1,10 +1,8 @@ -import json from pathlib import Path from .view_base import ViewBase -from spikeinterface.core.core_tools import check_json - +from spikeinterface.curation.curation_model import SequentialCuration class CurationView(ViewBase): id = "curation" @@ -70,10 +68,17 @@ def _qt_make_layout(self): but = QT.QPushButton("Save curation") tb.addWidget(but) but.clicked.connect(self.controller.save_curation_callback) - elif self.controller.curation_can_be_saved(): + + elif self.controller.curation_can_be_saved() and not self.controller.iterative_curation: but = QT.QPushButton("Save in analyzer") tb.addWidget(but) but.clicked.connect(self.controller.save_curation_in_analyzer) + + if self.controller.iterative_curation: + but_apply = QT.QPushButton("Apply curation") + tb.addWidget(but_apply) + but_apply.clicked.connect(self.apply_curation_to_analyzer) + but = QT.QPushButton("Export JSON") but.clicked.connect(self._qt_export_json) tb.addWidget(but) @@ -277,6 +282,10 @@ def _qt_on_unit_visibility_changed(self): def on_manual_curation_updated(self): self.refresh() + def apply_curation_to_analyzer(self): + with self.busy_cursor(): + self.controller.apply_curation() + def _qt_export_json(self): from .myqt import QT @@ -286,10 +295,20 @@ def _qt_export_json(self): fd.setViewMode(QT.QFileDialog.Detail) if fd.exec_(): json_file = Path(fd.selectedFiles()[0]) - curation_model = self.controller.construct_final_curation() - with json_file.open("w") as f: - f.write(curation_model.model_dump_json(indent=4)) - self.controller.current_curation_saved = True + if len(self.controller.applied_curations) == 0: + curation_model = self.controller.construct_final_curation() + with json_file.open("w") as f: + f.write(curation_model.model_dump_json(indent=4)) + self.controller.current_curation_saved = True + else: + current_curation_model = self.controller.construct_final_curation() + applied_curations = self.controller.applied_curations + current_and_applied_curations = applied_curations + [current_curation_model] + + sequential_curation_model = SequentialCuration(curation_steps=current_and_applied_curations) + with json_file.open("w") as f: + f.write(sequential_curation_model.model_dump_json(indent=4)) + self.controller.current_curation_saved = True # PANEL def _panel_make_layout(self): @@ -350,22 +369,45 @@ def _panel_make_layout(self): ) # Create buttons + + save_buttons = [] + + if self.controller.iterative_curation: + apply_button = pn.widgets.Button( + name="Apply curation", + button_type="primary", + height=30 + ) + apply_button.on_click(self._panel_apply_curation_to_analyzer) + save_buttons.append(apply_button) + + if not self.controller.iterative_curation and self.controller.curation_callback is None: + save_button_name = "Save in analyzer" + save_button_callback = self._panel_save_in_analyzer + save_button = pn.widgets.Button( + name=save_button_name, + button_type="primary", + height=30 + ) + save_button.on_click(save_button_callback) + save_buttons.append(save_button) + if self.controller.curation_callback is not None: save_button_name = "Save curation" save_button_callback = self._panel_save_curation_callback - else: - save_button_name = "Save in analyzer" - save_button_callback = self._panel_save_in_analyzer - save_button = pn.widgets.Button( - name=save_button_name, - button_type="primary", - height=30 - ) - save_button.on_click(save_button_callback) + save_button = pn.widgets.Button( + name=save_button_name, + button_type="primary", + height=30 + ) + save_button.on_click(save_button_callback) + save_buttons.append(save_button) + download_button = pn.widgets.FileDownload( button_type="primary", filename="curation.json", callback=self._panel_generate_json, height=30 ) + save_buttons.append(download_button) restore_button = pn.widgets.Button(name="Restore", button_type="primary", height=30) restore_button.on_click(self._panel_restore_units) @@ -378,8 +420,7 @@ def _panel_make_layout(self): # Create layout buttons_save = pn.Row( - save_button, - download_button, + *save_buttons, sizing_mode="stretch_width", ) save_sections = pn.Column( @@ -450,7 +491,7 @@ def _panel_refresh(self): def _panel_ensure_save_warning_message(self): - if self.layout[0].name == "curation_save_warning": + if self.layout[0].name == "curation_save_warning" or self.layout[0].name == "busy...": return import panel as pn @@ -495,6 +536,9 @@ def _panel_restore_units(self, event): def _panel_unmerge(self, event): self.unmerge() + def _panel_apply_curation_to_analyzer(self, event): + self.apply_curation_to_analyzer() + def _panel_unsplit(self, event): self.unsplit() diff --git a/spikeinterface_gui/isiview.py b/spikeinterface_gui/isiview.py index f9fa293..c894ce0 100644 --- a/spikeinterface_gui/isiview.py +++ b/spikeinterface_gui/isiview.py @@ -25,6 +25,10 @@ def _on_settings_changed(self): self.isi_histograms, self.isi_bins = None, None self.refresh() + def _reinitialize(self): + self.isi_histograms, self.isi_bins = self.controller.get_isi_histograms() + self._refresh() + ## QT ## def _qt_make_layout(self): diff --git a/spikeinterface_gui/main.py b/spikeinterface_gui/main.py index 050b614..6e4e4a9 100644 --- a/spikeinterface_gui/main.py +++ b/spikeinterface_gui/main.py @@ -21,6 +21,7 @@ def run_mainwindow( mode: str = "desktop", with_traces: bool = True, curation: bool = False, + iterative_curation: bool = False, curation_dict: dict | None = None, label_definitions: dict | None = None, displayed_unit_properties: list | None=None, @@ -57,6 +58,8 @@ def run_mainwindow( If True, traces are displayed curation: bool, default: False If True, the curation panel is displayed + iterative_curation: bool, default: False + If True, a user can iteratively curate their analyzer curation_dict: dict | None, default: None The curation dictionary to start from an existing curation label_definitions: dict | None, default: None @@ -154,6 +157,7 @@ def run_mainwindow( backend=backend, verbose=verbose, curation=curation, + iterative_curation=iterative_curation, curation_data=curation_dict, label_definitions=label_definitions, with_traces=with_traces, @@ -313,6 +317,7 @@ def run_mainwindow_cli(): parser.add_argument('--mode', help='Mode desktop or web', default='desktop') parser.add_argument('--no-traces', help='Do not show traces', action='store_true', default=False) parser.add_argument('--curation', help='Enable curation panel', action='store_true', default=False) + parser.add_argument('--iterative_curation', help='Enable iterative curation', action='store_true', default=False) parser.add_argument('--recording', help='Path to a recording file (.json/.pkl) or folder that can be loaded with spikeinterface.load', default=None) parser.add_argument('--recording-base-folder', help='Base folder path for the recording (if .json/.pkl)', default=None) parser.add_argument('--verbose', help='Make the output verbose', action='store_true', default=False) @@ -384,6 +389,7 @@ def run_mainwindow_cli(): mode=args.mode, with_traces=not(args.no_traces), curation=args.curation, + iterative_curation = args.iterative_curation, recording=recording, skip_extensions=skip_extensions_list, verbose=args.verbose, @@ -414,4 +420,4 @@ def find_skippable_extensions(layout_dict): skippable_extensions = list(all_extensions.difference(set(needed_extensions))) - return skippable_extensions \ No newline at end of file + return skippable_extensions diff --git a/spikeinterface_gui/mainsettingsview.py b/spikeinterface_gui/mainsettingsview.py index 79a2638..abdeb9f 100644 --- a/spikeinterface_gui/mainsettingsview.py +++ b/spikeinterface_gui/mainsettingsview.py @@ -8,7 +8,9 @@ {'name': 'max_visible_units', 'type': 'int', 'value' : 10 }, {'name': 'color_mode', 'type': 'list', 'value' : 'color_by_unit', 'limits': ['color_by_unit', 'color_only_visible', 'color_by_visibility']}, - {'name': 'use_times', 'type': 'bool', 'value': False} + {'name': 'use_times', 'type': 'bool', 'value': False}, + {'name': 'merge_new_id_strategy', 'type': 'list', 'limits' : ['take_first', 'append', 'join'], 'value': 'take_first'}, + {'name': 'split_new_id_strategy', 'type': 'list', 'limits' : ['append', 'split'], 'value': 'append'}, ] @@ -51,6 +53,12 @@ def on_use_times(self): self.controller.update_time_info() self.notify_use_times_updated() + def on_merge_new_id_strategy(self): + self.controller.main_settings['merge_new_id_strategy'] = self.main_settings['merge_new_id_strategy'] + + def on_split_new_id_strategy(self): + self.controller.main_settings['split_new_id_strategy'] = self.main_settings['split_new_id_strategy'] + def save_current_settings(self, event=None): backend = self.controller.backend @@ -116,6 +124,8 @@ def _qt_make_layout(self): self.main_settings.param('max_visible_units').sigValueChanged.connect(self.on_max_visible_units_changed) self.main_settings.param('color_mode').sigValueChanged.connect(self.on_change_color_mode) self.main_settings.param('use_times').sigValueChanged.connect(self.on_use_times) + self.main_settings.param('merge_new_id_strategy').sigValueChanged.connect(self.on_merge_new_id_strategy) + self.main_settings.param('split_new_id_strategy').sigValueChanged.connect(self.on_split_new_id_strategy) def qt_make_settings_dict(self, view): """For a given view, return the current settings in a dict""" @@ -151,6 +161,8 @@ def _panel_make_layout(self): self.main_settings._parameterized.param.watch(self._panel_on_max_visible_units_changed, 'max_visible_units') self.main_settings._parameterized.param.watch(self._panel_on_change_color_mode, 'color_mode') self.main_settings._parameterized.param.watch(self._panel_on_use_times, 'use_times') + self.main_settings._parameterized.param.watch(self._panel_on_merge_new_id_strategy, 'merge_new_id_strategy') + self.main_settings._parameterized.param.watch(self._panel_on_split_new_id_strategy, 'split_new_id_strategy') self.layout = pn.Column(self.save_setting_button, self.main_settings_layout, sizing_mode="stretch_both") def panel_make_settings_dict(self, view): @@ -170,6 +182,12 @@ def _panel_on_max_visible_units_changed(self, event): def _panel_on_change_color_mode(self, event): self.on_change_color_mode() + def _panel_on_merge_new_id_strategy(self, event): + self.on_merge_new_id_strategy() + + def _panel_on_split_new_id_strategy(self, event): + self.on_split_new_id_strategy() + def _panel_on_use_times(self, event): self.on_use_times() diff --git a/spikeinterface_gui/maintemplateview.py b/spikeinterface_gui/maintemplateview.py index 0849a0b..5a7550d 100644 --- a/spikeinterface_gui/maintemplateview.py +++ b/spikeinterface_gui/maintemplateview.py @@ -92,19 +92,19 @@ def _qt_refresh(self): if peak_data is not None: # trough - peak_inds = peak_data[['trough_index']].values + peak_inds = peak_data[['trough_index']].values.astype(int) scatter = pg.ScatterPlotItem(x = times[peak_inds], y = template_high[peak_inds], size=10, pxMode = True, color="white", symbol="t") plot.addItem(scatter) names = ('peak_before', 'peak_after') - peak_inds = peak_data[[f'{k}_index' for k in names]].values + peak_inds = peak_data[[f'{k}_index' for k in names]].values.astype(int) scatter = pg.ScatterPlotItem(x = times[peak_inds], y = template_high[peak_inds], size=10, pxMode = True, color="white", symbol="t1") plot.addItem(scatter) all_names = ('trough', 'peak_before', 'peak_after') - peak_inds = peak_data[[f'{k}_index' for k in all_names]].values + peak_inds = peak_data[[f'{k}_index' for k in all_names]].values.astype(int) # Vertical dotted lines from peak to zero for ind in peak_inds: x = [times[ind], times[ind]] @@ -238,7 +238,7 @@ def _panel_refresh(self): if peak_data is not None: # Trough (downward triangle) - trough_inds = peak_data[['trough_index']].values + trough_inds = peak_data[['trough_index']].values.astype(int) fig.scatter( x=times[trough_inds].tolist(), y=template_high[trough_inds].tolist(), @@ -249,7 +249,7 @@ def _panel_refresh(self): # Peaks before/after (upward triangle) names = ('peak_before', 'peak_after') - peak_inds = peak_data[[f'{k}_index' for k in names]].values + peak_inds = peak_data[[f'{k}_index' for k in names]].values.astype(int) fig.scatter( x=times[peak_inds].tolist(), y=template_high[peak_inds].tolist(), @@ -260,7 +260,7 @@ def _panel_refresh(self): # Peaks before/after (upward triangle) all_names = ('trough', 'peak_before', 'peak_after') - peak_inds = peak_data[[f'{k}_index' for k in all_names]].values + peak_inds = peak_data[[f'{k}_index' for k in all_names]].values.astype(int) # Vertical dotted lines from peak to zero for ind in peak_inds: fig.line( @@ -324,4 +324,4 @@ def _panel_refresh(self): x-axis represents time and is in units of milliseconds. y-axis represents the electrical signal. The units depend on your preprocessing steps, but is usually in uV. -""" \ No newline at end of file +""" diff --git a/spikeinterface_gui/mergeview.py b/spikeinterface_gui/mergeview.py index 66712ea..1196d4d 100644 --- a/spikeinterface_gui/mergeview.py +++ b/spikeinterface_gui/mergeview.py @@ -160,6 +160,12 @@ def accept_group_merge(self, group_ids): self.notify_manual_curation_updated() self.refresh() + def _reinitialize(self): + self.proposed_merge_unit_groups_all = [] + self.proposed_merge_unit_groups = [] + self.merge_info = {} + self._refresh() + ### QT def _qt_get_selected_group_ids(self): inds = self.table.selectedIndexes() diff --git a/spikeinterface_gui/probeview.py b/spikeinterface_gui/probeview.py index d1eb4dd..9a459b9 100644 --- a/spikeinterface_gui/probeview.py +++ b/spikeinterface_gui/probeview.py @@ -150,6 +150,17 @@ def _qt_make_layout(self): self.roi_units.sigRegionChangeFinished.connect(self._qt_on_roi_units_changed) + def _qt_reinitialize(self): + import pyqtgraph as pg + + self.plot.removeItem(self.scatter) + unit_positions = self.controller.unit_positions + brush = [self.get_unit_color(u) for u in self.controller.unit_ids] + self.scatter = pg.ScatterPlotItem(pos=unit_positions, pxMode=False, size=10, brush=brush) + self.plot.addItem(self.scatter) + + self._qt_refresh() + def _qt_refresh(self): current_unit_positions = self.controller.unit_positions # if not np.array_equal(current_unit_positions, self._unit_positions): @@ -479,11 +490,14 @@ def _panel_make_layout(self): self.should_resize_unit_circle = None # Main layout - self.layout = pn.Column( - self.figure, - styles={"display": "flex", "flex-direction": "column"}, - sizing_mode="stretch_both", - ) + if self.layout is None: + self.layout = pn.Column( + self.figure, + styles={"display": "flex", "flex-direction": "column"}, + sizing_mode="stretch_both", + ) + else: + self.layout.objects = [self.figure] def _panel_refresh(self): import panel as pn @@ -555,6 +569,9 @@ def _panel_refresh(self): self.y_range.start = zoom_bounds[2] self.y_range.end = zoom_bounds[3] + def _panel_reinitialize(self): + self._panel_make_layout() + self._refresh() def _panel_compute_unit_glyph_patches(self): """Compute glyph patches without modifying Bokeh models.""" diff --git a/spikeinterface_gui/spikeamplitudeview.py b/spikeinterface_gui/spikeamplitudeview.py index ee4ac61..bfaf642 100644 --- a/spikeinterface_gui/spikeamplitudeview.py +++ b/spikeinterface_gui/spikeamplitudeview.py @@ -25,6 +25,10 @@ def __init__(self, controller=None, parent=None, backend="qt"): spike_data=spike_data, ) + def _reinitialize(self): + self.spike_data = self.controller.spike_amplitudes + self._refresh() + def _qt_make_layout(self): from .myqt import QT diff --git a/spikeinterface_gui/spikedepthview.py b/spikeinterface_gui/spikedepthview.py index 0bee9df..032e4b5 100644 --- a/spikeinterface_gui/spikedepthview.py +++ b/spikeinterface_gui/spikedepthview.py @@ -17,6 +17,9 @@ def __init__(self, controller=None, parent=None, backend="qt"): spike_data=spike_data, ) + def _reinitialize(self): + self.spike_data = self.controller.spike_depths + self._refresh() SpikeDepthView._gui_help_txt = """ diff --git a/spikeinterface_gui/traceview.py b/spikeinterface_gui/traceview.py index e79f29d..abe3926 100644 --- a/spikeinterface_gui/traceview.py +++ b/spikeinterface_gui/traceview.py @@ -458,7 +458,7 @@ def _panel_on_event_type_changed(self): self.refresh() def _panel_add_event_lines(self, t1, t2): - if self.event_line is not None: + if self.controller.has_extension("events") and self.event_line is not None: event_samples = self.controller.get_events(self.event_key) segment_index = self.controller.get_time()[1] start_sample, end_sample = self.controller.get_chunk_indices(t1, t2, segment_index) @@ -583,6 +583,10 @@ def _panel_add_event_line(self): yspan = [fig.y_range.start, fig.y_range.end] self.event_source.data = {"xs": [[evt_time, evt_time]], "ys": [yspan]} + def _panel_remove_event_line(self): + if self.controller.has_extension("events"): + self.event_source.data = {"xs": [], "ys": []} + # TODO: pan behavior like Qt? # def _panel_on_pan_start(self, event): # self.drag_state["x_start"] = event.x diff --git a/spikeinterface_gui/unitlistview.py b/spikeinterface_gui/unitlistview.py index c501dc8..eee4d21 100644 --- a/spikeinterface_gui/unitlistview.py +++ b/spikeinterface_gui/unitlistview.py @@ -43,8 +43,6 @@ def notify_unit_and_channel_visibility_changed(self): def _qt_make_layout(self): from .myqt import QT - import pyqtgraph as pg - self.menu = None self.layout = QT.QVBoxLayout() @@ -54,21 +52,7 @@ def _qt_make_layout(self): but.clicked.connect(self._qt_select_columns) tb.addWidget(but) - - visible_cols = [] - for col in self.controller.units_table.columns: - visible_cols.append( - {'name': str(col), 'type': 'bool', 'value': col in self.controller.displayed_unit_properties, 'default': True} - ) - self.visible_columns = pg.parametertree.Parameter.create( name='visible columns', type='group', children=visible_cols) - self.tree_visible_columns = pg.parametertree.ParameterTree(parent=self.qt_widget) - self.tree_visible_columns.header().hide() - self.tree_visible_columns.setParameters(self.visible_columns, showTop=True) - # self.tree_visible_columns.setWindowTitle(u'visible columns') - # self.tree_visible_columns.setWindowFlags(QT.Qt.Window) - self.visible_columns.sigTreeStateChanged.connect(self._qt_on_visible_columns_changed) - self.layout.addWidget(self.tree_visible_columns) - self.tree_visible_columns.hide() + self._qt_set_up_visible_columns() # h = QT.QHBoxLayout() # self.layout.addLayout(h) @@ -134,6 +118,28 @@ def _qt_make_layout(self): self.shortcut_noise.setKey(QT.QKeySequence('n')) self.shortcut_noise.activated.connect(lambda: self._qt_set_default_label('noise')) + def _qt_set_up_visible_columns(self): + + import pyqtgraph as pg + visible_cols = [] + for col in self.controller.units_table.columns: + visible_cols.append( + {'name': str(col), 'type': 'bool', 'value': col in self.controller.displayed_unit_properties, 'default': True} + ) + self.visible_columns = pg.parametertree.Parameter.create( name='visible columns', type='group', children=visible_cols) + self.tree_visible_columns = pg.parametertree.ParameterTree(parent=self.qt_widget) + self.tree_visible_columns.header().hide() + self.tree_visible_columns.setParameters(self.visible_columns, showTop=True) + + self.visible_columns.sigTreeStateChanged.connect(self._qt_on_visible_columns_changed) + self.layout.addWidget(self.tree_visible_columns) + self.tree_visible_columns.hide() + + def _qt_reinitialize(self): + + self._qt_set_up_visible_columns() + self._qt_full_table_refresh() + self._qt_refresh() def _qt_on_column_moved(self, logical_index, old_visual_index, new_visual_index): # Update stored column order @@ -228,7 +234,6 @@ def _qt_full_table_refresh(self): self.table.clear() - internal_column_names = ['unit_id', 'visible', 'channel_id'] # internal labels @@ -589,16 +594,22 @@ def _panel_make_layout(self): shortcuts_component = KeyboardShortcuts(shortcuts=shortcuts) shortcuts_component.on_msg(self._panel_handle_shortcut) - self.layout = pn.Column( - pn.Row( - self.info_text, - ), - buttons, - sizing_mode="stretch_width", - ) + if self.layout is None: + self.layout = pn.Column( + pn.Row( + self.info_text, + ), + buttons, + sizing_mode="stretch_width", + ) - self.layout.append(self.table) - self.layout.append(shortcuts_component) + self.layout.append(self.table) + self.layout.append(shortcuts_component) + else: + self.layout[0][0] = self.info_text + self.layout[1] = buttons + self.layout[2] = self.table + self.layout[3] = shortcuts_component self.table.tabulator.on_edit(self._panel_on_edit) self.refresh_button.on_click(self._panel_refresh_click) @@ -653,6 +664,10 @@ def _panel_refresh(self): # refresh header self._panel_refresh_header() + def _panel_reinitialize(self): + self._panel_make_layout() + self._panel_refresh() + def _panel_refresh_header(self): unit_ids = self.controller.unit_ids n1 = len(unit_ids) diff --git a/spikeinterface_gui/utils_global.py b/spikeinterface_gui/utils_global.py index 23fc61d..20174d6 100644 --- a/spikeinterface_gui/utils_global.py +++ b/spikeinterface_gui/utils_global.py @@ -1,6 +1,10 @@ import numpy as np from pathlib import Path import os +from copy import copy + +from spikeinterface.core.sorting_tools import generate_unit_ids_for_split, generate_unit_ids_for_merge_group +from spikeinterface.curation.curation_model import Curation def get_config_folder() -> Path: """Get the config folder for spikeinterface-gui settings files. @@ -58,3 +62,33 @@ def get_present_zones_in_half_of_layout(layout_zone, shift): is_present = [views is not None and len(views) > 0 for views in half_dict.values()] present_zones = set(np.array(list(half_dict.keys()))[np.array(is_present)]) return present_zones + + +def add_new_unit_ids_to_curation_dict(curation_dict, sorting, split_new_id_strategy, merge_new_id_strategy): + """ + Explicitly adds the new unit ids to `curation_dict` based on the split and merge new id strategies. + These *should* be the ids that would have been generated during `apply_curation` with these strategies. + """ + curation_model = Curation(**curation_dict) + old_unit_ids = copy(curation_model.unit_ids) + + if len(curation_model.splits) > 0: + unit_splits = {split.unit_id: split.get_full_spike_indices(sorting) for split in curation_model.splits} + new_split_unit_ids = generate_unit_ids_for_split(old_unit_ids, unit_splits, new_unit_ids=None, new_id_strategy=split_new_id_strategy) + + all_new_unit_ids = [] + for split_index, new_unit_ids in enumerate(new_split_unit_ids): + curation_dict['splits'][split_index]['new_unit_ids'] = new_unit_ids + all_new_unit_ids = all_new_unit_ids + new_unit_ids + + # update old unit ids with the newly split units + old_unit_ids = np.setdiff1d(old_unit_ids, np.array(list(unit_splits.keys()))) + old_unit_ids = np.concat([old_unit_ids, all_new_unit_ids]) + + if len(curation_model.merges) > 0: + merge_unit_groups = [m.unit_ids for m in curation_model.merges] + new_merge_unit_ids = generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ids=None, new_id_strategy=merge_new_id_strategy) + for merge_index, new_unit_id in enumerate(new_merge_unit_ids): + curation_dict['merges'][merge_index]['new_unit_id'] = new_unit_id + + return curation_dict diff --git a/spikeinterface_gui/view_base.py b/spikeinterface_gui/view_base.py index 25c36b0..e4f12f5 100644 --- a/spikeinterface_gui/view_base.py +++ b/spikeinterface_gui/view_base.py @@ -41,6 +41,7 @@ def __init__(self, controller=None, parent=None, backend="qt"): self.notifier = SignalNotifier(view=self) self.busy = pn.indicators.LoadingSpinner(value=True, size=20, name='busy...') + self.layout = None make_layout() if self._settings is not None: listen_setting_changes(self) @@ -115,6 +116,14 @@ def refresh(self, **kwargs): t1 = time.perf_counter() print(f"Refresh {self.__class__.__name__} took {t1 - t0:.3f} seconds", flush=True) + def reinitialize(self, **kwargs): + if self.controller.verbose: + t0 = time.perf_counter() + self._reinitialize(**kwargs) + if self.controller.verbose: + t1 = time.perf_counter() + print(f"Reinitialize {self.__class__.__name__} took {t1 - t0:.3f} seconds", flush=True) + def compute(self, event=None): with self.busy_cursor(): self._compute() @@ -130,6 +139,12 @@ def _refresh(self, **kwargs): import panel as pn pn.state.execute(lambda: self._panel_refresh(**kwargs), schedule=True) + def _reinitialize(self, **kwargs): + if self.backend == "qt": + self._qt_reinitialize(**kwargs) + elif self.backend == "panel": + self._panel_reinitialize(**kwargs) + def warning(self, warning_msg): if self.backend == "qt": self._qt_insert_warning(warning_msg) @@ -264,6 +279,9 @@ def _qt_make_layout(self): def _qt_refresh(self): raise (NotImplementedError) + + def _qt_reinitialize(self): + self._qt_refresh() def _qt_on_spike_selection_changed(self): pass @@ -324,6 +342,9 @@ def _panel_make_layout(self): def _panel_refresh(self): raise (NotImplementedError) + + def _panel_reinitialize(self): + self._panel_refresh() def _panel_on_spike_selection_changed(self): pass