Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 41 additions & 35 deletions src/ezmsg/event/peak.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def _reset_state(self, message: AxisArray) -> None:
self._state.refrac_width = int(self.settings.refrac_dur * fs)

# We'll need the first sample (keep time dim!) for a few of our state initializations
data = np.moveaxis(message.data, ax_idx, -1)
first_samp = data[..., :1]
data = np.moveaxis(message.data, ax_idx, 0)
first_samp = data[:1]

# Prepare optional state variables
self._state.scaler = None
Expand All @@ -102,7 +102,7 @@ def _reset_state(self, message: AxisArray) -> None:

# Initialize the count of samples since last event for each feature. We initialize at refrac_width+1
# to ensure that even the first sample is eligible for events.
self._state.elapsed = np.zeros((np.prod(data.shape[:-1]),), dtype=int) + (self._state.refrac_width + 1)
self._state.elapsed = np.zeros((np.prod(data.shape[1:]),), dtype=int) + (self._state.refrac_width + 1)

def _process(self, message: AxisArray) -> AxisArray:
"""
Expand All @@ -117,35 +117,41 @@ def _process(self, message: AxisArray) -> AxisArray:
ax_idx = message.get_axis_idx("time")

# If the time axis is not the last axis, we need to move it to the end.
if ax_idx != (message.data.ndim - 1):
if ax_idx != 0:
message = replace(
message,
data=np.moveaxis(message.data, ax_idx, -1),
dims=message.dims[:ax_idx] + message.dims[ax_idx + 1 :] + ["time"],
data=np.moveaxis(message.data, ax_idx, 0),
dims=["time"] + message.dims[:ax_idx] + message.dims[ax_idx + 1 :],
)

# Take a copy of the raw data if needed and prepend to our state data_raw
# This will only exist if we are autoscaling AND we need to capture the true peak value.
if self._state.data_raw is not None:
self._state.data_raw = np.concatenate((self._state.data_raw, message.data), axis=-1)
self._state.data_raw = np.concatenate((self._state.data_raw, message.data), axis=0)

# Run the message through the standard scaler if needed. Note: raw value is lost unless we copied it above.
if self._state.scaler is not None:
message = self._state.scaler(message)

# Prepend the previous iteration's last (maybe z-scored) sample to the current (maybe z-scored) data.
data = np.concatenate((self._state.data, message.data), axis=-1)
data = np.concatenate((self._state.data, message.data), axis=0)
# Take note of how many samples were prepended. We will need this later when we modify `overs`.
n_prepended = self._state.data.shape[-1]
n_prepended = self._state.data.shape[0]

# Identify which data points are over threshold
overs = data >= self.settings.threshold if self.settings.threshold >= 0 else data <= self.settings.threshold

# Find threshold _crossing_: where sample k is over and sample k-1 is not over.
b_cross_over = np.logical_and(overs[..., 1:], ~overs[..., :-1])
b_cross_over = np.logical_and(overs[1:], ~overs[:-1])
cross_idx = list(np.where(b_cross_over)) # List of indices into each dim
# We ignored the first sample when looking for crosses so we increment the sample index by 1.
cross_idx[-1] += 1
cross_idx[0] += 1
# Sort events by feature first, then by time within each feature.
# np.where on a time-first array returns events sorted by time; we need them grouped by feature
# for the refractory period logic and elapsed tracking to work correctly.
if len(cross_idx[0]) > 0 and len(cross_idx) > 1:
sort_order = np.lexsort([cross_idx[0]] + cross_idx[1:][::-1])
cross_idx = [_[sort_order] for _ in cross_idx]

# Note: There is an assumption that the 0th sample only serves as a reference and is not part of the output;
# this will be trimmed at the very end. For now the offset is useful for bookkeeping (peak finding, etc.).
Expand All @@ -154,15 +160,15 @@ def _process(self, message: AxisArray) -> AxisArray:
# TODO: This should go in its own transformer.
# However, a general purpose refractory-period-enforcer would keep track of its own event history,
# so we would probably do this step before prepending with historical samples.
if self._state.refrac_width > 2 and len(cross_idx[-1]) > 0:
if self._state.refrac_width > 2 and len(cross_idx[0]) > 0:
# Find the unique set of features that have at least one cross-over,
# and the indices of the first crossover for each.
ravel_feat_inds = np.ravel_multi_index(cross_idx[:-1], overs.shape[:-1])
ravel_feat_inds = np.ravel_multi_index(cross_idx[1:], overs.shape[1:])
uq_feats, feat_splits = np.unique(ravel_feat_inds, return_index=True)
# Calculate the inter-event intervals (IEIs) for each feature. First get all the IEIs.
ieis = np.diff(np.hstack(([cross_idx[-1][0] + 1], cross_idx[-1])))
ieis = np.diff(np.hstack(([cross_idx[0][0] + 1], cross_idx[0])))
# Then reset the interval at feature boundaries.
ieis[feat_splits] = cross_idx[-1][feat_splits] + self._state.elapsed[uq_feats]
ieis[feat_splits] = cross_idx[0][feat_splits] + self._state.elapsed[uq_feats]
b_drop = ieis <= self._state.refrac_width
drop_idx = np.where(b_drop)[0]
final_drop = []
Expand All @@ -183,33 +189,33 @@ def _process(self, message: AxisArray) -> AxisArray:
cross_idx = [np.delete(_, final_drop) for _ in cross_idx]

# Calculate the 'value' at each event.
hold_idx = overs.shape[-1] - 1
if len(cross_idx[-1]) == 0:
hold_idx = overs.shape[0] - 1
if len(cross_idx[0]) == 0:
# No events; not values to calculate.
result_val = np.ones(
cross_idx[-1].shape,
cross_idx[0].shape,
dtype=data.dtype if self.settings.return_peak_val else bool,
)
elif not (self._state.min_width > 1 or self.settings.align_on_peak or self.settings.return_peak_val):
# No postprocessing required. TODO: Why is min_width <= 1 a requirement here?
result_val = np.ones(cross_idx[-1].shape, dtype=bool)
result_val = np.ones(cross_idx[0].shape, dtype=bool)
else:
# Do postprocessing of events: width-checking, align-on-peak, and/or include peak value in return.
# Each of these requires finding the true peak, which requires pulling out a snippet around the
# threshold crossing event.
# We extract max_width-length vectors of `overs` values for each event. This might require some padding
# if the event is near the end of the data. Pad with the last sample until the expected end of the event.
n_pad = max(0, max(cross_idx[-1]) + self._state.max_width - overs.shape[-1])
pad_width = ((0, 0),) * (overs.ndim - 1) + ((0, n_pad),)
n_pad = max(0, max(cross_idx[0]) + self._state.max_width - overs.shape[0])
pad_width = ((0, n_pad),) + ((0, 0),) * (overs.ndim - 1)
overs_padded = np.pad(overs, pad_width, mode="edge")

# Extract the segments for each event.
# First we get the sample indices. This is 2-dimensional; first dim for offset and remaining for seg length.
s_idx = np.arange(self._state.max_width)[None, :] + cross_idx[-1][:, None]
s_idx = np.arange(self._state.max_width)[None, :] + cross_idx[0][:, None]
# Combine feature indices and time indices to extract segments of overs.
# Note: We had to expand each of our feature indices also be 2-dimensional
# -> ndarray (eat dims ..., max_width)
ep_overs = overs_padded[tuple(_[:, None] for _ in cross_idx[:-1]) + (s_idx,)]
ep_overs = overs_padded[(s_idx,) + tuple(_[:, None] for _ in cross_idx[1:])]

# Find the event lengths: i.e., the first non-over-threshold value for each event.
# Warning: Values are invalid for events that don't cross back.
Expand All @@ -229,10 +235,10 @@ def _process(self, message: AxisArray) -> AxisArray:
# We are returning a sparse array and unfinished peaks must be buffered for the next iteration.
# Find the earliest unfinished event. If none, we still buffer the final sample.
b_unf = ~b_ev_crossback
hold_idx = cross_idx[-1][b_unf].min() if np.any(b_unf) else hold_idx
hold_idx = cross_idx[0][b_unf].min() if np.any(b_unf) else hold_idx

# Trim events that are past the hold_idx. They will be processed next iteration.
b_pass_ev = cross_idx[-1] < hold_idx
b_pass_ev = cross_idx[0] < hold_idx
cross_idx = [_[b_pass_ev] for _ in cross_idx]
ev_len = ev_len[b_pass_ev]

Expand All @@ -241,7 +247,7 @@ def _process(self, message: AxisArray) -> AxisArray:
hold_idx = max(hold_idx - 1, 0)

# If we are not returning peak values, we can just return bools at the event locations.
result_val = np.ones(cross_idx[-1].shape, dtype=bool)
result_val = np.ones(cross_idx[0].shape, dtype=bool)

# For remaining _finished_ peaks, get the peak location -- for alignment or if returning its value.
if self.settings.align_on_peak or self.settings.return_peak_val:
Expand All @@ -252,8 +258,8 @@ def _process(self, message: AxisArray) -> AxisArray:
uq_lens, len_grps = np.unique(ev_len, return_inverse=True)
for len_idx, ep_len in enumerate(uq_lens):
b_grp = len_grps == len_idx
ep_resamp = np.arange(ep_len)[None, :] + cross_idx[-1][b_grp, None]
ep_inds_tuple = tuple(_[b_grp, None] for _ in cross_idx[:-1]) + (ep_resamp,)
ep_resamp = np.arange(ep_len)[None, :] + cross_idx[0][b_grp, None]
ep_inds_tuple = (ep_resamp,) + tuple(_[b_grp, None] for _ in cross_idx[1:])
eps = data[ep_inds_tuple]
if self.settings.threshold >= 0:
pk_offset[b_grp] = np.argmax(eps, axis=1)
Expand All @@ -262,38 +268,38 @@ def _process(self, message: AxisArray) -> AxisArray:

if self.settings.align_on_peak:
# We want to align on the peak, so add the peak offset.
cross_idx[-1] += pk_offset
cross_idx[0] += pk_offset

if self.settings.return_peak_val:
# We need the actual peak value.
peak_inds_tuple = (
tuple(cross_idx)
if self.settings.align_on_peak
else tuple(cross_idx[:-1]) + (cross_idx[-1] + pk_offset,)
else (cross_idx[0] + pk_offset,) + tuple(cross_idx[1:])
)
result_val = (self._state.data_raw if self._state.data_raw is not None else data)[peak_inds_tuple]

# Save data for next iteration
self._state.data = data[..., hold_idx:]
self._state.data = data[hold_idx:]
if self._state.data_raw is not None:
# Likely because we are using the scaler, we need a separate copy of the raw data.
self._state.data_raw = self._state.data_raw[..., hold_idx:]
self._state.data_raw = self._state.data_raw[hold_idx:]
# Clear out `elapsed` by adding the max number of samples since the last event.
self._state.elapsed += hold_idx
# Yet for features that actually had events, replace the elapsed time with the actual event time
self._state.elapsed[tuple(cross_idx[:-1])] = hold_idx - cross_idx[-1]
self._state.elapsed[tuple(cross_idx[1:])] = hold_idx - cross_idx[0]
# Note: multiple-write to same index ^ is fine because it is sorted and the last value for each is correct.

# Prepare sparse matrix output
# Note: The first of the held back samples for next iteration is part of this iteration's return.
# Likewise, the first prepended sample on this iteration was part of the previous iteration's return.
n_out_samps = hold_idx
t0 = message.axes["time"].offset - (n_prepended - 1) * message.axes["time"].gain
cross_idx[-1] -= 1 # Discard first prepended sample.
cross_idx[0] -= 1 # Discard first prepended sample.
result = sparse.COO(
cross_idx,
data=result_val,
shape=data.shape[:-1] + (n_out_samps,),
shape=(n_out_samps,) + data.shape[1:],
)
msg_out = replace(
message,
Expand Down
8 changes: 6 additions & 2 deletions tests/test_peak.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,12 @@ def test_threshold_crossing(return_peak_val: bool):
exp_feat_inds = np.array(exp_feat_inds)
exp_samp_inds = np.array(exp_samp_inds)

final_arr = sparse.concatenate([_.data for _ in msgs_out], axis=1)
feat_inds, samp_inds = final_arr.nonzero()
final_arr = sparse.concatenate([_.data for _ in msgs_out], axis=0)
samp_inds, feat_inds = final_arr.nonzero()
# Sort by feature then time to match the expected ordering (which is built channel-by-channel).
sort_order = np.lexsort((samp_inds, feat_inds))
samp_inds = samp_inds[sort_order]
feat_inds = feat_inds[sort_order]

"""
# This block of code was used to debug some discrepancies that popped up when the last sample of the last chunk
Expand Down