From 47d3359fcd8b2f85d6a612eb21f9a57151986cb5 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Mon, 5 Jan 2026 18:00:22 -0500 Subject: [PATCH 01/11] Use updated baseproc with BaseClockDrivenProducer --- pyproject.toml | 2 +- src/ezmsg/simbiophys/__init__.py | 16 +- src/ezmsg/simbiophys/dnss/synth.py | 2 +- src/ezmsg/simbiophys/eeg.py | 28 ++-- src/ezmsg/simbiophys/noise.py | 69 ++++---- src/ezmsg/simbiophys/oscillator.py | 45 ++--- tests/integration/test_dnss_system.py | 4 +- tests/integration/test_synth_system.py | 2 +- tests/unit/test_oscillator.py | 221 +++++++++++++++---------- 9 files changed, 213 insertions(+), 176 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bc87f32..2c40b34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ requires-python = ">=3.10.15" dynamic = ["version"] dependencies = [ "ezmsg>=3.6.0", - "ezmsg-baseproc>=1.1.0", + "ezmsg-baseproc>=1.2.0", "ezmsg-event>=0.6.0", "ezmsg-sigproc>=2.8.0", "numpy>=1.26.0", diff --git a/src/ezmsg/simbiophys/__init__.py b/src/ezmsg/simbiophys/__init__.py index a03c88a..4a175a5 100644 --- a/src/ezmsg/simbiophys/__init__.py +++ b/src/ezmsg/simbiophys/__init__.py @@ -54,18 +54,20 @@ # Noise from .noise import ( PinkNoise, + PinkNoiseProducer, PinkNoiseSettings, - PinkNoiseTransformer, WhiteNoise, + WhiteNoiseProducer, WhiteNoiseSettings, - WhiteNoiseTransformer, + WhiteNoiseState, ) # Oscillator from .oscillator import ( SinGenerator, SinGeneratorSettings, - SinTransformer, + SinGeneratorState, + SinProducer, ) __all__ = [ @@ -84,14 +86,16 @@ # Oscillator "SinGenerator", "SinGeneratorSettings", - "SinTransformer", + "SinGeneratorState", + "SinProducer", # Noise "PinkNoise", + "PinkNoiseProducer", "PinkNoiseSettings", - "PinkNoiseTransformer", "WhiteNoise", + "WhiteNoiseProducer", "WhiteNoiseSettings", - "WhiteNoiseTransformer", + "WhiteNoiseState", # EEG "EEGSynth", "EEGSynthSettings", diff --git a/src/ezmsg/simbiophys/dnss/synth.py b/src/ezmsg/simbiophys/dnss/synth.py index 7000515..bd49017 100644 --- a/src/ezmsg/simbiophys/dnss/synth.py +++ b/src/ezmsg/simbiophys/dnss/synth.py @@ -127,7 +127,7 @@ def configure(self) -> None: def network(self) -> ez.NetworkDefinition: return ( # Clock drives Counter - (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_SIGNAL), + (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_CLOCK), # Counter fans out to both Spike and LFP generators (self.COUNTER.OUTPUT_SIGNAL, self.SPIKE.INPUT_SIGNAL), (self.COUNTER.OUTPUT_SIGNAL, self.LFP.INPUT_SIGNAL), diff --git a/src/ezmsg/simbiophys/eeg.py b/src/ezmsg/simbiophys/eeg.py index 6ff9a2d..be7bc8a 100644 --- a/src/ezmsg/simbiophys/eeg.py +++ b/src/ezmsg/simbiophys/eeg.py @@ -1,7 +1,7 @@ """EEG signal synthesis.""" import ezmsg.core as ez -from ezmsg.baseproc import Clock, ClockSettings, Counter, CounterSettings +from ezmsg.baseproc import Clock, ClockSettings from ezmsg.sigproc.math.add import Add from ezmsg.util.messages.axisarray import AxisArray @@ -29,11 +29,11 @@ class EEGSynth(ez.Collection): """ A Collection that generates synthetic EEG signals. - Combines pink noise with alpha oscillations using a diamond flow: - Clock -> Counter -> {Noise, Oscillator} -> Add -> Output + Combines white noise with alpha oscillations using a diamond flow: + Clock -> {Noise, Oscillator} -> Add -> Output Network flow: - Clock -> Counter -> {Noise, Oscillator} + Clock -> {Noise, Oscillator} Noise -> Add.A Oscillator -> Add.B Add -> OUTPUT @@ -44,7 +44,6 @@ class EEGSynth(ez.Collection): OUTPUT_SIGNAL = ez.OutputStream(AxisArray) CLOCK = Clock() - COUNTER = Counter() NOISE = WhiteNoise() OSC = SinGenerator() ADD = Add() @@ -53,15 +52,10 @@ def configure(self) -> None: dispatch_rate = self.SETTINGS.fs / self.SETTINGS.n_time self.CLOCK.apply_settings(ClockSettings(dispatch_rate=dispatch_rate)) - self.COUNTER.apply_settings( - CounterSettings( - fs=self.SETTINGS.fs, - n_time=self.SETTINGS.n_time, - ) - ) - self.NOISE.apply_settings( WhiteNoiseSettings( + fs=self.SETTINGS.fs, + n_time=self.SETTINGS.n_time, n_ch=self.SETTINGS.n_ch, scale=5.0, ) @@ -69,6 +63,8 @@ def configure(self) -> None: self.OSC.apply_settings( SinGeneratorSettings( + fs=self.SETTINGS.fs, + n_time=self.SETTINGS.n_time, n_ch=self.SETTINGS.n_ch, freq=self.SETTINGS.alpha_freq, ) @@ -76,11 +72,9 @@ def configure(self) -> None: def network(self) -> ez.NetworkDefinition: return ( - # Clock drives Counter - (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_SIGNAL), - # Counter fans out to both Noise and Oscillator - (self.COUNTER.OUTPUT_SIGNAL, self.OSC.INPUT_SIGNAL), - (self.COUNTER.OUTPUT_SIGNAL, self.NOISE.INPUT_SIGNAL), + # Clock drives Noise and Oscillator directly + (self.CLOCK.OUTPUT_SIGNAL, self.OSC.INPUT_CLOCK), + (self.CLOCK.OUTPUT_SIGNAL, self.NOISE.INPUT_CLOCK), # Combine outputs (self.OSC.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_A), (self.NOISE.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_B), diff --git a/src/ezmsg/simbiophys/noise.py b/src/ezmsg/simbiophys/noise.py index a9c1140..b4b97da 100644 --- a/src/ezmsg/simbiophys/noise.py +++ b/src/ezmsg/simbiophys/noise.py @@ -1,11 +1,12 @@ """Noise signal generators.""" -import ezmsg.core as ez import numpy as np from ezmsg.baseproc import ( + BaseClockDrivenProducer, + BaseClockDrivenProducerUnit, BaseProcessor, - BaseStatefulTransformer, - BaseTransformerUnit, + ClockDrivenSettings, + ClockDrivenState, CompositeProcessor, processor_state, ) @@ -13,10 +14,10 @@ ButterworthFilterSettings, ButterworthFilterTransformer, ) -from ezmsg.util.messages.axisarray import AxisArray, replace +from ezmsg.util.messages.axisarray import AxisArray, LinearAxis, replace -class WhiteNoiseSettings(ez.Settings): +class WhiteNoiseSettings(ClockDrivenSettings): """Settings for white noise generators.""" n_ch: int = 1 @@ -30,30 +31,28 @@ class WhiteNoiseSettings(ez.Settings): @processor_state -class WhiteNoiseTransformerState: - """State for NoiseTransformer.""" +class WhiteNoiseState(ClockDrivenState): + """State for WhiteNoiseProducer.""" template: AxisArray | None = None -class WhiteNoiseTransformer( - BaseStatefulTransformer[WhiteNoiseSettings, AxisArray, AxisArray, WhiteNoiseTransformerState] -): +class WhiteNoiseProducer(BaseClockDrivenProducer[WhiteNoiseSettings, WhiteNoiseState]): """ - Transforms input AxisArray into white noise signal. + Generates white noise synchronized to clock ticks. - Takes timing information from input message (counter output) and - generates random data with the specified number of channels. + Each clock tick produces a block of Gaussian white noise based on the + sample rate (fs) and chunk size (n_time) settings. """ - def _reset_state(self, message: AxisArray) -> None: + def _reset_state(self, time_axis: LinearAxis) -> None: """Initialize template with channel axis.""" n_ch = self.settings.n_ch self._state.template = AxisArray( data=np.zeros((0, n_ch)), dims=["time", "ch"], axes={ - "time": message.axes["time"], + "time": time_axis, "ch": AxisArray.CoordinateAxis( data=np.arange(n_ch), dims=["ch"], @@ -61,14 +60,13 @@ def _reset_state(self, message: AxisArray) -> None: }, ) - def _process(self, message: AxisArray) -> AxisArray: - n_time = message.data.shape[0] - + def _produce(self, n_samples: int, time_axis: LinearAxis) -> AxisArray: + """Generate white noise for this chunk.""" # Generate random data random_data = np.random.normal( loc=self.settings.loc, scale=self.settings.scale, - size=(n_time, self.settings.n_ch), + size=(n_samples, self.settings.n_ch), ) # Create output using template @@ -76,18 +74,18 @@ def _process(self, message: AxisArray) -> AxisArray: self._state.template, data=random_data, axes={ - "time": message.axes["time"], - "ch": self._state.template.axes["ch"], + **self._state.template.axes, + "time": time_axis, }, ) -class WhiteNoise(BaseTransformerUnit[WhiteNoiseSettings, AxisArray, AxisArray, WhiteNoiseTransformer]): +class WhiteNoise(BaseClockDrivenProducerUnit[WhiteNoiseSettings, WhiteNoiseProducer]): """ - Transforms counter input into white noise signal. + Generates white noise synchronized to clock ticks. - Receives timing from INPUT_SIGNAL (AxisArray from Counter) and outputs - white noise AxisArray. + Receives timing from INPUT_CLOCK (LinearAxis from Clock) and outputs + white noise AxisArray on OUTPUT_SIGNAL. """ SETTINGS = WhiteNoiseSettings @@ -98,23 +96,26 @@ class PinkNoiseSettings(WhiteNoiseSettings): cutoff: float = 300.0 """ - Highpass Cutoff frequency (Hz). Lowpass corner of first order Butterworth filter. + Lowpass cutoff frequency (Hz) for the first-order Butterworth filter + that creates the 1/f characteristic. """ -class PinkNoiseTransformer(CompositeProcessor[PinkNoiseSettings, AxisArray, AxisArray]): +class PinkNoiseProducer(CompositeProcessor[PinkNoiseSettings, LinearAxis, AxisArray]): """ - Transforms input AxisArray into pink (1/f) noise signal. + Generates pink (1/f) noise synchronized to clock ticks. Pink noise is generated by filtering white noise with a first-order - lowpass Butterworth filter with cutoff at fs * 0.01 Hz. + lowpass Butterworth filter. """ @staticmethod def _initialize_processors(settings: PinkNoiseSettings) -> dict[str, BaseProcessor]: return { - "white_noise": WhiteNoiseTransformer( + "white_noise": WhiteNoiseProducer( WhiteNoiseSettings( + fs=settings.fs, + n_time=settings.n_time, n_ch=settings.n_ch, loc=settings.loc, scale=settings.scale, @@ -126,12 +127,12 @@ def _initialize_processors(settings: PinkNoiseSettings) -> dict[str, BaseProcess } -class PinkNoise(BaseTransformerUnit[PinkNoiseSettings, AxisArray, AxisArray, PinkNoiseTransformer]): +class PinkNoise(BaseClockDrivenProducerUnit[PinkNoiseSettings, PinkNoiseProducer]): """ - Transforms counter input into pink (1/f) noise signal. + Generates pink (1/f) noise synchronized to clock ticks. - Receives timing from INPUT_SIGNAL (AxisArray from Counter) and outputs - pink noise AxisArray. + Receives timing from INPUT_CLOCK (LinearAxis from Clock) and outputs + pink noise AxisArray on OUTPUT_SIGNAL. """ SETTINGS = PinkNoiseSettings diff --git a/src/ezmsg/simbiophys/oscillator.py b/src/ezmsg/simbiophys/oscillator.py index 475ab69..58417e5 100644 --- a/src/ezmsg/simbiophys/oscillator.py +++ b/src/ezmsg/simbiophys/oscillator.py @@ -1,17 +1,18 @@ """Oscillator/sinusoidal signal generators.""" -import ezmsg.core as ez import numpy as np import numpy.typing as npt from ezmsg.baseproc import ( - BaseStatefulTransformer, - BaseTransformerUnit, + BaseClockDrivenProducer, + BaseClockDrivenProducerUnit, + ClockDrivenSettings, + ClockDrivenState, processor_state, ) -from ezmsg.util.messages.axisarray import AxisArray, replace +from ezmsg.util.messages.axisarray import AxisArray, LinearAxis, replace -class SinGeneratorSettings(ez.Settings): +class SinGeneratorSettings(ClockDrivenSettings): """Settings for :obj:`SinGenerator`.""" n_ch: int = 1 @@ -28,8 +29,8 @@ class SinGeneratorSettings(ez.Settings): @processor_state -class SinTransformerState: - """State for SinTransformer.""" +class SinGeneratorState(ClockDrivenState): + """State for SinGenerator.""" template: AxisArray | None = None # Pre-computed arrays for efficient processing, shape (1, 1) or (1, n_ch) @@ -38,15 +39,15 @@ class SinTransformerState: phase: np.ndarray | None = None -class SinTransformer(BaseStatefulTransformer[SinGeneratorSettings, AxisArray, AxisArray, SinTransformerState]): +class SinProducer(BaseClockDrivenProducer[SinGeneratorSettings, SinGeneratorState]): """ - Transforms counter values into sinusoidal waveforms. + Generates sinusoidal waveforms synchronized to clock ticks. - Takes AxisArray with integer counter values and generates sinusoidal - output based on the time axis sample rate. + Each clock tick produces a block of sinusoidal data based on the + sample rate (fs) and chunk size (n_time) settings. """ - def _reset_state(self, message: AxisArray) -> None: + def _reset_state(self, time_axis: LinearAxis) -> None: """Initialize template and pre-compute parameter arrays.""" n_ch = self.settings.n_ch @@ -55,7 +56,7 @@ def _reset_state(self, message: AxisArray) -> None: data=np.zeros((0, n_ch)), dims=["time", "ch"], axes={ - "time": message.axes["time"], + "time": time_axis, "ch": AxisArray.CoordinateAxis( data=np.arange(n_ch), dims=["ch"], @@ -85,11 +86,11 @@ def _reset_state(self, message: AxisArray) -> None: self._state.amp = amp self._state.phase = phase - def _process(self, message: AxisArray) -> AxisArray: - """Transform input counter values into sinusoidal waveform.""" + def _produce(self, n_samples: int, time_axis: LinearAxis) -> AxisArray: + """Generate sinusoidal waveform for this chunk.""" # Calculate sinusoid: amp * sin(ang_freq*t + phase) # t shape: (n_time,) -> (n_time, 1) for broadcasting with (1, n_ch) - t = message.data[:, np.newaxis] * message.axes["time"].gain + t = (np.arange(n_samples) + self._state.counter)[:, np.newaxis] * time_axis.gain sin_data = self._state.amp * np.sin(self._state.ang_freq * t + self._state.phase) # Tile if all params were scalar but n_ch > 1 @@ -100,18 +101,18 @@ def _process(self, message: AxisArray) -> AxisArray: self._state.template, data=sin_data, axes={ - "time": message.axes["time"], - "ch": self._state.template.axes["ch"], + **self._state.template.axes, + "time": time_axis, }, ) -class SinGenerator(BaseTransformerUnit[SinGeneratorSettings, AxisArray, AxisArray, SinTransformer]): +class SinGenerator(BaseClockDrivenProducerUnit[SinGeneratorSettings, SinProducer]): """ - Transforms counter input into sinusoidal waveform. + Generates sinusoidal waveforms synchronized to clock ticks. - Receives timing from INPUT_SIGNAL (AxisArray from Counter) and outputs - sinusoidal AxisArray. + Receives timing from INPUT_CLOCK (LinearAxis from Clock) and outputs + sinusoidal AxisArray on OUTPUT_SIGNAL. """ SETTINGS = SinGeneratorSettings diff --git a/tests/integration/test_dnss_system.py b/tests/integration/test_dnss_system.py index df02537..341ad3d 100644 --- a/tests/integration/test_dnss_system.py +++ b/tests/integration/test_dnss_system.py @@ -52,7 +52,7 @@ def configure(self) -> None: def network(self) -> ez.NetworkDefinition: return ( - (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_SIGNAL), + (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_CLOCK), (self.COUNTER.OUTPUT_SIGNAL, self.LFP.INPUT_SIGNAL), (self.LFP.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE), (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), @@ -179,7 +179,7 @@ def configure(self) -> None: def network(self) -> ez.NetworkDefinition: return ( - (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_SIGNAL), + (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_CLOCK), (self.COUNTER.OUTPUT_SIGNAL, self.SPIKE.INPUT_SIGNAL), (self.SPIKE.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE), (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), diff --git a/tests/integration/test_synth_system.py b/tests/integration/test_synth_system.py index e1da9f8..f20ed2d 100644 --- a/tests/integration/test_synth_system.py +++ b/tests/integration/test_synth_system.py @@ -99,7 +99,7 @@ def configure(self) -> None: def network(self) -> ez.NetworkDefinition: return ( - (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_SIGNAL), + (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_CLOCK), (self.COUNTER.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE), (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), ) diff --git a/tests/unit/test_oscillator.py b/tests/unit/test_oscillator.py index 0a39552..a6ffd28 100644 --- a/tests/unit/test_oscillator.py +++ b/tests/unit/test_oscillator.py @@ -1,70 +1,63 @@ """Unit tests for ezmsg.simbiophys.oscillator module.""" import numpy as np +import pytest from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.simbiophys import SinGeneratorSettings, SinTransformer +from ezmsg.simbiophys import SinGeneratorSettings, SinProducer -def test_sin_transformer(freq: float = 1.0, amp: float = 1.0, phase: float = 0.0): - """Test SinTransformer via __call__.""" +def test_sin_generator_basic(freq: float = 1.0, amp: float = 1.0, phase: float = 0.0): + """Test SinProducer via __call__.""" n_ch = 1 srate = max(4.0 * freq, 1000.0) sim_dur = 30.0 n_samples = int(srate * sim_dur) n_msgs = min(n_samples, 10) - - # Create input messages with counter data (integer sample counts) - messages = [] - counter = 0 samples_per_msg = n_samples // n_msgs - for i in range(n_msgs): - n = samples_per_msg if i < n_msgs - 1 else n_samples - counter - sample_indices = np.arange(counter, counter + n) - _time_axis = AxisArray.TimeAxis(fs=srate, offset=counter / srate) - messages.append(AxisArray(sample_indices, dims=["time"], axes={"time": _time_axis})) - counter += n def f_test(t): return amp * np.sin(2 * np.pi * freq * t + phase) - # Create transformer - transformer = SinTransformer(SinGeneratorSettings(n_ch=n_ch, freq=freq, amp=amp, phase=phase)) + # Create producer with clock-driven settings + producer = SinProducer( + SinGeneratorSettings(fs=srate, n_time=samples_per_msg, n_ch=n_ch, freq=freq, amp=amp, phase=phase) + ) - # Process messages + # Process clock ticks results = [] - for msg in messages: - res = transformer(msg) + for i in range(n_msgs): + offset = i * samples_per_msg / srate + clock_tick = AxisArray.TimeAxis(fs=srate, offset=offset) + res = producer(clock_tick) + # Check output shape - assert res.data.shape == (len(msg.data), n_ch) - # Check values - t = msg.data / srate - expected = f_test(t)[:, np.newaxis] - assert np.allclose(res.data, expected) + expected_n = samples_per_msg if i < n_msgs - 1 else n_samples - i * samples_per_msg + # Note: With fixed n_time, all chunks have the same size + assert res.data.shape == (expected_n, n_ch) results.append(res) # Verify concatenated output concat_ax_arr = AxisArray.concatenate(*results, dim="time") - assert np.allclose(concat_ax_arr.data, f_test(np.arange(n_samples) / srate)[:, np.newaxis]) + total_samples = concat_ax_arr.data.shape[0] + t = np.arange(total_samples) / srate + expected = f_test(t)[:, np.newaxis] + np.testing.assert_allclose(concat_ax_arr.data, expected, rtol=1e-10) -def test_sin_transformer_multi_channel(): - """Test SinTransformer with multiple channels.""" +def test_sin_generator_multi_channel(): + """Test SinProducer with multiple channels.""" n_ch = 4 freq = 10.0 srate = 1000.0 n_samples = 100 - # Create input with counter data - sample_indices = np.arange(n_samples) - _time_axis = AxisArray.TimeAxis(fs=srate, offset=0.0) - msg = AxisArray(sample_indices, dims=["time"], axes={"time": _time_axis}) + # Create producer + producer = SinProducer(SinGeneratorSettings(fs=srate, n_time=n_samples, n_ch=n_ch, freq=freq)) - # Create transformer - transformer = SinTransformer(SinGeneratorSettings(n_ch=n_ch, freq=freq)) - - # Process - result = transformer(msg) + # Process single clock tick + clock_tick = AxisArray.TimeAxis(fs=srate, offset=0.0) + result = producer(clock_tick) # Check output shape assert result.data.shape == (n_samples, n_ch) @@ -75,8 +68,8 @@ def test_sin_transformer_multi_channel(): np.testing.assert_allclose(result.data[:, 0], result.data[:, ch]) -def test_sin_transformer_per_channel_freq(): - """Test SinTransformer with per-channel frequencies.""" +def test_sin_generator_per_channel_freq(): + """Test SinProducer with per-channel frequencies.""" n_ch = 3 freqs = [5.0, 10.0, 20.0] amp = 1.0 @@ -84,29 +77,27 @@ def test_sin_transformer_per_channel_freq(): srate = 1000.0 n_samples = 100 - # Create input with counter data - sample_indices = np.arange(n_samples) - _time_axis = AxisArray.TimeAxis(fs=srate, offset=0.0) - msg = AxisArray(sample_indices, dims=["time"], axes={"time": _time_axis}) - - # Create transformer with per-channel freqs - transformer = SinTransformer(SinGeneratorSettings(n_ch=n_ch, freq=freqs, amp=amp, phase=phase)) + # Create producer with per-channel freqs + producer = SinProducer( + SinGeneratorSettings(fs=srate, n_time=n_samples, n_ch=n_ch, freq=freqs, amp=amp, phase=phase) + ) # Process - result = transformer(msg) + clock_tick = AxisArray.TimeAxis(fs=srate, offset=0.0) + result = producer(clock_tick) # Check output shape assert result.data.shape == (n_samples, n_ch) # Verify each channel has correct frequency - t = sample_indices / srate + t = np.arange(n_samples) / srate for ch, freq in enumerate(freqs): expected = amp * np.sin(2 * np.pi * freq * t + phase) np.testing.assert_allclose(result.data[:, ch], expected, rtol=1e-10) -def test_sin_transformer_per_channel_all_params(): - """Test SinTransformer with per-channel freq, amp, and phase.""" +def test_sin_generator_per_channel_all_params(): + """Test SinProducer with per-channel freq, amp, and phase.""" n_ch = 4 freqs = [5.0, 10.0, 15.0, 20.0] amps = [1.0, 2.0, 0.5, 1.5] @@ -114,29 +105,27 @@ def test_sin_transformer_per_channel_all_params(): srate = 1000.0 n_samples = 200 - # Create input with counter data - sample_indices = np.arange(n_samples) - _time_axis = AxisArray.TimeAxis(fs=srate, offset=0.0) - msg = AxisArray(sample_indices, dims=["time"], axes={"time": _time_axis}) - - # Create transformer with all per-channel params - transformer = SinTransformer(SinGeneratorSettings(n_ch=n_ch, freq=freqs, amp=amps, phase=phases)) + # Create producer with all per-channel params + producer = SinProducer( + SinGeneratorSettings(fs=srate, n_time=n_samples, n_ch=n_ch, freq=freqs, amp=amps, phase=phases) + ) # Process - result = transformer(msg) + clock_tick = AxisArray.TimeAxis(fs=srate, offset=0.0) + result = producer(clock_tick) # Check output shape assert result.data.shape == (n_samples, n_ch) # Verify each channel (use atol for values near zero) - t = sample_indices / srate + t = np.arange(n_samples) / srate for ch in range(n_ch): expected = amps[ch] * np.sin(2 * np.pi * freqs[ch] * t + phases[ch]) np.testing.assert_allclose(result.data[:, ch], expected, rtol=1e-10, atol=1e-14) -def test_sin_transformer_mixed_scalar_array(): - """Test SinTransformer with mixed scalar and array params.""" +def test_sin_generator_mixed_scalar_array(): + """Test SinProducer with mixed scalar and array params.""" n_ch = 3 freqs = [5.0, 10.0, 20.0] # per-channel amp = 2.0 # scalar - same for all channels @@ -144,50 +133,43 @@ def test_sin_transformer_mixed_scalar_array(): srate = 1000.0 n_samples = 100 - # Create input with counter data - sample_indices = np.arange(n_samples) - _time_axis = AxisArray.TimeAxis(fs=srate, offset=0.0) - msg = AxisArray(sample_indices, dims=["time"], axes={"time": _time_axis}) - - # Create transformer - transformer = SinTransformer(SinGeneratorSettings(n_ch=n_ch, freq=freqs, amp=amp, phase=phase)) + # Create producer + producer = SinProducer( + SinGeneratorSettings(fs=srate, n_time=n_samples, n_ch=n_ch, freq=freqs, amp=amp, phase=phase) + ) # Process - result = transformer(msg) + clock_tick = AxisArray.TimeAxis(fs=srate, offset=0.0) + result = producer(clock_tick) # Check output shape assert result.data.shape == (n_samples, n_ch) # Verify each channel - t = sample_indices / srate + t = np.arange(n_samples) / srate for ch, freq in enumerate(freqs): expected = amp * np.sin(2 * np.pi * freq * t + phase) np.testing.assert_allclose(result.data[:, ch], expected, rtol=1e-10) -def test_sin_transformer_array_length_mismatch(): - """Test SinTransformer raises error when array length doesn't match n_ch.""" - import pytest - +def test_sin_generator_array_length_mismatch(): + """Test SinProducer raises error when array length doesn't match n_ch.""" n_ch = 4 freqs = [5.0, 10.0, 20.0] # length 3, but n_ch is 4 srate = 1000.0 + n_samples = 100 - # Create input - sample_indices = np.arange(100) - _time_axis = AxisArray.TimeAxis(fs=srate, offset=0.0) - msg = AxisArray(sample_indices, dims=["time"], axes={"time": _time_axis}) - - # Create transformer - should not raise here - transformer = SinTransformer(SinGeneratorSettings(n_ch=n_ch, freq=freqs)) + # Create producer - should not raise here + producer = SinProducer(SinGeneratorSettings(fs=srate, n_time=n_samples, n_ch=n_ch, freq=freqs)) - # Should raise ValueError when processing + # Should raise ValueError when processing (during _reset_state) + clock_tick = AxisArray.TimeAxis(fs=srate, offset=0.0) with pytest.raises(ValueError, match="freq has length 3 but n_ch is 4"): - transformer(msg) + producer(clock_tick) -def test_sin_transformer_numpy_array_input(): - """Test SinTransformer accepts numpy arrays for per-channel params.""" +def test_sin_generator_numpy_array_input(): + """Test SinProducer accepts numpy arrays for per-channel params.""" n_ch = 3 freqs = np.array([5.0, 10.0, 20.0]) amps = np.array([1.0, 2.0, 0.5]) @@ -195,22 +177,77 @@ def test_sin_transformer_numpy_array_input(): srate = 1000.0 n_samples = 100 - # Create input with counter data - sample_indices = np.arange(n_samples) - _time_axis = AxisArray.TimeAxis(fs=srate, offset=0.0) - msg = AxisArray(sample_indices, dims=["time"], axes={"time": _time_axis}) - - # Create transformer with numpy arrays - transformer = SinTransformer(SinGeneratorSettings(n_ch=n_ch, freq=freqs, amp=amps, phase=phases)) + # Create producer with numpy arrays + producer = SinProducer( + SinGeneratorSettings(fs=srate, n_time=n_samples, n_ch=n_ch, freq=freqs, amp=amps, phase=phases) + ) # Process - result = transformer(msg) + clock_tick = AxisArray.TimeAxis(fs=srate, offset=0.0) + result = producer(clock_tick) # Check output shape assert result.data.shape == (n_samples, n_ch) # Verify each channel (use atol for values near zero) - t = sample_indices / srate + t = np.arange(n_samples) / srate for ch in range(n_ch): expected = amps[ch] * np.sin(2 * np.pi * freqs[ch] * t + phases[ch]) np.testing.assert_allclose(result.data[:, ch], expected, rtol=1e-10, atol=1e-14) + + +def test_sin_generator_continuity_across_chunks(): + """Test that sine wave is continuous across multiple clock ticks.""" + n_ch = 1 + freq = 10.0 + srate = 1000.0 + n_samples_per_chunk = 50 + n_chunks = 4 + + # Create producer + producer = SinProducer(SinGeneratorSettings(fs=srate, n_time=n_samples_per_chunk, n_ch=n_ch, freq=freq)) + + # Process multiple clock ticks + results = [] + for i in range(n_chunks): + offset = i * n_samples_per_chunk / srate + clock_tick = AxisArray.TimeAxis(fs=srate, offset=offset) + results.append(producer(clock_tick)) + + # Concatenate and verify continuity + concat = AxisArray.concatenate(*results, dim="time") + total_samples = n_samples_per_chunk * n_chunks + t = np.arange(total_samples) / srate + expected = np.sin(2 * np.pi * freq * t)[:, np.newaxis] + np.testing.assert_allclose(concat.data, expected, rtol=1e-10) + + +def test_sin_generator_variable_chunk_mode(): + """Test SinProducer with variable chunk sizes (n_time=None).""" + n_ch = 1 + freq = 10.0 + srate = 1000.0 + + # Create producer without fixed n_time + producer = SinProducer(SinGeneratorSettings(fs=srate, n_time=None, n_ch=n_ch, freq=freq)) + + # Clock ticks with different gains -> different chunk sizes + clock_ticks = [ + AxisArray.LinearAxis(gain=0.1, offset=0.0), # 100 samples + AxisArray.LinearAxis(gain=0.05, offset=0.1), # 50 samples + AxisArray.LinearAxis(gain=0.2, offset=0.15), # 200 samples + ] + + results = [] + for tick in clock_ticks: + results.append(producer(tick)) + + assert results[0].data.shape == (100, n_ch) + assert results[1].data.shape == (50, n_ch) + assert results[2].data.shape == (200, n_ch) + + # Verify continuity + concat = AxisArray.concatenate(*results, dim="time") + t = np.arange(350) / srate + expected = np.sin(2 * np.pi * freq * t)[:, np.newaxis] + np.testing.assert_allclose(concat.data, expected, rtol=1e-10) From 9d61053b943560c87cb0c20c519f81b05544ea92 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Mon, 5 Jan 2026 18:12:04 -0500 Subject: [PATCH 02/11] Fix base class name --- pyproject.toml | 2 +- src/ezmsg/simbiophys/noise.py | 6 +++--- src/ezmsg/simbiophys/oscillator.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2c40b34..cd8a27d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ requires-python = ">=3.10.15" dynamic = ["version"] dependencies = [ "ezmsg>=3.6.0", - "ezmsg-baseproc>=1.2.0", + "ezmsg-baseproc>=1.2.1", "ezmsg-event>=0.6.0", "ezmsg-sigproc>=2.8.0", "numpy>=1.26.0", diff --git a/src/ezmsg/simbiophys/noise.py b/src/ezmsg/simbiophys/noise.py index b4b97da..b1eff63 100644 --- a/src/ezmsg/simbiophys/noise.py +++ b/src/ezmsg/simbiophys/noise.py @@ -3,7 +3,7 @@ import numpy as np from ezmsg.baseproc import ( BaseClockDrivenProducer, - BaseClockDrivenProducerUnit, + BaseClockDrivenUnit, BaseProcessor, ClockDrivenSettings, ClockDrivenState, @@ -80,7 +80,7 @@ def _produce(self, n_samples: int, time_axis: LinearAxis) -> AxisArray: ) -class WhiteNoise(BaseClockDrivenProducerUnit[WhiteNoiseSettings, WhiteNoiseProducer]): +class WhiteNoise(BaseClockDrivenUnit[WhiteNoiseSettings, WhiteNoiseProducer]): """ Generates white noise synchronized to clock ticks. @@ -127,7 +127,7 @@ def _initialize_processors(settings: PinkNoiseSettings) -> dict[str, BaseProcess } -class PinkNoise(BaseClockDrivenProducerUnit[PinkNoiseSettings, PinkNoiseProducer]): +class PinkNoise(BaseClockDrivenUnit[PinkNoiseSettings, PinkNoiseProducer]): """ Generates pink (1/f) noise synchronized to clock ticks. diff --git a/src/ezmsg/simbiophys/oscillator.py b/src/ezmsg/simbiophys/oscillator.py index 58417e5..515085e 100644 --- a/src/ezmsg/simbiophys/oscillator.py +++ b/src/ezmsg/simbiophys/oscillator.py @@ -4,7 +4,7 @@ import numpy.typing as npt from ezmsg.baseproc import ( BaseClockDrivenProducer, - BaseClockDrivenProducerUnit, + BaseClockDrivenUnit, ClockDrivenSettings, ClockDrivenState, processor_state, @@ -107,7 +107,7 @@ def _produce(self, n_samples: int, time_axis: LinearAxis) -> AxisArray: ) -class SinGenerator(BaseClockDrivenProducerUnit[SinGeneratorSettings, SinProducer]): +class SinGenerator(BaseClockDrivenUnit[SinGeneratorSettings, SinProducer]): """ Generates sinusoidal waveforms synchronized to clock ticks. From 9b2f326ad1f544660377dd0d5687f9ed5e4be359 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Mon, 5 Jan 2026 18:12:28 -0500 Subject: [PATCH 03/11] Move old tests for Clock and Counter from this package to ezmsg-baseproc --- tests/integration/test_synth_system.py | 218 ------------------------- 1 file changed, 218 deletions(-) diff --git a/tests/integration/test_synth_system.py b/tests/integration/test_synth_system.py index f20ed2d..769782d 100644 --- a/tests/integration/test_synth_system.py +++ b/tests/integration/test_synth_system.py @@ -1,239 +1,21 @@ """Integration tests for ezmsg.simbiophys signal generator systems.""" -import math import os from dataclasses import field import ezmsg.core as ez -import numpy as np -import pytest from ezmsg.util.messagecodec import message_log from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings from ezmsg.simbiophys import ( - Clock, - ClockSettings, - Counter, - CounterSettings, EEGSynth, EEGSynthSettings, ) from tests.helpers.util import get_test_fn -class ClockTestSystemSettings(ez.Settings): - clock_settings: ClockSettings - log_settings: MessageLoggerSettings - term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings) - - -class ClockTestSystem(ez.Collection): - SETTINGS = ClockTestSystemSettings - - CLOCK = Clock() - LOG = MessageLogger() - TERM = TerminateOnTotal() - - def configure(self) -> None: - self.CLOCK.apply_settings(self.SETTINGS.clock_settings) - self.LOG.apply_settings(self.SETTINGS.log_settings) - self.TERM.apply_settings(self.SETTINGS.term_settings) - - def network(self) -> ez.NetworkDefinition: - return ( - (self.CLOCK.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE), - (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), - ) - - -@pytest.mark.parametrize("dispatch_rate", [math.inf, 2.0, 20.0]) -def test_clock_system( - dispatch_rate: float, - test_name: str | None = None, -): - run_time = 1.0 - n_target = 100 if math.isinf(dispatch_rate) else int(np.ceil(dispatch_rate * run_time)) - test_filename = get_test_fn(test_name) - ez.logger.info(test_filename) - settings = ClockTestSystemSettings( - clock_settings=ClockSettings(dispatch_rate=dispatch_rate), - log_settings=MessageLoggerSettings(output=test_filename), - term_settings=TerminateOnTotalSettings(total=n_target), - ) - system = ClockTestSystem(settings) - ez.run(SYSTEM=system) - - # Collect result - messages = list(message_log(test_filename)) - os.remove(test_filename) - - # Clock produces LinearAxis with gain and offset - assert all(isinstance(m, AxisArray.LinearAxis) for m in messages) - assert len(messages) >= n_target - - -class CounterTestSystemSettings(ez.Settings): - clock_settings: ClockSettings - counter_settings: CounterSettings - log_settings: MessageLoggerSettings - term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings) - - -class CounterTestSystem(ez.Collection): - """Counter must be driven by Clock in the new architecture.""" - - SETTINGS = CounterTestSystemSettings - - CLOCK = Clock() - COUNTER = Counter() - LOG = MessageLogger() - TERM = TerminateOnTotal() - - def configure(self) -> None: - self.CLOCK.apply_settings(self.SETTINGS.clock_settings) - self.COUNTER.apply_settings(self.SETTINGS.counter_settings) - self.LOG.apply_settings(self.SETTINGS.log_settings) - self.TERM.apply_settings(self.SETTINGS.term_settings) - - def network(self) -> ez.NetworkDefinition: - return ( - (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_CLOCK), - (self.COUNTER.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE), - (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), - ) - - -@pytest.mark.parametrize( - "n_time, fs, dispatch_rate, mod", - [ - (1, 10.0, math.inf, None), # AFAP mode - (20, 1000.0, 50.0, None), # Realtime mode (50 Hz dispatch = 20 samples/tick @ 1000 Hz) - (1, 1000.0, 100.0, 2**3), # 100 Hz dispatch with mod - (10, 10.0, 10.0, 2**3), # 10 Hz dispatch with mod - ], -) -def test_counter_system( - n_time: int, - fs: float, - dispatch_rate: float, - mod: int | None, - test_name: str | None = None, -): - target_dur = 2.6 # 2.6 seconds per test - if math.isinf(dispatch_rate): - # AFAP mode - runs as fast as possible - target_messages = 100 # Fixed target for AFAP - else: - target_messages = int(target_dur * dispatch_rate) - - test_filename = get_test_fn(test_name) - ez.logger.info(test_filename) - settings = CounterTestSystemSettings( - clock_settings=ClockSettings(dispatch_rate=dispatch_rate), - counter_settings=CounterSettings( - n_time=n_time, - fs=fs, - mod=mod, - ), - log_settings=MessageLoggerSettings( - output=test_filename, - ), - term_settings=TerminateOnTotalSettings( - total=target_messages, - ), - ) - system = CounterTestSystem(settings) - ez.run(SYSTEM=system) - - # Collect result - messages: list[AxisArray] = [_ for _ in message_log(test_filename)] - os.remove(test_filename) - - if math.isinf(dispatch_rate): - # The number of messages depends on how fast the computer is - target_messages = len(messages) - # This should be an equivalence assertion (==) but the use of TerminateOnTotal does - # not guarantee that MessageLogger will exit before an additional message is received. - # Let's just clip the last message if we exceed the target messages. - if len(messages) > target_messages: - messages = messages[:target_messages] - assert len(messages) >= target_messages - - # Just do one quick data check (Counter now outputs 1D array) - agg = AxisArray.concatenate(*messages, dim="time") - target_samples = n_time * target_messages - expected_data = np.arange(target_samples) - if mod is not None: - expected_data = expected_data % mod - assert np.array_equal(agg.data, expected_data) - - -@pytest.mark.parametrize( - "clock_rate, fs, n_time", - [ - (10.0, 1000.0, 100), # 10 Hz clock, fs=1000, n_time=100 (fixed) - (20.0, 500.0, None), # 20 Hz clock, fs=500, n_time derived (25 samples per tick) - (5.0, 1000.0, None), # 5 Hz clock, fs=1000, n_time derived (200 samples per tick) - ], -) -def test_counter_with_external_clock( - clock_rate: float, - fs: float, - n_time: int | None, - test_name: str | None = None, -): - """Test Counter driven by external Clock (now the standard pattern).""" - target_messages = 20 - test_filename = get_test_fn(test_name) - ez.logger.info(test_filename) - - # This now uses the same CounterTestSystem since all counters need clocks - settings = CounterTestSystemSettings( - clock_settings=ClockSettings(dispatch_rate=clock_rate), - counter_settings=CounterSettings( - fs=fs, - n_time=n_time, - ), - log_settings=MessageLoggerSettings(output=test_filename), - term_settings=TerminateOnTotalSettings(total=target_messages), - ) - system = CounterTestSystem(settings) - ez.run(SYSTEM=system) - - # Collect result - messages: list[AxisArray] = [_ for _ in message_log(test_filename)] - os.remove(test_filename) - - assert len(messages) >= target_messages - - # Verify each message has correct sample rate (gain = 1/fs) - for msg in messages: - assert msg.axes["time"].gain == 1.0 / fs - - # Verify data continuity - messages = messages[:target_messages] # Trim to target - agg = AxisArray.concatenate(*messages, dim="time") - - # Expected samples per tick - if n_time is not None: - expected_samples_per_tick = n_time - else: - expected_samples_per_tick = int(fs / clock_rate) - - expected_total = expected_samples_per_tick * target_messages - # Allow for fractional sample accumulation variance - assert abs(len(agg.data) - expected_total) <= target_messages - - # Counter values should be sequential (0, 1, 2, ...) - expected_data = np.arange(len(agg.data)) - assert np.array_equal(agg.data, expected_data) - - -# TODO: test SinGenerator in a system. - - class EEGSynthSettingsTest(ez.Settings): synth_settings: EEGSynthSettings log_settings: MessageLoggerSettings From 051e10b6e1b331ed796f664259e98dce664a3e68 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Mon, 5 Jan 2026 19:16:52 -0500 Subject: [PATCH 04/11] Update DNSS to use Clock->BaseClockDriven instead of Clock->Counter-> --- src/ezmsg/simbiophys/__init__.py | 8 +- src/ezmsg/simbiophys/dnss/__init__.py | 16 +-- src/ezmsg/simbiophys/dnss/lfp.py | 50 ++++---- src/ezmsg/simbiophys/dnss/spike.py | 54 ++++----- src/ezmsg/simbiophys/dnss/synth.py | 30 ++--- tests/integration/test_dnss_system.py | 27 ++--- tests/unit/test_dnss_lfp.py | 88 +++++++------- tests/unit/test_dnss_spike.py | 160 +++++++++++--------------- 8 files changed, 195 insertions(+), 238 deletions(-) diff --git a/src/ezmsg/simbiophys/__init__.py b/src/ezmsg/simbiophys/__init__.py index 4a175a5..ef64894 100644 --- a/src/ezmsg/simbiophys/__init__.py +++ b/src/ezmsg/simbiophys/__init__.py @@ -26,12 +26,12 @@ # DNSS (Digital Neural Signal Simulator) from .dnss import ( # LFP + DNSSLFPProducer, DNSSLFPSettings, - DNSSLFPTransformer, DNSSLFPUnit, # Spike + DNSSSpikeProducer, DNSSSpikeSettings, - DNSSSpikeTransformer, DNSSSpikeUnit, ) @@ -113,11 +113,11 @@ "DynamicColoredNoiseUnit", "compute_kasdin_coefficients", # DNSS LFP + "DNSSLFPProducer", "DNSSLFPSettings", - "DNSSLFPTransformer", "DNSSLFPUnit", # DNSS Spike + "DNSSSpikeProducer", "DNSSSpikeSettings", - "DNSSSpikeTransformer", "DNSSSpikeUnit", ] diff --git a/src/ezmsg/simbiophys/dnss/__init__.py b/src/ezmsg/simbiophys/dnss/__init__.py index cd90df9..7f61616 100644 --- a/src/ezmsg/simbiophys/dnss/__init__.py +++ b/src/ezmsg/simbiophys/dnss/__init__.py @@ -7,9 +7,9 @@ LFP_PERIOD, LFP_TIME_SHIFTS, OTHER_PERIOD, + DNSSLFPProducer, DNSSLFPSettings, - DNSSLFPTransformer, - DNSSLFPTransformerState, + DNSSLFPState, DNSSLFPUnit, lfp_generator, ) @@ -22,9 +22,9 @@ N_SLOW_SPIKES, SAMPS_BURST, SAMPS_SLOW, + DNSSSpikeProducer, DNSSSpikeSettings, - DNSSSpikeTransformer, - DNSSSpikeTransformerState, + DNSSSpikeState, DNSSSpikeUnit, spike_event_generator, ) @@ -43,9 +43,9 @@ "LFP_TIME_SHIFTS", "OTHER_PERIOD", # LFP classes + "DNSSLFPProducer", "DNSSLFPSettings", - "DNSSLFPTransformer", - "DNSSLFPTransformerState", + "DNSSLFPState", "DNSSLFPUnit", "lfp_generator", # Spike constants @@ -58,9 +58,9 @@ "SAMPS_BURST", "SAMPS_SLOW", # Spike classes + "DNSSSpikeProducer", "DNSSSpikeSettings", - "DNSSSpikeTransformer", - "DNSSSpikeTransformerState", + "DNSSSpikeState", "DNSSSpikeUnit", "spike_event_generator", # Synth classes diff --git a/src/ezmsg/simbiophys/dnss/lfp.py b/src/ezmsg/simbiophys/dnss/lfp.py index fcebbee..feb7d2c 100644 --- a/src/ezmsg/simbiophys/dnss/lfp.py +++ b/src/ezmsg/simbiophys/dnss/lfp.py @@ -27,15 +27,16 @@ from typing import Generator -import ezmsg.core as ez import numpy as np import numpy.typing as npt from ezmsg.baseproc import ( - BaseStatefulTransformer, - BaseTransformerUnit, + BaseClockDrivenProducer, + BaseClockDrivenUnit, + ClockDrivenSettings, + ClockDrivenState, processor_state, ) -from ezmsg.util.messages.axisarray import AxisArray, replace +from ezmsg.util.messages.axisarray import AxisArray, LinearAxis, replace # Default sample rate for DNSS DEFAULT_FS = 30_000 @@ -196,8 +197,11 @@ def lfp_generator( # ============================================================================= -class DNSSLFPSettings(ez.Settings): - """Settings for DNSS LFP transformer.""" +class DNSSLFPSettings(ClockDrivenSettings): + """Settings for DNSS LFP producer.""" + + fs: float = DEFAULT_FS + """Sample rate in Hz. DNSS is fixed at 30kHz.""" n_ch: int = 256 """Number of channels.""" @@ -210,30 +214,28 @@ class DNSSLFPSettings(ez.Settings): @processor_state -class DNSSLFPTransformerState: - """State for DNSS LFP transformer.""" +class DNSSLFPState(ClockDrivenState): + """State for DNSS LFP producer.""" lfp_gen: Generator | None = None template: AxisArray | None = None -class DNSSLFPTransformer(BaseStatefulTransformer[DNSSLFPSettings, AxisArray, AxisArray, DNSSLFPTransformerState]): +class DNSSLFPProducer(BaseClockDrivenProducer[DNSSLFPSettings, DNSSLFPState]): """ - Transforms input AxisArray into DNSS LFP signal. + Produces DNSS LFP signal synchronized to clock ticks. - Takes timing information from input message and generates LFP data. + Each clock tick produces a block of LFP data based on the + sample rate (fs) and chunk size (n_time) settings. All channels receive identical LFP values. """ - def _reset_state(self, message: AxisArray) -> None: + def _reset_state(self, time_axis: LinearAxis) -> None: """Initialize the LFP generator.""" - # Get sample rate from input message (fs = 1/gain for LinearAxis) - time_axis = message.axes["time"] - fs = getattr(time_axis, "fs", 1.0 / time_axis.gain) self._state.lfp_gen = lfp_generator( pattern=self.settings.pattern, mode=self.settings.mode, - fs=fs, + fs=self.settings.fs, ) next(self._state.lfp_gen) @@ -242,7 +244,7 @@ def _reset_state(self, message: AxisArray) -> None: data=np.zeros((0, self.settings.n_ch), dtype=np.float64), dims=["time", "ch"], axes={ - "time": message.axes["time"], + "time": time_axis, "ch": AxisArray.CoordinateAxis( data=np.arange(self.settings.n_ch), dims=["ch"], @@ -250,10 +252,8 @@ def _reset_state(self, message: AxisArray) -> None: }, ) - def _process(self, message: AxisArray) -> AxisArray: - """Transform input into LFP signal.""" - n_samples = message.data.shape[0] - + def _produce(self, n_samples: int, time_axis: LinearAxis) -> AxisArray: + """Generate LFP signal for this chunk.""" # Generate LFP samples lfp_1d = self._state.lfp_gen.send(n_samples) @@ -267,13 +267,13 @@ def _process(self, message: AxisArray) -> AxisArray: self._state.template, data=lfp_data, axes={ - "time": message.axes["time"], - "ch": self._state.template.axes["ch"], + **self._state.template.axes, + "time": time_axis, }, ) -class DNSSLFPUnit(BaseTransformerUnit[DNSSLFPSettings, AxisArray, AxisArray, DNSSLFPTransformer]): - """Unit for generating DNSS LFP from counter input.""" +class DNSSLFPUnit(BaseClockDrivenUnit[DNSSLFPSettings, DNSSLFPProducer]): + """Unit for generating DNSS LFP from clock input.""" SETTINGS = DNSSLFPSettings diff --git a/src/ezmsg/simbiophys/dnss/spike.py b/src/ezmsg/simbiophys/dnss/spike.py index f44c69e..986c269 100644 --- a/src/ezmsg/simbiophys/dnss/spike.py +++ b/src/ezmsg/simbiophys/dnss/spike.py @@ -2,16 +2,17 @@ from typing import Generator -import ezmsg.core as ez import numpy as np import numpy.typing as npt import sparse from ezmsg.baseproc import ( - BaseStatefulTransformer, - BaseTransformerUnit, + BaseClockDrivenProducer, + BaseClockDrivenUnit, + ClockDrivenSettings, + ClockDrivenState, processor_state, ) -from ezmsg.util.messages.axisarray import AxisArray, replace +from ezmsg.util.messages.axisarray import AxisArray, LinearAxis, replace """ ## Spike Pattern @@ -258,8 +259,11 @@ def spike_event_generator( # ============================================================================= -class DNSSSpikeSettings(ez.Settings): - """Settings for DNSS spike transformer.""" +class DNSSSpikeSettings(ClockDrivenSettings): + """Settings for DNSS spike producer.""" + + fs: float = FS + """Sample rate in Hz. DNSS is fixed at 30kHz.""" n_ch: int = 256 """Number of channels.""" @@ -269,30 +273,28 @@ class DNSSSpikeSettings(ez.Settings): @processor_state -class DNSSSpikeTransformerState: - """State for DNSS spike transformer.""" +class DNSSSpikeState(ClockDrivenState): + """State for DNSS spike producer.""" spike_gen: Generator | None = None template: AxisArray | None = None -class DNSSSpikeTransformer(BaseStatefulTransformer[DNSSSpikeSettings, AxisArray, AxisArray, DNSSSpikeTransformerState]): +class DNSSSpikeProducer(BaseClockDrivenProducer[DNSSSpikeSettings, DNSSSpikeState]): """ - Transforms input AxisArray into DNSS spike signal. + Produces DNSS spike signal synchronized to clock ticks. - Takes timing information from input message and generates spike data as sparse COO arrays. + Each clock tick produces a block of spike data as sparse COO arrays + based on the sample rate (fs) and chunk size (n_time) settings. """ - def _reset_state(self, message: AxisArray) -> None: + def _reset_state(self, time_axis: LinearAxis) -> None: """Initialize the spike generator.""" # Verify sample rate is 30kHz - spike patterns are tied to this rate - time_axis = message.axes["time"] - fs = getattr(time_axis, "fs", 1.0 / time_axis.gain) - expected_gain = 1.0 / FS - if not np.isclose(time_axis.gain, expected_gain, rtol=1e-6): + if not np.isclose(self.settings.fs, FS, rtol=1e-6): raise ValueError( - f"DNSSSpikeTransformer requires fs={FS} Hz (gain={expected_gain:.6e}), " - f"but received fs={fs:.1f} Hz (gain={time_axis.gain:.6e}). " + f"DNSSSpikeProducer requires fs={FS} Hz, " + f"but settings.fs={self.settings.fs:.1f} Hz. " f"Spike patterns cannot be resampled to other rates." ) @@ -311,7 +313,7 @@ def _reset_state(self, message: AxisArray) -> None: ), dims=["time", "ch"], axes={ - "time": message.axes["time"], + "time": time_axis, "ch": AxisArray.CoordinateAxis( data=np.arange(self.settings.n_ch), dims=["ch"], @@ -319,10 +321,8 @@ def _reset_state(self, message: AxisArray) -> None: }, ) - def _process(self, message: AxisArray) -> AxisArray: - """Transform input into spike signal.""" - n_samples = message.data.shape[0] - + def _produce(self, n_samples: int, time_axis: LinearAxis) -> AxisArray: + """Generate spike signal for this chunk.""" # Generate spike events coords, waveform_ids = self._state.spike_gen.send(n_samples) @@ -337,13 +337,13 @@ def _process(self, message: AxisArray) -> AxisArray: self._state.template, data=spike_data, axes={ - "time": message.axes["time"], - "ch": self._state.template.axes["ch"], + **self._state.template.axes, + "time": time_axis, }, ) -class DNSSSpikeUnit(BaseTransformerUnit[DNSSSpikeSettings, AxisArray, AxisArray, DNSSSpikeTransformer]): - """Unit for generating DNSS spikes from counter input.""" +class DNSSSpikeUnit(BaseClockDrivenUnit[DNSSSpikeSettings, DNSSSpikeProducer]): + """Unit for generating DNSS spikes from clock input.""" SETTINGS = DNSSSpikeSettings diff --git a/src/ezmsg/simbiophys/dnss/synth.py b/src/ezmsg/simbiophys/dnss/synth.py index bd49017..46cc33f 100644 --- a/src/ezmsg/simbiophys/dnss/synth.py +++ b/src/ezmsg/simbiophys/dnss/synth.py @@ -2,14 +2,14 @@ import ezmsg.core as ez import numpy as np -from ezmsg.baseproc import Clock, ClockSettings, Counter, CounterSettings +from ezmsg.baseproc import Clock, ClockSettings from ezmsg.event.kernel import ArrayKernel, MultiKernel from ezmsg.event.kernel_insert import SparseKernelInserterSettings, SparseKernelInserterUnit from ezmsg.sigproc.math.add import Add from ezmsg.util.messages.axisarray import AxisArray -from .lfp import DNSSLFPSettings, DNSSLFPUnit -from .spike import FS, DNSSSpikeSettings, DNSSSpikeUnit +from .lfp import DEFAULT_FS, DNSSLFPSettings, DNSSLFPUnit +from .spike import DNSSSpikeSettings, DNSSSpikeUnit from .wfs import wf_orig @@ -61,7 +61,7 @@ class DNSSSynth(ez.Collection): The final output is the sum of spike waveforms and LFP signal. Network flow: - Clock -> Counter -> {SpikeGenerator, LFPGenerator} + Clock -> {SpikeGenerator, LFPGenerator} SpikeGenerator -> KernelInserter -> Add.A LFPGenerator -> Add.B Add -> OUTPUT @@ -74,9 +74,6 @@ class DNSSSynth(ez.Collection): # Clock produces timestamps at the block rate CLOCK = Clock() - # Counter converts timestamps to AxisArray with timing metadata - COUNTER = Counter() - # Spike path: produces sparse events, then inserts waveforms SPIKE = DNSSSpikeUnit() KERNEL_INSERT = SparseKernelInserterUnit() @@ -89,19 +86,13 @@ class DNSSSynth(ez.Collection): def configure(self) -> None: # Calculate dispatch rate for blocks (DNSS is fixed at 30kHz) - dispatch_rate = FS / self.SETTINGS.n_time + dispatch_rate = DEFAULT_FS / self.SETTINGS.n_time self.CLOCK.apply_settings(ClockSettings(dispatch_rate=dispatch_rate)) - self.COUNTER.apply_settings( - CounterSettings( - n_time=self.SETTINGS.n_time, - fs=FS, - ) - ) - self.SPIKE.apply_settings( DNSSSpikeSettings( + n_time=self.SETTINGS.n_time, n_ch=self.SETTINGS.n_ch, mode=self.SETTINGS.mode, ) @@ -118,6 +109,7 @@ def configure(self) -> None: self.LFP.apply_settings( DNSSLFPSettings( + n_time=self.SETTINGS.n_time, n_ch=self.SETTINGS.n_ch, pattern=self.SETTINGS.lfp_pattern, mode=self.SETTINGS.mode, @@ -126,11 +118,9 @@ def configure(self) -> None: def network(self) -> ez.NetworkDefinition: return ( - # Clock drives Counter - (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_CLOCK), - # Counter fans out to both Spike and LFP generators - (self.COUNTER.OUTPUT_SIGNAL, self.SPIKE.INPUT_SIGNAL), - (self.COUNTER.OUTPUT_SIGNAL, self.LFP.INPUT_SIGNAL), + # Clock drives Spike and LFP generators directly + (self.CLOCK.OUTPUT_SIGNAL, self.SPIKE.INPUT_CLOCK), + (self.CLOCK.OUTPUT_SIGNAL, self.LFP.INPUT_CLOCK), # Spike path: insert waveforms (self.SPIKE.OUTPUT_SIGNAL, self.KERNEL_INSERT.INPUT_SIGNAL), # Combine spike waveforms and LFP diff --git a/tests/integration/test_dnss_system.py b/tests/integration/test_dnss_system.py index 341ad3d..f782268 100644 --- a/tests/integration/test_dnss_system.py +++ b/tests/integration/test_dnss_system.py @@ -10,7 +10,7 @@ from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings -from ezmsg.simbiophys import Clock, ClockSettings, Counter, CounterSettings +from ezmsg.simbiophys import Clock, ClockSettings from ezmsg.simbiophys.dnss import ( DEFAULT_FS, LFP_GAINS, @@ -26,34 +26,30 @@ class DNSSLFPTestSystemSettings(ez.Settings): clock_settings: ClockSettings - counter_settings: CounterSettings lfp_settings: DNSSLFPSettings log_settings: MessageLoggerSettings term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings) class DNSSLFPTestSystem(ez.Collection): - """Test system for DNSS LFP: Clock -> Counter -> LFP.""" + """Test system for DNSS LFP: Clock -> LFP.""" SETTINGS = DNSSLFPTestSystemSettings CLOCK = Clock() - COUNTER = Counter() LFP = DNSSLFPUnit() LOG = MessageLogger() TERM = TerminateOnTotal() def configure(self) -> None: self.CLOCK.apply_settings(self.SETTINGS.clock_settings) - self.COUNTER.apply_settings(self.SETTINGS.counter_settings) self.LFP.apply_settings(self.SETTINGS.lfp_settings) self.LOG.apply_settings(self.SETTINGS.log_settings) self.TERM.apply_settings(self.SETTINGS.term_settings) def network(self) -> ez.NetworkDefinition: return ( - (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_CLOCK), - (self.COUNTER.OUTPUT_SIGNAL, self.LFP.INPUT_SIGNAL), + (self.CLOCK.OUTPUT_SIGNAL, self.LFP.INPUT_CLOCK), (self.LFP.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE), (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), ) @@ -72,8 +68,8 @@ def test_dnss_lfp_unit(test_name: str | None = None): settings = DNSSLFPTestSystemSettings( clock_settings=ClockSettings(dispatch_rate=dispatch_rate), - counter_settings=CounterSettings(fs=fs, n_time=n_time), lfp_settings=DNSSLFPSettings( + n_time=n_time, n_ch=n_ch, pattern="spike", mode="hdmi", @@ -126,8 +122,8 @@ def test_dnss_lfp_unit_other_pattern(test_name: str | None = None): settings = DNSSLFPTestSystemSettings( clock_settings=ClockSettings(dispatch_rate=dispatch_rate), - counter_settings=CounterSettings(fs=fs, n_time=n_time), lfp_settings=DNSSLFPSettings( + n_time=n_time, n_ch=n_ch, pattern="other", mode="hdmi", @@ -153,34 +149,30 @@ def test_dnss_lfp_unit_other_pattern(test_name: str | None = None): class DNSSSpikeTestSystemSettings(ez.Settings): clock_settings: ClockSettings - counter_settings: CounterSettings spike_settings: DNSSSpikeSettings log_settings: MessageLoggerSettings term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings) class DNSSSpikeTestSystem(ez.Collection): - """Test system for DNSS Spike: Clock -> Counter -> Spike.""" + """Test system for DNSS Spike: Clock -> Spike.""" SETTINGS = DNSSSpikeTestSystemSettings CLOCK = Clock() - COUNTER = Counter() SPIKE = DNSSSpikeUnit() LOG = MessageLogger() TERM = TerminateOnTotal() def configure(self) -> None: self.CLOCK.apply_settings(self.SETTINGS.clock_settings) - self.COUNTER.apply_settings(self.SETTINGS.counter_settings) self.SPIKE.apply_settings(self.SETTINGS.spike_settings) self.LOG.apply_settings(self.SETTINGS.log_settings) self.TERM.apply_settings(self.SETTINGS.term_settings) def network(self) -> ez.NetworkDefinition: return ( - (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_CLOCK), - (self.COUNTER.OUTPUT_SIGNAL, self.SPIKE.INPUT_SIGNAL), + (self.CLOCK.OUTPUT_SIGNAL, self.SPIKE.INPUT_CLOCK), (self.SPIKE.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE), (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), ) @@ -199,8 +191,8 @@ def test_dnss_spike_unit(test_name: str | None = None): settings = DNSSSpikeTestSystemSettings( clock_settings=ClockSettings(dispatch_rate=dispatch_rate), - counter_settings=CounterSettings(fs=fs, n_time=n_time), spike_settings=DNSSSpikeSettings( + n_time=n_time, n_ch=n_ch, mode="hdmi", ), @@ -236,7 +228,6 @@ def test_dnss_spike_unit(test_name: str | None = None): def test_dnss_spike_unit_burst_period(test_name: str | None = None): """Test DNSSSpikeUnit during burst period (more spikes).""" - fs = DEFAULT_FS n_time = 3000 # 100ms blocks n_ch = 4 # Run long enough to hit burst period (starts at sample 270000 = 9 seconds) @@ -249,8 +240,8 @@ def test_dnss_spike_unit_burst_period(test_name: str | None = None): settings = DNSSSpikeTestSystemSettings( clock_settings=ClockSettings(dispatch_rate=dispatch_rate), - counter_settings=CounterSettings(fs=fs, n_time=n_time), spike_settings=DNSSSpikeSettings( + n_time=n_time, n_ch=n_ch, mode="ideal", ), diff --git a/tests/unit/test_dnss_lfp.py b/tests/unit/test_dnss_lfp.py index 5f58862..e29272b 100644 --- a/tests/unit/test_dnss_lfp.py +++ b/tests/unit/test_dnss_lfp.py @@ -228,25 +228,20 @@ def test_other_mode_frequency_segments(self): assert freqs[peak_idx] == pytest.approx(10.0, abs=1.0) -class TestDNSSLFPTransformer: - """Tests for DNSSLFPTransformer.""" - - def _create_counter_input(self, n_time: int, offset: float = 0.0, counter_start: int = 0) -> "AxisArray": - """Create a counter AxisArray input.""" - data = np.arange(counter_start, counter_start + n_time) - return AxisArray( - data=data, - dims=["time"], - axes={"time": AxisArray.TimeAxis(fs=DEFAULT_FS, offset=offset)}, - ) - - def test_transformer_sync_call(self): - """Test synchronous transformer via __call__.""" - from ezmsg.simbiophys.dnss.lfp import DNSSLFPSettings, DNSSLFPTransformer - - transformer = DNSSLFPTransformer(DNSSLFPSettings(n_ch=4)) - input_msg = self._create_counter_input(n_time=600) - result = transformer(input_msg) +class TestDNSSLFPProducer: + """Tests for DNSSLFPProducer.""" + + def _create_clock_tick(self, n_time: int, offset: float = 0.0) -> "AxisArray.LinearAxis": + """Create a clock tick (LinearAxis).""" + return AxisArray.LinearAxis(gain=1.0 / DEFAULT_FS, offset=offset) + + def test_producer_sync_call(self): + """Test synchronous producer via __call__.""" + from ezmsg.simbiophys.dnss.lfp import DNSSLFPProducer, DNSSLFPSettings + + producer = DNSSLFPProducer(DNSSLFPSettings(n_time=600, n_ch=4)) + clock_tick = self._create_clock_tick(n_time=600) + result = producer(clock_tick) assert result is not None assert result.data.shape[1] == 4 # n_ch @@ -254,43 +249,45 @@ def test_transformer_sync_call(self): assert "time" in result.dims assert "ch" in result.dims - def test_transformer_output_shape(self): - """Transformer output has correct shape (time, ch).""" - from ezmsg.simbiophys.dnss.lfp import DNSSLFPSettings, DNSSLFPTransformer + def test_producer_output_shape(self): + """Producer output has correct shape (time, ch).""" + from ezmsg.simbiophys.dnss.lfp import DNSSLFPProducer, DNSSLFPSettings n_ch = 16 - transformer = DNSSLFPTransformer(DNSSLFPSettings(n_ch=n_ch)) - input_msg = self._create_counter_input(n_time=100) + n_time = 100 + producer = DNSSLFPProducer(DNSSLFPSettings(n_time=n_time, n_ch=n_ch)) + clock_tick = self._create_clock_tick(n_time=n_time) - result = transformer(input_msg) + result = producer(clock_tick) assert result.data.ndim == 2 assert result.data.shape[1] == n_ch - def test_transformer_channels_identical(self): + def test_producer_channels_identical(self): """All channels have identical LFP values.""" - from ezmsg.simbiophys.dnss.lfp import DNSSLFPSettings, DNSSLFPTransformer + from ezmsg.simbiophys.dnss.lfp import DNSSLFPProducer, DNSSLFPSettings - transformer = DNSSLFPTransformer(DNSSLFPSettings(n_ch=8)) - input_msg = self._create_counter_input(n_time=600) - result = transformer(input_msg) + n_time = 600 + producer = DNSSLFPProducer(DNSSLFPSettings(n_time=n_time, n_ch=8)) + clock_tick = self._create_clock_tick(n_time=n_time) + result = producer(clock_tick) # All columns should be identical for ch in range(1, result.data.shape[1]): np.testing.assert_allclose(result.data[:, 0], result.data[:, ch]) - def test_transformer_continuity(self): + def test_producer_continuity(self): """Multiple calls produce continuous data.""" - from ezmsg.simbiophys.dnss.lfp import DNSSLFPSettings, DNSSLFPTransformer + from ezmsg.simbiophys.dnss.lfp import DNSSLFPProducer, DNSSLFPSettings - transformer = DNSSLFPTransformer(DNSSLFPSettings(n_ch=4)) + n_time = 600 + producer = DNSSLFPProducer(DNSSLFPSettings(n_time=n_time, n_ch=4)) # Get multiple chunks results = [] - counter = 0 for i in range(5): - input_msg = self._create_counter_input(n_time=600, offset=counter / DEFAULT_FS, counter_start=counter) - results.append(transformer(input_msg)) - counter += 600 + offset = i * n_time / DEFAULT_FS + clock_tick = self._create_clock_tick(n_time=n_time, offset=offset) + results.append(producer(clock_tick)) # Concatenate first channel combined = np.concatenate([r.data[:, 0] for r in results]) @@ -302,17 +299,18 @@ def test_transformer_continuity(self): np.testing.assert_allclose(combined, expected) - def test_transformer_different_patterns(self): - """Transformer works with different patterns.""" - from ezmsg.simbiophys.dnss.lfp import DNSSLFPSettings, DNSSLFPTransformer + def test_producer_different_patterns(self): + """Producer works with different patterns.""" + from ezmsg.simbiophys.dnss.lfp import DNSSLFPProducer, DNSSLFPSettings - spike_transformer = DNSSLFPTransformer(DNSSLFPSettings(n_ch=4, pattern="spike")) - other_transformer = DNSSLFPTransformer(DNSSLFPSettings(n_ch=4, pattern="other")) + n_time = 600 + spike_producer = DNSSLFPProducer(DNSSLFPSettings(n_time=n_time, n_ch=4, pattern="spike")) + other_producer = DNSSLFPProducer(DNSSLFPSettings(n_time=n_time, n_ch=4, pattern="other")) - input_msg = self._create_counter_input(n_time=600) + clock_tick = self._create_clock_tick(n_time=n_time) - spike_result = spike_transformer(input_msg) - other_result = other_transformer(input_msg) + spike_result = spike_producer(clock_tick) + other_result = other_producer(clock_tick) # Both should produce valid output assert spike_result.data.shape[0] > 0 diff --git a/tests/unit/test_dnss_spike.py b/tests/unit/test_dnss_spike.py index ad4bb9e..49bfdc0 100644 --- a/tests/unit/test_dnss_spike.py +++ b/tests/unit/test_dnss_spike.py @@ -508,25 +508,20 @@ def test_different_n_chans(self): assert np.sum(burst_mask) == N_BURST_SPIKES * n_chans -class TestDNSSSpikeTransformer: - """Tests for DNSSSpikeTransformer.""" - - def _create_counter_input(self, n_time: int, offset: float = 0.0, counter_start: int = 0) -> "AxisArray": - """Create a counter AxisArray input.""" - data = np.arange(counter_start, counter_start + n_time) - return AxisArray( - data=data, - dims=["time"], - axes={"time": AxisArray.TimeAxis(fs=FS, offset=offset)}, - ) - - def test_transformer_sync_call(self): - """Test synchronous transformer via __call__.""" - from ezmsg.simbiophys.dnss.spike import DNSSSpikeSettings, DNSSSpikeTransformer - - transformer = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=4)) - input_msg = self._create_counter_input(n_time=600) - result = transformer(input_msg) +class TestDNSSSpikeProducer: + """Tests for DNSSSpikeProducer.""" + + def _create_clock_tick(self, n_time: int, offset: float = 0.0) -> "AxisArray.LinearAxis": + """Create a clock tick (LinearAxis).""" + return AxisArray.LinearAxis(gain=1.0 / FS, offset=offset) + + def test_producer_sync_call(self): + """Test synchronous producer via __call__.""" + from ezmsg.simbiophys.dnss.spike import DNSSSpikeProducer, DNSSSpikeSettings + + producer = DNSSSpikeProducer(DNSSSpikeSettings(n_time=600, n_ch=4)) + clock_tick = self._create_clock_tick(n_time=600) + result = producer(clock_tick) assert result is not None assert result.data.shape[1] == 4 # n_ch @@ -534,137 +529,120 @@ def test_transformer_sync_call(self): assert "time" in result.dims assert "ch" in result.dims - def test_transformer_output_is_sparse(self): - """Transformer output data is sparse.COO.""" + def test_producer_output_is_sparse(self): + """Producer output data is sparse.COO.""" import sparse - from ezmsg.simbiophys.dnss.spike import DNSSSpikeSettings, DNSSSpikeTransformer + from ezmsg.simbiophys.dnss.spike import DNSSSpikeProducer, DNSSSpikeSettings - transformer = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=4)) - input_msg = self._create_counter_input(n_time=600) - result = transformer(input_msg) + producer = DNSSSpikeProducer(DNSSSpikeSettings(n_time=600, n_ch=4)) + clock_tick = self._create_clock_tick(n_time=600) + result = producer(clock_tick) assert isinstance(result.data, sparse.COO) assert result.data.ndim == 2 - def test_transformer_spike_values(self): + def test_producer_spike_values(self): """Spike values are waveform IDs (1, 2, or 3).""" - from ezmsg.simbiophys.dnss.spike import DNSSSpikeSettings, DNSSSpikeTransformer + from ezmsg.simbiophys.dnss.spike import DNSSSpikeProducer, DNSSSpikeSettings - transformer = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=4)) + n_time = 600 + producer = DNSSSpikeProducer(DNSSSpikeSettings(n_time=n_time, n_ch=4)) # Collect multiple chunks to ensure we get spikes all_data = [] - counter = 0 - for _ in range(10): - input_msg = self._create_counter_input(n_time=600, offset=counter / FS, counter_start=counter) - result = transformer(input_msg) + for i in range(10): + offset = i * n_time / FS + clock_tick = self._create_clock_tick(n_time=n_time, offset=offset) + result = producer(clock_tick) if result.data.nnz > 0: all_data.extend(result.data.data.tolist()) - counter += 600 # All non-zero values should be 1, 2, or 3 assert len(all_data) > 0 assert set(all_data).issubset({1, 2, 3}) - def test_transformer_continuity(self): + def test_producer_continuity(self): """Multiple calls produce continuous spike pattern.""" - from ezmsg.simbiophys.dnss.spike import DNSSSpikeSettings, DNSSSpikeTransformer + from ezmsg.simbiophys.dnss.spike import DNSSSpikeProducer, DNSSSpikeSettings - transformer = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=4, mode="ideal")) + n_time = 600 + producer = DNSSSpikeProducer(DNSSSpikeSettings(n_time=n_time, n_ch=4, mode="ideal")) # Get multiple chunks all_coords = [] all_waveforms = [] - counter = 0 - for _ in range(20): - input_msg = self._create_counter_input(n_time=600, offset=counter / FS, counter_start=counter) - result = transformer(input_msg) + total_samples = 0 + for i in range(20): + offset = i * n_time / FS + clock_tick = self._create_clock_tick(n_time=n_time, offset=offset) + result = producer(clock_tick) if result.data.nnz > 0: coords = result.data.coords - all_coords.append(coords[0] + counter) # Adjust sample indices + all_coords.append(coords[0] + i * n_time) # Adjust sample indices all_waveforms.extend(result.data.data.tolist()) - counter += 600 + total_samples += n_time # Compare with generator output gen = spike_event_generator(mode="ideal", n_chans=4) next(gen) - expected_coords, expected_waveforms = gen.send(counter) + expected_coords, expected_waveforms = gen.send(total_samples) if len(all_coords) > 0: - transformer_samples = np.concatenate(all_coords) - np.testing.assert_array_equal(np.sort(transformer_samples), np.sort(expected_coords[0])) + producer_samples = np.concatenate(all_coords) + np.testing.assert_array_equal(np.sort(producer_samples), np.sort(expected_coords[0])) - def test_transformer_different_n_chans(self): - """Transformer works with different channel counts.""" - from ezmsg.simbiophys.dnss.spike import DNSSSpikeSettings, DNSSSpikeTransformer + def test_producer_different_n_chans(self): + """Producer works with different channel counts.""" + from ezmsg.simbiophys.dnss.spike import DNSSSpikeProducer, DNSSSpikeSettings + n_time = 600 for n_ch in [4, 8, 32, 256]: - transformer = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=n_ch)) - input_msg = self._create_counter_input(n_time=600) - result = transformer(input_msg) + producer = DNSSSpikeProducer(DNSSSpikeSettings(n_time=n_time, n_ch=n_ch)) + clock_tick = self._create_clock_tick(n_time=n_time) + result = producer(clock_tick) assert result.data.shape[1] == n_ch - def test_transformer_hdmi_vs_ideal_mode(self): + def test_producer_hdmi_vs_ideal_mode(self): """HDMI and ideal modes produce different spike patterns.""" - from ezmsg.simbiophys.dnss.spike import DNSSSpikeSettings, DNSSSpikeTransformer + from ezmsg.simbiophys.dnss.spike import DNSSSpikeProducer, DNSSSpikeSettings - hdmi_transformer = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=4, mode="hdmi")) - # ideal_transformer = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=4, mode="ideal")) + n_time = 600 + hdmi_producer = DNSSSpikeProducer(DNSSSpikeSettings(n_time=n_time, n_ch=4, mode="hdmi")) # Collect enough data to see differences hdmi_coords = [] ideal_coords = [] - counter = 0 - for _ in range(50): - input_msg = self._create_counter_input(n_time=600, offset=counter / FS, counter_start=counter) - hdmi_result = hdmi_transformer(input_msg) + for i in range(50): + offset = i * n_time / FS + clock_tick = self._create_clock_tick(n_time=n_time, offset=offset) + hdmi_result = hdmi_producer(clock_tick) - # Reset transformer for ideal (need fresh input) - ideal_transformer_fresh = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=4, mode="ideal")) - ideal_input = self._create_counter_input(n_time=600, offset=counter / FS, counter_start=counter) - ideal_result = ideal_transformer_fresh(ideal_input) + # Reset producer for ideal (need fresh input) + ideal_producer_fresh = DNSSSpikeProducer(DNSSSpikeSettings(n_time=n_time, n_ch=4, mode="ideal")) + ideal_result = ideal_producer_fresh(clock_tick) if hdmi_result.data.nnz > 0: hdmi_coords.append(hdmi_result.data.coords.copy()) if ideal_result.data.nnz > 0: ideal_coords.append(ideal_result.data.coords.copy()) - counter += 600 # Both should produce spikes assert len(hdmi_coords) > 0 assert len(ideal_coords) > 0 - def test_transformer_empty_chunks_handled(self): - """Transformer handles chunks with no spikes correctly.""" - from ezmsg.simbiophys.dnss.spike import DNSSSpikeSettings, DNSSSpikeTransformer + def test_producer_empty_chunks_handled(self): + """Producer handles chunks with no spikes correctly.""" + from ezmsg.simbiophys.dnss.spike import DNSSSpikeProducer, DNSSSpikeSettings - transformer = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=4)) + n_time = 600 + producer = DNSSSpikeProducer(DNSSSpikeSettings(n_time=n_time, n_ch=4)) # Even if no spikes, should return valid sparse array - counter = 0 - for _ in range(10): - input_msg = self._create_counter_input(n_time=600, offset=counter / FS, counter_start=counter) - result = transformer(input_msg) + for i in range(10): + offset = i * n_time / FS + clock_tick = self._create_clock_tick(n_time=n_time, offset=offset) + result = producer(clock_tick) assert result.data.shape[0] >= 0 assert result.data.shape[1] == 4 - counter += 600 - - def test_transformer_rejects_wrong_sample_rate(self): - """Transformer raises error for sample rates other than 30kHz.""" - from ezmsg.simbiophys.dnss.spike import DNSSSpikeSettings, DNSSSpikeTransformer - - transformer = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=4)) - - # Create input with wrong sample rate (e.g., 1000 Hz) - wrong_fs = 1000 - data = np.arange(100) - input_msg = AxisArray( - data=data, - dims=["time"], - axes={"time": AxisArray.TimeAxis(fs=wrong_fs, offset=0.0)}, - ) - - with pytest.raises(ValueError, match="requires fs=30000"): - transformer(input_msg) From 209e5ebc329c03b12bf5318fd821c55a46eb2993 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Wed, 7 Jan 2026 23:28:23 -0500 Subject: [PATCH 05/11] Replace cosine_tuning with more general cosine_encoder --- src/ezmsg/simbiophys/__init__.py | 32 ++- src/ezmsg/simbiophys/cosine_encoder.py | 249 +++++++++++++++++ src/ezmsg/simbiophys/cosine_tuning.py | 249 ----------------- src/ezmsg/simbiophys/system/velocity2spike.py | 28 +- tests/unit/test_cosine_encoder.py | 232 ++++++++++++++++ tests/unit/test_cosine_tuning.py | 254 ------------------ 6 files changed, 522 insertions(+), 522 deletions(-) create mode 100644 src/ezmsg/simbiophys/cosine_encoder.py delete mode 100644 src/ezmsg/simbiophys/cosine_tuning.py create mode 100644 tests/unit/test_cosine_encoder.py delete mode 100644 tests/unit/test_cosine_tuning.py diff --git a/src/ezmsg/simbiophys/__init__.py b/src/ezmsg/simbiophys/__init__.py index ef64894..b6d6e4f 100644 --- a/src/ezmsg/simbiophys/__init__.py +++ b/src/ezmsg/simbiophys/__init__.py @@ -14,13 +14,12 @@ from .__version__ import __version__ as __version__ -# Cosine Tuning -from .cosine_tuning import ( - CosineTuningParams, - CosineTuningSettings, - CosineTuningState, - CosineTuningTransformer, - CosineTuningUnit, +# Cosine Encoder +from .cosine_encoder import ( + CosineEncoderSettings, + CosineEncoderState, + CosineEncoderTransformer, + CosineEncoderUnit, ) # DNSS (Digital Neural Signal Simulator) @@ -68,6 +67,10 @@ SinGeneratorSettings, SinGeneratorState, SinProducer, + SpiralGenerator, + SpiralGeneratorSettings, + SpiralGeneratorState, + SpiralProducer, ) __all__ = [ @@ -88,6 +91,10 @@ "SinGeneratorSettings", "SinGeneratorState", "SinProducer", + "SpiralGenerator", + "SpiralGeneratorSettings", + "SpiralGeneratorState", + "SpiralProducer", # Noise "PinkNoise", "PinkNoiseProducer", @@ -99,12 +106,11 @@ # EEG "EEGSynth", "EEGSynthSettings", - # Cosine Tuning - "CosineTuningParams", - "CosineTuningSettings", - "CosineTuningState", - "CosineTuningTransformer", - "CosineTuningUnit", + # Cosine Encoder + "CosineEncoderSettings", + "CosineEncoderState", + "CosineEncoderTransformer", + "CosineEncoderUnit", # Dynamic Colored Noise "ColoredNoiseFilterState", "DynamicColoredNoiseSettings", diff --git a/src/ezmsg/simbiophys/cosine_encoder.py b/src/ezmsg/simbiophys/cosine_encoder.py new file mode 100644 index 0000000..808faee --- /dev/null +++ b/src/ezmsg/simbiophys/cosine_encoder.py @@ -0,0 +1,249 @@ +"""Generic cosine-tuning encoder for polar coordinates. + +This module provides a generalized cosine-tuning encoder that maps polar +coordinates (magnitude, angle) to multiple output channels with configurable +preferred directions, baseline, and modulation parameters. + +The encoding formula is: + output = baseline + modulation * magnitude * cos(angle - preferred_direction) + + speed_modulation * magnitude + +This implements the offset model from "Decoding arm speed during reaching" +(https://ncbi.nlm.nih.gov/pmc/articles/PMC6286377/) with generic terminology +suitable for various applications: + - Neural firing rate encoding (baseline=10Hz, modulation=20Hz) + - LFP spectral parameter modulation (baseline=1.0, modulation=0.5) + - Any other cosine-tuning based encoding + +Input: + Polar coordinates (magnitude, angle) as AxisArray with shape (n_samples, 2). + Use CoordinateSpaces(mode=CART2POL) upstream to convert from Cartesian. + +Output: + AxisArray with shape (n_samples, output_ch) containing encoded values. +""" + +from pathlib import Path + +import ezmsg.core as ez +import numpy as np +import numpy.typing as npt +from ezmsg.baseproc import ( + BaseStatefulTransformer, + BaseTransformerUnit, + processor_state, +) +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace + + +class CosineEncoderSettings(ez.Settings): + """Settings for CosineEncoder. + + Either `model_file` OR the random generation parameters should be specified. + If `model_file` is provided, parameters are loaded from file. + Otherwise, parameters are randomly generated. + """ + + # File-based parameters + model_file: str | None = None + """Path to .npz file with encoder parameters (baseline, modulation, pd, speed_modulation). + Also supports legacy neural tuning files with keys (b0, m, pd, bs).""" + + # Random generation parameters + output_ch: int = 10 + """Number of output channels (used if model_file is None).""" + + baseline: float = 0.0 + """Baseline output value for all channels (used if model_file is None).""" + + modulation: float = 1.0 + """Directional modulation depth for all channels (used if model_file is None).""" + + speed_modulation: float = 0.0 + """Speed modulation (non-directional) for all channels (used if model_file is None).""" + + seed: int | None = None + """Random seed for reproducibility of preferred directions (used if model_file is None).""" + + +@processor_state +class CosineEncoderState: + """State for cosine encoder transformer. + + Holds the per-channel encoding parameters. All arrays have shape (1, output_ch) + for efficient broadcasting during processing. + + Attributes: + baseline: Baseline output value for each channel. + modulation: Directional modulation depth for each channel. + pd: Preferred direction (radians) for each channel. + speed_modulation: Speed modulation (non-directional) for each channel. + ch_axis: Pre-built channel axis for output messages. + """ + + baseline: npt.NDArray[np.floating] | None = None + modulation: npt.NDArray[np.floating] | None = None + pd: npt.NDArray[np.floating] | None = None + speed_modulation: npt.NDArray[np.floating] | None = None + ch_axis: AxisArray.CoordinateAxis | None = None + + @property + def output_ch(self) -> int: + """Number of output channels.""" + return self.baseline.shape[1] if self.baseline is not None else 0 + + def validate(self) -> None: + """Validate that all parameters have consistent shapes.""" + if any(x is None for x in [self.baseline, self.modulation, self.pd, self.speed_modulation]): + raise ValueError("All parameters must be set") + if not (self.baseline.shape == self.modulation.shape == self.pd.shape == self.speed_modulation.shape): + raise ValueError("All parameters must have the same shape") + if self.baseline.ndim != 2 or self.baseline.shape[0] != 1: + raise ValueError("Parameters must have shape (1, output_ch)") + if self.baseline.shape[1] < 1: + raise ValueError("Parameters must have at least 1 channel") + + def load_from_file( + self, + filepath: str | Path, + output_ch: int | None = None, + ) -> None: + """Load parameters from a .npz file. + + The file should contain arrays with keys matching the parameter names. + For backwards compatibility with neural tuning files, the following + key mappings are supported: + - 'b0' -> baseline + - 'm' -> modulation + - 'pd' -> pd (preferred direction) + - 'bs' -> speed_modulation + + Args: + filepath: Path to .npz file containing parameter arrays. + output_ch: Number of channels to use. If None, uses all in file. + """ + params = np.load(filepath) + + # Support both new names and legacy neural tuning names + baseline = np.asarray(params.get("baseline", params.get("b0")), dtype=np.float64).ravel() + modulation = np.asarray(params.get("modulation", params.get("m")), dtype=np.float64).ravel() + pd = np.asarray(params["pd"], dtype=np.float64).ravel() + speed_modulation = np.asarray(params.get("speed_modulation", params.get("bs")), dtype=np.float64).ravel() + + if output_ch is not None: + baseline = baseline[:output_ch] + modulation = modulation[:output_ch] + pd = pd[:output_ch] + speed_modulation = speed_modulation[:output_ch] + + # Reshape to (1, output_ch) for broadcasting + self.baseline = baseline[np.newaxis, :] + self.modulation = modulation[np.newaxis, :] + self.pd = pd[np.newaxis, :] + self.speed_modulation = speed_modulation[np.newaxis, :] + + # Create channel axis for output messages + ch_labels = np.array([f"ch{i}" for i in range(len(baseline))]) + self.ch_axis = AxisArray.CoordinateAxis(data=ch_labels, dims=["ch"]) + + self.validate() + + def init_random( + self, + output_ch: int, + baseline: float = 0.0, + modulation: float = 1.0, + speed_modulation: float = 0.0, + seed: int | None = None, + ) -> None: + """Initialize encoder parameters with random preferred directions. + + Args: + output_ch: Number of output channels. + baseline: Baseline value for all channels. + modulation: Directional modulation depth for all channels. + speed_modulation: Speed modulation (non-directional) for all channels. + seed: Random seed for reproducibility. + """ + rng = np.random.default_rng(seed) + + # Shape (1, output_ch) for efficient broadcasting + self.baseline = np.full((1, output_ch), baseline, dtype=np.float64) + self.modulation = np.full((1, output_ch), modulation, dtype=np.float64) + self.pd = rng.uniform(0.0, 2.0 * np.pi, size=(1, output_ch)).astype(np.float64) + self.speed_modulation = np.full((1, output_ch), speed_modulation, dtype=np.float64) + + # Create channel axis for output messages + ch_labels = np.array([f"ch{i}" for i in range(output_ch)]) + self.ch_axis = AxisArray.CoordinateAxis(data=ch_labels, dims=["ch"]) + + self.validate() + + +class CosineEncoderTransformer( + BaseStatefulTransformer[CosineEncoderSettings, AxisArray, AxisArray, CosineEncoderState] +): + """Transform polar coordinates to multi-channel encoded output. + + Input: AxisArray with shape (n_samples, 2) containing polar coordinates + (magnitude, angle) where magnitude is speed and angle is direction. + Output: AxisArray with shape (n_samples, output_ch) containing encoded values. + + The encoding formula is: + output = baseline + modulation * magnitude * cos(angle - pd) + + speed_modulation * magnitude + + This is a generic encoder suitable for various applications including: + - Neural firing rate encoding (baseline=10Hz, modulation=20Hz) + - LFP spectral parameter modulation (baseline=1.0, modulation=0.5) + - Any other cosine-tuning based encoding + """ + + def _reset_state(self, message: AxisArray) -> None: + """Initialize encoder parameters.""" + if self.settings.model_file is not None: + self.state.load_from_file( + self.settings.model_file, + output_ch=None, # Use all channels from file + ) + else: + self.state.init_random( + output_ch=self.settings.output_ch, + baseline=self.settings.baseline, + modulation=self.settings.modulation, + speed_modulation=self.settings.speed_modulation, + seed=self.settings.seed, + ) + + def _process(self, message: AxisArray) -> AxisArray: + """Transform polar coordinates to encoded output.""" + polar = np.asarray(message.data, dtype=np.float64) + + if polar.ndim != 2 or polar.shape[1] != 2: + raise ValueError(f"Expected polar coords with shape (n_samples, 2), got {polar.shape}") + + # Extract polar components (from CART2POL: magnitude, angle) + magnitude = polar[:, 0:1] # (n_samples, 1) + angle = polar[:, 1:2] # (n_samples, 1) + + # Compute output: baseline + modulation * magnitude * cos(angle - pd) + speed_mod * magnitude + # State arrays are pre-shaped to (1, output_ch) for broadcasting + output = ( + self.state.baseline + + self.state.modulation * magnitude * np.cos(angle - self.state.pd) + + self.state.speed_modulation * magnitude + ) + + return replace( + message, + data=output, + dims=["time", "ch"], + axes={**message.axes, "ch": self.state.ch_axis}, + ) + + +class CosineEncoderUnit(BaseTransformerUnit[CosineEncoderSettings, AxisArray, AxisArray, CosineEncoderTransformer]): + """Unit wrapper for CosineEncoderTransformer.""" + + SETTINGS = CosineEncoderSettings diff --git a/src/ezmsg/simbiophys/cosine_tuning.py b/src/ezmsg/simbiophys/cosine_tuning.py deleted file mode 100644 index 8b7b02e..0000000 --- a/src/ezmsg/simbiophys/cosine_tuning.py +++ /dev/null @@ -1,249 +0,0 @@ -"""Cosine tuning model for neural encoding of velocity/movement. - -Implements the offset model from "Decoding arm speed during reaching" -(https://ncbi.nlm.nih.gov/pmc/articles/PMC6286377/): - - firing_rate = b0 + m * |v| * cos(θ - θ_pd) + bs * |v| - -Where: - - b0: baseline firing rate - - m: directional modulation depth - - θ: velocity direction (angle) - - θ_pd: preferred direction - - bs: speed modulation (non-directional) - - |v|: velocity magnitude (speed) - -For spike generation from firing rates, use EventsFromRatesTransformer -from ezmsg-event. -""" - -from dataclasses import dataclass -from pathlib import Path - -import ezmsg.core as ez -import numpy as np -import numpy.typing as npt -from ezmsg.baseproc import ( - BaseStatefulTransformer, - BaseTransformerUnit, - processor_state, -) -from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.util.messages.util import replace - - -@dataclass -class CosineTuningParams: - """Parameters for cosine tuning model. - - All arrays must have the same shape (n_units,). - - Attributes: - b0: Baseline firing rate (Hz) for each unit. - m: Directional modulation depth for each unit. - pd: Preferred direction (radians) for each unit. - bs: Speed modulation (non-directional) for each unit. - """ - - b0: npt.NDArray[np.floating] - m: npt.NDArray[np.floating] - pd: npt.NDArray[np.floating] - bs: npt.NDArray[np.floating] - - def __post_init__(self): - """Validate that all parameters have consistent shapes.""" - if not (self.b0.shape == self.m.shape == self.pd.shape == self.bs.shape): - raise ValueError("All parameters must have the same shape") - if self.b0.ndim != 1: - raise ValueError("Parameters must be 1D arrays") - if len(self.b0) < 1: - raise ValueError("Parameters must have length >= 1") - - @property - def n_units(self) -> int: - """Number of neural units.""" - return len(self.b0) - - @classmethod - def from_file( - cls, - filepath: str | Path, - n_units: int | None = None, - weight_gain: float = 1.0, - ) -> "CosineTuningParams": - """Load parameters from a .npz file. - - Args: - filepath: Path to .npz file containing b0, m, pd, bs arrays. - n_units: Number of units to use. If None, uses all units in file. - weight_gain: Scaling factor applied to m and bs parameters. - - Returns: - CosineTuningParams instance. - """ - params = np.load(filepath) - - b0 = np.asarray(params["b0"], dtype=np.float64) - m = np.asarray(params["m"], dtype=np.float64) - pd = np.asarray(params["pd"], dtype=np.float64) - bs = np.asarray(params["bs"], dtype=np.float64) - - if n_units is not None: - b0 = b0[:n_units] - m = m[:n_units] - pd = pd[:n_units] - bs = bs[:n_units] - - m = m * weight_gain - bs = bs * weight_gain - - return cls(b0=b0, m=m, pd=pd, bs=bs) - - @classmethod - def from_random( - cls, - n_units: int, - baseline_hz: float = 10.0, - modulation_hz: float = 20.0, - speed_modulation_hz: float = 0.0, - seed: int | None = None, - ) -> "CosineTuningParams": - """Generate random tuning parameters. - - Args: - n_units: Number of neural units. - baseline_hz: Baseline firing rate (Hz) for all units. - modulation_hz: Directional modulation depth for all units. - speed_modulation_hz: Speed modulation (non-directional) for all units. - seed: Random seed for reproducibility. - - Returns: - CosineTuningParams instance with random preferred directions. - """ - rng = np.random.default_rng(seed) - - return cls( - b0=np.full(n_units, baseline_hz, dtype=np.float64), - m=np.full(n_units, modulation_hz, dtype=np.float64), - pd=rng.uniform(0.0, 2.0 * np.pi, size=n_units).astype(np.float64), - bs=np.full(n_units, speed_modulation_hz, dtype=np.float64), - ) - - -class CosineTuningSettings(ez.Settings): - """Settings for CosineTuningTransformer. - - Either `model_file` OR the random generation parameters should be specified. - If `model_file` is provided, parameters are loaded from file. - Otherwise, parameters are randomly generated. - """ - - # File-based parameters - model_file: str | None = None - """Path to .npz file with tuning parameters (b0, m, pd, bs).""" - - weight_gain: float = 1.0 - """Scaling factor for m and bs when loading from file.""" - - # Random generation parameters - n_units: int = 50 - """Number of neural units (used if model_file is None).""" - - baseline_hz: float = 10.0 - """Baseline firing rate in Hz (used if model_file is None).""" - - modulation_hz: float = 20.0 - """Directional modulation depth in Hz (used if model_file is None).""" - - speed_modulation_hz: float = 0.0 - """Speed modulation (non-directional) in Hz (used if model_file is None).""" - - seed: int | None = None - """Random seed for reproducibility (used if model_file is None).""" - - # Output settings - min_rate: float = 0.0 - """Minimum firing rate (Hz). Rates are clipped to this value.""" - - -@processor_state -class CosineTuningState: - """State for cosine tuning transformer.""" - - params: CosineTuningParams | None = None - """Tuning curve parameters.""" - - -class CosineTuningTransformer(BaseStatefulTransformer[CosineTuningSettings, AxisArray, AxisArray, CosineTuningState]): - """Transform 2D velocity into firing rates using cosine tuning model. - - Input: AxisArray with shape (n_samples, 2) containing velocity (vx, vy). - Output: AxisArray with shape (n_samples, n_units) containing firing rates (Hz). - - The model implements: - rate = b0 + m * |v| * cos(θ - θ_pd) + bs * |v| - - For spike generation, chain with EventsFromRatesTransformer from ezmsg-event. - """ - - def _reset_state(self, message: AxisArray) -> None: - """Initialize tuning parameters.""" - if self.settings.model_file is not None: - self.state.params = CosineTuningParams.from_file( - self.settings.model_file, - n_units=None, # Use all units from file - weight_gain=self.settings.weight_gain, - ) - else: - self.state.params = CosineTuningParams.from_random( - n_units=self.settings.n_units, - baseline_hz=self.settings.baseline_hz, - modulation_hz=self.settings.modulation_hz, - speed_modulation_hz=self.settings.speed_modulation_hz, - seed=self.settings.seed, - ) - - def _process(self, message: AxisArray) -> AxisArray: - """Transform velocity to firing rates.""" - v = np.asarray(message.data, dtype=np.float64) - - if v.ndim != 2 or v.shape[1] != 2: - raise ValueError(f"Expected velocity with shape (n_samples, 2), got {v.shape}") - - # Extract velocity components - vx = v[:, 0] - vy = v[:, 1] - - # Calculate speed (magnitude) and direction (angle) - speed = np.hypot(vx, vy)[:, np.newaxis] # (n_samples, 1) - theta = np.arctan2(vy, vx)[:, np.newaxis] # (n_samples, 1) - - # Get parameters as row vectors for broadcasting - params = self.state.params - b0 = params.b0[np.newaxis, :] # (1, n_units) - m = params.m[np.newaxis, :] # (1, n_units) - pd = params.pd[np.newaxis, :] # (1, n_units) - bs = params.bs[np.newaxis, :] # (1, n_units) - - # Compute firing rates: b0 + m * |v| * cos(θ - θ_pd) + bs * |v| - rates = b0 + m * speed * np.cos(theta - pd) + bs * speed - - # Clip to minimum rate - rates = np.maximum(rates, self.settings.min_rate) - - # Create channel axis - ch_labels = np.array([f"unit{i}" for i in range(params.n_units)]) - ch_axis = AxisArray.CoordinateAxis(data=ch_labels, dims=["ch"]) - - return replace( - message, - data=rates, - dims=["time", "ch"], - axes={**message.axes, "ch": ch_axis}, - ) - - -class CosineTuningUnit(BaseTransformerUnit[CosineTuningSettings, AxisArray, AxisArray, CosineTuningTransformer]): - """Unit wrapper for CosineTuningTransformer.""" - - SETTINGS = CosineTuningSettings diff --git a/src/ezmsg/simbiophys/system/velocity2spike.py b/src/ezmsg/simbiophys/system/velocity2spike.py index 04f7a40..bc9188b 100644 --- a/src/ezmsg/simbiophys/system/velocity2spike.py +++ b/src/ezmsg/simbiophys/system/velocity2spike.py @@ -5,7 +5,7 @@ waveforms. Pipeline: - velocity (x,y) -> polar coords -> cosine tuning -> Poisson events -> waveforms + velocity (x,y) -> polar coords -> cosine encoder -> clip -> Poisson events -> waveforms See Also: :mod:`ezmsg.simbiophys.system.velocity2lfp`: Velocity to LFP encoding. @@ -18,9 +18,10 @@ from ezmsg.event.kernel_insert import SparseKernelInserterSettings, SparseKernelInserterUnit from ezmsg.event.poissonevents import PoissonEventSettings, PoissonEventUnit from ezmsg.sigproc.coordinatespaces import CoordinateMode, CoordinateSpaces, CoordinateSpacesSettings +from ezmsg.sigproc.math.clip import Clip, ClipSettings from ezmsg.util.messages.axisarray import AxisArray -from ..cosine_tuning import CosineTuningSettings, CosineTuningUnit +from ..cosine_encoder import CosineEncoderSettings, CosineEncoderUnit from ..dnss.wfs import wf_orig @@ -33,6 +34,16 @@ class Velocity2SpikeSettings(ez.Settings): output_ch: int = 256 """Number of output channels (simulated electrodes).""" + baseline_rate: float = 10.0 + """Baseline firing rate in Hz.""" + + modulation_depth: float = 20.0 / 314.0 + """Directional modulation depth in Hz per (pixel/second). + At max velocity (~314 px/s), this gives ~20 Hz modulation.""" + + min_rate: float = 0.0 + """Minimum firing rate (Hz). Rates are clipped to this value.""" + seed: int = 6767 """Random seed for reproducible preferred directions and waveform selection.""" @@ -65,7 +76,8 @@ class Velocity2Spike(ez.Collection): # Velocity inputs (via mouse / gamepad system, or via task parsing system) INPUT_SIGNAL = ez.InputStream(AxisArray) COORDS = CoordinateSpaces() - RATE_ENCODER = CosineTuningUnit() + RATE_ENCODER = CosineEncoderUnit() + CLIP_RATE = Clip() SPIKE_EVENT = PoissonEventUnit() WAVEFORMS = SparseKernelInserterUnit() OUTPUT_SIGNAL = ez.OutputStream(AxisArray) @@ -73,11 +85,14 @@ class Velocity2Spike(ez.Collection): def configure(self) -> None: self.COORDS.apply_settings(CoordinateSpacesSettings(mode=CoordinateMode.CART2POL, axis="ch")) self.RATE_ENCODER.apply_settings( - CosineTuningSettings( - n_units=self.SETTINGS.output_ch, + CosineEncoderSettings( + output_ch=self.SETTINGS.output_ch, + baseline=self.SETTINGS.baseline_rate, + modulation=self.SETTINGS.modulation_depth, seed=self.SETTINGS.seed, ) ) + self.CLIP_RATE.apply_settings(ClipSettings(min=self.SETTINGS.min_rate)) self.SPIKE_EVENT.apply_settings( PoissonEventSettings( output_fs=self.SETTINGS.output_fs, @@ -94,7 +109,8 @@ def network(self) -> ez.NetworkDefinition: return ( (self.INPUT_SIGNAL, self.COORDS.INPUT_SIGNAL), (self.COORDS.OUTPUT_SIGNAL, self.RATE_ENCODER.INPUT_SIGNAL), - (self.RATE_ENCODER.OUTPUT_SIGNAL, self.SPIKE_EVENT.INPUT_SIGNAL), + (self.RATE_ENCODER.OUTPUT_SIGNAL, self.CLIP_RATE.INPUT_SIGNAL), + (self.CLIP_RATE.OUTPUT_SIGNAL, self.SPIKE_EVENT.INPUT_SIGNAL), (self.SPIKE_EVENT.OUTPUT_SIGNAL, self.WAVEFORMS.INPUT_SIGNAL), (self.WAVEFORMS.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL), ) diff --git a/tests/unit/test_cosine_encoder.py b/tests/unit/test_cosine_encoder.py new file mode 100644 index 0000000..a0ff514 --- /dev/null +++ b/tests/unit/test_cosine_encoder.py @@ -0,0 +1,232 @@ +"""Unit tests for ezmsg.simbiophys.cosine_encoder module.""" + +import numpy as np +import pytest +from ezmsg.util.messages.axisarray import AxisArray + +from ezmsg.simbiophys import ( + CosineEncoderSettings, + CosineEncoderState, + CosineEncoderTransformer, +) + + +class TestCosineEncoderState: + """Tests for CosineEncoderState.""" + + def test_init_random_basic(self): + """Test random parameter initialization.""" + state = CosineEncoderState() + state.init_random(output_ch=10, seed=42) + + assert state.output_ch == 10 + assert state.baseline.shape == (1, 10) + assert state.modulation.shape == (1, 10) + assert state.pd.shape == (1, 10) + assert state.speed_modulation.shape == (1, 10) + + # Check default values + assert np.allclose(state.baseline, 0.0) + assert np.allclose(state.modulation, 1.0) + assert np.allclose(state.speed_modulation, 0.0) + + # Check preferred directions are in [0, 2*pi) + assert np.all(state.pd >= 0) + assert np.all(state.pd < 2 * np.pi) + + def test_init_random_custom_params(self): + """Test random initialization with custom parameters.""" + state = CosineEncoderState() + state.init_random( + output_ch=5, + baseline=15.0, + modulation=30.0, + speed_modulation=5.0, + seed=123, + ) + + assert state.output_ch == 5 + assert np.allclose(state.baseline, 15.0) + assert np.allclose(state.modulation, 30.0) + assert np.allclose(state.speed_modulation, 5.0) + + def test_init_random_reproducible(self): + """Test that seed produces reproducible results.""" + state1 = CosineEncoderState() + state1.init_random(output_ch=10, seed=42) + + state2 = CosineEncoderState() + state2.init_random(output_ch=10, seed=42) + + assert np.array_equal(state1.pd, state2.pd) + + def test_validation_shape_mismatch(self): + """Test that mismatched shapes raise error.""" + state = CosineEncoderState() + state.baseline = np.array([[1.0, 2.0]]) + state.modulation = np.array([[1.0, 2.0, 3.0]]) + state.pd = np.array([[1.0, 2.0]]) + state.speed_modulation = np.array([[1.0, 2.0]]) + + with pytest.raises(ValueError, match="same shape"): + state.validate() + + def test_validation_wrong_shape(self): + """Test that wrong shape raises error.""" + state = CosineEncoderState() + state.baseline = np.array([1.0, 2.0]) # 1D instead of 2D + state.modulation = np.array([1.0, 2.0]) + state.pd = np.array([1.0, 2.0]) + state.speed_modulation = np.array([1.0, 2.0]) + + with pytest.raises(ValueError, match="shape"): + state.validate() + + def test_validation_empty(self): + """Test that empty arrays raise error.""" + state = CosineEncoderState() + state.baseline = np.array([[]]).reshape(1, 0) + state.modulation = np.array([[]]).reshape(1, 0) + state.pd = np.array([[]]).reshape(1, 0) + state.speed_modulation = np.array([[]]).reshape(1, 0) + + with pytest.raises(ValueError, match="at least 1 channel"): + state.validate() + + +class TestCosineEncoderTransformer: + """Tests for CosineEncoderTransformer.""" + + def test_basic_transform(self): + """Test basic polar to encoded output transformation.""" + transformer = CosineEncoderTransformer( + CosineEncoderSettings( + output_ch=4, + baseline=10.0, + modulation=20.0, + speed_modulation=0.0, + seed=42, + ) + ) + + # Create polar input: magnitude=1, angle=0 (rightward direction) + polar = np.array([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]]) + time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) + msg_in = AxisArray(polar, dims=["time", "ch"], axes={"time": time_axis}) + + msg_out = transformer(msg_in) + + assert msg_out.data.shape == (3, 4) # (n_samples, output_ch) + assert "ch" in msg_out.axes + assert msg_out.axes["ch"].data.shape == (4,) + + def test_stationary_baseline(self): + """Test that zero magnitude produces baseline values.""" + transformer = CosineEncoderTransformer( + CosineEncoderSettings( + output_ch=3, + baseline=15.0, + modulation=25.0, + speed_modulation=5.0, + seed=42, + ) + ) + + # Zero magnitude (stationary) + polar = np.array([[0.0, 0.0], [0.0, 0.0]]) + time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) + msg_in = AxisArray(polar, dims=["time", "ch"], axes={"time": time_axis}) + + msg_out = transformer(msg_in) + + # When magnitude=0, output = baseline + modulation*0*cos(...) + speed_mod*0 = baseline + assert np.allclose(msg_out.data, 15.0) + + def test_directional_tuning(self): + """Test that tuning varies with direction.""" + transformer = CosineEncoderTransformer(CosineEncoderSettings(output_ch=1, seed=42)) + + # Manually set state with known preferred direction (shape: 1, output_ch) + transformer._state.baseline = np.array([[10.0]]) + transformer._state.modulation = np.array([[20.0]]) + transformer._state.pd = np.array([[0.0]]) # Preferred direction = 0 (rightward) + transformer._state.speed_modulation = np.array([[0.0]]) + transformer._state.ch_axis = AxisArray.CoordinateAxis(data=np.array(["ch0"]), dims=["ch"]) + transformer._hash = 0 + + time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) + + # Polar: magnitude=1, angle=0 (aligned with pd) + aligned = AxisArray(np.array([[1.0, 0.0]]), dims=["time", "ch"], axes={"time": time_axis}) + output_aligned = transformer(aligned).data[0, 0] + + # Reset hash to reuse state + transformer._hash = 0 + + # Polar: magnitude=1, angle=pi (opposite to pd) + opposite = AxisArray(np.array([[1.0, np.pi]]), dims=["time", "ch"], axes={"time": time_axis}) + output_opposite = transformer(opposite).data[0, 0] + + # Aligned should have higher output (cos(0 - 0) = 1 vs cos(pi - 0) = -1) + # output_aligned = 10 + 20*1*1 = 30 + # output_opposite = 10 + 20*1*(-1) = -10 + assert output_aligned > output_opposite + assert np.isclose(output_aligned, 30.0) + assert np.isclose(output_opposite, -10.0) + + def test_speed_modulation(self): + """Test speed modulation term (speed_modulation * magnitude).""" + transformer = CosineEncoderTransformer(CosineEncoderSettings(output_ch=1, seed=42)) + + # Manually set state (shape: 1, output_ch) + transformer._state.baseline = np.array([[10.0]]) + transformer._state.modulation = np.array([[0.0]]) # No directional modulation + transformer._state.pd = np.array([[0.0]]) + transformer._state.speed_modulation = np.array([[5.0]]) + transformer._state.ch_axis = AxisArray.CoordinateAxis(data=np.array(["ch0"]), dims=["ch"]) + transformer._hash = 0 + + time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) + + # Different magnitudes + slow = AxisArray(np.array([[1.0, 0.0]]), dims=["time", "ch"], axes={"time": time_axis}) + output_slow = transformer(slow).data[0, 0] + + transformer._hash = 0 + + fast = AxisArray(np.array([[2.0, 0.0]]), dims=["time", "ch"], axes={"time": time_axis}) + output_fast = transformer(fast).data[0, 0] + + # output = 10 + 0 + 5*magnitude + assert np.isclose(output_slow, 10.0 + 5.0 * 1.0) + assert np.isclose(output_fast, 10.0 + 5.0 * 2.0) + + def test_invalid_input_shape(self): + """Test that invalid input shape raises error.""" + transformer = CosineEncoderTransformer(CosineEncoderSettings(output_ch=3, seed=42)) + + time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) + + # Wrong number of columns (should be 2) + bad_input = AxisArray(np.array([[1.0, 2.0, 3.0]]), dims=["time", "ch"], axes={"time": time_axis}) + + with pytest.raises(ValueError, match="shape"): + transformer(bad_input) + + def test_multiple_samples(self): + """Test processing multiple samples at once.""" + transformer = CosineEncoderTransformer(CosineEncoderSettings(output_ch=5, seed=42)) + + # 100 polar coordinate samples + n_samples = 100 + np.random.seed(123) + magnitude = np.abs(np.random.randn(n_samples, 1)) + angle = np.random.uniform(-np.pi, np.pi, (n_samples, 1)) + polar = np.hstack([magnitude, angle]) + + time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) + msg_in = AxisArray(polar, dims=["time", "ch"], axes={"time": time_axis}) + + msg_out = transformer(msg_in) + + assert msg_out.data.shape == (n_samples, 5) diff --git a/tests/unit/test_cosine_tuning.py b/tests/unit/test_cosine_tuning.py deleted file mode 100644 index bbbf29c..0000000 --- a/tests/unit/test_cosine_tuning.py +++ /dev/null @@ -1,254 +0,0 @@ -"""Unit tests for ezmsg.simbiophys.cosine_tuning module.""" - -import numpy as np -import pytest -from ezmsg.util.messages.axisarray import AxisArray - -from ezmsg.simbiophys import ( - CosineTuningParams, - CosineTuningSettings, - CosineTuningTransformer, -) - - -class TestCosineTuningParams: - """Tests for CosineTuningParams dataclass.""" - - def test_from_random_basic(self): - """Test random parameter generation.""" - params = CosineTuningParams.from_random(n_units=10, seed=42) - - assert params.n_units == 10 - assert params.b0.shape == (10,) - assert params.m.shape == (10,) - assert params.pd.shape == (10,) - assert params.bs.shape == (10,) - - # Check default values - assert np.allclose(params.b0, 10.0) # baseline_hz default - assert np.allclose(params.m, 20.0) # modulation_hz default - assert np.allclose(params.bs, 0.0) # speed_modulation_hz default - - # Check preferred directions are in [0, 2*pi) - assert np.all(params.pd >= 0) - assert np.all(params.pd < 2 * np.pi) - - def test_from_random_custom_params(self): - """Test random generation with custom parameters.""" - params = CosineTuningParams.from_random( - n_units=5, - baseline_hz=15.0, - modulation_hz=30.0, - speed_modulation_hz=5.0, - seed=123, - ) - - assert params.n_units == 5 - assert np.allclose(params.b0, 15.0) - assert np.allclose(params.m, 30.0) - assert np.allclose(params.bs, 5.0) - - def test_from_random_reproducible(self): - """Test that seed produces reproducible results.""" - params1 = CosineTuningParams.from_random(n_units=10, seed=42) - params2 = CosineTuningParams.from_random(n_units=10, seed=42) - - assert np.array_equal(params1.pd, params2.pd) - - def test_validation_shape_mismatch(self): - """Test that mismatched shapes raise error.""" - with pytest.raises(ValueError, match="same shape"): - CosineTuningParams( - b0=np.array([1.0, 2.0]), - m=np.array([1.0, 2.0, 3.0]), - pd=np.array([1.0, 2.0]), - bs=np.array([1.0, 2.0]), - ) - - def test_validation_not_1d(self): - """Test that non-1D arrays raise error.""" - with pytest.raises(ValueError, match="1D"): - CosineTuningParams( - b0=np.array([[1.0, 2.0]]), - m=np.array([[1.0, 2.0]]), - pd=np.array([[1.0, 2.0]]), - bs=np.array([[1.0, 2.0]]), - ) - - def test_validation_empty(self): - """Test that empty arrays raise error.""" - with pytest.raises(ValueError, match="length >= 1"): - CosineTuningParams( - b0=np.array([]), - m=np.array([]), - pd=np.array([]), - bs=np.array([]), - ) - - -class TestCosineTuningTransformer: - """Tests for CosineTuningTransformer.""" - - def test_basic_transform(self): - """Test basic velocity to firing rate transformation.""" - transformer = CosineTuningTransformer( - CosineTuningSettings( - n_units=4, - baseline_hz=10.0, - modulation_hz=20.0, - speed_modulation_hz=0.0, - seed=42, - ) - ) - - # Create velocity input: moving right at unit speed - velocity = np.array([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]]) - time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) - msg_in = AxisArray(velocity, dims=["time", "ch"], axes={"time": time_axis}) - - msg_out = transformer(msg_in) - - assert msg_out.data.shape == (3, 4) # (n_samples, n_units) - assert "ch" in msg_out.axes - assert msg_out.axes["ch"].data.shape == (4,) - - # All rates should be positive (baseline + modulation term) - assert np.all(msg_out.data >= 0) - - def test_stationary_baseline(self): - """Test that stationary input produces baseline rates.""" - transformer = CosineTuningTransformer( - CosineTuningSettings( - n_units=3, - baseline_hz=15.0, - modulation_hz=25.0, - speed_modulation_hz=5.0, - seed=42, - ) - ) - - # Zero velocity - velocity = np.array([[0.0, 0.0], [0.0, 0.0]]) - time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) - msg_in = AxisArray(velocity, dims=["time", "ch"], axes={"time": time_axis}) - - msg_out = transformer(msg_in) - - # When speed=0, rate = b0 + m*0*cos(...) + bs*0 = b0 - # But arctan2(0,0) can be arbitrary, so just check rates are ~baseline - assert np.allclose(msg_out.data, 15.0) - - def test_directional_tuning(self): - """Test that tuning varies with direction.""" - # Create transformer with known preferred direction - params = CosineTuningParams( - b0=np.array([10.0]), - m=np.array([20.0]), - pd=np.array([0.0]), # Preferred direction = 0 (rightward) - bs=np.array([0.0]), - ) - - transformer = CosineTuningTransformer(CosineTuningSettings(n_units=1, seed=42)) - # Manually set params - transformer._state.params = params - transformer._hash = 0 - - time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) - - # Moving right (direction = 0, aligned with pd) - right = AxisArray(np.array([[1.0, 0.0]]), dims=["time", "ch"], axes={"time": time_axis}) - rate_right = transformer(right).data[0, 0] - - # Reset state for fresh processing - transformer._hash = 0 - transformer._state.params = params - - # Moving left (direction = pi, opposite to pd) - left = AxisArray(np.array([[-1.0, 0.0]]), dims=["time", "ch"], axes={"time": time_axis}) - rate_left = transformer(left).data[0, 0] - - # Right should have higher rate (cos(0 - 0) = 1 vs cos(pi - 0) = -1) - # rate_right = 10 + 20*1*1 + 0 = 30 - # rate_left = 10 + 20*1*(-1) + 0 = -10, but clipped to min_rate=0 - assert rate_right > rate_left - assert np.isclose(rate_right, 30.0) - - def test_speed_modulation(self): - """Test speed modulation term (bs * |v|).""" - params = CosineTuningParams( - b0=np.array([10.0]), - m=np.array([0.0]), # No directional modulation - pd=np.array([0.0]), - bs=np.array([5.0]), # Speed modulation - ) - - transformer = CosineTuningTransformer(CosineTuningSettings(n_units=1, seed=42)) - transformer._state.params = params - transformer._hash = 0 - - time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) - - # Different speeds, same direction - slow = AxisArray(np.array([[1.0, 0.0]]), dims=["time", "ch"], axes={"time": time_axis}) - rate_slow = transformer(slow).data[0, 0] - - transformer._hash = 0 - transformer._state.params = params - - fast = AxisArray(np.array([[2.0, 0.0]]), dims=["time", "ch"], axes={"time": time_axis}) - rate_fast = transformer(fast).data[0, 0] - - # rate = 10 + 0 + 5*|v| - assert np.isclose(rate_slow, 10.0 + 5.0 * 1.0) - assert np.isclose(rate_fast, 10.0 + 5.0 * 2.0) - - def test_min_rate_clipping(self): - """Test that rates are clipped to min_rate.""" - params = CosineTuningParams( - b0=np.array([5.0]), - m=np.array([20.0]), - pd=np.array([0.0]), - bs=np.array([0.0]), - ) - - transformer = CosineTuningTransformer(CosineTuningSettings(n_units=1, min_rate=0.0, seed=42)) - transformer._state.params = params - transformer._hash = 0 - - time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) - - # Moving opposite to preferred direction - # rate = 5 + 20*1*cos(pi) = 5 - 20 = -15, should be clipped to 0 - msg_in = AxisArray(np.array([[-1.0, 0.0]]), dims=["time", "ch"], axes={"time": time_axis}) - msg_out = transformer(msg_in) - - assert msg_out.data[0, 0] >= 0.0 - - def test_invalid_input_shape(self): - """Test that invalid input shape raises error.""" - transformer = CosineTuningTransformer(CosineTuningSettings(n_units=3, seed=42)) - - time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) - - # Wrong number of columns (should be 2) - bad_input = AxisArray(np.array([[1.0, 2.0, 3.0]]), dims=["time", "ch"], axes={"time": time_axis}) - - with pytest.raises(ValueError, match="shape"): - transformer(bad_input) - - def test_multiple_samples(self): - """Test processing multiple samples at once.""" - transformer = CosineTuningTransformer(CosineTuningSettings(n_units=5, seed=42)) - - # 100 velocity samples - n_samples = 100 - np.random.seed(123) - velocity = np.random.randn(n_samples, 2) - - time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) - msg_in = AxisArray(velocity, dims=["time", "ch"], axes={"time": time_axis}) - - msg_out = transformer(msg_in) - - assert msg_out.data.shape == (n_samples, 5) - assert np.all(msg_out.data >= 0) # All rates non-negative From 53dc2d2120b8264198e197150b08ec443b88f3ac Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Wed, 7 Jan 2026 23:47:15 -0500 Subject: [PATCH 06/11] Add SpiralGenerator to oscillator --- src/ezmsg/simbiophys/oscillator.py | 104 +++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/src/ezmsg/simbiophys/oscillator.py b/src/ezmsg/simbiophys/oscillator.py index 515085e..0131238 100644 --- a/src/ezmsg/simbiophys/oscillator.py +++ b/src/ezmsg/simbiophys/oscillator.py @@ -12,6 +12,110 @@ from ezmsg.util.messages.axisarray import AxisArray, LinearAxis, replace +class SpiralGeneratorSettings(ClockDrivenSettings): + """Settings for :obj:`SpiralGenerator`. + + Generates 2D position (x, y) following a spiral pattern where both + the radius and angle change over time. + + The parametric equations are: + r(t) = r_mean + r_amp * sin(2*π*radial_freq*t + radial_phase) + θ(t) = 2*π*angular_freq*t + angular_phase + x(t) = r(t) * cos(θ(t)) + y(t) = r(t) * sin(θ(t)) + """ + + r_mean: float = 150.0 + """Mean radius of the spiral.""" + + r_amp: float = 50.0 + """Amplitude of the radial oscillation.""" + + radial_freq: float = 0.1 + """Frequency of the radial oscillation in Hz.""" + + radial_phase: float = 0.0 + """Initial phase of the radial oscillation in radians.""" + + angular_freq: float = 0.25 + """Frequency of the angular rotation in Hz.""" + + angular_phase: float = 0.0 + """Initial angular phase in radians.""" + + +@processor_state +class SpiralGeneratorState(ClockDrivenState): + """State for SpiralGenerator.""" + + template: AxisArray | None = None + + +class SpiralProducer(BaseClockDrivenProducer[SpiralGeneratorSettings, SpiralGeneratorState]): + """ + Generates spiral motion synchronized to clock ticks. + + Each clock tick produces a block of 2D position data (x, y) following + a spiral pattern where both radius and angle change over time. + """ + + def _reset_state(self, time_axis: LinearAxis) -> None: + """Initialize template.""" + self._state.template = AxisArray( + data=np.zeros((0, 2)), + dims=["time", "ch"], + axes={ + "time": time_axis, + "ch": AxisArray.CoordinateAxis( + data=np.array(["x", "y"]), + dims=["ch"], + ), + }, + ) + + def _produce(self, n_samples: int, time_axis: LinearAxis) -> AxisArray: + """Generate spiral motion for this chunk.""" + t = (np.arange(n_samples) + self._state.counter) * time_axis.gain + + # Radial component: oscillates between r_mean - r_amp and r_mean + r_amp + r = self.settings.r_mean + self.settings.r_amp * np.sin( + 2.0 * np.pi * self.settings.radial_freq * t + self.settings.radial_phase + ) + + # Angular component: rotates at angular_freq + theta = 2.0 * np.pi * self.settings.angular_freq * t + self.settings.angular_phase + + # Convert to Cartesian + x = r * np.cos(theta) + y = r * np.sin(theta) + + data = np.column_stack([x, y]) + + return replace( + self._state.template, + data=data, + axes={ + **self._state.template.axes, + "time": time_axis, + }, + ) + + +class SpiralGenerator(BaseClockDrivenUnit[SpiralGeneratorSettings, SpiralProducer]): + """ + Generates 2D spiral motion synchronized to clock ticks. + + Receives timing from INPUT_CLOCK (LinearAxis from Clock) and outputs + 2D position AxisArray (x, y) on OUTPUT_SIGNAL. + + The spiral pattern has both radius and angle varying over time: + - Radius oscillates sinusoidally (breathing in/out) + - Angle increases linearly (rotation) + """ + + SETTINGS = SpiralGeneratorSettings + + class SinGeneratorSettings(ClockDrivenSettings): """Settings for :obj:`SinGenerator`.""" From fc0e3c76beecd591ab8a98e17d0b88f313791d22 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Wed, 7 Jan 2026 23:48:55 -0500 Subject: [PATCH 07/11] For velocity2lfp, replace LinearTransform encoder with CosineEncoder --- src/ezmsg/simbiophys/system/velocity2lfp.py | 86 ++++++++++++++------- 1 file changed, 57 insertions(+), 29 deletions(-) diff --git a/src/ezmsg/simbiophys/system/velocity2lfp.py b/src/ezmsg/simbiophys/system/velocity2lfp.py index 6387ac8..b2e1997 100644 --- a/src/ezmsg/simbiophys/system/velocity2lfp.py +++ b/src/ezmsg/simbiophys/system/velocity2lfp.py @@ -4,11 +4,13 @@ properties of colored (1/f^beta) noise, producing LFP-like signals. Pipeline: - velocity (x,y) -> polar coords -> scale to beta range -> colored noise -> mix to channels + velocity (x,y) -> polar coords -> cosine encoder (beta values) -> clip + -> colored noise -> mix to channels -The velocity magnitude modulates one noise source's spectral exponent, and the -velocity angle modulates another. These two sources are then mixed across -output channels using a spatial mixing matrix. +The velocity is encoded using a cosine tuning model where multiple noise +sources have different preferred directions. Each source's spectral exponent +(beta) is modulated by the velocity direction and magnitude. These sources +are then mixed across output channels using a spatial mixing matrix. See Also: :mod:`ezmsg.simbiophys.system.velocity2spike`: Velocity to spike encoding. @@ -19,10 +21,10 @@ import numpy as np from ezmsg.sigproc.affinetransform import AffineTransform, AffineTransformSettings from ezmsg.sigproc.coordinatespaces import CoordinateMode, CoordinateSpaces, CoordinateSpacesSettings -from ezmsg.sigproc.linear import LinearTransform, LinearTransformSettings from ezmsg.sigproc.math.clip import Clip, ClipSettings from ezmsg.util.messages.axisarray import AxisArray +from ..cosine_encoder import CosineEncoderSettings, CosineEncoderUnit from ..dynamic_colored_noise import DynamicColoredNoiseSettings, DynamicColoredNoiseUnit @@ -35,8 +37,15 @@ class Velocity2LFPSettings(ez.Settings): output_ch: int = 256 """Number of output channels (simulated electrodes).""" + n_lfp_sources: int = 8 + """Number of cosine-encoded LFP sources. Each source has a different + preferred direction and generates colored noise with velocity-modulated + spectral exponent.""" + + max_velocity: float = 315.0 + seed: int = 6767 - """Random seed for reproducible mixing matrix.""" + """Random seed for reproducible preferred directions and mixing matrix.""" class Velocity2LFP(ez.Collection): @@ -46,13 +55,13 @@ class Velocity2LFP(ez.Collection): 1. **Coordinate transform**: Converts Cartesian velocity (x, y) to polar coordinates (magnitude, angle). - 2. **Scale to beta**: Maps velocity magnitude (0-314 px/s) and angle (0-2pi) - to spectral exponent beta (0.5-2.0) for colored noise generation. + 2. **Cosine encoder**: Each of n_lfp_sources has a different preferred + direction. The spectral exponent beta (0-2) is modulated by the cosine + of the angle between velocity and preferred direction, scaled by speed. 3. **Clip**: Ensures beta values stay within valid range [0, 2]. 4. **Colored noise**: Generates 1/f^beta noise where beta is dynamically - modulated by the scaled velocity. Two noise sources are generated: - one modulated by magnitude, one by angle. - 5. **Spatial mixing**: Projects the 2 noise sources onto output_ch channels + modulated per source. + 5. **Spatial mixing**: Projects the n_lfp_sources onto output_ch channels using a sinusoidal mixing matrix with random perturbations. Input: @@ -69,22 +78,30 @@ class Velocity2LFP(ez.Collection): # Velocity inputs (via mouse / gamepad, or via task parsing system) INPUT_SIGNAL = ez.InputStream(AxisArray) COORDS = CoordinateSpaces() - ALPHA_EXP = LinearTransform() - CLIP_ALPHA = Clip() + BETA_ENCODER = CosineEncoderUnit() + CLIP_BETA = Clip() PINK_NOISE = DynamicColoredNoiseUnit() - MIX_NOISE = AffineTransform() # Project 2 colored noise sources to n_chans sensors + MIX_NOISE = AffineTransform() # Project n_lfp_sources to output_ch sensors OUTPUT_SIGNAL = ez.OutputStream(AxisArray) def configure(self) -> None: self.COORDS.apply_settings(CoordinateSpacesSettings(mode=CoordinateMode.CART2POL, axis="ch")) - self.ALPHA_EXP.apply_settings( - LinearTransformSettings( - scale=[1.5 / 314, 1.5 / (2 * np.pi)], - offset=[0.5, 0.5], - axis="ch", + # COORDS output is 2-ch: [magnitude, angle] + # magnitude ranges from 0 to ~315 px/s, angle from -pi to +pi + + # Configure cosine encoder to output beta values in range [0, 2] + # baseline=1.0 (middle of range), modulation=1/315 so at max velocity we get full range + self.BETA_ENCODER.apply_settings( + CosineEncoderSettings( + output_ch=self.SETTINGS.n_lfp_sources, + baseline=1.0, + modulation=1.0 / self.SETTINGS.max_velocity, + seed=self.SETTINGS.seed, ) ) - self.CLIP_ALPHA.apply_settings(ClipSettings(min=0.0, max=2.0)) + + self.CLIP_BETA.apply_settings(ClipSettings(min=0.0, max=2.0)) + self.PINK_NOISE.apply_settings( DynamicColoredNoiseSettings( output_fs=self.SETTINGS.output_fs, @@ -95,22 +112,33 @@ def configure(self) -> None: seed=self.SETTINGS.seed, ) ) + + # Create mixing matrix: n_lfp_sources -> output_ch + # Use sinusoids at different frequencies for spatial patterns rng = np.random.default_rng(self.SETTINGS.seed) ch_idx = np.arange(self.SETTINGS.output_ch) - weights = np.array( - [ - np.sin(2 * np.pi * ch_idx / self.SETTINGS.output_ch), # radius source - np.cos(2 * np.pi * ch_idx / self.SETTINGS.output_ch), # angle (phi) source - ] - ) + 0.3 * rng.standard_normal((2, self.SETTINGS.output_ch)) + n_sources = self.SETTINGS.n_lfp_sources + + # Each source gets a sinusoidal spatial pattern with different frequency + # Plus random perturbations for more realistic mixing + weights = np.zeros((n_sources, self.SETTINGS.output_ch)) + for i in range(n_sources): + # Different spatial frequency for each source + freq = (i + 1) / n_sources + phase = 2 * np.pi * i / n_sources + weights[i, :] = np.sin(2 * np.pi * freq * ch_idx / self.SETTINGS.output_ch + phase) + + # Add random perturbations + weights += 0.3 * rng.standard_normal((n_sources, self.SETTINGS.output_ch)) + self.MIX_NOISE.apply_settings(AffineTransformSettings(weights=weights, axis="ch")) def network(self) -> ez.NetworkDefinition: return ( (self.INPUT_SIGNAL, self.COORDS.INPUT_SIGNAL), - (self.COORDS.OUTPUT_SIGNAL, self.ALPHA_EXP.INPUT_SIGNAL), - (self.ALPHA_EXP.OUTPUT_SIGNAL, self.CLIP_ALPHA.INPUT_SIGNAL), - (self.CLIP_ALPHA.OUTPUT_SIGNAL, self.PINK_NOISE.INPUT_SIGNAL), + (self.COORDS.OUTPUT_SIGNAL, self.BETA_ENCODER.INPUT_SIGNAL), + (self.BETA_ENCODER.OUTPUT_SIGNAL, self.CLIP_BETA.INPUT_SIGNAL), + (self.CLIP_BETA.OUTPUT_SIGNAL, self.PINK_NOISE.INPUT_SIGNAL), (self.PINK_NOISE.OUTPUT_SIGNAL, self.MIX_NOISE.INPUT_SIGNAL), (self.MIX_NOISE.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL), ) From 18502ac7f800b512962c9aa96428cc2d51c42df4 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Wed, 7 Jan 2026 23:49:19 -0500 Subject: [PATCH 08/11] Bump ezmsg-peripheraldevice --- examples/mouse_to_lsl_full.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/mouse_to_lsl_full.py b/examples/mouse_to_lsl_full.py index 74e6e9f..f0d62be 100644 --- a/examples/mouse_to_lsl_full.py +++ b/examples/mouse_to_lsl_full.py @@ -84,7 +84,7 @@ def main( "SINK": LSLOutletUnit(LSLOutletSettings(stream_name="MouseModulatedRaw", stream_type="ECEPhys")), } conns = ( - (comps["CLOCK"].OUTPUT_SIGNAL, comps["SOURCE"].INPUT_SIGNAL), + (comps["CLOCK"].OUTPUT_SIGNAL, comps["SOURCE"].INPUT_CLOCK), (comps["SOURCE"].OUTPUT_SIGNAL, comps["DIFF"].INPUT_SIGNAL), (comps["DIFF"].OUTPUT_SIGNAL, comps["ENCODER"].INPUT_SIGNAL), (comps["ENCODER"].OUTPUT_SIGNAL, comps["SINK"].INPUT_SIGNAL), diff --git a/pyproject.toml b/pyproject.toml index cd8a27d..e0dd1d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dev = [ {include-group = "test"}, "typer>=0.20.0", "scipy-stubs>=1.15.3.0", - "ezmsg-peripheraldevice>=0.2.0", + "ezmsg-peripheraldevice>=0.3.0", "ezmsg-lsl>=1.4.2", ] lint = [ From bef4257da449998b3e15472575f894b781c1d555 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Wed, 7 Jan 2026 23:49:39 -0500 Subject: [PATCH 09/11] Use new features in spiral_to_dynamic_pink_outlet.py --- docs/source/guides/circle_to_lfp.rst | 219 ---------------- docs/source/guides/spiral_to_lfp.rst | 236 ++++++++++++++++++ ...et.py => spiral_to_dynamic_pink_outlet.py} | 50 ++-- 3 files changed, 263 insertions(+), 242 deletions(-) delete mode 100644 docs/source/guides/circle_to_lfp.rst create mode 100644 docs/source/guides/spiral_to_lfp.rst rename examples/{circle_to_dynamic_pink_outlet.py => spiral_to_dynamic_pink_outlet.py} (59%) diff --git a/docs/source/guides/circle_to_lfp.rst b/docs/source/guides/circle_to_lfp.rst deleted file mode 100644 index c35e8bd..0000000 --- a/docs/source/guides/circle_to_lfp.rst +++ /dev/null @@ -1,219 +0,0 @@ -Circular Motion to Velocity-Modulated LFP -========================================== - -This guide walks through the ``circle_to_dynamic_pink_outlet.py`` example, -which generates a simulated cursor moving in a circle and encodes its velocity -into LFP-like colored noise streamed over Lab Streaming Layer (LSL). - -Overview --------- - -The example demonstrates velocity-to-LFP encoding with a predictable, -synthetic input: - -.. code-block:: text - - Clock -> Counter -> SinGenerator -> Diff -> Velocity2LFP -> LSLOutlet - -The circular motion produces smoothly varying velocity vectors that sweep -through all directions, providing known ground truth for testing decoders. - -Prerequisites -------------- - -Install the required packages: - -.. code-block:: bash - - uv add ezmsg-simbiophys ezmsg-lsl - -Running the Example -------------------- - -Basic usage: - -.. code-block:: bash - - cd examples - uv run python circle_to_dynamic_pink_outlet.py - -With custom parameters: - -.. code-block:: bash - - uv run python circle_to_dynamic_pink_outlet.py \ - --cursor-fs 100 \ - --output-fs 30000 \ - --output-ch 256 - -Command-Line Arguments -~~~~~~~~~~~~~~~~~~~~~~ - -``--graph-addr`` - Address for the ezmsg graph server (ip:port). Set empty to disable. - Default: ``127.0.0.1:25978`` - -``--cursor-fs`` - Simulated cursor update rate in Hz. Default: ``100.0`` - -``--output-fs`` - Output sampling rate in Hz. Default: ``30000.0`` - -``--output-ch`` - Number of output channels. Default: ``256`` - -``--seed`` - Random seed for reproducibility. Default: ``6767`` - -Pipeline Components -------------------- - -Clock -~~~~~ - -Generates timing signals at the specified rate (default 100 Hz). - -Counter -~~~~~~~ - -Converts clock signals to integer sample counts, used for computing -sinusoidal positions. - -SinGenerator -~~~~~~~~~~~~ - -Generates circular motion by outputting two sinusoids with a 90-degree -phase difference: - -.. code-block:: python - - SinGenerator(SinGeneratorSettings( - n_ch=2, # x, y coordinates - freq=0.25, # 0.25 Hz = 4 second period - amp=200.0, # 200 pixel radius - phase=[np.pi/2, 0.0], # cos, sin -> counterclockwise circle - )) - -This produces a cursor position that traces a circle with: - -- Period: 4 seconds -- Radius: 200 pixels -- Direction: counterclockwise starting at (200, 0) - -Diff -~~~~ - -Differentiates position to get velocity. With ``scale_by_fs=True``, the -output is in pixels per second. - -For circular motion at radius *r* and angular frequency *ω*, the velocity -magnitude is constant: *v = r × ω = 200 × 2π/4 ≈ 314* pixels/second. - -Velocity2LFP -~~~~~~~~~~~~ - -Encodes velocity into LFP-like colored noise: - -1. **Polar conversion:** Transform (vx, vy) to (magnitude, angle) -2. **Scale to beta:** Map magnitude and angle to spectral exponent range -3. **Colored noise:** Generate 1/f^β noise with β modulated by velocity -4. **Spatial mixing:** Project 2 noise sources to output channels - -The result is multi-channel colored noise where spectral properties vary -with cursor velocity. - -LSLOutlet -~~~~~~~~~ - -Streams the output over LSL with name ``CircleModulatedPinkNoise`` and -type ``EEG``. - -Understanding the Encoding --------------------------- - -Velocity to Beta Mapping -~~~~~~~~~~~~~~~~~~~~~~~~ - -The velocity components are mapped to spectral exponents (β): - -- **Magnitude:** 0-314 px/s → β = 0.5-2.0 -- **Angle:** 0-2π → β = 0.5-2.0 - -This creates two noise sources with different spectral characteristics that -vary as the cursor moves. - -Spectral Exponent Effects -~~~~~~~~~~~~~~~~~~~~~~~~~ - -The spectral exponent β controls the noise color: - -- **β = 0:** White noise (flat spectrum) -- **β = 1:** Pink noise (1/f, equal power per octave) -- **β = 2:** Brown noise (1/f², random walk) - -As the cursor moves faster, one noise source becomes more "red" (higher β). -As the direction changes, the other source's spectral properties shift. - -Spatial Mixing -~~~~~~~~~~~~~~ - -The two noise sources are projected onto output channels using a mixing -matrix based on sinusoidal weights with random perturbations: - -.. code-block:: python - - weights = np.array([ - np.sin(2 * np.pi * ch_idx / output_ch), # Source 1 weights - np.cos(2 * np.pi * ch_idx / output_ch), # Source 2 weights - ]) + 0.3 * rng.standard_normal((2, output_ch)) - -This creates spatially-varying patterns where different channels have -different mixtures of the two velocity-modulated sources. - -Verifying the Output --------------------- - -The circular motion provides predictable ground truth: - -1. **Constant velocity magnitude:** ~314 pixels/second -2. **Linearly varying angle:** 0 to 2π over 4 seconds -3. **Periodic behavior:** Pattern repeats every 4 seconds - -You can verify the encoding by: - -1. Recording the LSL stream -2. Computing spectral features from the output -3. Checking that spectral properties correlate with the known velocity pattern - -Example Analysis -~~~~~~~~~~~~~~~~ - -.. code-block:: python - - import numpy as np - from pylsl import StreamInlet, resolve_stream - - # Capture one period (4 seconds) - streams = resolve_stream('name', 'CircleModulatedPinkNoise') - inlet = StreamInlet(streams[0]) - - samples = [] - for _ in range(int(4 * 30000)): # 4 seconds at 30 kHz - sample, _ = inlet.pull_sample() - samples.append(sample) - - data = np.array(samples) - - # Compute spectrum for each second - from scipy import signal - for i in range(4): - segment = data[i*30000:(i+1)*30000, 0] # First channel - f, psd = signal.welch(segment, fs=30000) - # Compare spectral slope across segments - -See Also --------- - -- :doc:`mouse_to_ecephys` - Real mouse input with full ecephys output -- :mod:`ezmsg.simbiophys.system.velocity2lfp` - API documentation -- :mod:`ezmsg.simbiophys.dynamic_colored_noise` - Colored noise generator diff --git a/docs/source/guides/spiral_to_lfp.rst b/docs/source/guides/spiral_to_lfp.rst new file mode 100644 index 0000000..680c56f --- /dev/null +++ b/docs/source/guides/spiral_to_lfp.rst @@ -0,0 +1,236 @@ +Spiral Motion to Velocity-Modulated LFP +======================================= + +This guide walks through the ``spiral_to_dynamic_pink_outlet.py`` example, +which generates a simulated cursor moving in a spiral pattern and encodes its +velocity into LFP-like colored noise streamed over Lab Streaming Layer (LSL). + +Overview +-------- + +The example demonstrates velocity-to-LFP encoding with a predictable, +synthetic input: + +.. code-block:: text + + Clock -> SpiralGenerator -> Diff -> Velocity2LFP -> LSLOutlet + +The spiral motion produces smoothly varying velocity vectors that sweep +through all directions while also varying in magnitude, providing known +ground truth for testing decoders. + +Prerequisites +------------- + +Install the required packages: + +.. code-block:: bash + + uv add ezmsg-simbiophys ezmsg-lsl + +Running the Example +------------------- + +Basic usage: + +.. code-block:: bash + + cd examples + uv run python spiral_to_dynamic_pink_outlet.py + +With custom parameters: + +.. code-block:: bash + + uv run python spiral_to_dynamic_pink_outlet.py \ + --cursor-fs 100 \ + --output-fs 30000 \ + --output-ch 256 + +Command-Line Arguments +~~~~~~~~~~~~~~~~~~~~~~ + +``--graph-addr`` + Address for the ezmsg graph server (ip:port). Set empty to disable. + Default: ``127.0.0.1:25978`` + +``--cursor-fs`` + Simulated cursor update rate in Hz. Default: ``100.0`` + +``--output-fs`` + Output sampling rate in Hz. Default: ``30000.0`` + +``--output-ch`` + Number of output channels. Default: ``256`` + +``--seed`` + Random seed for reproducibility. Default: ``6767`` + +Pipeline Components +------------------- + +Clock +~~~~~ + +Generates timing signals at the specified rate (default 100 Hz). + +SpiralGenerator +~~~~~~~~~~~~~~~ + +Generates spiral 2-dimensional motion where both radius and angle vary over time: + +.. code-block:: python + + SpiralGenerator(SpiralGeneratorSettings( + r_mean=150.0, # Mean radius + r_amp=50.0, # Amplitude of radial oscillation + radial_freq=0.1, # Radial oscillation frequency (Hz) + angular_freq=0.25, # Angular rotation frequency (Hz) + )) + +The parametric equations are: + +- ``r(t) = r_mean + r_amp * sin(2*pi*radial_freq*t)`` +- ``theta(t) = 2*pi*angular_freq*t`` +- ``x(t) = r(t) * cos(theta(t))`` +- ``y(t) = r(t) * sin(theta(t))`` + +This creates a "breathing" spiral where the cursor rotates while moving +in and out from the center. + + +Diff +~~~~ + +Differentiates position to get velocity. With ``scale_by_fs=True``, the +output is in pixels per second. + +Velocity2LFP +~~~~~~~~~~~~ + +Encodes velocity into LFP-like colored noise using a cosine tuning model: + +1. **Polar conversion:** Transform (vx, vy) to (magnitude, angle) +2. **Cosine encoder:** Each of ``n_lfp_sources`` (default 8) has a random + preferred direction. The spectral exponent beta is computed as: + ``beta = baseline + modulation * magnitude * cos(angle - pd)`` +3. **Clip:** Ensures beta values stay within valid range [0, 2] +4. **Colored noise:** Generate 1/f^β noise with β dynamically modulated per source +5. **Spatial mixing:** Project n_lfp_sources onto output_ch channels using + sinusoidal mixing patterns + +The result is multi-channel colored noise where spectral properties vary +with cursor velocity direction and magnitude. + +LSLOutlet +~~~~~~~~~ + +Streams the output over LSL with name ``SpiralModulatedPinkNoise`` and +type ``EEG``. + +Understanding the Encoding +-------------------------- + +Cosine Tuning Model +~~~~~~~~~~~~~~~~~~~ + +Each of the ``n_lfp_sources`` (default 8) has a randomly assigned preferred +direction. The spectral exponent beta for each source is computed using +a cosine tuning model: + +.. code-block:: python + + beta = baseline + modulation * magnitude * cos(angle - pd) + +With default settings: + +- ``baseline = 1.0`` (pink noise at rest) +- ``modulation = 1.0 / max_velocity`` (scales with velocity) +- ``max_velocity = 315.0`` (pixels/second) + +When moving at maximum velocity in a source's preferred direction, +beta reaches 2.0 (brown noise). When moving opposite to the preferred +direction, beta reaches 0.0 (white noise). The output is clipped to [0, 2]. + +Spectral Exponent Effects +~~~~~~~~~~~~~~~~~~~~~~~~~ + +The spectral exponent β controls the noise color: + +- **β = 0:** White noise (flat spectrum) +- **β = 1:** Pink noise (1/f, equal power per octave) +- **β = 2:** Brown noise (1/f², random walk) + +Each source responds differently to velocity direction based on its +preferred direction, creating a rich mixture of spectral characteristics. + +Spatial Mixing +~~~~~~~~~~~~~~ + +The ``n_lfp_sources`` noise sources are projected onto ``output_ch`` channels +using a mixing matrix based on sinusoidal weights at different spatial +frequencies, plus random perturbations: + +.. code-block:: python + + weights = np.zeros((n_sources, output_ch)) + for i in range(n_sources): + freq = (i + 1) / n_sources + phase = 2 * np.pi * i / n_sources + weights[i, :] = np.sin(2 * np.pi * freq * ch_idx / output_ch + phase) + weights += 0.3 * rng.standard_normal((n_sources, output_ch)) + +This creates spatially-varying patterns where different channels have +different mixtures of the velocity-modulated sources, mimicking the +spatial spread of LFP signals across electrode arrays. + +Verifying the Output +-------------------- + +The spiral motion provides predictable ground truth: + +1. **Varying velocity magnitude:** Oscillates due to radial breathing +2. **Linearly varying angle:** Rotates at ``angular_freq`` Hz +3. **Periodic behavior:** Angular period = 1/angular_freq seconds (4s default), + radial period = 1/radial_freq seconds (10s default) + +You can verify the encoding by: + +1. Recording the LSL stream +2. Computing spectral features from the output +3. Checking that spectral properties correlate with the known velocity pattern + +Example Analysis +~~~~~~~~~~~~~~~~ + +.. code-block:: python + + import numpy as np + from pylsl import StreamInlet, resolve_stream + + # Capture one angular period (4 seconds with default settings) + streams = resolve_stream('name', 'SpiralModulatedPinkNoise') + inlet = StreamInlet(streams[0]) + + samples = [] + for _ in range(int(4 * 30000)): # 4 seconds at 30 kHz + sample, _ = inlet.pull_sample() + samples.append(sample) + + data = np.array(samples) + + # Compute spectrum for each second + from scipy import signal + for i in range(4): + segment = data[i*30000:(i+1)*30000, 0] # First channel + f, psd = signal.welch(segment, fs=30000) + # Compare spectral slope across segments + +See Also +-------- + +- :doc:`mouse_to_ecephys` - Real mouse input with full ecephys output +- :mod:`ezmsg.simbiophys.system.velocity2lfp` - API documentation +- :mod:`ezmsg.simbiophys.cosine_encoder` - Cosine tuning encoder +- :mod:`ezmsg.simbiophys.dynamic_colored_noise` - Colored noise generator +- :mod:`ezmsg.simbiophys.oscillator` - SpiralGenerator and SinGenerator diff --git a/examples/circle_to_dynamic_pink_outlet.py b/examples/spiral_to_dynamic_pink_outlet.py similarity index 59% rename from examples/circle_to_dynamic_pink_outlet.py rename to examples/spiral_to_dynamic_pink_outlet.py index 8e0ec22..0b4a81a 100644 --- a/examples/circle_to_dynamic_pink_outlet.py +++ b/examples/spiral_to_dynamic_pink_outlet.py @@ -1,18 +1,22 @@ -"""Circular motion to velocity-modulated LFP, streamed over LSL. +"""Spiral motion to velocity-modulated LFP, streamed over LSL. -This example generates a simulated cursor moving in a circle, computes its -velocity, encodes the velocity into LFP-like colored noise, and streams the +This example generates a simulated cursor moving in a spiral pattern, computes +its velocity, encodes the velocity into LFP-like colored noise, and streams the result over Lab Streaming Layer (LSL). Pipeline:: - Clock -> Counter -> SinGenerator (circle) -> Diff (velocity) - -> Velocity2LFP -> LSLOutlet + Clock -> SpiralGenerator -> Diff (velocity) -> Velocity2LFP -> LSLOutlet -The circular motion produces smoothly varying velocity vectors that sweep -through all directions. The Velocity2LFP system converts this into multi-channel -colored noise where the spectral properties are modulated by velocity magnitude -and direction. +The spiral motion produces varying velocity vectors where both the magnitude +(speed) and direction change over time. The SpiralGenerator creates a pattern +where: + - The radius oscillates sinusoidally (breathing in/out) + - The angle increases linearly (rotation) + +This provides richer dynamics than circular motion for testing the velocity +encoding system, as velocity is non-zero and varying even when the cursor +"pauses" at the turning points of the radial oscillation. This is useful for: - Testing LFP processing pipelines with known ground truth @@ -37,13 +41,12 @@ """ import ezmsg.core as ez -import numpy as np import typer -from ezmsg.baseproc import Clock, ClockSettings, Counter, CounterSettings +from ezmsg.baseproc import Clock, ClockSettings from ezmsg.lsl.outlet import LSLOutletSettings, LSLOutletUnit from ezmsg.sigproc.diff import DiffSettings, DiffUnit -from ezmsg.simbiophys.oscillator import SinGenerator, SinGeneratorSettings +from ezmsg.simbiophys.oscillator import SpiralGenerator, SpiralGeneratorSettings from ezmsg.simbiophys.system.velocity2lfp import Velocity2LFP, Velocity2LFPSettings GRAPH_IP = "127.0.0.1" @@ -66,34 +69,35 @@ def main( comps = { "CLOCK": Clock(ClockSettings(dispatch_rate=cursor_fs)), - "COUNTER": Counter(CounterSettings(fs=cursor_fs)), - "OSCILLATOR": SinGenerator( - SinGeneratorSettings( - n_ch=2, # x,y - freq=0.25, # 1/4 Hz = 4 second period - amp=200.0, # radius 200 pixels - phase=[np.pi / 2, 0.0], # [x, y]: cos = sin + π/2, counterclockwise from (200, 0) + "SPIRAL": SpiralGenerator( + SpiralGeneratorSettings( + fs=cursor_fs, + r_mean=150.0, # Mean radius 150 pixels + r_amp=150.0, # Radius oscillates +/- 50 pixels (100-200 range) + radial_freq=0.1, # Radial breathing at 0.1 Hz (10 second period) + angular_freq=0.25, # Rotation at 0.25 Hz (4 second period) ) ), "DIFF": DiffUnit(DiffSettings(axis="time", scale_by_fs=True)), + # DIFF Output is [[dx, dy]] pixels/sec with varying magnitude "VEL2LFP": Velocity2LFP( Velocity2LFPSettings( output_fs=output_fs, output_ch=output_ch, + max_velocity=472.0, seed=seed, ) ), "SINK": LSLOutletUnit( LSLOutletSettings( - stream_name="CircleModulatedPinkNoise", + stream_name="SpiralModulatedPinkNoise", stream_type="EEG", ) ), } conns = ( - (comps["CLOCK"].OUTPUT_SIGNAL, comps["COUNTER"].INPUT_SIGNAL), - (comps["COUNTER"].OUTPUT_SIGNAL, comps["OSCILLATOR"].INPUT_SIGNAL), - (comps["OSCILLATOR"].OUTPUT_SIGNAL, comps["DIFF"].INPUT_SIGNAL), + (comps["CLOCK"].OUTPUT_SIGNAL, comps["SPIRAL"].INPUT_CLOCK), + (comps["SPIRAL"].OUTPUT_SIGNAL, comps["DIFF"].INPUT_SIGNAL), (comps["DIFF"].OUTPUT_SIGNAL, comps["VEL2LFP"].INPUT_SIGNAL), (comps["VEL2LFP"].OUTPUT_SIGNAL, comps["SINK"].INPUT_SIGNAL), ) From 90595304bafb005315ec9ae5b903c0222caab8af Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Thu, 8 Jan 2026 00:33:11 -0500 Subject: [PATCH 10/11] Pull coordinate transformation out of velocity2lfp and velocity2spike and do it in upstream script. --- docs/source/guides/spiral_to_lfp.rst | 20 +++++--- examples/spiral_to_dynamic_pink_outlet.py | 10 ++-- .../simbiophys/system/velocity2ecephys.py | 17 +++++-- src/ezmsg/simbiophys/system/velocity2lfp.py | 46 +++++++++---------- src/ezmsg/simbiophys/system/velocity2spike.py | 40 ++++++++-------- 5 files changed, 75 insertions(+), 58 deletions(-) diff --git a/docs/source/guides/spiral_to_lfp.rst b/docs/source/guides/spiral_to_lfp.rst index 680c56f..f6343b1 100644 --- a/docs/source/guides/spiral_to_lfp.rst +++ b/docs/source/guides/spiral_to_lfp.rst @@ -13,7 +13,7 @@ synthetic input: .. code-block:: text - Clock -> SpiralGenerator -> Diff -> Velocity2LFP -> LSLOutlet + Clock -> SpiralGenerator -> Diff -> CART2POL -> Velocity2LFP -> LSLOutlet The spiral motion produces smoothly varying velocity vectors that sweep through all directions while also varying in magnitude, providing known @@ -105,18 +105,24 @@ Diff Differentiates position to get velocity. With ``scale_by_fs=True``, the output is in pixels per second. +CART2POL (CoordinateSpaces) +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Converts Cartesian velocity (vx, vy) to polar coordinates (magnitude, angle). +This transformation is done once upstream and shared if both spike and LFP +encoding are used (via VelocityEncoder). + Velocity2LFP ~~~~~~~~~~~~ -Encodes velocity into LFP-like colored noise using a cosine tuning model: +Encodes polar velocity into LFP-like colored noise using a cosine tuning model: -1. **Polar conversion:** Transform (vx, vy) to (magnitude, angle) -2. **Cosine encoder:** Each of ``n_lfp_sources`` (default 8) has a random +1. **Cosine encoder:** Each of ``n_lfp_sources`` (default 8) has a random preferred direction. The spectral exponent beta is computed as: ``beta = baseline + modulation * magnitude * cos(angle - pd)`` -3. **Clip:** Ensures beta values stay within valid range [0, 2] -4. **Colored noise:** Generate 1/f^β noise with β dynamically modulated per source -5. **Spatial mixing:** Project n_lfp_sources onto output_ch channels using +2. **Clip:** Ensures beta values stay within valid range [0, 2] +3. **Colored noise:** Generate 1/f^β noise with β dynamically modulated per source +4. **Spatial mixing:** Project n_lfp_sources onto output_ch channels using sinusoidal mixing patterns The result is multi-channel colored noise where spectral properties vary diff --git a/examples/spiral_to_dynamic_pink_outlet.py b/examples/spiral_to_dynamic_pink_outlet.py index 0b4a81a..7288f32 100644 --- a/examples/spiral_to_dynamic_pink_outlet.py +++ b/examples/spiral_to_dynamic_pink_outlet.py @@ -6,7 +6,7 @@ Pipeline:: - Clock -> SpiralGenerator -> Diff (velocity) -> Velocity2LFP -> LSLOutlet + Clock -> SpiralGenerator -> Diff (velocity) -> CART2POL -> Velocity2LFP -> LSLOutlet The spiral motion produces varying velocity vectors where both the magnitude (speed) and direction change over time. The SpiralGenerator creates a pattern @@ -44,6 +44,7 @@ import typer from ezmsg.baseproc import Clock, ClockSettings from ezmsg.lsl.outlet import LSLOutletSettings, LSLOutletUnit +from ezmsg.sigproc.coordinatespaces import CoordinateMode, CoordinateSpaces, CoordinateSpacesSettings from ezmsg.sigproc.diff import DiffSettings, DiffUnit from ezmsg.simbiophys.oscillator import SpiralGenerator, SpiralGeneratorSettings @@ -79,7 +80,9 @@ def main( ) ), "DIFF": DiffUnit(DiffSettings(axis="time", scale_by_fs=True)), - # DIFF Output is [[dx, dy]] pixels/sec with varying magnitude + # DIFF Output is [[vx, vy]] pixels/sec with varying magnitude + "COORDS": CoordinateSpaces(CoordinateSpacesSettings(mode=CoordinateMode.CART2POL, axis="ch")), + # COORDS Output is [[magnitude, angle]] polar velocity "VEL2LFP": Velocity2LFP( Velocity2LFPSettings( output_fs=output_fs, @@ -98,7 +101,8 @@ def main( conns = ( (comps["CLOCK"].OUTPUT_SIGNAL, comps["SPIRAL"].INPUT_CLOCK), (comps["SPIRAL"].OUTPUT_SIGNAL, comps["DIFF"].INPUT_SIGNAL), - (comps["DIFF"].OUTPUT_SIGNAL, comps["VEL2LFP"].INPUT_SIGNAL), + (comps["DIFF"].OUTPUT_SIGNAL, comps["COORDS"].INPUT_SIGNAL), + (comps["COORDS"].OUTPUT_SIGNAL, comps["VEL2LFP"].INPUT_SIGNAL), (comps["VEL2LFP"].OUTPUT_SIGNAL, comps["SINK"].INPUT_SIGNAL), ) diff --git a/src/ezmsg/simbiophys/system/velocity2ecephys.py b/src/ezmsg/simbiophys/system/velocity2ecephys.py index a25a8c1..c595265 100644 --- a/src/ezmsg/simbiophys/system/velocity2ecephys.py +++ b/src/ezmsg/simbiophys/system/velocity2ecephys.py @@ -5,9 +5,12 @@ background activity. Pipeline: - velocity (x,y) --+--> Velocity2Spike --> spikes --| - | +--> Add --> ecephys - +--> Velocity2LFP ----> lfp -----| + velocity (x,y) -> CART2POL --+--> Velocity2Spike --> spikes --| + | +--> Add --> ecephys + +--> Velocity2LFP ----> lfp -----| + +The coordinate transformation from Cartesian to polar is done once at the +input, then shared by both spike and LFP encoding branches. This is the top-level system for velocity-encoded neural simulation. Use this when you need full ecephys-like output suitable for testing BCI decoders. @@ -18,6 +21,7 @@ """ import ezmsg.core as ez +from ezmsg.sigproc.coordinatespaces import CoordinateMode, CoordinateSpaces, CoordinateSpacesSettings from ezmsg.sigproc.math.add import Add from ezmsg.util.messages.axisarray import AxisArray @@ -71,12 +75,14 @@ class VelocityEncoder(ez.Collection): # Velocity inputs (via mouse / gamepad system, or via task parsing system) INPUT_SIGNAL = ez.InputStream(AxisArray) + COORDS = CoordinateSpaces() # Cartesian to polar (done once, shared by both branches) SPIKES = Velocity2Spike() LFP = Velocity2LFP() ADD = Add() # Add colored noise and waveforms OUTPUT_SIGNAL = ez.OutputStream(AxisArray) def configure(self) -> None: + self.COORDS.apply_settings(CoordinateSpacesSettings(mode=CoordinateMode.CART2POL, axis="ch")) self.SPIKES.apply_settings( Velocity2SpikeSettings( output_fs=self.SETTINGS.output_fs, output_ch=self.SETTINGS.output_ch, seed=self.SETTINGS.seed @@ -90,9 +96,10 @@ def configure(self) -> None: def network(self) -> ez.NetworkDefinition: return ( - (self.INPUT_SIGNAL, self.SPIKES.INPUT_SIGNAL), + (self.INPUT_SIGNAL, self.COORDS.INPUT_SIGNAL), + (self.COORDS.OUTPUT_SIGNAL, self.SPIKES.INPUT_SIGNAL), (self.SPIKES.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_A), - (self.INPUT_SIGNAL, self.LFP.INPUT_SIGNAL), + (self.COORDS.OUTPUT_SIGNAL, self.LFP.INPUT_SIGNAL), (self.LFP.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_B), (self.ADD.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL), ) diff --git a/src/ezmsg/simbiophys/system/velocity2lfp.py b/src/ezmsg/simbiophys/system/velocity2lfp.py index b2e1997..ca6120b 100644 --- a/src/ezmsg/simbiophys/system/velocity2lfp.py +++ b/src/ezmsg/simbiophys/system/velocity2lfp.py @@ -1,17 +1,22 @@ -"""Convert 2D cursor velocity to simulated LFP-like colored noise. +"""Convert polar velocity coordinates to simulated LFP-like colored noise. -This module provides a system that encodes cursor velocity into the spectral -properties of colored (1/f^beta) noise, producing LFP-like signals. +This module provides a system that encodes velocity (in polar coordinates) into +the spectral properties of colored (1/f^beta) noise, producing LFP-like signals. Pipeline: - velocity (x,y) -> polar coords -> cosine encoder (beta values) -> clip - -> colored noise -> mix to channels + polar coords (magnitude, angle) -> cosine encoder (beta values) -> clip + -> colored noise -> mix to channels The velocity is encoded using a cosine tuning model where multiple noise sources have different preferred directions. Each source's spectral exponent (beta) is modulated by the velocity direction and magnitude. These sources are then mixed across output channels using a spatial mixing matrix. +Note: + This system expects polar coordinates as input. Use CoordinateSpaces with + mode=CART2POL upstream to convert Cartesian velocity (vx, vy) to polar + coordinates (magnitude, angle). + See Also: :mod:`ezmsg.simbiophys.system.velocity2spike`: Velocity to spike encoding. :mod:`ezmsg.simbiophys.system.velocity2ecephys`: Combined spike + LFP encoding. @@ -20,7 +25,6 @@ import ezmsg.core as ez import numpy as np from ezmsg.sigproc.affinetransform import AffineTransform, AffineTransformSettings -from ezmsg.sigproc.coordinatespaces import CoordinateMode, CoordinateSpaces, CoordinateSpacesSettings from ezmsg.sigproc.math.clip import Clip, ClipSettings from ezmsg.util.messages.axisarray import AxisArray @@ -49,24 +53,23 @@ class Velocity2LFPSettings(ez.Settings): class Velocity2LFP(ez.Collection): - """Encode cursor velocity into LFP-like colored noise. + """Encode velocity (polar coordinates) into LFP-like colored noise. - This system converts 2D cursor velocity into multi-channel LFP-like signals: + This system converts polar velocity coordinates into multi-channel LFP-like signals: - 1. **Coordinate transform**: Converts Cartesian velocity (x, y) to polar - coordinates (magnitude, angle). - 2. **Cosine encoder**: Each of n_lfp_sources has a different preferred + 1. **Cosine encoder**: Each of n_lfp_sources has a different preferred direction. The spectral exponent beta (0-2) is modulated by the cosine of the angle between velocity and preferred direction, scaled by speed. - 3. **Clip**: Ensures beta values stay within valid range [0, 2]. - 4. **Colored noise**: Generates 1/f^beta noise where beta is dynamically + 2. **Clip**: Ensures beta values stay within valid range [0, 2]. + 3. **Colored noise**: Generates 1/f^beta noise where beta is dynamically modulated per source. - 5. **Spatial mixing**: Projects the n_lfp_sources onto output_ch channels + 4. **Spatial mixing**: Projects the n_lfp_sources onto output_ch channels using a sinusoidal mixing matrix with random perturbations. Input: - AxisArray with shape (N, 2) containing cursor velocity in pixels/second. - Dimension 0 is time, dimension 1 is [vx, vy]. + AxisArray with shape (N, 2) containing polar velocity coordinates. + Dimension 0 is time, dimension 1 is [magnitude, angle]. + Use CoordinateSpaces(mode=CART2POL) upstream if starting from (vx, vy). Output: AxisArray with shape (M, output_ch) containing LFP-like colored noise @@ -75,9 +78,8 @@ class Velocity2LFP(ez.Collection): SETTINGS = Velocity2LFPSettings - # Velocity inputs (via mouse / gamepad, or via task parsing system) + # Polar velocity inputs (magnitude, angle) INPUT_SIGNAL = ez.InputStream(AxisArray) - COORDS = CoordinateSpaces() BETA_ENCODER = CosineEncoderUnit() CLIP_BETA = Clip() PINK_NOISE = DynamicColoredNoiseUnit() @@ -85,9 +87,8 @@ class Velocity2LFP(ez.Collection): OUTPUT_SIGNAL = ez.OutputStream(AxisArray) def configure(self) -> None: - self.COORDS.apply_settings(CoordinateSpacesSettings(mode=CoordinateMode.CART2POL, axis="ch")) - # COORDS output is 2-ch: [magnitude, angle] - # magnitude ranges from 0 to ~315 px/s, angle from -pi to +pi + # Input is polar coords: [magnitude, angle] + # magnitude ranges from 0 to ~max_velocity px/s, angle from -pi to +pi # Configure cosine encoder to output beta values in range [0, 2] # baseline=1.0 (middle of range), modulation=1/315 so at max velocity we get full range @@ -135,8 +136,7 @@ def configure(self) -> None: def network(self) -> ez.NetworkDefinition: return ( - (self.INPUT_SIGNAL, self.COORDS.INPUT_SIGNAL), - (self.COORDS.OUTPUT_SIGNAL, self.BETA_ENCODER.INPUT_SIGNAL), + (self.INPUT_SIGNAL, self.BETA_ENCODER.INPUT_SIGNAL), (self.BETA_ENCODER.OUTPUT_SIGNAL, self.CLIP_BETA.INPUT_SIGNAL), (self.CLIP_BETA.OUTPUT_SIGNAL, self.PINK_NOISE.INPUT_SIGNAL), (self.PINK_NOISE.OUTPUT_SIGNAL, self.MIX_NOISE.INPUT_SIGNAL), diff --git a/src/ezmsg/simbiophys/system/velocity2spike.py b/src/ezmsg/simbiophys/system/velocity2spike.py index bc9188b..b5a434e 100644 --- a/src/ezmsg/simbiophys/system/velocity2spike.py +++ b/src/ezmsg/simbiophys/system/velocity2spike.py @@ -1,11 +1,16 @@ -"""Convert 2D cursor velocity to simulated spike waveforms. +"""Convert polar velocity coordinates to simulated spike waveforms. -This module provides a system that encodes cursor velocity into spike activity -using a cosine tuning model, then generates spike events and inserts realistic -waveforms. +This module provides a system that encodes velocity (in polar coordinates) into +spike activity using a cosine tuning model, then generates spike events and +inserts realistic waveforms. Pipeline: - velocity (x,y) -> polar coords -> cosine encoder -> clip -> Poisson events -> waveforms + polar coords (magnitude, angle) -> cosine encoder -> clip -> Poisson events -> waveforms + +Note: + This system expects polar coordinates as input. Use CoordinateSpaces with + mode=CART2POL upstream to convert Cartesian velocity (vx, vy) to polar + coordinates (magnitude, angle). See Also: :mod:`ezmsg.simbiophys.system.velocity2lfp`: Velocity to LFP encoding. @@ -17,7 +22,6 @@ from ezmsg.event.kernel import ArrayKernel, MultiKernel from ezmsg.event.kernel_insert import SparseKernelInserterSettings, SparseKernelInserterUnit from ezmsg.event.poissonevents import PoissonEventSettings, PoissonEventUnit -from ezmsg.sigproc.coordinatespaces import CoordinateMode, CoordinateSpaces, CoordinateSpacesSettings from ezmsg.sigproc.math.clip import Clip, ClipSettings from ezmsg.util.messages.axisarray import AxisArray @@ -49,22 +53,21 @@ class Velocity2SpikeSettings(ez.Settings): class Velocity2Spike(ez.Collection): - """Encode cursor velocity into simulated spike waveforms. + """Encode velocity (polar coordinates) into simulated spike waveforms. - This system converts 2D cursor velocity into multi-channel spike activity: + This system converts polar velocity coordinates into multi-channel spike activity: - 1. **Coordinate transform**: Converts Cartesian velocity (x, y) to polar - coordinates (magnitude, angle). - 2. **Cosine tuning**: Each channel has a preferred direction; firing rate + 1. **Cosine tuning**: Each channel has a preferred direction; firing rate is modulated by the cosine of the angle between velocity and preferred direction, scaled by velocity magnitude. - 3. **Poisson spike generation**: Converts firing rates to discrete spike + 2. **Poisson spike generation**: Converts firing rates to discrete spike events using an inhomogeneous Poisson process. - 4. **Waveform insertion**: Inserts realistic spike waveforms at event times. + 3. **Waveform insertion**: Inserts realistic spike waveforms at event times. Input: - AxisArray with shape (N, 2) containing cursor velocity in pixels/second. - Dimension 0 is time, dimension 1 is [vx, vy]. + AxisArray with shape (N, 2) containing polar velocity coordinates. + Dimension 0 is time, dimension 1 is [magnitude, angle]. + Use CoordinateSpaces(mode=CART2POL) upstream if starting from (vx, vy). Output: AxisArray with shape (M, output_ch) containing spike waveforms at @@ -73,9 +76,8 @@ class Velocity2Spike(ez.Collection): SETTINGS = Velocity2SpikeSettings - # Velocity inputs (via mouse / gamepad system, or via task parsing system) + # Polar velocity inputs (magnitude, angle) INPUT_SIGNAL = ez.InputStream(AxisArray) - COORDS = CoordinateSpaces() RATE_ENCODER = CosineEncoderUnit() CLIP_RATE = Clip() SPIKE_EVENT = PoissonEventUnit() @@ -83,7 +85,6 @@ class Velocity2Spike(ez.Collection): OUTPUT_SIGNAL = ez.OutputStream(AxisArray) def configure(self) -> None: - self.COORDS.apply_settings(CoordinateSpacesSettings(mode=CoordinateMode.CART2POL, axis="ch")) self.RATE_ENCODER.apply_settings( CosineEncoderSettings( output_ch=self.SETTINGS.output_ch, @@ -107,8 +108,7 @@ def configure(self) -> None: def network(self) -> ez.NetworkDefinition: return ( - (self.INPUT_SIGNAL, self.COORDS.INPUT_SIGNAL), - (self.COORDS.OUTPUT_SIGNAL, self.RATE_ENCODER.INPUT_SIGNAL), + (self.INPUT_SIGNAL, self.RATE_ENCODER.INPUT_SIGNAL), (self.RATE_ENCODER.OUTPUT_SIGNAL, self.CLIP_RATE.INPUT_SIGNAL), (self.CLIP_RATE.OUTPUT_SIGNAL, self.SPIKE_EVENT.INPUT_SIGNAL), (self.SPIKE_EVENT.OUTPUT_SIGNAL, self.WAVEFORMS.INPUT_SIGNAL), From 26dcbad05f1a6419455b82fb50909afd2c3fdcd9 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Thu, 8 Jan 2026 00:47:56 -0500 Subject: [PATCH 11/11] Change default scaling factor in velocity2lfp.py to 20.0 --- src/ezmsg/simbiophys/system/velocity2lfp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ezmsg/simbiophys/system/velocity2lfp.py b/src/ezmsg/simbiophys/system/velocity2lfp.py index ca6120b..58196cd 100644 --- a/src/ezmsg/simbiophys/system/velocity2lfp.py +++ b/src/ezmsg/simbiophys/system/velocity2lfp.py @@ -109,7 +109,7 @@ def configure(self) -> None: n_poles=5, smoothing_tau=0.01, initial_beta=1.0, - scale=1.0, + scale=20.0, seed=self.SETTINGS.seed, ) )