diff --git a/src/ezmsg/event/peak.py b/src/ezmsg/event/peak.py index 72d6743..f2debad 100644 --- a/src/ezmsg/event/peak.py +++ b/src/ezmsg/event/peak.py @@ -83,8 +83,8 @@ def _reset_state(self, message: AxisArray) -> None: self._state.refrac_width = int(self.settings.refrac_dur * fs) # We'll need the first sample (keep time dim!) for a few of our state initializations - data = np.moveaxis(message.data, ax_idx, -1) - first_samp = data[..., :1] + data = np.moveaxis(message.data, ax_idx, 0) + first_samp = data[:1] # Prepare optional state variables self._state.scaler = None @@ -102,7 +102,7 @@ def _reset_state(self, message: AxisArray) -> None: # Initialize the count of samples since last event for each feature. We initialize at refrac_width+1 # to ensure that even the first sample is eligible for events. - self._state.elapsed = np.zeros((np.prod(data.shape[:-1]),), dtype=int) + (self._state.refrac_width + 1) + self._state.elapsed = np.zeros((np.prod(data.shape[1:]),), dtype=int) + (self._state.refrac_width + 1) def _process(self, message: AxisArray) -> AxisArray: """ @@ -117,35 +117,41 @@ def _process(self, message: AxisArray) -> AxisArray: ax_idx = message.get_axis_idx("time") # If the time axis is not the last axis, we need to move it to the end. - if ax_idx != (message.data.ndim - 1): + if ax_idx != 0: message = replace( message, - data=np.moveaxis(message.data, ax_idx, -1), - dims=message.dims[:ax_idx] + message.dims[ax_idx + 1 :] + ["time"], + data=np.moveaxis(message.data, ax_idx, 0), + dims=["time"] + message.dims[:ax_idx] + message.dims[ax_idx + 1 :], ) # Take a copy of the raw data if needed and prepend to our state data_raw # This will only exist if we are autoscaling AND we need to capture the true peak value. if self._state.data_raw is not None: - self._state.data_raw = np.concatenate((self._state.data_raw, message.data), axis=-1) + self._state.data_raw = np.concatenate((self._state.data_raw, message.data), axis=0) # Run the message through the standard scaler if needed. Note: raw value is lost unless we copied it above. if self._state.scaler is not None: message = self._state.scaler(message) # Prepend the previous iteration's last (maybe z-scored) sample to the current (maybe z-scored) data. - data = np.concatenate((self._state.data, message.data), axis=-1) + data = np.concatenate((self._state.data, message.data), axis=0) # Take note of how many samples were prepended. We will need this later when we modify `overs`. - n_prepended = self._state.data.shape[-1] + n_prepended = self._state.data.shape[0] # Identify which data points are over threshold overs = data >= self.settings.threshold if self.settings.threshold >= 0 else data <= self.settings.threshold # Find threshold _crossing_: where sample k is over and sample k-1 is not over. - b_cross_over = np.logical_and(overs[..., 1:], ~overs[..., :-1]) + b_cross_over = np.logical_and(overs[1:], ~overs[:-1]) cross_idx = list(np.where(b_cross_over)) # List of indices into each dim # We ignored the first sample when looking for crosses so we increment the sample index by 1. - cross_idx[-1] += 1 + cross_idx[0] += 1 + # Sort events by feature first, then by time within each feature. + # np.where on a time-first array returns events sorted by time; we need them grouped by feature + # for the refractory period logic and elapsed tracking to work correctly. + if len(cross_idx[0]) > 0 and len(cross_idx) > 1: + sort_order = np.lexsort([cross_idx[0]] + cross_idx[1:][::-1]) + cross_idx = [_[sort_order] for _ in cross_idx] # Note: There is an assumption that the 0th sample only serves as a reference and is not part of the output; # this will be trimmed at the very end. For now the offset is useful for bookkeeping (peak finding, etc.). @@ -154,15 +160,15 @@ def _process(self, message: AxisArray) -> AxisArray: # TODO: This should go in its own transformer. # However, a general purpose refractory-period-enforcer would keep track of its own event history, # so we would probably do this step before prepending with historical samples. - if self._state.refrac_width > 2 and len(cross_idx[-1]) > 0: + if self._state.refrac_width > 2 and len(cross_idx[0]) > 0: # Find the unique set of features that have at least one cross-over, # and the indices of the first crossover for each. - ravel_feat_inds = np.ravel_multi_index(cross_idx[:-1], overs.shape[:-1]) + ravel_feat_inds = np.ravel_multi_index(cross_idx[1:], overs.shape[1:]) uq_feats, feat_splits = np.unique(ravel_feat_inds, return_index=True) # Calculate the inter-event intervals (IEIs) for each feature. First get all the IEIs. - ieis = np.diff(np.hstack(([cross_idx[-1][0] + 1], cross_idx[-1]))) + ieis = np.diff(np.hstack(([cross_idx[0][0] + 1], cross_idx[0]))) # Then reset the interval at feature boundaries. - ieis[feat_splits] = cross_idx[-1][feat_splits] + self._state.elapsed[uq_feats] + ieis[feat_splits] = cross_idx[0][feat_splits] + self._state.elapsed[uq_feats] b_drop = ieis <= self._state.refrac_width drop_idx = np.where(b_drop)[0] final_drop = [] @@ -183,33 +189,33 @@ def _process(self, message: AxisArray) -> AxisArray: cross_idx = [np.delete(_, final_drop) for _ in cross_idx] # Calculate the 'value' at each event. - hold_idx = overs.shape[-1] - 1 - if len(cross_idx[-1]) == 0: + hold_idx = overs.shape[0] - 1 + if len(cross_idx[0]) == 0: # No events; not values to calculate. result_val = np.ones( - cross_idx[-1].shape, + cross_idx[0].shape, dtype=data.dtype if self.settings.return_peak_val else bool, ) elif not (self._state.min_width > 1 or self.settings.align_on_peak or self.settings.return_peak_val): # No postprocessing required. TODO: Why is min_width <= 1 a requirement here? - result_val = np.ones(cross_idx[-1].shape, dtype=bool) + result_val = np.ones(cross_idx[0].shape, dtype=bool) else: # Do postprocessing of events: width-checking, align-on-peak, and/or include peak value in return. # Each of these requires finding the true peak, which requires pulling out a snippet around the # threshold crossing event. # We extract max_width-length vectors of `overs` values for each event. This might require some padding # if the event is near the end of the data. Pad with the last sample until the expected end of the event. - n_pad = max(0, max(cross_idx[-1]) + self._state.max_width - overs.shape[-1]) - pad_width = ((0, 0),) * (overs.ndim - 1) + ((0, n_pad),) + n_pad = max(0, max(cross_idx[0]) + self._state.max_width - overs.shape[0]) + pad_width = ((0, n_pad),) + ((0, 0),) * (overs.ndim - 1) overs_padded = np.pad(overs, pad_width, mode="edge") # Extract the segments for each event. # First we get the sample indices. This is 2-dimensional; first dim for offset and remaining for seg length. - s_idx = np.arange(self._state.max_width)[None, :] + cross_idx[-1][:, None] + s_idx = np.arange(self._state.max_width)[None, :] + cross_idx[0][:, None] # Combine feature indices and time indices to extract segments of overs. # Note: We had to expand each of our feature indices also be 2-dimensional # -> ndarray (eat dims ..., max_width) - ep_overs = overs_padded[tuple(_[:, None] for _ in cross_idx[:-1]) + (s_idx,)] + ep_overs = overs_padded[(s_idx,) + tuple(_[:, None] for _ in cross_idx[1:])] # Find the event lengths: i.e., the first non-over-threshold value for each event. # Warning: Values are invalid for events that don't cross back. @@ -229,10 +235,10 @@ def _process(self, message: AxisArray) -> AxisArray: # We are returning a sparse array and unfinished peaks must be buffered for the next iteration. # Find the earliest unfinished event. If none, we still buffer the final sample. b_unf = ~b_ev_crossback - hold_idx = cross_idx[-1][b_unf].min() if np.any(b_unf) else hold_idx + hold_idx = cross_idx[0][b_unf].min() if np.any(b_unf) else hold_idx # Trim events that are past the hold_idx. They will be processed next iteration. - b_pass_ev = cross_idx[-1] < hold_idx + b_pass_ev = cross_idx[0] < hold_idx cross_idx = [_[b_pass_ev] for _ in cross_idx] ev_len = ev_len[b_pass_ev] @@ -241,7 +247,7 @@ def _process(self, message: AxisArray) -> AxisArray: hold_idx = max(hold_idx - 1, 0) # If we are not returning peak values, we can just return bools at the event locations. - result_val = np.ones(cross_idx[-1].shape, dtype=bool) + result_val = np.ones(cross_idx[0].shape, dtype=bool) # For remaining _finished_ peaks, get the peak location -- for alignment or if returning its value. if self.settings.align_on_peak or self.settings.return_peak_val: @@ -252,8 +258,8 @@ def _process(self, message: AxisArray) -> AxisArray: uq_lens, len_grps = np.unique(ev_len, return_inverse=True) for len_idx, ep_len in enumerate(uq_lens): b_grp = len_grps == len_idx - ep_resamp = np.arange(ep_len)[None, :] + cross_idx[-1][b_grp, None] - ep_inds_tuple = tuple(_[b_grp, None] for _ in cross_idx[:-1]) + (ep_resamp,) + ep_resamp = np.arange(ep_len)[None, :] + cross_idx[0][b_grp, None] + ep_inds_tuple = (ep_resamp,) + tuple(_[b_grp, None] for _ in cross_idx[1:]) eps = data[ep_inds_tuple] if self.settings.threshold >= 0: pk_offset[b_grp] = np.argmax(eps, axis=1) @@ -262,26 +268,26 @@ def _process(self, message: AxisArray) -> AxisArray: if self.settings.align_on_peak: # We want to align on the peak, so add the peak offset. - cross_idx[-1] += pk_offset + cross_idx[0] += pk_offset if self.settings.return_peak_val: # We need the actual peak value. peak_inds_tuple = ( tuple(cross_idx) if self.settings.align_on_peak - else tuple(cross_idx[:-1]) + (cross_idx[-1] + pk_offset,) + else (cross_idx[0] + pk_offset,) + tuple(cross_idx[1:]) ) result_val = (self._state.data_raw if self._state.data_raw is not None else data)[peak_inds_tuple] # Save data for next iteration - self._state.data = data[..., hold_idx:] + self._state.data = data[hold_idx:] if self._state.data_raw is not None: # Likely because we are using the scaler, we need a separate copy of the raw data. - self._state.data_raw = self._state.data_raw[..., hold_idx:] + self._state.data_raw = self._state.data_raw[hold_idx:] # Clear out `elapsed` by adding the max number of samples since the last event. self._state.elapsed += hold_idx # Yet for features that actually had events, replace the elapsed time with the actual event time - self._state.elapsed[tuple(cross_idx[:-1])] = hold_idx - cross_idx[-1] + self._state.elapsed[tuple(cross_idx[1:])] = hold_idx - cross_idx[0] # Note: multiple-write to same index ^ is fine because it is sorted and the last value for each is correct. # Prepare sparse matrix output @@ -289,11 +295,11 @@ def _process(self, message: AxisArray) -> AxisArray: # Likewise, the first prepended sample on this iteration was part of the previous iteration's return. n_out_samps = hold_idx t0 = message.axes["time"].offset - (n_prepended - 1) * message.axes["time"].gain - cross_idx[-1] -= 1 # Discard first prepended sample. + cross_idx[0] -= 1 # Discard first prepended sample. result = sparse.COO( cross_idx, data=result_val, - shape=data.shape[:-1] + (n_out_samps,), + shape=(n_out_samps,) + data.shape[1:], ) msg_out = replace( message, diff --git a/tests/test_peak.py b/tests/test_peak.py index 66be1f2..7d1adb4 100644 --- a/tests/test_peak.py +++ b/tests/test_peak.py @@ -57,8 +57,12 @@ def test_threshold_crossing(return_peak_val: bool): exp_feat_inds = np.array(exp_feat_inds) exp_samp_inds = np.array(exp_samp_inds) - final_arr = sparse.concatenate([_.data for _ in msgs_out], axis=1) - feat_inds, samp_inds = final_arr.nonzero() + final_arr = sparse.concatenate([_.data for _ in msgs_out], axis=0) + samp_inds, feat_inds = final_arr.nonzero() + # Sort by feature then time to match the expected ordering (which is built channel-by-channel). + sort_order = np.lexsort((samp_inds, feat_inds)) + samp_inds = samp_inds[sort_order] + feat_inds = feat_inds[sort_order] """ # This block of code was used to debug some discrepancies that popped up when the last sample of the last chunk