diff --git a/docs/source/guides/circle_to_lfp.rst b/docs/source/guides/circle_to_lfp.rst deleted file mode 100644 index c35e8bd..0000000 --- a/docs/source/guides/circle_to_lfp.rst +++ /dev/null @@ -1,219 +0,0 @@ -Circular Motion to Velocity-Modulated LFP -========================================== - -This guide walks through the ``circle_to_dynamic_pink_outlet.py`` example, -which generates a simulated cursor moving in a circle and encodes its velocity -into LFP-like colored noise streamed over Lab Streaming Layer (LSL). - -Overview --------- - -The example demonstrates velocity-to-LFP encoding with a predictable, -synthetic input: - -.. code-block:: text - - Clock -> Counter -> SinGenerator -> Diff -> Velocity2LFP -> LSLOutlet - -The circular motion produces smoothly varying velocity vectors that sweep -through all directions, providing known ground truth for testing decoders. - -Prerequisites -------------- - -Install the required packages: - -.. code-block:: bash - - uv add ezmsg-simbiophys ezmsg-lsl - -Running the Example -------------------- - -Basic usage: - -.. code-block:: bash - - cd examples - uv run python circle_to_dynamic_pink_outlet.py - -With custom parameters: - -.. code-block:: bash - - uv run python circle_to_dynamic_pink_outlet.py \ - --cursor-fs 100 \ - --output-fs 30000 \ - --output-ch 256 - -Command-Line Arguments -~~~~~~~~~~~~~~~~~~~~~~ - -``--graph-addr`` - Address for the ezmsg graph server (ip:port). Set empty to disable. - Default: ``127.0.0.1:25978`` - -``--cursor-fs`` - Simulated cursor update rate in Hz. Default: ``100.0`` - -``--output-fs`` - Output sampling rate in Hz. Default: ``30000.0`` - -``--output-ch`` - Number of output channels. Default: ``256`` - -``--seed`` - Random seed for reproducibility. Default: ``6767`` - -Pipeline Components -------------------- - -Clock -~~~~~ - -Generates timing signals at the specified rate (default 100 Hz). - -Counter -~~~~~~~ - -Converts clock signals to integer sample counts, used for computing -sinusoidal positions. - -SinGenerator -~~~~~~~~~~~~ - -Generates circular motion by outputting two sinusoids with a 90-degree -phase difference: - -.. code-block:: python - - SinGenerator(SinGeneratorSettings( - n_ch=2, # x, y coordinates - freq=0.25, # 0.25 Hz = 4 second period - amp=200.0, # 200 pixel radius - phase=[np.pi/2, 0.0], # cos, sin -> counterclockwise circle - )) - -This produces a cursor position that traces a circle with: - -- Period: 4 seconds -- Radius: 200 pixels -- Direction: counterclockwise starting at (200, 0) - -Diff -~~~~ - -Differentiates position to get velocity. With ``scale_by_fs=True``, the -output is in pixels per second. - -For circular motion at radius *r* and angular frequency *ω*, the velocity -magnitude is constant: *v = r × ω = 200 × 2π/4 ≈ 314* pixels/second. - -Velocity2LFP -~~~~~~~~~~~~ - -Encodes velocity into LFP-like colored noise: - -1. **Polar conversion:** Transform (vx, vy) to (magnitude, angle) -2. **Scale to beta:** Map magnitude and angle to spectral exponent range -3. **Colored noise:** Generate 1/f^β noise with β modulated by velocity -4. **Spatial mixing:** Project 2 noise sources to output channels - -The result is multi-channel colored noise where spectral properties vary -with cursor velocity. - -LSLOutlet -~~~~~~~~~ - -Streams the output over LSL with name ``CircleModulatedPinkNoise`` and -type ``EEG``. - -Understanding the Encoding --------------------------- - -Velocity to Beta Mapping -~~~~~~~~~~~~~~~~~~~~~~~~ - -The velocity components are mapped to spectral exponents (β): - -- **Magnitude:** 0-314 px/s → β = 0.5-2.0 -- **Angle:** 0-2π → β = 0.5-2.0 - -This creates two noise sources with different spectral characteristics that -vary as the cursor moves. - -Spectral Exponent Effects -~~~~~~~~~~~~~~~~~~~~~~~~~ - -The spectral exponent β controls the noise color: - -- **β = 0:** White noise (flat spectrum) -- **β = 1:** Pink noise (1/f, equal power per octave) -- **β = 2:** Brown noise (1/f², random walk) - -As the cursor moves faster, one noise source becomes more "red" (higher β). -As the direction changes, the other source's spectral properties shift. - -Spatial Mixing -~~~~~~~~~~~~~~ - -The two noise sources are projected onto output channels using a mixing -matrix based on sinusoidal weights with random perturbations: - -.. code-block:: python - - weights = np.array([ - np.sin(2 * np.pi * ch_idx / output_ch), # Source 1 weights - np.cos(2 * np.pi * ch_idx / output_ch), # Source 2 weights - ]) + 0.3 * rng.standard_normal((2, output_ch)) - -This creates spatially-varying patterns where different channels have -different mixtures of the two velocity-modulated sources. - -Verifying the Output --------------------- - -The circular motion provides predictable ground truth: - -1. **Constant velocity magnitude:** ~314 pixels/second -2. **Linearly varying angle:** 0 to 2π over 4 seconds -3. **Periodic behavior:** Pattern repeats every 4 seconds - -You can verify the encoding by: - -1. Recording the LSL stream -2. Computing spectral features from the output -3. Checking that spectral properties correlate with the known velocity pattern - -Example Analysis -~~~~~~~~~~~~~~~~ - -.. code-block:: python - - import numpy as np - from pylsl import StreamInlet, resolve_stream - - # Capture one period (4 seconds) - streams = resolve_stream('name', 'CircleModulatedPinkNoise') - inlet = StreamInlet(streams[0]) - - samples = [] - for _ in range(int(4 * 30000)): # 4 seconds at 30 kHz - sample, _ = inlet.pull_sample() - samples.append(sample) - - data = np.array(samples) - - # Compute spectrum for each second - from scipy import signal - for i in range(4): - segment = data[i*30000:(i+1)*30000, 0] # First channel - f, psd = signal.welch(segment, fs=30000) - # Compare spectral slope across segments - -See Also --------- - -- :doc:`mouse_to_ecephys` - Real mouse input with full ecephys output -- :mod:`ezmsg.simbiophys.system.velocity2lfp` - API documentation -- :mod:`ezmsg.simbiophys.dynamic_colored_noise` - Colored noise generator diff --git a/docs/source/guides/spiral_to_lfp.rst b/docs/source/guides/spiral_to_lfp.rst new file mode 100644 index 0000000..f6343b1 --- /dev/null +++ b/docs/source/guides/spiral_to_lfp.rst @@ -0,0 +1,242 @@ +Spiral Motion to Velocity-Modulated LFP +======================================= + +This guide walks through the ``spiral_to_dynamic_pink_outlet.py`` example, +which generates a simulated cursor moving in a spiral pattern and encodes its +velocity into LFP-like colored noise streamed over Lab Streaming Layer (LSL). + +Overview +-------- + +The example demonstrates velocity-to-LFP encoding with a predictable, +synthetic input: + +.. code-block:: text + + Clock -> SpiralGenerator -> Diff -> CART2POL -> Velocity2LFP -> LSLOutlet + +The spiral motion produces smoothly varying velocity vectors that sweep +through all directions while also varying in magnitude, providing known +ground truth for testing decoders. + +Prerequisites +------------- + +Install the required packages: + +.. code-block:: bash + + uv add ezmsg-simbiophys ezmsg-lsl + +Running the Example +------------------- + +Basic usage: + +.. code-block:: bash + + cd examples + uv run python spiral_to_dynamic_pink_outlet.py + +With custom parameters: + +.. code-block:: bash + + uv run python spiral_to_dynamic_pink_outlet.py \ + --cursor-fs 100 \ + --output-fs 30000 \ + --output-ch 256 + +Command-Line Arguments +~~~~~~~~~~~~~~~~~~~~~~ + +``--graph-addr`` + Address for the ezmsg graph server (ip:port). Set empty to disable. + Default: ``127.0.0.1:25978`` + +``--cursor-fs`` + Simulated cursor update rate in Hz. Default: ``100.0`` + +``--output-fs`` + Output sampling rate in Hz. Default: ``30000.0`` + +``--output-ch`` + Number of output channels. Default: ``256`` + +``--seed`` + Random seed for reproducibility. Default: ``6767`` + +Pipeline Components +------------------- + +Clock +~~~~~ + +Generates timing signals at the specified rate (default 100 Hz). + +SpiralGenerator +~~~~~~~~~~~~~~~ + +Generates spiral 2-dimensional motion where both radius and angle vary over time: + +.. code-block:: python + + SpiralGenerator(SpiralGeneratorSettings( + r_mean=150.0, # Mean radius + r_amp=50.0, # Amplitude of radial oscillation + radial_freq=0.1, # Radial oscillation frequency (Hz) + angular_freq=0.25, # Angular rotation frequency (Hz) + )) + +The parametric equations are: + +- ``r(t) = r_mean + r_amp * sin(2*pi*radial_freq*t)`` +- ``theta(t) = 2*pi*angular_freq*t`` +- ``x(t) = r(t) * cos(theta(t))`` +- ``y(t) = r(t) * sin(theta(t))`` + +This creates a "breathing" spiral where the cursor rotates while moving +in and out from the center. + + +Diff +~~~~ + +Differentiates position to get velocity. With ``scale_by_fs=True``, the +output is in pixels per second. + +CART2POL (CoordinateSpaces) +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Converts Cartesian velocity (vx, vy) to polar coordinates (magnitude, angle). +This transformation is done once upstream and shared if both spike and LFP +encoding are used (via VelocityEncoder). + +Velocity2LFP +~~~~~~~~~~~~ + +Encodes polar velocity into LFP-like colored noise using a cosine tuning model: + +1. **Cosine encoder:** Each of ``n_lfp_sources`` (default 8) has a random + preferred direction. The spectral exponent beta is computed as: + ``beta = baseline + modulation * magnitude * cos(angle - pd)`` +2. **Clip:** Ensures beta values stay within valid range [0, 2] +3. **Colored noise:** Generate 1/f^β noise with β dynamically modulated per source +4. **Spatial mixing:** Project n_lfp_sources onto output_ch channels using + sinusoidal mixing patterns + +The result is multi-channel colored noise where spectral properties vary +with cursor velocity direction and magnitude. + +LSLOutlet +~~~~~~~~~ + +Streams the output over LSL with name ``SpiralModulatedPinkNoise`` and +type ``EEG``. + +Understanding the Encoding +-------------------------- + +Cosine Tuning Model +~~~~~~~~~~~~~~~~~~~ + +Each of the ``n_lfp_sources`` (default 8) has a randomly assigned preferred +direction. The spectral exponent beta for each source is computed using +a cosine tuning model: + +.. code-block:: python + + beta = baseline + modulation * magnitude * cos(angle - pd) + +With default settings: + +- ``baseline = 1.0`` (pink noise at rest) +- ``modulation = 1.0 / max_velocity`` (scales with velocity) +- ``max_velocity = 315.0`` (pixels/second) + +When moving at maximum velocity in a source's preferred direction, +beta reaches 2.0 (brown noise). When moving opposite to the preferred +direction, beta reaches 0.0 (white noise). The output is clipped to [0, 2]. + +Spectral Exponent Effects +~~~~~~~~~~~~~~~~~~~~~~~~~ + +The spectral exponent β controls the noise color: + +- **β = 0:** White noise (flat spectrum) +- **β = 1:** Pink noise (1/f, equal power per octave) +- **β = 2:** Brown noise (1/f², random walk) + +Each source responds differently to velocity direction based on its +preferred direction, creating a rich mixture of spectral characteristics. + +Spatial Mixing +~~~~~~~~~~~~~~ + +The ``n_lfp_sources`` noise sources are projected onto ``output_ch`` channels +using a mixing matrix based on sinusoidal weights at different spatial +frequencies, plus random perturbations: + +.. code-block:: python + + weights = np.zeros((n_sources, output_ch)) + for i in range(n_sources): + freq = (i + 1) / n_sources + phase = 2 * np.pi * i / n_sources + weights[i, :] = np.sin(2 * np.pi * freq * ch_idx / output_ch + phase) + weights += 0.3 * rng.standard_normal((n_sources, output_ch)) + +This creates spatially-varying patterns where different channels have +different mixtures of the velocity-modulated sources, mimicking the +spatial spread of LFP signals across electrode arrays. + +Verifying the Output +-------------------- + +The spiral motion provides predictable ground truth: + +1. **Varying velocity magnitude:** Oscillates due to radial breathing +2. **Linearly varying angle:** Rotates at ``angular_freq`` Hz +3. **Periodic behavior:** Angular period = 1/angular_freq seconds (4s default), + radial period = 1/radial_freq seconds (10s default) + +You can verify the encoding by: + +1. Recording the LSL stream +2. Computing spectral features from the output +3. Checking that spectral properties correlate with the known velocity pattern + +Example Analysis +~~~~~~~~~~~~~~~~ + +.. code-block:: python + + import numpy as np + from pylsl import StreamInlet, resolve_stream + + # Capture one angular period (4 seconds with default settings) + streams = resolve_stream('name', 'SpiralModulatedPinkNoise') + inlet = StreamInlet(streams[0]) + + samples = [] + for _ in range(int(4 * 30000)): # 4 seconds at 30 kHz + sample, _ = inlet.pull_sample() + samples.append(sample) + + data = np.array(samples) + + # Compute spectrum for each second + from scipy import signal + for i in range(4): + segment = data[i*30000:(i+1)*30000, 0] # First channel + f, psd = signal.welch(segment, fs=30000) + # Compare spectral slope across segments + +See Also +-------- + +- :doc:`mouse_to_ecephys` - Real mouse input with full ecephys output +- :mod:`ezmsg.simbiophys.system.velocity2lfp` - API documentation +- :mod:`ezmsg.simbiophys.cosine_encoder` - Cosine tuning encoder +- :mod:`ezmsg.simbiophys.dynamic_colored_noise` - Colored noise generator +- :mod:`ezmsg.simbiophys.oscillator` - SpiralGenerator and SinGenerator diff --git a/examples/mouse_to_lsl_full.py b/examples/mouse_to_lsl_full.py index 74e6e9f..f0d62be 100644 --- a/examples/mouse_to_lsl_full.py +++ b/examples/mouse_to_lsl_full.py @@ -84,7 +84,7 @@ def main( "SINK": LSLOutletUnit(LSLOutletSettings(stream_name="MouseModulatedRaw", stream_type="ECEPhys")), } conns = ( - (comps["CLOCK"].OUTPUT_SIGNAL, comps["SOURCE"].INPUT_SIGNAL), + (comps["CLOCK"].OUTPUT_SIGNAL, comps["SOURCE"].INPUT_CLOCK), (comps["SOURCE"].OUTPUT_SIGNAL, comps["DIFF"].INPUT_SIGNAL), (comps["DIFF"].OUTPUT_SIGNAL, comps["ENCODER"].INPUT_SIGNAL), (comps["ENCODER"].OUTPUT_SIGNAL, comps["SINK"].INPUT_SIGNAL), diff --git a/examples/circle_to_dynamic_pink_outlet.py b/examples/spiral_to_dynamic_pink_outlet.py similarity index 52% rename from examples/circle_to_dynamic_pink_outlet.py rename to examples/spiral_to_dynamic_pink_outlet.py index 8e0ec22..7288f32 100644 --- a/examples/circle_to_dynamic_pink_outlet.py +++ b/examples/spiral_to_dynamic_pink_outlet.py @@ -1,18 +1,22 @@ -"""Circular motion to velocity-modulated LFP, streamed over LSL. +"""Spiral motion to velocity-modulated LFP, streamed over LSL. -This example generates a simulated cursor moving in a circle, computes its -velocity, encodes the velocity into LFP-like colored noise, and streams the +This example generates a simulated cursor moving in a spiral pattern, computes +its velocity, encodes the velocity into LFP-like colored noise, and streams the result over Lab Streaming Layer (LSL). Pipeline:: - Clock -> Counter -> SinGenerator (circle) -> Diff (velocity) - -> Velocity2LFP -> LSLOutlet + Clock -> SpiralGenerator -> Diff (velocity) -> CART2POL -> Velocity2LFP -> LSLOutlet -The circular motion produces smoothly varying velocity vectors that sweep -through all directions. The Velocity2LFP system converts this into multi-channel -colored noise where the spectral properties are modulated by velocity magnitude -and direction. +The spiral motion produces varying velocity vectors where both the magnitude +(speed) and direction change over time. The SpiralGenerator creates a pattern +where: + - The radius oscillates sinusoidally (breathing in/out) + - The angle increases linearly (rotation) + +This provides richer dynamics than circular motion for testing the velocity +encoding system, as velocity is non-zero and varying even when the cursor +"pauses" at the turning points of the radial oscillation. This is useful for: - Testing LFP processing pipelines with known ground truth @@ -37,13 +41,13 @@ """ import ezmsg.core as ez -import numpy as np import typer -from ezmsg.baseproc import Clock, ClockSettings, Counter, CounterSettings +from ezmsg.baseproc import Clock, ClockSettings from ezmsg.lsl.outlet import LSLOutletSettings, LSLOutletUnit +from ezmsg.sigproc.coordinatespaces import CoordinateMode, CoordinateSpaces, CoordinateSpacesSettings from ezmsg.sigproc.diff import DiffSettings, DiffUnit -from ezmsg.simbiophys.oscillator import SinGenerator, SinGeneratorSettings +from ezmsg.simbiophys.oscillator import SpiralGenerator, SpiralGeneratorSettings from ezmsg.simbiophys.system.velocity2lfp import Velocity2LFP, Velocity2LFPSettings GRAPH_IP = "127.0.0.1" @@ -66,35 +70,39 @@ def main( comps = { "CLOCK": Clock(ClockSettings(dispatch_rate=cursor_fs)), - "COUNTER": Counter(CounterSettings(fs=cursor_fs)), - "OSCILLATOR": SinGenerator( - SinGeneratorSettings( - n_ch=2, # x,y - freq=0.25, # 1/4 Hz = 4 second period - amp=200.0, # radius 200 pixels - phase=[np.pi / 2, 0.0], # [x, y]: cos = sin + π/2, counterclockwise from (200, 0) + "SPIRAL": SpiralGenerator( + SpiralGeneratorSettings( + fs=cursor_fs, + r_mean=150.0, # Mean radius 150 pixels + r_amp=150.0, # Radius oscillates +/- 50 pixels (100-200 range) + radial_freq=0.1, # Radial breathing at 0.1 Hz (10 second period) + angular_freq=0.25, # Rotation at 0.25 Hz (4 second period) ) ), "DIFF": DiffUnit(DiffSettings(axis="time", scale_by_fs=True)), + # DIFF Output is [[vx, vy]] pixels/sec with varying magnitude + "COORDS": CoordinateSpaces(CoordinateSpacesSettings(mode=CoordinateMode.CART2POL, axis="ch")), + # COORDS Output is [[magnitude, angle]] polar velocity "VEL2LFP": Velocity2LFP( Velocity2LFPSettings( output_fs=output_fs, output_ch=output_ch, + max_velocity=472.0, seed=seed, ) ), "SINK": LSLOutletUnit( LSLOutletSettings( - stream_name="CircleModulatedPinkNoise", + stream_name="SpiralModulatedPinkNoise", stream_type="EEG", ) ), } conns = ( - (comps["CLOCK"].OUTPUT_SIGNAL, comps["COUNTER"].INPUT_SIGNAL), - (comps["COUNTER"].OUTPUT_SIGNAL, comps["OSCILLATOR"].INPUT_SIGNAL), - (comps["OSCILLATOR"].OUTPUT_SIGNAL, comps["DIFF"].INPUT_SIGNAL), - (comps["DIFF"].OUTPUT_SIGNAL, comps["VEL2LFP"].INPUT_SIGNAL), + (comps["CLOCK"].OUTPUT_SIGNAL, comps["SPIRAL"].INPUT_CLOCK), + (comps["SPIRAL"].OUTPUT_SIGNAL, comps["DIFF"].INPUT_SIGNAL), + (comps["DIFF"].OUTPUT_SIGNAL, comps["COORDS"].INPUT_SIGNAL), + (comps["COORDS"].OUTPUT_SIGNAL, comps["VEL2LFP"].INPUT_SIGNAL), (comps["VEL2LFP"].OUTPUT_SIGNAL, comps["SINK"].INPUT_SIGNAL), ) diff --git a/pyproject.toml b/pyproject.toml index bc87f32..e0dd1d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ requires-python = ">=3.10.15" dynamic = ["version"] dependencies = [ "ezmsg>=3.6.0", - "ezmsg-baseproc>=1.1.0", + "ezmsg-baseproc>=1.2.1", "ezmsg-event>=0.6.0", "ezmsg-sigproc>=2.8.0", "numpy>=1.26.0", @@ -24,7 +24,7 @@ dev = [ {include-group = "test"}, "typer>=0.20.0", "scipy-stubs>=1.15.3.0", - "ezmsg-peripheraldevice>=0.2.0", + "ezmsg-peripheraldevice>=0.3.0", "ezmsg-lsl>=1.4.2", ] lint = [ diff --git a/src/ezmsg/simbiophys/__init__.py b/src/ezmsg/simbiophys/__init__.py index a03c88a..b6d6e4f 100644 --- a/src/ezmsg/simbiophys/__init__.py +++ b/src/ezmsg/simbiophys/__init__.py @@ -14,24 +14,23 @@ from .__version__ import __version__ as __version__ -# Cosine Tuning -from .cosine_tuning import ( - CosineTuningParams, - CosineTuningSettings, - CosineTuningState, - CosineTuningTransformer, - CosineTuningUnit, +# Cosine Encoder +from .cosine_encoder import ( + CosineEncoderSettings, + CosineEncoderState, + CosineEncoderTransformer, + CosineEncoderUnit, ) # DNSS (Digital Neural Signal Simulator) from .dnss import ( # LFP + DNSSLFPProducer, DNSSLFPSettings, - DNSSLFPTransformer, DNSSLFPUnit, # Spike + DNSSSpikeProducer, DNSSSpikeSettings, - DNSSSpikeTransformer, DNSSSpikeUnit, ) @@ -54,18 +53,24 @@ # Noise from .noise import ( PinkNoise, + PinkNoiseProducer, PinkNoiseSettings, - PinkNoiseTransformer, WhiteNoise, + WhiteNoiseProducer, WhiteNoiseSettings, - WhiteNoiseTransformer, + WhiteNoiseState, ) # Oscillator from .oscillator import ( SinGenerator, SinGeneratorSettings, - SinTransformer, + SinGeneratorState, + SinProducer, + SpiralGenerator, + SpiralGeneratorSettings, + SpiralGeneratorState, + SpiralProducer, ) __all__ = [ @@ -84,23 +89,28 @@ # Oscillator "SinGenerator", "SinGeneratorSettings", - "SinTransformer", + "SinGeneratorState", + "SinProducer", + "SpiralGenerator", + "SpiralGeneratorSettings", + "SpiralGeneratorState", + "SpiralProducer", # Noise "PinkNoise", + "PinkNoiseProducer", "PinkNoiseSettings", - "PinkNoiseTransformer", "WhiteNoise", + "WhiteNoiseProducer", "WhiteNoiseSettings", - "WhiteNoiseTransformer", + "WhiteNoiseState", # EEG "EEGSynth", "EEGSynthSettings", - # Cosine Tuning - "CosineTuningParams", - "CosineTuningSettings", - "CosineTuningState", - "CosineTuningTransformer", - "CosineTuningUnit", + # Cosine Encoder + "CosineEncoderSettings", + "CosineEncoderState", + "CosineEncoderTransformer", + "CosineEncoderUnit", # Dynamic Colored Noise "ColoredNoiseFilterState", "DynamicColoredNoiseSettings", @@ -109,11 +119,11 @@ "DynamicColoredNoiseUnit", "compute_kasdin_coefficients", # DNSS LFP + "DNSSLFPProducer", "DNSSLFPSettings", - "DNSSLFPTransformer", "DNSSLFPUnit", # DNSS Spike + "DNSSSpikeProducer", "DNSSSpikeSettings", - "DNSSSpikeTransformer", "DNSSSpikeUnit", ] diff --git a/src/ezmsg/simbiophys/cosine_encoder.py b/src/ezmsg/simbiophys/cosine_encoder.py new file mode 100644 index 0000000..808faee --- /dev/null +++ b/src/ezmsg/simbiophys/cosine_encoder.py @@ -0,0 +1,249 @@ +"""Generic cosine-tuning encoder for polar coordinates. + +This module provides a generalized cosine-tuning encoder that maps polar +coordinates (magnitude, angle) to multiple output channels with configurable +preferred directions, baseline, and modulation parameters. + +The encoding formula is: + output = baseline + modulation * magnitude * cos(angle - preferred_direction) + + speed_modulation * magnitude + +This implements the offset model from "Decoding arm speed during reaching" +(https://ncbi.nlm.nih.gov/pmc/articles/PMC6286377/) with generic terminology +suitable for various applications: + - Neural firing rate encoding (baseline=10Hz, modulation=20Hz) + - LFP spectral parameter modulation (baseline=1.0, modulation=0.5) + - Any other cosine-tuning based encoding + +Input: + Polar coordinates (magnitude, angle) as AxisArray with shape (n_samples, 2). + Use CoordinateSpaces(mode=CART2POL) upstream to convert from Cartesian. + +Output: + AxisArray with shape (n_samples, output_ch) containing encoded values. +""" + +from pathlib import Path + +import ezmsg.core as ez +import numpy as np +import numpy.typing as npt +from ezmsg.baseproc import ( + BaseStatefulTransformer, + BaseTransformerUnit, + processor_state, +) +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace + + +class CosineEncoderSettings(ez.Settings): + """Settings for CosineEncoder. + + Either `model_file` OR the random generation parameters should be specified. + If `model_file` is provided, parameters are loaded from file. + Otherwise, parameters are randomly generated. + """ + + # File-based parameters + model_file: str | None = None + """Path to .npz file with encoder parameters (baseline, modulation, pd, speed_modulation). + Also supports legacy neural tuning files with keys (b0, m, pd, bs).""" + + # Random generation parameters + output_ch: int = 10 + """Number of output channels (used if model_file is None).""" + + baseline: float = 0.0 + """Baseline output value for all channels (used if model_file is None).""" + + modulation: float = 1.0 + """Directional modulation depth for all channels (used if model_file is None).""" + + speed_modulation: float = 0.0 + """Speed modulation (non-directional) for all channels (used if model_file is None).""" + + seed: int | None = None + """Random seed for reproducibility of preferred directions (used if model_file is None).""" + + +@processor_state +class CosineEncoderState: + """State for cosine encoder transformer. + + Holds the per-channel encoding parameters. All arrays have shape (1, output_ch) + for efficient broadcasting during processing. + + Attributes: + baseline: Baseline output value for each channel. + modulation: Directional modulation depth for each channel. + pd: Preferred direction (radians) for each channel. + speed_modulation: Speed modulation (non-directional) for each channel. + ch_axis: Pre-built channel axis for output messages. + """ + + baseline: npt.NDArray[np.floating] | None = None + modulation: npt.NDArray[np.floating] | None = None + pd: npt.NDArray[np.floating] | None = None + speed_modulation: npt.NDArray[np.floating] | None = None + ch_axis: AxisArray.CoordinateAxis | None = None + + @property + def output_ch(self) -> int: + """Number of output channels.""" + return self.baseline.shape[1] if self.baseline is not None else 0 + + def validate(self) -> None: + """Validate that all parameters have consistent shapes.""" + if any(x is None for x in [self.baseline, self.modulation, self.pd, self.speed_modulation]): + raise ValueError("All parameters must be set") + if not (self.baseline.shape == self.modulation.shape == self.pd.shape == self.speed_modulation.shape): + raise ValueError("All parameters must have the same shape") + if self.baseline.ndim != 2 or self.baseline.shape[0] != 1: + raise ValueError("Parameters must have shape (1, output_ch)") + if self.baseline.shape[1] < 1: + raise ValueError("Parameters must have at least 1 channel") + + def load_from_file( + self, + filepath: str | Path, + output_ch: int | None = None, + ) -> None: + """Load parameters from a .npz file. + + The file should contain arrays with keys matching the parameter names. + For backwards compatibility with neural tuning files, the following + key mappings are supported: + - 'b0' -> baseline + - 'm' -> modulation + - 'pd' -> pd (preferred direction) + - 'bs' -> speed_modulation + + Args: + filepath: Path to .npz file containing parameter arrays. + output_ch: Number of channels to use. If None, uses all in file. + """ + params = np.load(filepath) + + # Support both new names and legacy neural tuning names + baseline = np.asarray(params.get("baseline", params.get("b0")), dtype=np.float64).ravel() + modulation = np.asarray(params.get("modulation", params.get("m")), dtype=np.float64).ravel() + pd = np.asarray(params["pd"], dtype=np.float64).ravel() + speed_modulation = np.asarray(params.get("speed_modulation", params.get("bs")), dtype=np.float64).ravel() + + if output_ch is not None: + baseline = baseline[:output_ch] + modulation = modulation[:output_ch] + pd = pd[:output_ch] + speed_modulation = speed_modulation[:output_ch] + + # Reshape to (1, output_ch) for broadcasting + self.baseline = baseline[np.newaxis, :] + self.modulation = modulation[np.newaxis, :] + self.pd = pd[np.newaxis, :] + self.speed_modulation = speed_modulation[np.newaxis, :] + + # Create channel axis for output messages + ch_labels = np.array([f"ch{i}" for i in range(len(baseline))]) + self.ch_axis = AxisArray.CoordinateAxis(data=ch_labels, dims=["ch"]) + + self.validate() + + def init_random( + self, + output_ch: int, + baseline: float = 0.0, + modulation: float = 1.0, + speed_modulation: float = 0.0, + seed: int | None = None, + ) -> None: + """Initialize encoder parameters with random preferred directions. + + Args: + output_ch: Number of output channels. + baseline: Baseline value for all channels. + modulation: Directional modulation depth for all channels. + speed_modulation: Speed modulation (non-directional) for all channels. + seed: Random seed for reproducibility. + """ + rng = np.random.default_rng(seed) + + # Shape (1, output_ch) for efficient broadcasting + self.baseline = np.full((1, output_ch), baseline, dtype=np.float64) + self.modulation = np.full((1, output_ch), modulation, dtype=np.float64) + self.pd = rng.uniform(0.0, 2.0 * np.pi, size=(1, output_ch)).astype(np.float64) + self.speed_modulation = np.full((1, output_ch), speed_modulation, dtype=np.float64) + + # Create channel axis for output messages + ch_labels = np.array([f"ch{i}" for i in range(output_ch)]) + self.ch_axis = AxisArray.CoordinateAxis(data=ch_labels, dims=["ch"]) + + self.validate() + + +class CosineEncoderTransformer( + BaseStatefulTransformer[CosineEncoderSettings, AxisArray, AxisArray, CosineEncoderState] +): + """Transform polar coordinates to multi-channel encoded output. + + Input: AxisArray with shape (n_samples, 2) containing polar coordinates + (magnitude, angle) where magnitude is speed and angle is direction. + Output: AxisArray with shape (n_samples, output_ch) containing encoded values. + + The encoding formula is: + output = baseline + modulation * magnitude * cos(angle - pd) + + speed_modulation * magnitude + + This is a generic encoder suitable for various applications including: + - Neural firing rate encoding (baseline=10Hz, modulation=20Hz) + - LFP spectral parameter modulation (baseline=1.0, modulation=0.5) + - Any other cosine-tuning based encoding + """ + + def _reset_state(self, message: AxisArray) -> None: + """Initialize encoder parameters.""" + if self.settings.model_file is not None: + self.state.load_from_file( + self.settings.model_file, + output_ch=None, # Use all channels from file + ) + else: + self.state.init_random( + output_ch=self.settings.output_ch, + baseline=self.settings.baseline, + modulation=self.settings.modulation, + speed_modulation=self.settings.speed_modulation, + seed=self.settings.seed, + ) + + def _process(self, message: AxisArray) -> AxisArray: + """Transform polar coordinates to encoded output.""" + polar = np.asarray(message.data, dtype=np.float64) + + if polar.ndim != 2 or polar.shape[1] != 2: + raise ValueError(f"Expected polar coords with shape (n_samples, 2), got {polar.shape}") + + # Extract polar components (from CART2POL: magnitude, angle) + magnitude = polar[:, 0:1] # (n_samples, 1) + angle = polar[:, 1:2] # (n_samples, 1) + + # Compute output: baseline + modulation * magnitude * cos(angle - pd) + speed_mod * magnitude + # State arrays are pre-shaped to (1, output_ch) for broadcasting + output = ( + self.state.baseline + + self.state.modulation * magnitude * np.cos(angle - self.state.pd) + + self.state.speed_modulation * magnitude + ) + + return replace( + message, + data=output, + dims=["time", "ch"], + axes={**message.axes, "ch": self.state.ch_axis}, + ) + + +class CosineEncoderUnit(BaseTransformerUnit[CosineEncoderSettings, AxisArray, AxisArray, CosineEncoderTransformer]): + """Unit wrapper for CosineEncoderTransformer.""" + + SETTINGS = CosineEncoderSettings diff --git a/src/ezmsg/simbiophys/cosine_tuning.py b/src/ezmsg/simbiophys/cosine_tuning.py deleted file mode 100644 index 8b7b02e..0000000 --- a/src/ezmsg/simbiophys/cosine_tuning.py +++ /dev/null @@ -1,249 +0,0 @@ -"""Cosine tuning model for neural encoding of velocity/movement. - -Implements the offset model from "Decoding arm speed during reaching" -(https://ncbi.nlm.nih.gov/pmc/articles/PMC6286377/): - - firing_rate = b0 + m * |v| * cos(θ - θ_pd) + bs * |v| - -Where: - - b0: baseline firing rate - - m: directional modulation depth - - θ: velocity direction (angle) - - θ_pd: preferred direction - - bs: speed modulation (non-directional) - - |v|: velocity magnitude (speed) - -For spike generation from firing rates, use EventsFromRatesTransformer -from ezmsg-event. -""" - -from dataclasses import dataclass -from pathlib import Path - -import ezmsg.core as ez -import numpy as np -import numpy.typing as npt -from ezmsg.baseproc import ( - BaseStatefulTransformer, - BaseTransformerUnit, - processor_state, -) -from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.util.messages.util import replace - - -@dataclass -class CosineTuningParams: - """Parameters for cosine tuning model. - - All arrays must have the same shape (n_units,). - - Attributes: - b0: Baseline firing rate (Hz) for each unit. - m: Directional modulation depth for each unit. - pd: Preferred direction (radians) for each unit. - bs: Speed modulation (non-directional) for each unit. - """ - - b0: npt.NDArray[np.floating] - m: npt.NDArray[np.floating] - pd: npt.NDArray[np.floating] - bs: npt.NDArray[np.floating] - - def __post_init__(self): - """Validate that all parameters have consistent shapes.""" - if not (self.b0.shape == self.m.shape == self.pd.shape == self.bs.shape): - raise ValueError("All parameters must have the same shape") - if self.b0.ndim != 1: - raise ValueError("Parameters must be 1D arrays") - if len(self.b0) < 1: - raise ValueError("Parameters must have length >= 1") - - @property - def n_units(self) -> int: - """Number of neural units.""" - return len(self.b0) - - @classmethod - def from_file( - cls, - filepath: str | Path, - n_units: int | None = None, - weight_gain: float = 1.0, - ) -> "CosineTuningParams": - """Load parameters from a .npz file. - - Args: - filepath: Path to .npz file containing b0, m, pd, bs arrays. - n_units: Number of units to use. If None, uses all units in file. - weight_gain: Scaling factor applied to m and bs parameters. - - Returns: - CosineTuningParams instance. - """ - params = np.load(filepath) - - b0 = np.asarray(params["b0"], dtype=np.float64) - m = np.asarray(params["m"], dtype=np.float64) - pd = np.asarray(params["pd"], dtype=np.float64) - bs = np.asarray(params["bs"], dtype=np.float64) - - if n_units is not None: - b0 = b0[:n_units] - m = m[:n_units] - pd = pd[:n_units] - bs = bs[:n_units] - - m = m * weight_gain - bs = bs * weight_gain - - return cls(b0=b0, m=m, pd=pd, bs=bs) - - @classmethod - def from_random( - cls, - n_units: int, - baseline_hz: float = 10.0, - modulation_hz: float = 20.0, - speed_modulation_hz: float = 0.0, - seed: int | None = None, - ) -> "CosineTuningParams": - """Generate random tuning parameters. - - Args: - n_units: Number of neural units. - baseline_hz: Baseline firing rate (Hz) for all units. - modulation_hz: Directional modulation depth for all units. - speed_modulation_hz: Speed modulation (non-directional) for all units. - seed: Random seed for reproducibility. - - Returns: - CosineTuningParams instance with random preferred directions. - """ - rng = np.random.default_rng(seed) - - return cls( - b0=np.full(n_units, baseline_hz, dtype=np.float64), - m=np.full(n_units, modulation_hz, dtype=np.float64), - pd=rng.uniform(0.0, 2.0 * np.pi, size=n_units).astype(np.float64), - bs=np.full(n_units, speed_modulation_hz, dtype=np.float64), - ) - - -class CosineTuningSettings(ez.Settings): - """Settings for CosineTuningTransformer. - - Either `model_file` OR the random generation parameters should be specified. - If `model_file` is provided, parameters are loaded from file. - Otherwise, parameters are randomly generated. - """ - - # File-based parameters - model_file: str | None = None - """Path to .npz file with tuning parameters (b0, m, pd, bs).""" - - weight_gain: float = 1.0 - """Scaling factor for m and bs when loading from file.""" - - # Random generation parameters - n_units: int = 50 - """Number of neural units (used if model_file is None).""" - - baseline_hz: float = 10.0 - """Baseline firing rate in Hz (used if model_file is None).""" - - modulation_hz: float = 20.0 - """Directional modulation depth in Hz (used if model_file is None).""" - - speed_modulation_hz: float = 0.0 - """Speed modulation (non-directional) in Hz (used if model_file is None).""" - - seed: int | None = None - """Random seed for reproducibility (used if model_file is None).""" - - # Output settings - min_rate: float = 0.0 - """Minimum firing rate (Hz). Rates are clipped to this value.""" - - -@processor_state -class CosineTuningState: - """State for cosine tuning transformer.""" - - params: CosineTuningParams | None = None - """Tuning curve parameters.""" - - -class CosineTuningTransformer(BaseStatefulTransformer[CosineTuningSettings, AxisArray, AxisArray, CosineTuningState]): - """Transform 2D velocity into firing rates using cosine tuning model. - - Input: AxisArray with shape (n_samples, 2) containing velocity (vx, vy). - Output: AxisArray with shape (n_samples, n_units) containing firing rates (Hz). - - The model implements: - rate = b0 + m * |v| * cos(θ - θ_pd) + bs * |v| - - For spike generation, chain with EventsFromRatesTransformer from ezmsg-event. - """ - - def _reset_state(self, message: AxisArray) -> None: - """Initialize tuning parameters.""" - if self.settings.model_file is not None: - self.state.params = CosineTuningParams.from_file( - self.settings.model_file, - n_units=None, # Use all units from file - weight_gain=self.settings.weight_gain, - ) - else: - self.state.params = CosineTuningParams.from_random( - n_units=self.settings.n_units, - baseline_hz=self.settings.baseline_hz, - modulation_hz=self.settings.modulation_hz, - speed_modulation_hz=self.settings.speed_modulation_hz, - seed=self.settings.seed, - ) - - def _process(self, message: AxisArray) -> AxisArray: - """Transform velocity to firing rates.""" - v = np.asarray(message.data, dtype=np.float64) - - if v.ndim != 2 or v.shape[1] != 2: - raise ValueError(f"Expected velocity with shape (n_samples, 2), got {v.shape}") - - # Extract velocity components - vx = v[:, 0] - vy = v[:, 1] - - # Calculate speed (magnitude) and direction (angle) - speed = np.hypot(vx, vy)[:, np.newaxis] # (n_samples, 1) - theta = np.arctan2(vy, vx)[:, np.newaxis] # (n_samples, 1) - - # Get parameters as row vectors for broadcasting - params = self.state.params - b0 = params.b0[np.newaxis, :] # (1, n_units) - m = params.m[np.newaxis, :] # (1, n_units) - pd = params.pd[np.newaxis, :] # (1, n_units) - bs = params.bs[np.newaxis, :] # (1, n_units) - - # Compute firing rates: b0 + m * |v| * cos(θ - θ_pd) + bs * |v| - rates = b0 + m * speed * np.cos(theta - pd) + bs * speed - - # Clip to minimum rate - rates = np.maximum(rates, self.settings.min_rate) - - # Create channel axis - ch_labels = np.array([f"unit{i}" for i in range(params.n_units)]) - ch_axis = AxisArray.CoordinateAxis(data=ch_labels, dims=["ch"]) - - return replace( - message, - data=rates, - dims=["time", "ch"], - axes={**message.axes, "ch": ch_axis}, - ) - - -class CosineTuningUnit(BaseTransformerUnit[CosineTuningSettings, AxisArray, AxisArray, CosineTuningTransformer]): - """Unit wrapper for CosineTuningTransformer.""" - - SETTINGS = CosineTuningSettings diff --git a/src/ezmsg/simbiophys/dnss/__init__.py b/src/ezmsg/simbiophys/dnss/__init__.py index cd90df9..7f61616 100644 --- a/src/ezmsg/simbiophys/dnss/__init__.py +++ b/src/ezmsg/simbiophys/dnss/__init__.py @@ -7,9 +7,9 @@ LFP_PERIOD, LFP_TIME_SHIFTS, OTHER_PERIOD, + DNSSLFPProducer, DNSSLFPSettings, - DNSSLFPTransformer, - DNSSLFPTransformerState, + DNSSLFPState, DNSSLFPUnit, lfp_generator, ) @@ -22,9 +22,9 @@ N_SLOW_SPIKES, SAMPS_BURST, SAMPS_SLOW, + DNSSSpikeProducer, DNSSSpikeSettings, - DNSSSpikeTransformer, - DNSSSpikeTransformerState, + DNSSSpikeState, DNSSSpikeUnit, spike_event_generator, ) @@ -43,9 +43,9 @@ "LFP_TIME_SHIFTS", "OTHER_PERIOD", # LFP classes + "DNSSLFPProducer", "DNSSLFPSettings", - "DNSSLFPTransformer", - "DNSSLFPTransformerState", + "DNSSLFPState", "DNSSLFPUnit", "lfp_generator", # Spike constants @@ -58,9 +58,9 @@ "SAMPS_BURST", "SAMPS_SLOW", # Spike classes + "DNSSSpikeProducer", "DNSSSpikeSettings", - "DNSSSpikeTransformer", - "DNSSSpikeTransformerState", + "DNSSSpikeState", "DNSSSpikeUnit", "spike_event_generator", # Synth classes diff --git a/src/ezmsg/simbiophys/dnss/lfp.py b/src/ezmsg/simbiophys/dnss/lfp.py index fcebbee..feb7d2c 100644 --- a/src/ezmsg/simbiophys/dnss/lfp.py +++ b/src/ezmsg/simbiophys/dnss/lfp.py @@ -27,15 +27,16 @@ from typing import Generator -import ezmsg.core as ez import numpy as np import numpy.typing as npt from ezmsg.baseproc import ( - BaseStatefulTransformer, - BaseTransformerUnit, + BaseClockDrivenProducer, + BaseClockDrivenUnit, + ClockDrivenSettings, + ClockDrivenState, processor_state, ) -from ezmsg.util.messages.axisarray import AxisArray, replace +from ezmsg.util.messages.axisarray import AxisArray, LinearAxis, replace # Default sample rate for DNSS DEFAULT_FS = 30_000 @@ -196,8 +197,11 @@ def lfp_generator( # ============================================================================= -class DNSSLFPSettings(ez.Settings): - """Settings for DNSS LFP transformer.""" +class DNSSLFPSettings(ClockDrivenSettings): + """Settings for DNSS LFP producer.""" + + fs: float = DEFAULT_FS + """Sample rate in Hz. DNSS is fixed at 30kHz.""" n_ch: int = 256 """Number of channels.""" @@ -210,30 +214,28 @@ class DNSSLFPSettings(ez.Settings): @processor_state -class DNSSLFPTransformerState: - """State for DNSS LFP transformer.""" +class DNSSLFPState(ClockDrivenState): + """State for DNSS LFP producer.""" lfp_gen: Generator | None = None template: AxisArray | None = None -class DNSSLFPTransformer(BaseStatefulTransformer[DNSSLFPSettings, AxisArray, AxisArray, DNSSLFPTransformerState]): +class DNSSLFPProducer(BaseClockDrivenProducer[DNSSLFPSettings, DNSSLFPState]): """ - Transforms input AxisArray into DNSS LFP signal. + Produces DNSS LFP signal synchronized to clock ticks. - Takes timing information from input message and generates LFP data. + Each clock tick produces a block of LFP data based on the + sample rate (fs) and chunk size (n_time) settings. All channels receive identical LFP values. """ - def _reset_state(self, message: AxisArray) -> None: + def _reset_state(self, time_axis: LinearAxis) -> None: """Initialize the LFP generator.""" - # Get sample rate from input message (fs = 1/gain for LinearAxis) - time_axis = message.axes["time"] - fs = getattr(time_axis, "fs", 1.0 / time_axis.gain) self._state.lfp_gen = lfp_generator( pattern=self.settings.pattern, mode=self.settings.mode, - fs=fs, + fs=self.settings.fs, ) next(self._state.lfp_gen) @@ -242,7 +244,7 @@ def _reset_state(self, message: AxisArray) -> None: data=np.zeros((0, self.settings.n_ch), dtype=np.float64), dims=["time", "ch"], axes={ - "time": message.axes["time"], + "time": time_axis, "ch": AxisArray.CoordinateAxis( data=np.arange(self.settings.n_ch), dims=["ch"], @@ -250,10 +252,8 @@ def _reset_state(self, message: AxisArray) -> None: }, ) - def _process(self, message: AxisArray) -> AxisArray: - """Transform input into LFP signal.""" - n_samples = message.data.shape[0] - + def _produce(self, n_samples: int, time_axis: LinearAxis) -> AxisArray: + """Generate LFP signal for this chunk.""" # Generate LFP samples lfp_1d = self._state.lfp_gen.send(n_samples) @@ -267,13 +267,13 @@ def _process(self, message: AxisArray) -> AxisArray: self._state.template, data=lfp_data, axes={ - "time": message.axes["time"], - "ch": self._state.template.axes["ch"], + **self._state.template.axes, + "time": time_axis, }, ) -class DNSSLFPUnit(BaseTransformerUnit[DNSSLFPSettings, AxisArray, AxisArray, DNSSLFPTransformer]): - """Unit for generating DNSS LFP from counter input.""" +class DNSSLFPUnit(BaseClockDrivenUnit[DNSSLFPSettings, DNSSLFPProducer]): + """Unit for generating DNSS LFP from clock input.""" SETTINGS = DNSSLFPSettings diff --git a/src/ezmsg/simbiophys/dnss/spike.py b/src/ezmsg/simbiophys/dnss/spike.py index f44c69e..986c269 100644 --- a/src/ezmsg/simbiophys/dnss/spike.py +++ b/src/ezmsg/simbiophys/dnss/spike.py @@ -2,16 +2,17 @@ from typing import Generator -import ezmsg.core as ez import numpy as np import numpy.typing as npt import sparse from ezmsg.baseproc import ( - BaseStatefulTransformer, - BaseTransformerUnit, + BaseClockDrivenProducer, + BaseClockDrivenUnit, + ClockDrivenSettings, + ClockDrivenState, processor_state, ) -from ezmsg.util.messages.axisarray import AxisArray, replace +from ezmsg.util.messages.axisarray import AxisArray, LinearAxis, replace """ ## Spike Pattern @@ -258,8 +259,11 @@ def spike_event_generator( # ============================================================================= -class DNSSSpikeSettings(ez.Settings): - """Settings for DNSS spike transformer.""" +class DNSSSpikeSettings(ClockDrivenSettings): + """Settings for DNSS spike producer.""" + + fs: float = FS + """Sample rate in Hz. DNSS is fixed at 30kHz.""" n_ch: int = 256 """Number of channels.""" @@ -269,30 +273,28 @@ class DNSSSpikeSettings(ez.Settings): @processor_state -class DNSSSpikeTransformerState: - """State for DNSS spike transformer.""" +class DNSSSpikeState(ClockDrivenState): + """State for DNSS spike producer.""" spike_gen: Generator | None = None template: AxisArray | None = None -class DNSSSpikeTransformer(BaseStatefulTransformer[DNSSSpikeSettings, AxisArray, AxisArray, DNSSSpikeTransformerState]): +class DNSSSpikeProducer(BaseClockDrivenProducer[DNSSSpikeSettings, DNSSSpikeState]): """ - Transforms input AxisArray into DNSS spike signal. + Produces DNSS spike signal synchronized to clock ticks. - Takes timing information from input message and generates spike data as sparse COO arrays. + Each clock tick produces a block of spike data as sparse COO arrays + based on the sample rate (fs) and chunk size (n_time) settings. """ - def _reset_state(self, message: AxisArray) -> None: + def _reset_state(self, time_axis: LinearAxis) -> None: """Initialize the spike generator.""" # Verify sample rate is 30kHz - spike patterns are tied to this rate - time_axis = message.axes["time"] - fs = getattr(time_axis, "fs", 1.0 / time_axis.gain) - expected_gain = 1.0 / FS - if not np.isclose(time_axis.gain, expected_gain, rtol=1e-6): + if not np.isclose(self.settings.fs, FS, rtol=1e-6): raise ValueError( - f"DNSSSpikeTransformer requires fs={FS} Hz (gain={expected_gain:.6e}), " - f"but received fs={fs:.1f} Hz (gain={time_axis.gain:.6e}). " + f"DNSSSpikeProducer requires fs={FS} Hz, " + f"but settings.fs={self.settings.fs:.1f} Hz. " f"Spike patterns cannot be resampled to other rates." ) @@ -311,7 +313,7 @@ def _reset_state(self, message: AxisArray) -> None: ), dims=["time", "ch"], axes={ - "time": message.axes["time"], + "time": time_axis, "ch": AxisArray.CoordinateAxis( data=np.arange(self.settings.n_ch), dims=["ch"], @@ -319,10 +321,8 @@ def _reset_state(self, message: AxisArray) -> None: }, ) - def _process(self, message: AxisArray) -> AxisArray: - """Transform input into spike signal.""" - n_samples = message.data.shape[0] - + def _produce(self, n_samples: int, time_axis: LinearAxis) -> AxisArray: + """Generate spike signal for this chunk.""" # Generate spike events coords, waveform_ids = self._state.spike_gen.send(n_samples) @@ -337,13 +337,13 @@ def _process(self, message: AxisArray) -> AxisArray: self._state.template, data=spike_data, axes={ - "time": message.axes["time"], - "ch": self._state.template.axes["ch"], + **self._state.template.axes, + "time": time_axis, }, ) -class DNSSSpikeUnit(BaseTransformerUnit[DNSSSpikeSettings, AxisArray, AxisArray, DNSSSpikeTransformer]): - """Unit for generating DNSS spikes from counter input.""" +class DNSSSpikeUnit(BaseClockDrivenUnit[DNSSSpikeSettings, DNSSSpikeProducer]): + """Unit for generating DNSS spikes from clock input.""" SETTINGS = DNSSSpikeSettings diff --git a/src/ezmsg/simbiophys/dnss/synth.py b/src/ezmsg/simbiophys/dnss/synth.py index 7000515..46cc33f 100644 --- a/src/ezmsg/simbiophys/dnss/synth.py +++ b/src/ezmsg/simbiophys/dnss/synth.py @@ -2,14 +2,14 @@ import ezmsg.core as ez import numpy as np -from ezmsg.baseproc import Clock, ClockSettings, Counter, CounterSettings +from ezmsg.baseproc import Clock, ClockSettings from ezmsg.event.kernel import ArrayKernel, MultiKernel from ezmsg.event.kernel_insert import SparseKernelInserterSettings, SparseKernelInserterUnit from ezmsg.sigproc.math.add import Add from ezmsg.util.messages.axisarray import AxisArray -from .lfp import DNSSLFPSettings, DNSSLFPUnit -from .spike import FS, DNSSSpikeSettings, DNSSSpikeUnit +from .lfp import DEFAULT_FS, DNSSLFPSettings, DNSSLFPUnit +from .spike import DNSSSpikeSettings, DNSSSpikeUnit from .wfs import wf_orig @@ -61,7 +61,7 @@ class DNSSSynth(ez.Collection): The final output is the sum of spike waveforms and LFP signal. Network flow: - Clock -> Counter -> {SpikeGenerator, LFPGenerator} + Clock -> {SpikeGenerator, LFPGenerator} SpikeGenerator -> KernelInserter -> Add.A LFPGenerator -> Add.B Add -> OUTPUT @@ -74,9 +74,6 @@ class DNSSSynth(ez.Collection): # Clock produces timestamps at the block rate CLOCK = Clock() - # Counter converts timestamps to AxisArray with timing metadata - COUNTER = Counter() - # Spike path: produces sparse events, then inserts waveforms SPIKE = DNSSSpikeUnit() KERNEL_INSERT = SparseKernelInserterUnit() @@ -89,19 +86,13 @@ class DNSSSynth(ez.Collection): def configure(self) -> None: # Calculate dispatch rate for blocks (DNSS is fixed at 30kHz) - dispatch_rate = FS / self.SETTINGS.n_time + dispatch_rate = DEFAULT_FS / self.SETTINGS.n_time self.CLOCK.apply_settings(ClockSettings(dispatch_rate=dispatch_rate)) - self.COUNTER.apply_settings( - CounterSettings( - n_time=self.SETTINGS.n_time, - fs=FS, - ) - ) - self.SPIKE.apply_settings( DNSSSpikeSettings( + n_time=self.SETTINGS.n_time, n_ch=self.SETTINGS.n_ch, mode=self.SETTINGS.mode, ) @@ -118,6 +109,7 @@ def configure(self) -> None: self.LFP.apply_settings( DNSSLFPSettings( + n_time=self.SETTINGS.n_time, n_ch=self.SETTINGS.n_ch, pattern=self.SETTINGS.lfp_pattern, mode=self.SETTINGS.mode, @@ -126,11 +118,9 @@ def configure(self) -> None: def network(self) -> ez.NetworkDefinition: return ( - # Clock drives Counter - (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_SIGNAL), - # Counter fans out to both Spike and LFP generators - (self.COUNTER.OUTPUT_SIGNAL, self.SPIKE.INPUT_SIGNAL), - (self.COUNTER.OUTPUT_SIGNAL, self.LFP.INPUT_SIGNAL), + # Clock drives Spike and LFP generators directly + (self.CLOCK.OUTPUT_SIGNAL, self.SPIKE.INPUT_CLOCK), + (self.CLOCK.OUTPUT_SIGNAL, self.LFP.INPUT_CLOCK), # Spike path: insert waveforms (self.SPIKE.OUTPUT_SIGNAL, self.KERNEL_INSERT.INPUT_SIGNAL), # Combine spike waveforms and LFP diff --git a/src/ezmsg/simbiophys/eeg.py b/src/ezmsg/simbiophys/eeg.py index 6ff9a2d..be7bc8a 100644 --- a/src/ezmsg/simbiophys/eeg.py +++ b/src/ezmsg/simbiophys/eeg.py @@ -1,7 +1,7 @@ """EEG signal synthesis.""" import ezmsg.core as ez -from ezmsg.baseproc import Clock, ClockSettings, Counter, CounterSettings +from ezmsg.baseproc import Clock, ClockSettings from ezmsg.sigproc.math.add import Add from ezmsg.util.messages.axisarray import AxisArray @@ -29,11 +29,11 @@ class EEGSynth(ez.Collection): """ A Collection that generates synthetic EEG signals. - Combines pink noise with alpha oscillations using a diamond flow: - Clock -> Counter -> {Noise, Oscillator} -> Add -> Output + Combines white noise with alpha oscillations using a diamond flow: + Clock -> {Noise, Oscillator} -> Add -> Output Network flow: - Clock -> Counter -> {Noise, Oscillator} + Clock -> {Noise, Oscillator} Noise -> Add.A Oscillator -> Add.B Add -> OUTPUT @@ -44,7 +44,6 @@ class EEGSynth(ez.Collection): OUTPUT_SIGNAL = ez.OutputStream(AxisArray) CLOCK = Clock() - COUNTER = Counter() NOISE = WhiteNoise() OSC = SinGenerator() ADD = Add() @@ -53,15 +52,10 @@ def configure(self) -> None: dispatch_rate = self.SETTINGS.fs / self.SETTINGS.n_time self.CLOCK.apply_settings(ClockSettings(dispatch_rate=dispatch_rate)) - self.COUNTER.apply_settings( - CounterSettings( - fs=self.SETTINGS.fs, - n_time=self.SETTINGS.n_time, - ) - ) - self.NOISE.apply_settings( WhiteNoiseSettings( + fs=self.SETTINGS.fs, + n_time=self.SETTINGS.n_time, n_ch=self.SETTINGS.n_ch, scale=5.0, ) @@ -69,6 +63,8 @@ def configure(self) -> None: self.OSC.apply_settings( SinGeneratorSettings( + fs=self.SETTINGS.fs, + n_time=self.SETTINGS.n_time, n_ch=self.SETTINGS.n_ch, freq=self.SETTINGS.alpha_freq, ) @@ -76,11 +72,9 @@ def configure(self) -> None: def network(self) -> ez.NetworkDefinition: return ( - # Clock drives Counter - (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_SIGNAL), - # Counter fans out to both Noise and Oscillator - (self.COUNTER.OUTPUT_SIGNAL, self.OSC.INPUT_SIGNAL), - (self.COUNTER.OUTPUT_SIGNAL, self.NOISE.INPUT_SIGNAL), + # Clock drives Noise and Oscillator directly + (self.CLOCK.OUTPUT_SIGNAL, self.OSC.INPUT_CLOCK), + (self.CLOCK.OUTPUT_SIGNAL, self.NOISE.INPUT_CLOCK), # Combine outputs (self.OSC.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_A), (self.NOISE.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_B), diff --git a/src/ezmsg/simbiophys/noise.py b/src/ezmsg/simbiophys/noise.py index a9c1140..b1eff63 100644 --- a/src/ezmsg/simbiophys/noise.py +++ b/src/ezmsg/simbiophys/noise.py @@ -1,11 +1,12 @@ """Noise signal generators.""" -import ezmsg.core as ez import numpy as np from ezmsg.baseproc import ( + BaseClockDrivenProducer, + BaseClockDrivenUnit, BaseProcessor, - BaseStatefulTransformer, - BaseTransformerUnit, + ClockDrivenSettings, + ClockDrivenState, CompositeProcessor, processor_state, ) @@ -13,10 +14,10 @@ ButterworthFilterSettings, ButterworthFilterTransformer, ) -from ezmsg.util.messages.axisarray import AxisArray, replace +from ezmsg.util.messages.axisarray import AxisArray, LinearAxis, replace -class WhiteNoiseSettings(ez.Settings): +class WhiteNoiseSettings(ClockDrivenSettings): """Settings for white noise generators.""" n_ch: int = 1 @@ -30,30 +31,28 @@ class WhiteNoiseSettings(ez.Settings): @processor_state -class WhiteNoiseTransformerState: - """State for NoiseTransformer.""" +class WhiteNoiseState(ClockDrivenState): + """State for WhiteNoiseProducer.""" template: AxisArray | None = None -class WhiteNoiseTransformer( - BaseStatefulTransformer[WhiteNoiseSettings, AxisArray, AxisArray, WhiteNoiseTransformerState] -): +class WhiteNoiseProducer(BaseClockDrivenProducer[WhiteNoiseSettings, WhiteNoiseState]): """ - Transforms input AxisArray into white noise signal. + Generates white noise synchronized to clock ticks. - Takes timing information from input message (counter output) and - generates random data with the specified number of channels. + Each clock tick produces a block of Gaussian white noise based on the + sample rate (fs) and chunk size (n_time) settings. """ - def _reset_state(self, message: AxisArray) -> None: + def _reset_state(self, time_axis: LinearAxis) -> None: """Initialize template with channel axis.""" n_ch = self.settings.n_ch self._state.template = AxisArray( data=np.zeros((0, n_ch)), dims=["time", "ch"], axes={ - "time": message.axes["time"], + "time": time_axis, "ch": AxisArray.CoordinateAxis( data=np.arange(n_ch), dims=["ch"], @@ -61,14 +60,13 @@ def _reset_state(self, message: AxisArray) -> None: }, ) - def _process(self, message: AxisArray) -> AxisArray: - n_time = message.data.shape[0] - + def _produce(self, n_samples: int, time_axis: LinearAxis) -> AxisArray: + """Generate white noise for this chunk.""" # Generate random data random_data = np.random.normal( loc=self.settings.loc, scale=self.settings.scale, - size=(n_time, self.settings.n_ch), + size=(n_samples, self.settings.n_ch), ) # Create output using template @@ -76,18 +74,18 @@ def _process(self, message: AxisArray) -> AxisArray: self._state.template, data=random_data, axes={ - "time": message.axes["time"], - "ch": self._state.template.axes["ch"], + **self._state.template.axes, + "time": time_axis, }, ) -class WhiteNoise(BaseTransformerUnit[WhiteNoiseSettings, AxisArray, AxisArray, WhiteNoiseTransformer]): +class WhiteNoise(BaseClockDrivenUnit[WhiteNoiseSettings, WhiteNoiseProducer]): """ - Transforms counter input into white noise signal. + Generates white noise synchronized to clock ticks. - Receives timing from INPUT_SIGNAL (AxisArray from Counter) and outputs - white noise AxisArray. + Receives timing from INPUT_CLOCK (LinearAxis from Clock) and outputs + white noise AxisArray on OUTPUT_SIGNAL. """ SETTINGS = WhiteNoiseSettings @@ -98,23 +96,26 @@ class PinkNoiseSettings(WhiteNoiseSettings): cutoff: float = 300.0 """ - Highpass Cutoff frequency (Hz). Lowpass corner of first order Butterworth filter. + Lowpass cutoff frequency (Hz) for the first-order Butterworth filter + that creates the 1/f characteristic. """ -class PinkNoiseTransformer(CompositeProcessor[PinkNoiseSettings, AxisArray, AxisArray]): +class PinkNoiseProducer(CompositeProcessor[PinkNoiseSettings, LinearAxis, AxisArray]): """ - Transforms input AxisArray into pink (1/f) noise signal. + Generates pink (1/f) noise synchronized to clock ticks. Pink noise is generated by filtering white noise with a first-order - lowpass Butterworth filter with cutoff at fs * 0.01 Hz. + lowpass Butterworth filter. """ @staticmethod def _initialize_processors(settings: PinkNoiseSettings) -> dict[str, BaseProcessor]: return { - "white_noise": WhiteNoiseTransformer( + "white_noise": WhiteNoiseProducer( WhiteNoiseSettings( + fs=settings.fs, + n_time=settings.n_time, n_ch=settings.n_ch, loc=settings.loc, scale=settings.scale, @@ -126,12 +127,12 @@ def _initialize_processors(settings: PinkNoiseSettings) -> dict[str, BaseProcess } -class PinkNoise(BaseTransformerUnit[PinkNoiseSettings, AxisArray, AxisArray, PinkNoiseTransformer]): +class PinkNoise(BaseClockDrivenUnit[PinkNoiseSettings, PinkNoiseProducer]): """ - Transforms counter input into pink (1/f) noise signal. + Generates pink (1/f) noise synchronized to clock ticks. - Receives timing from INPUT_SIGNAL (AxisArray from Counter) and outputs - pink noise AxisArray. + Receives timing from INPUT_CLOCK (LinearAxis from Clock) and outputs + pink noise AxisArray on OUTPUT_SIGNAL. """ SETTINGS = PinkNoiseSettings diff --git a/src/ezmsg/simbiophys/oscillator.py b/src/ezmsg/simbiophys/oscillator.py index 475ab69..0131238 100644 --- a/src/ezmsg/simbiophys/oscillator.py +++ b/src/ezmsg/simbiophys/oscillator.py @@ -1,17 +1,122 @@ """Oscillator/sinusoidal signal generators.""" -import ezmsg.core as ez import numpy as np import numpy.typing as npt from ezmsg.baseproc import ( - BaseStatefulTransformer, - BaseTransformerUnit, + BaseClockDrivenProducer, + BaseClockDrivenUnit, + ClockDrivenSettings, + ClockDrivenState, processor_state, ) -from ezmsg.util.messages.axisarray import AxisArray, replace +from ezmsg.util.messages.axisarray import AxisArray, LinearAxis, replace -class SinGeneratorSettings(ez.Settings): +class SpiralGeneratorSettings(ClockDrivenSettings): + """Settings for :obj:`SpiralGenerator`. + + Generates 2D position (x, y) following a spiral pattern where both + the radius and angle change over time. + + The parametric equations are: + r(t) = r_mean + r_amp * sin(2*π*radial_freq*t + radial_phase) + θ(t) = 2*π*angular_freq*t + angular_phase + x(t) = r(t) * cos(θ(t)) + y(t) = r(t) * sin(θ(t)) + """ + + r_mean: float = 150.0 + """Mean radius of the spiral.""" + + r_amp: float = 50.0 + """Amplitude of the radial oscillation.""" + + radial_freq: float = 0.1 + """Frequency of the radial oscillation in Hz.""" + + radial_phase: float = 0.0 + """Initial phase of the radial oscillation in radians.""" + + angular_freq: float = 0.25 + """Frequency of the angular rotation in Hz.""" + + angular_phase: float = 0.0 + """Initial angular phase in radians.""" + + +@processor_state +class SpiralGeneratorState(ClockDrivenState): + """State for SpiralGenerator.""" + + template: AxisArray | None = None + + +class SpiralProducer(BaseClockDrivenProducer[SpiralGeneratorSettings, SpiralGeneratorState]): + """ + Generates spiral motion synchronized to clock ticks. + + Each clock tick produces a block of 2D position data (x, y) following + a spiral pattern where both radius and angle change over time. + """ + + def _reset_state(self, time_axis: LinearAxis) -> None: + """Initialize template.""" + self._state.template = AxisArray( + data=np.zeros((0, 2)), + dims=["time", "ch"], + axes={ + "time": time_axis, + "ch": AxisArray.CoordinateAxis( + data=np.array(["x", "y"]), + dims=["ch"], + ), + }, + ) + + def _produce(self, n_samples: int, time_axis: LinearAxis) -> AxisArray: + """Generate spiral motion for this chunk.""" + t = (np.arange(n_samples) + self._state.counter) * time_axis.gain + + # Radial component: oscillates between r_mean - r_amp and r_mean + r_amp + r = self.settings.r_mean + self.settings.r_amp * np.sin( + 2.0 * np.pi * self.settings.radial_freq * t + self.settings.radial_phase + ) + + # Angular component: rotates at angular_freq + theta = 2.0 * np.pi * self.settings.angular_freq * t + self.settings.angular_phase + + # Convert to Cartesian + x = r * np.cos(theta) + y = r * np.sin(theta) + + data = np.column_stack([x, y]) + + return replace( + self._state.template, + data=data, + axes={ + **self._state.template.axes, + "time": time_axis, + }, + ) + + +class SpiralGenerator(BaseClockDrivenUnit[SpiralGeneratorSettings, SpiralProducer]): + """ + Generates 2D spiral motion synchronized to clock ticks. + + Receives timing from INPUT_CLOCK (LinearAxis from Clock) and outputs + 2D position AxisArray (x, y) on OUTPUT_SIGNAL. + + The spiral pattern has both radius and angle varying over time: + - Radius oscillates sinusoidally (breathing in/out) + - Angle increases linearly (rotation) + """ + + SETTINGS = SpiralGeneratorSettings + + +class SinGeneratorSettings(ClockDrivenSettings): """Settings for :obj:`SinGenerator`.""" n_ch: int = 1 @@ -28,8 +133,8 @@ class SinGeneratorSettings(ez.Settings): @processor_state -class SinTransformerState: - """State for SinTransformer.""" +class SinGeneratorState(ClockDrivenState): + """State for SinGenerator.""" template: AxisArray | None = None # Pre-computed arrays for efficient processing, shape (1, 1) or (1, n_ch) @@ -38,15 +143,15 @@ class SinTransformerState: phase: np.ndarray | None = None -class SinTransformer(BaseStatefulTransformer[SinGeneratorSettings, AxisArray, AxisArray, SinTransformerState]): +class SinProducer(BaseClockDrivenProducer[SinGeneratorSettings, SinGeneratorState]): """ - Transforms counter values into sinusoidal waveforms. + Generates sinusoidal waveforms synchronized to clock ticks. - Takes AxisArray with integer counter values and generates sinusoidal - output based on the time axis sample rate. + Each clock tick produces a block of sinusoidal data based on the + sample rate (fs) and chunk size (n_time) settings. """ - def _reset_state(self, message: AxisArray) -> None: + def _reset_state(self, time_axis: LinearAxis) -> None: """Initialize template and pre-compute parameter arrays.""" n_ch = self.settings.n_ch @@ -55,7 +160,7 @@ def _reset_state(self, message: AxisArray) -> None: data=np.zeros((0, n_ch)), dims=["time", "ch"], axes={ - "time": message.axes["time"], + "time": time_axis, "ch": AxisArray.CoordinateAxis( data=np.arange(n_ch), dims=["ch"], @@ -85,11 +190,11 @@ def _reset_state(self, message: AxisArray) -> None: self._state.amp = amp self._state.phase = phase - def _process(self, message: AxisArray) -> AxisArray: - """Transform input counter values into sinusoidal waveform.""" + def _produce(self, n_samples: int, time_axis: LinearAxis) -> AxisArray: + """Generate sinusoidal waveform for this chunk.""" # Calculate sinusoid: amp * sin(ang_freq*t + phase) # t shape: (n_time,) -> (n_time, 1) for broadcasting with (1, n_ch) - t = message.data[:, np.newaxis] * message.axes["time"].gain + t = (np.arange(n_samples) + self._state.counter)[:, np.newaxis] * time_axis.gain sin_data = self._state.amp * np.sin(self._state.ang_freq * t + self._state.phase) # Tile if all params were scalar but n_ch > 1 @@ -100,18 +205,18 @@ def _process(self, message: AxisArray) -> AxisArray: self._state.template, data=sin_data, axes={ - "time": message.axes["time"], - "ch": self._state.template.axes["ch"], + **self._state.template.axes, + "time": time_axis, }, ) -class SinGenerator(BaseTransformerUnit[SinGeneratorSettings, AxisArray, AxisArray, SinTransformer]): +class SinGenerator(BaseClockDrivenUnit[SinGeneratorSettings, SinProducer]): """ - Transforms counter input into sinusoidal waveform. + Generates sinusoidal waveforms synchronized to clock ticks. - Receives timing from INPUT_SIGNAL (AxisArray from Counter) and outputs - sinusoidal AxisArray. + Receives timing from INPUT_CLOCK (LinearAxis from Clock) and outputs + sinusoidal AxisArray on OUTPUT_SIGNAL. """ SETTINGS = SinGeneratorSettings diff --git a/src/ezmsg/simbiophys/system/velocity2ecephys.py b/src/ezmsg/simbiophys/system/velocity2ecephys.py index a25a8c1..c595265 100644 --- a/src/ezmsg/simbiophys/system/velocity2ecephys.py +++ b/src/ezmsg/simbiophys/system/velocity2ecephys.py @@ -5,9 +5,12 @@ background activity. Pipeline: - velocity (x,y) --+--> Velocity2Spike --> spikes --| - | +--> Add --> ecephys - +--> Velocity2LFP ----> lfp -----| + velocity (x,y) -> CART2POL --+--> Velocity2Spike --> spikes --| + | +--> Add --> ecephys + +--> Velocity2LFP ----> lfp -----| + +The coordinate transformation from Cartesian to polar is done once at the +input, then shared by both spike and LFP encoding branches. This is the top-level system for velocity-encoded neural simulation. Use this when you need full ecephys-like output suitable for testing BCI decoders. @@ -18,6 +21,7 @@ """ import ezmsg.core as ez +from ezmsg.sigproc.coordinatespaces import CoordinateMode, CoordinateSpaces, CoordinateSpacesSettings from ezmsg.sigproc.math.add import Add from ezmsg.util.messages.axisarray import AxisArray @@ -71,12 +75,14 @@ class VelocityEncoder(ez.Collection): # Velocity inputs (via mouse / gamepad system, or via task parsing system) INPUT_SIGNAL = ez.InputStream(AxisArray) + COORDS = CoordinateSpaces() # Cartesian to polar (done once, shared by both branches) SPIKES = Velocity2Spike() LFP = Velocity2LFP() ADD = Add() # Add colored noise and waveforms OUTPUT_SIGNAL = ez.OutputStream(AxisArray) def configure(self) -> None: + self.COORDS.apply_settings(CoordinateSpacesSettings(mode=CoordinateMode.CART2POL, axis="ch")) self.SPIKES.apply_settings( Velocity2SpikeSettings( output_fs=self.SETTINGS.output_fs, output_ch=self.SETTINGS.output_ch, seed=self.SETTINGS.seed @@ -90,9 +96,10 @@ def configure(self) -> None: def network(self) -> ez.NetworkDefinition: return ( - (self.INPUT_SIGNAL, self.SPIKES.INPUT_SIGNAL), + (self.INPUT_SIGNAL, self.COORDS.INPUT_SIGNAL), + (self.COORDS.OUTPUT_SIGNAL, self.SPIKES.INPUT_SIGNAL), (self.SPIKES.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_A), - (self.INPUT_SIGNAL, self.LFP.INPUT_SIGNAL), + (self.COORDS.OUTPUT_SIGNAL, self.LFP.INPUT_SIGNAL), (self.LFP.OUTPUT_SIGNAL, self.ADD.INPUT_SIGNAL_B), (self.ADD.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL), ) diff --git a/src/ezmsg/simbiophys/system/velocity2lfp.py b/src/ezmsg/simbiophys/system/velocity2lfp.py index 6387ac8..58196cd 100644 --- a/src/ezmsg/simbiophys/system/velocity2lfp.py +++ b/src/ezmsg/simbiophys/system/velocity2lfp.py @@ -1,14 +1,21 @@ -"""Convert 2D cursor velocity to simulated LFP-like colored noise. +"""Convert polar velocity coordinates to simulated LFP-like colored noise. -This module provides a system that encodes cursor velocity into the spectral -properties of colored (1/f^beta) noise, producing LFP-like signals. +This module provides a system that encodes velocity (in polar coordinates) into +the spectral properties of colored (1/f^beta) noise, producing LFP-like signals. Pipeline: - velocity (x,y) -> polar coords -> scale to beta range -> colored noise -> mix to channels + polar coords (magnitude, angle) -> cosine encoder (beta values) -> clip + -> colored noise -> mix to channels -The velocity magnitude modulates one noise source's spectral exponent, and the -velocity angle modulates another. These two sources are then mixed across -output channels using a spatial mixing matrix. +The velocity is encoded using a cosine tuning model where multiple noise +sources have different preferred directions. Each source's spectral exponent +(beta) is modulated by the velocity direction and magnitude. These sources +are then mixed across output channels using a spatial mixing matrix. + +Note: + This system expects polar coordinates as input. Use CoordinateSpaces with + mode=CART2POL upstream to convert Cartesian velocity (vx, vy) to polar + coordinates (magnitude, angle). See Also: :mod:`ezmsg.simbiophys.system.velocity2spike`: Velocity to spike encoding. @@ -18,11 +25,10 @@ import ezmsg.core as ez import numpy as np from ezmsg.sigproc.affinetransform import AffineTransform, AffineTransformSettings -from ezmsg.sigproc.coordinatespaces import CoordinateMode, CoordinateSpaces, CoordinateSpacesSettings -from ezmsg.sigproc.linear import LinearTransform, LinearTransformSettings from ezmsg.sigproc.math.clip import Clip, ClipSettings from ezmsg.util.messages.axisarray import AxisArray +from ..cosine_encoder import CosineEncoderSettings, CosineEncoderUnit from ..dynamic_colored_noise import DynamicColoredNoiseSettings, DynamicColoredNoiseUnit @@ -35,29 +41,35 @@ class Velocity2LFPSettings(ez.Settings): output_ch: int = 256 """Number of output channels (simulated electrodes).""" + n_lfp_sources: int = 8 + """Number of cosine-encoded LFP sources. Each source has a different + preferred direction and generates colored noise with velocity-modulated + spectral exponent.""" + + max_velocity: float = 315.0 + seed: int = 6767 - """Random seed for reproducible mixing matrix.""" + """Random seed for reproducible preferred directions and mixing matrix.""" class Velocity2LFP(ez.Collection): - """Encode cursor velocity into LFP-like colored noise. - - This system converts 2D cursor velocity into multi-channel LFP-like signals: - - 1. **Coordinate transform**: Converts Cartesian velocity (x, y) to polar - coordinates (magnitude, angle). - 2. **Scale to beta**: Maps velocity magnitude (0-314 px/s) and angle (0-2pi) - to spectral exponent beta (0.5-2.0) for colored noise generation. - 3. **Clip**: Ensures beta values stay within valid range [0, 2]. - 4. **Colored noise**: Generates 1/f^beta noise where beta is dynamically - modulated by the scaled velocity. Two noise sources are generated: - one modulated by magnitude, one by angle. - 5. **Spatial mixing**: Projects the 2 noise sources onto output_ch channels + """Encode velocity (polar coordinates) into LFP-like colored noise. + + This system converts polar velocity coordinates into multi-channel LFP-like signals: + + 1. **Cosine encoder**: Each of n_lfp_sources has a different preferred + direction. The spectral exponent beta (0-2) is modulated by the cosine + of the angle between velocity and preferred direction, scaled by speed. + 2. **Clip**: Ensures beta values stay within valid range [0, 2]. + 3. **Colored noise**: Generates 1/f^beta noise where beta is dynamically + modulated per source. + 4. **Spatial mixing**: Projects the n_lfp_sources onto output_ch channels using a sinusoidal mixing matrix with random perturbations. Input: - AxisArray with shape (N, 2) containing cursor velocity in pixels/second. - Dimension 0 is time, dimension 1 is [vx, vy]. + AxisArray with shape (N, 2) containing polar velocity coordinates. + Dimension 0 is time, dimension 1 is [magnitude, angle]. + Use CoordinateSpaces(mode=CART2POL) upstream if starting from (vx, vy). Output: AxisArray with shape (M, output_ch) containing LFP-like colored noise @@ -66,51 +78,67 @@ class Velocity2LFP(ez.Collection): SETTINGS = Velocity2LFPSettings - # Velocity inputs (via mouse / gamepad, or via task parsing system) + # Polar velocity inputs (magnitude, angle) INPUT_SIGNAL = ez.InputStream(AxisArray) - COORDS = CoordinateSpaces() - ALPHA_EXP = LinearTransform() - CLIP_ALPHA = Clip() + BETA_ENCODER = CosineEncoderUnit() + CLIP_BETA = Clip() PINK_NOISE = DynamicColoredNoiseUnit() - MIX_NOISE = AffineTransform() # Project 2 colored noise sources to n_chans sensors + MIX_NOISE = AffineTransform() # Project n_lfp_sources to output_ch sensors OUTPUT_SIGNAL = ez.OutputStream(AxisArray) def configure(self) -> None: - self.COORDS.apply_settings(CoordinateSpacesSettings(mode=CoordinateMode.CART2POL, axis="ch")) - self.ALPHA_EXP.apply_settings( - LinearTransformSettings( - scale=[1.5 / 314, 1.5 / (2 * np.pi)], - offset=[0.5, 0.5], - axis="ch", + # Input is polar coords: [magnitude, angle] + # magnitude ranges from 0 to ~max_velocity px/s, angle from -pi to +pi + + # Configure cosine encoder to output beta values in range [0, 2] + # baseline=1.0 (middle of range), modulation=1/315 so at max velocity we get full range + self.BETA_ENCODER.apply_settings( + CosineEncoderSettings( + output_ch=self.SETTINGS.n_lfp_sources, + baseline=1.0, + modulation=1.0 / self.SETTINGS.max_velocity, + seed=self.SETTINGS.seed, ) ) - self.CLIP_ALPHA.apply_settings(ClipSettings(min=0.0, max=2.0)) + + self.CLIP_BETA.apply_settings(ClipSettings(min=0.0, max=2.0)) + self.PINK_NOISE.apply_settings( DynamicColoredNoiseSettings( output_fs=self.SETTINGS.output_fs, n_poles=5, smoothing_tau=0.01, initial_beta=1.0, - scale=1.0, + scale=20.0, seed=self.SETTINGS.seed, ) ) + + # Create mixing matrix: n_lfp_sources -> output_ch + # Use sinusoids at different frequencies for spatial patterns rng = np.random.default_rng(self.SETTINGS.seed) ch_idx = np.arange(self.SETTINGS.output_ch) - weights = np.array( - [ - np.sin(2 * np.pi * ch_idx / self.SETTINGS.output_ch), # radius source - np.cos(2 * np.pi * ch_idx / self.SETTINGS.output_ch), # angle (phi) source - ] - ) + 0.3 * rng.standard_normal((2, self.SETTINGS.output_ch)) + n_sources = self.SETTINGS.n_lfp_sources + + # Each source gets a sinusoidal spatial pattern with different frequency + # Plus random perturbations for more realistic mixing + weights = np.zeros((n_sources, self.SETTINGS.output_ch)) + for i in range(n_sources): + # Different spatial frequency for each source + freq = (i + 1) / n_sources + phase = 2 * np.pi * i / n_sources + weights[i, :] = np.sin(2 * np.pi * freq * ch_idx / self.SETTINGS.output_ch + phase) + + # Add random perturbations + weights += 0.3 * rng.standard_normal((n_sources, self.SETTINGS.output_ch)) + self.MIX_NOISE.apply_settings(AffineTransformSettings(weights=weights, axis="ch")) def network(self) -> ez.NetworkDefinition: return ( - (self.INPUT_SIGNAL, self.COORDS.INPUT_SIGNAL), - (self.COORDS.OUTPUT_SIGNAL, self.ALPHA_EXP.INPUT_SIGNAL), - (self.ALPHA_EXP.OUTPUT_SIGNAL, self.CLIP_ALPHA.INPUT_SIGNAL), - (self.CLIP_ALPHA.OUTPUT_SIGNAL, self.PINK_NOISE.INPUT_SIGNAL), + (self.INPUT_SIGNAL, self.BETA_ENCODER.INPUT_SIGNAL), + (self.BETA_ENCODER.OUTPUT_SIGNAL, self.CLIP_BETA.INPUT_SIGNAL), + (self.CLIP_BETA.OUTPUT_SIGNAL, self.PINK_NOISE.INPUT_SIGNAL), (self.PINK_NOISE.OUTPUT_SIGNAL, self.MIX_NOISE.INPUT_SIGNAL), (self.MIX_NOISE.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL), ) diff --git a/src/ezmsg/simbiophys/system/velocity2spike.py b/src/ezmsg/simbiophys/system/velocity2spike.py index 04f7a40..b5a434e 100644 --- a/src/ezmsg/simbiophys/system/velocity2spike.py +++ b/src/ezmsg/simbiophys/system/velocity2spike.py @@ -1,11 +1,16 @@ -"""Convert 2D cursor velocity to simulated spike waveforms. +"""Convert polar velocity coordinates to simulated spike waveforms. -This module provides a system that encodes cursor velocity into spike activity -using a cosine tuning model, then generates spike events and inserts realistic -waveforms. +This module provides a system that encodes velocity (in polar coordinates) into +spike activity using a cosine tuning model, then generates spike events and +inserts realistic waveforms. Pipeline: - velocity (x,y) -> polar coords -> cosine tuning -> Poisson events -> waveforms + polar coords (magnitude, angle) -> cosine encoder -> clip -> Poisson events -> waveforms + +Note: + This system expects polar coordinates as input. Use CoordinateSpaces with + mode=CART2POL upstream to convert Cartesian velocity (vx, vy) to polar + coordinates (magnitude, angle). See Also: :mod:`ezmsg.simbiophys.system.velocity2lfp`: Velocity to LFP encoding. @@ -17,10 +22,10 @@ from ezmsg.event.kernel import ArrayKernel, MultiKernel from ezmsg.event.kernel_insert import SparseKernelInserterSettings, SparseKernelInserterUnit from ezmsg.event.poissonevents import PoissonEventSettings, PoissonEventUnit -from ezmsg.sigproc.coordinatespaces import CoordinateMode, CoordinateSpaces, CoordinateSpacesSettings +from ezmsg.sigproc.math.clip import Clip, ClipSettings from ezmsg.util.messages.axisarray import AxisArray -from ..cosine_tuning import CosineTuningSettings, CosineTuningUnit +from ..cosine_encoder import CosineEncoderSettings, CosineEncoderUnit from ..dnss.wfs import wf_orig @@ -33,27 +38,36 @@ class Velocity2SpikeSettings(ez.Settings): output_ch: int = 256 """Number of output channels (simulated electrodes).""" + baseline_rate: float = 10.0 + """Baseline firing rate in Hz.""" + + modulation_depth: float = 20.0 / 314.0 + """Directional modulation depth in Hz per (pixel/second). + At max velocity (~314 px/s), this gives ~20 Hz modulation.""" + + min_rate: float = 0.0 + """Minimum firing rate (Hz). Rates are clipped to this value.""" + seed: int = 6767 """Random seed for reproducible preferred directions and waveform selection.""" class Velocity2Spike(ez.Collection): - """Encode cursor velocity into simulated spike waveforms. + """Encode velocity (polar coordinates) into simulated spike waveforms. - This system converts 2D cursor velocity into multi-channel spike activity: + This system converts polar velocity coordinates into multi-channel spike activity: - 1. **Coordinate transform**: Converts Cartesian velocity (x, y) to polar - coordinates (magnitude, angle). - 2. **Cosine tuning**: Each channel has a preferred direction; firing rate + 1. **Cosine tuning**: Each channel has a preferred direction; firing rate is modulated by the cosine of the angle between velocity and preferred direction, scaled by velocity magnitude. - 3. **Poisson spike generation**: Converts firing rates to discrete spike + 2. **Poisson spike generation**: Converts firing rates to discrete spike events using an inhomogeneous Poisson process. - 4. **Waveform insertion**: Inserts realistic spike waveforms at event times. + 3. **Waveform insertion**: Inserts realistic spike waveforms at event times. Input: - AxisArray with shape (N, 2) containing cursor velocity in pixels/second. - Dimension 0 is time, dimension 1 is [vx, vy]. + AxisArray with shape (N, 2) containing polar velocity coordinates. + Dimension 0 is time, dimension 1 is [magnitude, angle]. + Use CoordinateSpaces(mode=CART2POL) upstream if starting from (vx, vy). Output: AxisArray with shape (M, output_ch) containing spike waveforms at @@ -62,22 +76,24 @@ class Velocity2Spike(ez.Collection): SETTINGS = Velocity2SpikeSettings - # Velocity inputs (via mouse / gamepad system, or via task parsing system) + # Polar velocity inputs (magnitude, angle) INPUT_SIGNAL = ez.InputStream(AxisArray) - COORDS = CoordinateSpaces() - RATE_ENCODER = CosineTuningUnit() + RATE_ENCODER = CosineEncoderUnit() + CLIP_RATE = Clip() SPIKE_EVENT = PoissonEventUnit() WAVEFORMS = SparseKernelInserterUnit() OUTPUT_SIGNAL = ez.OutputStream(AxisArray) def configure(self) -> None: - self.COORDS.apply_settings(CoordinateSpacesSettings(mode=CoordinateMode.CART2POL, axis="ch")) self.RATE_ENCODER.apply_settings( - CosineTuningSettings( - n_units=self.SETTINGS.output_ch, + CosineEncoderSettings( + output_ch=self.SETTINGS.output_ch, + baseline=self.SETTINGS.baseline_rate, + modulation=self.SETTINGS.modulation_depth, seed=self.SETTINGS.seed, ) ) + self.CLIP_RATE.apply_settings(ClipSettings(min=self.SETTINGS.min_rate)) self.SPIKE_EVENT.apply_settings( PoissonEventSettings( output_fs=self.SETTINGS.output_fs, @@ -92,9 +108,9 @@ def configure(self) -> None: def network(self) -> ez.NetworkDefinition: return ( - (self.INPUT_SIGNAL, self.COORDS.INPUT_SIGNAL), - (self.COORDS.OUTPUT_SIGNAL, self.RATE_ENCODER.INPUT_SIGNAL), - (self.RATE_ENCODER.OUTPUT_SIGNAL, self.SPIKE_EVENT.INPUT_SIGNAL), + (self.INPUT_SIGNAL, self.RATE_ENCODER.INPUT_SIGNAL), + (self.RATE_ENCODER.OUTPUT_SIGNAL, self.CLIP_RATE.INPUT_SIGNAL), + (self.CLIP_RATE.OUTPUT_SIGNAL, self.SPIKE_EVENT.INPUT_SIGNAL), (self.SPIKE_EVENT.OUTPUT_SIGNAL, self.WAVEFORMS.INPUT_SIGNAL), (self.WAVEFORMS.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL), ) diff --git a/tests/integration/test_dnss_system.py b/tests/integration/test_dnss_system.py index df02537..f782268 100644 --- a/tests/integration/test_dnss_system.py +++ b/tests/integration/test_dnss_system.py @@ -10,7 +10,7 @@ from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings -from ezmsg.simbiophys import Clock, ClockSettings, Counter, CounterSettings +from ezmsg.simbiophys import Clock, ClockSettings from ezmsg.simbiophys.dnss import ( DEFAULT_FS, LFP_GAINS, @@ -26,34 +26,30 @@ class DNSSLFPTestSystemSettings(ez.Settings): clock_settings: ClockSettings - counter_settings: CounterSettings lfp_settings: DNSSLFPSettings log_settings: MessageLoggerSettings term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings) class DNSSLFPTestSystem(ez.Collection): - """Test system for DNSS LFP: Clock -> Counter -> LFP.""" + """Test system for DNSS LFP: Clock -> LFP.""" SETTINGS = DNSSLFPTestSystemSettings CLOCK = Clock() - COUNTER = Counter() LFP = DNSSLFPUnit() LOG = MessageLogger() TERM = TerminateOnTotal() def configure(self) -> None: self.CLOCK.apply_settings(self.SETTINGS.clock_settings) - self.COUNTER.apply_settings(self.SETTINGS.counter_settings) self.LFP.apply_settings(self.SETTINGS.lfp_settings) self.LOG.apply_settings(self.SETTINGS.log_settings) self.TERM.apply_settings(self.SETTINGS.term_settings) def network(self) -> ez.NetworkDefinition: return ( - (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_SIGNAL), - (self.COUNTER.OUTPUT_SIGNAL, self.LFP.INPUT_SIGNAL), + (self.CLOCK.OUTPUT_SIGNAL, self.LFP.INPUT_CLOCK), (self.LFP.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE), (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), ) @@ -72,8 +68,8 @@ def test_dnss_lfp_unit(test_name: str | None = None): settings = DNSSLFPTestSystemSettings( clock_settings=ClockSettings(dispatch_rate=dispatch_rate), - counter_settings=CounterSettings(fs=fs, n_time=n_time), lfp_settings=DNSSLFPSettings( + n_time=n_time, n_ch=n_ch, pattern="spike", mode="hdmi", @@ -126,8 +122,8 @@ def test_dnss_lfp_unit_other_pattern(test_name: str | None = None): settings = DNSSLFPTestSystemSettings( clock_settings=ClockSettings(dispatch_rate=dispatch_rate), - counter_settings=CounterSettings(fs=fs, n_time=n_time), lfp_settings=DNSSLFPSettings( + n_time=n_time, n_ch=n_ch, pattern="other", mode="hdmi", @@ -153,34 +149,30 @@ def test_dnss_lfp_unit_other_pattern(test_name: str | None = None): class DNSSSpikeTestSystemSettings(ez.Settings): clock_settings: ClockSettings - counter_settings: CounterSettings spike_settings: DNSSSpikeSettings log_settings: MessageLoggerSettings term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings) class DNSSSpikeTestSystem(ez.Collection): - """Test system for DNSS Spike: Clock -> Counter -> Spike.""" + """Test system for DNSS Spike: Clock -> Spike.""" SETTINGS = DNSSSpikeTestSystemSettings CLOCK = Clock() - COUNTER = Counter() SPIKE = DNSSSpikeUnit() LOG = MessageLogger() TERM = TerminateOnTotal() def configure(self) -> None: self.CLOCK.apply_settings(self.SETTINGS.clock_settings) - self.COUNTER.apply_settings(self.SETTINGS.counter_settings) self.SPIKE.apply_settings(self.SETTINGS.spike_settings) self.LOG.apply_settings(self.SETTINGS.log_settings) self.TERM.apply_settings(self.SETTINGS.term_settings) def network(self) -> ez.NetworkDefinition: return ( - (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_SIGNAL), - (self.COUNTER.OUTPUT_SIGNAL, self.SPIKE.INPUT_SIGNAL), + (self.CLOCK.OUTPUT_SIGNAL, self.SPIKE.INPUT_CLOCK), (self.SPIKE.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE), (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), ) @@ -199,8 +191,8 @@ def test_dnss_spike_unit(test_name: str | None = None): settings = DNSSSpikeTestSystemSettings( clock_settings=ClockSettings(dispatch_rate=dispatch_rate), - counter_settings=CounterSettings(fs=fs, n_time=n_time), spike_settings=DNSSSpikeSettings( + n_time=n_time, n_ch=n_ch, mode="hdmi", ), @@ -236,7 +228,6 @@ def test_dnss_spike_unit(test_name: str | None = None): def test_dnss_spike_unit_burst_period(test_name: str | None = None): """Test DNSSSpikeUnit during burst period (more spikes).""" - fs = DEFAULT_FS n_time = 3000 # 100ms blocks n_ch = 4 # Run long enough to hit burst period (starts at sample 270000 = 9 seconds) @@ -249,8 +240,8 @@ def test_dnss_spike_unit_burst_period(test_name: str | None = None): settings = DNSSSpikeTestSystemSettings( clock_settings=ClockSettings(dispatch_rate=dispatch_rate), - counter_settings=CounterSettings(fs=fs, n_time=n_time), spike_settings=DNSSSpikeSettings( + n_time=n_time, n_ch=n_ch, mode="ideal", ), diff --git a/tests/integration/test_synth_system.py b/tests/integration/test_synth_system.py index e1da9f8..769782d 100644 --- a/tests/integration/test_synth_system.py +++ b/tests/integration/test_synth_system.py @@ -1,239 +1,21 @@ """Integration tests for ezmsg.simbiophys signal generator systems.""" -import math import os from dataclasses import field import ezmsg.core as ez -import numpy as np -import pytest from ezmsg.util.messagecodec import message_log from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings from ezmsg.simbiophys import ( - Clock, - ClockSettings, - Counter, - CounterSettings, EEGSynth, EEGSynthSettings, ) from tests.helpers.util import get_test_fn -class ClockTestSystemSettings(ez.Settings): - clock_settings: ClockSettings - log_settings: MessageLoggerSettings - term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings) - - -class ClockTestSystem(ez.Collection): - SETTINGS = ClockTestSystemSettings - - CLOCK = Clock() - LOG = MessageLogger() - TERM = TerminateOnTotal() - - def configure(self) -> None: - self.CLOCK.apply_settings(self.SETTINGS.clock_settings) - self.LOG.apply_settings(self.SETTINGS.log_settings) - self.TERM.apply_settings(self.SETTINGS.term_settings) - - def network(self) -> ez.NetworkDefinition: - return ( - (self.CLOCK.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE), - (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), - ) - - -@pytest.mark.parametrize("dispatch_rate", [math.inf, 2.0, 20.0]) -def test_clock_system( - dispatch_rate: float, - test_name: str | None = None, -): - run_time = 1.0 - n_target = 100 if math.isinf(dispatch_rate) else int(np.ceil(dispatch_rate * run_time)) - test_filename = get_test_fn(test_name) - ez.logger.info(test_filename) - settings = ClockTestSystemSettings( - clock_settings=ClockSettings(dispatch_rate=dispatch_rate), - log_settings=MessageLoggerSettings(output=test_filename), - term_settings=TerminateOnTotalSettings(total=n_target), - ) - system = ClockTestSystem(settings) - ez.run(SYSTEM=system) - - # Collect result - messages = list(message_log(test_filename)) - os.remove(test_filename) - - # Clock produces LinearAxis with gain and offset - assert all(isinstance(m, AxisArray.LinearAxis) for m in messages) - assert len(messages) >= n_target - - -class CounterTestSystemSettings(ez.Settings): - clock_settings: ClockSettings - counter_settings: CounterSettings - log_settings: MessageLoggerSettings - term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings) - - -class CounterTestSystem(ez.Collection): - """Counter must be driven by Clock in the new architecture.""" - - SETTINGS = CounterTestSystemSettings - - CLOCK = Clock() - COUNTER = Counter() - LOG = MessageLogger() - TERM = TerminateOnTotal() - - def configure(self) -> None: - self.CLOCK.apply_settings(self.SETTINGS.clock_settings) - self.COUNTER.apply_settings(self.SETTINGS.counter_settings) - self.LOG.apply_settings(self.SETTINGS.log_settings) - self.TERM.apply_settings(self.SETTINGS.term_settings) - - def network(self) -> ez.NetworkDefinition: - return ( - (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_SIGNAL), - (self.COUNTER.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE), - (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), - ) - - -@pytest.mark.parametrize( - "n_time, fs, dispatch_rate, mod", - [ - (1, 10.0, math.inf, None), # AFAP mode - (20, 1000.0, 50.0, None), # Realtime mode (50 Hz dispatch = 20 samples/tick @ 1000 Hz) - (1, 1000.0, 100.0, 2**3), # 100 Hz dispatch with mod - (10, 10.0, 10.0, 2**3), # 10 Hz dispatch with mod - ], -) -def test_counter_system( - n_time: int, - fs: float, - dispatch_rate: float, - mod: int | None, - test_name: str | None = None, -): - target_dur = 2.6 # 2.6 seconds per test - if math.isinf(dispatch_rate): - # AFAP mode - runs as fast as possible - target_messages = 100 # Fixed target for AFAP - else: - target_messages = int(target_dur * dispatch_rate) - - test_filename = get_test_fn(test_name) - ez.logger.info(test_filename) - settings = CounterTestSystemSettings( - clock_settings=ClockSettings(dispatch_rate=dispatch_rate), - counter_settings=CounterSettings( - n_time=n_time, - fs=fs, - mod=mod, - ), - log_settings=MessageLoggerSettings( - output=test_filename, - ), - term_settings=TerminateOnTotalSettings( - total=target_messages, - ), - ) - system = CounterTestSystem(settings) - ez.run(SYSTEM=system) - - # Collect result - messages: list[AxisArray] = [_ for _ in message_log(test_filename)] - os.remove(test_filename) - - if math.isinf(dispatch_rate): - # The number of messages depends on how fast the computer is - target_messages = len(messages) - # This should be an equivalence assertion (==) but the use of TerminateOnTotal does - # not guarantee that MessageLogger will exit before an additional message is received. - # Let's just clip the last message if we exceed the target messages. - if len(messages) > target_messages: - messages = messages[:target_messages] - assert len(messages) >= target_messages - - # Just do one quick data check (Counter now outputs 1D array) - agg = AxisArray.concatenate(*messages, dim="time") - target_samples = n_time * target_messages - expected_data = np.arange(target_samples) - if mod is not None: - expected_data = expected_data % mod - assert np.array_equal(agg.data, expected_data) - - -@pytest.mark.parametrize( - "clock_rate, fs, n_time", - [ - (10.0, 1000.0, 100), # 10 Hz clock, fs=1000, n_time=100 (fixed) - (20.0, 500.0, None), # 20 Hz clock, fs=500, n_time derived (25 samples per tick) - (5.0, 1000.0, None), # 5 Hz clock, fs=1000, n_time derived (200 samples per tick) - ], -) -def test_counter_with_external_clock( - clock_rate: float, - fs: float, - n_time: int | None, - test_name: str | None = None, -): - """Test Counter driven by external Clock (now the standard pattern).""" - target_messages = 20 - test_filename = get_test_fn(test_name) - ez.logger.info(test_filename) - - # This now uses the same CounterTestSystem since all counters need clocks - settings = CounterTestSystemSettings( - clock_settings=ClockSettings(dispatch_rate=clock_rate), - counter_settings=CounterSettings( - fs=fs, - n_time=n_time, - ), - log_settings=MessageLoggerSettings(output=test_filename), - term_settings=TerminateOnTotalSettings(total=target_messages), - ) - system = CounterTestSystem(settings) - ez.run(SYSTEM=system) - - # Collect result - messages: list[AxisArray] = [_ for _ in message_log(test_filename)] - os.remove(test_filename) - - assert len(messages) >= target_messages - - # Verify each message has correct sample rate (gain = 1/fs) - for msg in messages: - assert msg.axes["time"].gain == 1.0 / fs - - # Verify data continuity - messages = messages[:target_messages] # Trim to target - agg = AxisArray.concatenate(*messages, dim="time") - - # Expected samples per tick - if n_time is not None: - expected_samples_per_tick = n_time - else: - expected_samples_per_tick = int(fs / clock_rate) - - expected_total = expected_samples_per_tick * target_messages - # Allow for fractional sample accumulation variance - assert abs(len(agg.data) - expected_total) <= target_messages - - # Counter values should be sequential (0, 1, 2, ...) - expected_data = np.arange(len(agg.data)) - assert np.array_equal(agg.data, expected_data) - - -# TODO: test SinGenerator in a system. - - class EEGSynthSettingsTest(ez.Settings): synth_settings: EEGSynthSettings log_settings: MessageLoggerSettings diff --git a/tests/unit/test_cosine_encoder.py b/tests/unit/test_cosine_encoder.py new file mode 100644 index 0000000..a0ff514 --- /dev/null +++ b/tests/unit/test_cosine_encoder.py @@ -0,0 +1,232 @@ +"""Unit tests for ezmsg.simbiophys.cosine_encoder module.""" + +import numpy as np +import pytest +from ezmsg.util.messages.axisarray import AxisArray + +from ezmsg.simbiophys import ( + CosineEncoderSettings, + CosineEncoderState, + CosineEncoderTransformer, +) + + +class TestCosineEncoderState: + """Tests for CosineEncoderState.""" + + def test_init_random_basic(self): + """Test random parameter initialization.""" + state = CosineEncoderState() + state.init_random(output_ch=10, seed=42) + + assert state.output_ch == 10 + assert state.baseline.shape == (1, 10) + assert state.modulation.shape == (1, 10) + assert state.pd.shape == (1, 10) + assert state.speed_modulation.shape == (1, 10) + + # Check default values + assert np.allclose(state.baseline, 0.0) + assert np.allclose(state.modulation, 1.0) + assert np.allclose(state.speed_modulation, 0.0) + + # Check preferred directions are in [0, 2*pi) + assert np.all(state.pd >= 0) + assert np.all(state.pd < 2 * np.pi) + + def test_init_random_custom_params(self): + """Test random initialization with custom parameters.""" + state = CosineEncoderState() + state.init_random( + output_ch=5, + baseline=15.0, + modulation=30.0, + speed_modulation=5.0, + seed=123, + ) + + assert state.output_ch == 5 + assert np.allclose(state.baseline, 15.0) + assert np.allclose(state.modulation, 30.0) + assert np.allclose(state.speed_modulation, 5.0) + + def test_init_random_reproducible(self): + """Test that seed produces reproducible results.""" + state1 = CosineEncoderState() + state1.init_random(output_ch=10, seed=42) + + state2 = CosineEncoderState() + state2.init_random(output_ch=10, seed=42) + + assert np.array_equal(state1.pd, state2.pd) + + def test_validation_shape_mismatch(self): + """Test that mismatched shapes raise error.""" + state = CosineEncoderState() + state.baseline = np.array([[1.0, 2.0]]) + state.modulation = np.array([[1.0, 2.0, 3.0]]) + state.pd = np.array([[1.0, 2.0]]) + state.speed_modulation = np.array([[1.0, 2.0]]) + + with pytest.raises(ValueError, match="same shape"): + state.validate() + + def test_validation_wrong_shape(self): + """Test that wrong shape raises error.""" + state = CosineEncoderState() + state.baseline = np.array([1.0, 2.0]) # 1D instead of 2D + state.modulation = np.array([1.0, 2.0]) + state.pd = np.array([1.0, 2.0]) + state.speed_modulation = np.array([1.0, 2.0]) + + with pytest.raises(ValueError, match="shape"): + state.validate() + + def test_validation_empty(self): + """Test that empty arrays raise error.""" + state = CosineEncoderState() + state.baseline = np.array([[]]).reshape(1, 0) + state.modulation = np.array([[]]).reshape(1, 0) + state.pd = np.array([[]]).reshape(1, 0) + state.speed_modulation = np.array([[]]).reshape(1, 0) + + with pytest.raises(ValueError, match="at least 1 channel"): + state.validate() + + +class TestCosineEncoderTransformer: + """Tests for CosineEncoderTransformer.""" + + def test_basic_transform(self): + """Test basic polar to encoded output transformation.""" + transformer = CosineEncoderTransformer( + CosineEncoderSettings( + output_ch=4, + baseline=10.0, + modulation=20.0, + speed_modulation=0.0, + seed=42, + ) + ) + + # Create polar input: magnitude=1, angle=0 (rightward direction) + polar = np.array([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]]) + time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) + msg_in = AxisArray(polar, dims=["time", "ch"], axes={"time": time_axis}) + + msg_out = transformer(msg_in) + + assert msg_out.data.shape == (3, 4) # (n_samples, output_ch) + assert "ch" in msg_out.axes + assert msg_out.axes["ch"].data.shape == (4,) + + def test_stationary_baseline(self): + """Test that zero magnitude produces baseline values.""" + transformer = CosineEncoderTransformer( + CosineEncoderSettings( + output_ch=3, + baseline=15.0, + modulation=25.0, + speed_modulation=5.0, + seed=42, + ) + ) + + # Zero magnitude (stationary) + polar = np.array([[0.0, 0.0], [0.0, 0.0]]) + time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) + msg_in = AxisArray(polar, dims=["time", "ch"], axes={"time": time_axis}) + + msg_out = transformer(msg_in) + + # When magnitude=0, output = baseline + modulation*0*cos(...) + speed_mod*0 = baseline + assert np.allclose(msg_out.data, 15.0) + + def test_directional_tuning(self): + """Test that tuning varies with direction.""" + transformer = CosineEncoderTransformer(CosineEncoderSettings(output_ch=1, seed=42)) + + # Manually set state with known preferred direction (shape: 1, output_ch) + transformer._state.baseline = np.array([[10.0]]) + transformer._state.modulation = np.array([[20.0]]) + transformer._state.pd = np.array([[0.0]]) # Preferred direction = 0 (rightward) + transformer._state.speed_modulation = np.array([[0.0]]) + transformer._state.ch_axis = AxisArray.CoordinateAxis(data=np.array(["ch0"]), dims=["ch"]) + transformer._hash = 0 + + time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) + + # Polar: magnitude=1, angle=0 (aligned with pd) + aligned = AxisArray(np.array([[1.0, 0.0]]), dims=["time", "ch"], axes={"time": time_axis}) + output_aligned = transformer(aligned).data[0, 0] + + # Reset hash to reuse state + transformer._hash = 0 + + # Polar: magnitude=1, angle=pi (opposite to pd) + opposite = AxisArray(np.array([[1.0, np.pi]]), dims=["time", "ch"], axes={"time": time_axis}) + output_opposite = transformer(opposite).data[0, 0] + + # Aligned should have higher output (cos(0 - 0) = 1 vs cos(pi - 0) = -1) + # output_aligned = 10 + 20*1*1 = 30 + # output_opposite = 10 + 20*1*(-1) = -10 + assert output_aligned > output_opposite + assert np.isclose(output_aligned, 30.0) + assert np.isclose(output_opposite, -10.0) + + def test_speed_modulation(self): + """Test speed modulation term (speed_modulation * magnitude).""" + transformer = CosineEncoderTransformer(CosineEncoderSettings(output_ch=1, seed=42)) + + # Manually set state (shape: 1, output_ch) + transformer._state.baseline = np.array([[10.0]]) + transformer._state.modulation = np.array([[0.0]]) # No directional modulation + transformer._state.pd = np.array([[0.0]]) + transformer._state.speed_modulation = np.array([[5.0]]) + transformer._state.ch_axis = AxisArray.CoordinateAxis(data=np.array(["ch0"]), dims=["ch"]) + transformer._hash = 0 + + time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) + + # Different magnitudes + slow = AxisArray(np.array([[1.0, 0.0]]), dims=["time", "ch"], axes={"time": time_axis}) + output_slow = transformer(slow).data[0, 0] + + transformer._hash = 0 + + fast = AxisArray(np.array([[2.0, 0.0]]), dims=["time", "ch"], axes={"time": time_axis}) + output_fast = transformer(fast).data[0, 0] + + # output = 10 + 0 + 5*magnitude + assert np.isclose(output_slow, 10.0 + 5.0 * 1.0) + assert np.isclose(output_fast, 10.0 + 5.0 * 2.0) + + def test_invalid_input_shape(self): + """Test that invalid input shape raises error.""" + transformer = CosineEncoderTransformer(CosineEncoderSettings(output_ch=3, seed=42)) + + time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) + + # Wrong number of columns (should be 2) + bad_input = AxisArray(np.array([[1.0, 2.0, 3.0]]), dims=["time", "ch"], axes={"time": time_axis}) + + with pytest.raises(ValueError, match="shape"): + transformer(bad_input) + + def test_multiple_samples(self): + """Test processing multiple samples at once.""" + transformer = CosineEncoderTransformer(CosineEncoderSettings(output_ch=5, seed=42)) + + # 100 polar coordinate samples + n_samples = 100 + np.random.seed(123) + magnitude = np.abs(np.random.randn(n_samples, 1)) + angle = np.random.uniform(-np.pi, np.pi, (n_samples, 1)) + polar = np.hstack([magnitude, angle]) + + time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) + msg_in = AxisArray(polar, dims=["time", "ch"], axes={"time": time_axis}) + + msg_out = transformer(msg_in) + + assert msg_out.data.shape == (n_samples, 5) diff --git a/tests/unit/test_cosine_tuning.py b/tests/unit/test_cosine_tuning.py deleted file mode 100644 index bbbf29c..0000000 --- a/tests/unit/test_cosine_tuning.py +++ /dev/null @@ -1,254 +0,0 @@ -"""Unit tests for ezmsg.simbiophys.cosine_tuning module.""" - -import numpy as np -import pytest -from ezmsg.util.messages.axisarray import AxisArray - -from ezmsg.simbiophys import ( - CosineTuningParams, - CosineTuningSettings, - CosineTuningTransformer, -) - - -class TestCosineTuningParams: - """Tests for CosineTuningParams dataclass.""" - - def test_from_random_basic(self): - """Test random parameter generation.""" - params = CosineTuningParams.from_random(n_units=10, seed=42) - - assert params.n_units == 10 - assert params.b0.shape == (10,) - assert params.m.shape == (10,) - assert params.pd.shape == (10,) - assert params.bs.shape == (10,) - - # Check default values - assert np.allclose(params.b0, 10.0) # baseline_hz default - assert np.allclose(params.m, 20.0) # modulation_hz default - assert np.allclose(params.bs, 0.0) # speed_modulation_hz default - - # Check preferred directions are in [0, 2*pi) - assert np.all(params.pd >= 0) - assert np.all(params.pd < 2 * np.pi) - - def test_from_random_custom_params(self): - """Test random generation with custom parameters.""" - params = CosineTuningParams.from_random( - n_units=5, - baseline_hz=15.0, - modulation_hz=30.0, - speed_modulation_hz=5.0, - seed=123, - ) - - assert params.n_units == 5 - assert np.allclose(params.b0, 15.0) - assert np.allclose(params.m, 30.0) - assert np.allclose(params.bs, 5.0) - - def test_from_random_reproducible(self): - """Test that seed produces reproducible results.""" - params1 = CosineTuningParams.from_random(n_units=10, seed=42) - params2 = CosineTuningParams.from_random(n_units=10, seed=42) - - assert np.array_equal(params1.pd, params2.pd) - - def test_validation_shape_mismatch(self): - """Test that mismatched shapes raise error.""" - with pytest.raises(ValueError, match="same shape"): - CosineTuningParams( - b0=np.array([1.0, 2.0]), - m=np.array([1.0, 2.0, 3.0]), - pd=np.array([1.0, 2.0]), - bs=np.array([1.0, 2.0]), - ) - - def test_validation_not_1d(self): - """Test that non-1D arrays raise error.""" - with pytest.raises(ValueError, match="1D"): - CosineTuningParams( - b0=np.array([[1.0, 2.0]]), - m=np.array([[1.0, 2.0]]), - pd=np.array([[1.0, 2.0]]), - bs=np.array([[1.0, 2.0]]), - ) - - def test_validation_empty(self): - """Test that empty arrays raise error.""" - with pytest.raises(ValueError, match="length >= 1"): - CosineTuningParams( - b0=np.array([]), - m=np.array([]), - pd=np.array([]), - bs=np.array([]), - ) - - -class TestCosineTuningTransformer: - """Tests for CosineTuningTransformer.""" - - def test_basic_transform(self): - """Test basic velocity to firing rate transformation.""" - transformer = CosineTuningTransformer( - CosineTuningSettings( - n_units=4, - baseline_hz=10.0, - modulation_hz=20.0, - speed_modulation_hz=0.0, - seed=42, - ) - ) - - # Create velocity input: moving right at unit speed - velocity = np.array([[1.0, 0.0], [1.0, 0.0], [1.0, 0.0]]) - time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) - msg_in = AxisArray(velocity, dims=["time", "ch"], axes={"time": time_axis}) - - msg_out = transformer(msg_in) - - assert msg_out.data.shape == (3, 4) # (n_samples, n_units) - assert "ch" in msg_out.axes - assert msg_out.axes["ch"].data.shape == (4,) - - # All rates should be positive (baseline + modulation term) - assert np.all(msg_out.data >= 0) - - def test_stationary_baseline(self): - """Test that stationary input produces baseline rates.""" - transformer = CosineTuningTransformer( - CosineTuningSettings( - n_units=3, - baseline_hz=15.0, - modulation_hz=25.0, - speed_modulation_hz=5.0, - seed=42, - ) - ) - - # Zero velocity - velocity = np.array([[0.0, 0.0], [0.0, 0.0]]) - time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) - msg_in = AxisArray(velocity, dims=["time", "ch"], axes={"time": time_axis}) - - msg_out = transformer(msg_in) - - # When speed=0, rate = b0 + m*0*cos(...) + bs*0 = b0 - # But arctan2(0,0) can be arbitrary, so just check rates are ~baseline - assert np.allclose(msg_out.data, 15.0) - - def test_directional_tuning(self): - """Test that tuning varies with direction.""" - # Create transformer with known preferred direction - params = CosineTuningParams( - b0=np.array([10.0]), - m=np.array([20.0]), - pd=np.array([0.0]), # Preferred direction = 0 (rightward) - bs=np.array([0.0]), - ) - - transformer = CosineTuningTransformer(CosineTuningSettings(n_units=1, seed=42)) - # Manually set params - transformer._state.params = params - transformer._hash = 0 - - time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) - - # Moving right (direction = 0, aligned with pd) - right = AxisArray(np.array([[1.0, 0.0]]), dims=["time", "ch"], axes={"time": time_axis}) - rate_right = transformer(right).data[0, 0] - - # Reset state for fresh processing - transformer._hash = 0 - transformer._state.params = params - - # Moving left (direction = pi, opposite to pd) - left = AxisArray(np.array([[-1.0, 0.0]]), dims=["time", "ch"], axes={"time": time_axis}) - rate_left = transformer(left).data[0, 0] - - # Right should have higher rate (cos(0 - 0) = 1 vs cos(pi - 0) = -1) - # rate_right = 10 + 20*1*1 + 0 = 30 - # rate_left = 10 + 20*1*(-1) + 0 = -10, but clipped to min_rate=0 - assert rate_right > rate_left - assert np.isclose(rate_right, 30.0) - - def test_speed_modulation(self): - """Test speed modulation term (bs * |v|).""" - params = CosineTuningParams( - b0=np.array([10.0]), - m=np.array([0.0]), # No directional modulation - pd=np.array([0.0]), - bs=np.array([5.0]), # Speed modulation - ) - - transformer = CosineTuningTransformer(CosineTuningSettings(n_units=1, seed=42)) - transformer._state.params = params - transformer._hash = 0 - - time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) - - # Different speeds, same direction - slow = AxisArray(np.array([[1.0, 0.0]]), dims=["time", "ch"], axes={"time": time_axis}) - rate_slow = transformer(slow).data[0, 0] - - transformer._hash = 0 - transformer._state.params = params - - fast = AxisArray(np.array([[2.0, 0.0]]), dims=["time", "ch"], axes={"time": time_axis}) - rate_fast = transformer(fast).data[0, 0] - - # rate = 10 + 0 + 5*|v| - assert np.isclose(rate_slow, 10.0 + 5.0 * 1.0) - assert np.isclose(rate_fast, 10.0 + 5.0 * 2.0) - - def test_min_rate_clipping(self): - """Test that rates are clipped to min_rate.""" - params = CosineTuningParams( - b0=np.array([5.0]), - m=np.array([20.0]), - pd=np.array([0.0]), - bs=np.array([0.0]), - ) - - transformer = CosineTuningTransformer(CosineTuningSettings(n_units=1, min_rate=0.0, seed=42)) - transformer._state.params = params - transformer._hash = 0 - - time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) - - # Moving opposite to preferred direction - # rate = 5 + 20*1*cos(pi) = 5 - 20 = -15, should be clipped to 0 - msg_in = AxisArray(np.array([[-1.0, 0.0]]), dims=["time", "ch"], axes={"time": time_axis}) - msg_out = transformer(msg_in) - - assert msg_out.data[0, 0] >= 0.0 - - def test_invalid_input_shape(self): - """Test that invalid input shape raises error.""" - transformer = CosineTuningTransformer(CosineTuningSettings(n_units=3, seed=42)) - - time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) - - # Wrong number of columns (should be 2) - bad_input = AxisArray(np.array([[1.0, 2.0, 3.0]]), dims=["time", "ch"], axes={"time": time_axis}) - - with pytest.raises(ValueError, match="shape"): - transformer(bad_input) - - def test_multiple_samples(self): - """Test processing multiple samples at once.""" - transformer = CosineTuningTransformer(CosineTuningSettings(n_units=5, seed=42)) - - # 100 velocity samples - n_samples = 100 - np.random.seed(123) - velocity = np.random.randn(n_samples, 2) - - time_axis = AxisArray.TimeAxis(fs=100.0, offset=0.0) - msg_in = AxisArray(velocity, dims=["time", "ch"], axes={"time": time_axis}) - - msg_out = transformer(msg_in) - - assert msg_out.data.shape == (n_samples, 5) - assert np.all(msg_out.data >= 0) # All rates non-negative diff --git a/tests/unit/test_dnss_lfp.py b/tests/unit/test_dnss_lfp.py index 5f58862..e29272b 100644 --- a/tests/unit/test_dnss_lfp.py +++ b/tests/unit/test_dnss_lfp.py @@ -228,25 +228,20 @@ def test_other_mode_frequency_segments(self): assert freqs[peak_idx] == pytest.approx(10.0, abs=1.0) -class TestDNSSLFPTransformer: - """Tests for DNSSLFPTransformer.""" - - def _create_counter_input(self, n_time: int, offset: float = 0.0, counter_start: int = 0) -> "AxisArray": - """Create a counter AxisArray input.""" - data = np.arange(counter_start, counter_start + n_time) - return AxisArray( - data=data, - dims=["time"], - axes={"time": AxisArray.TimeAxis(fs=DEFAULT_FS, offset=offset)}, - ) - - def test_transformer_sync_call(self): - """Test synchronous transformer via __call__.""" - from ezmsg.simbiophys.dnss.lfp import DNSSLFPSettings, DNSSLFPTransformer - - transformer = DNSSLFPTransformer(DNSSLFPSettings(n_ch=4)) - input_msg = self._create_counter_input(n_time=600) - result = transformer(input_msg) +class TestDNSSLFPProducer: + """Tests for DNSSLFPProducer.""" + + def _create_clock_tick(self, n_time: int, offset: float = 0.0) -> "AxisArray.LinearAxis": + """Create a clock tick (LinearAxis).""" + return AxisArray.LinearAxis(gain=1.0 / DEFAULT_FS, offset=offset) + + def test_producer_sync_call(self): + """Test synchronous producer via __call__.""" + from ezmsg.simbiophys.dnss.lfp import DNSSLFPProducer, DNSSLFPSettings + + producer = DNSSLFPProducer(DNSSLFPSettings(n_time=600, n_ch=4)) + clock_tick = self._create_clock_tick(n_time=600) + result = producer(clock_tick) assert result is not None assert result.data.shape[1] == 4 # n_ch @@ -254,43 +249,45 @@ def test_transformer_sync_call(self): assert "time" in result.dims assert "ch" in result.dims - def test_transformer_output_shape(self): - """Transformer output has correct shape (time, ch).""" - from ezmsg.simbiophys.dnss.lfp import DNSSLFPSettings, DNSSLFPTransformer + def test_producer_output_shape(self): + """Producer output has correct shape (time, ch).""" + from ezmsg.simbiophys.dnss.lfp import DNSSLFPProducer, DNSSLFPSettings n_ch = 16 - transformer = DNSSLFPTransformer(DNSSLFPSettings(n_ch=n_ch)) - input_msg = self._create_counter_input(n_time=100) + n_time = 100 + producer = DNSSLFPProducer(DNSSLFPSettings(n_time=n_time, n_ch=n_ch)) + clock_tick = self._create_clock_tick(n_time=n_time) - result = transformer(input_msg) + result = producer(clock_tick) assert result.data.ndim == 2 assert result.data.shape[1] == n_ch - def test_transformer_channels_identical(self): + def test_producer_channels_identical(self): """All channels have identical LFP values.""" - from ezmsg.simbiophys.dnss.lfp import DNSSLFPSettings, DNSSLFPTransformer + from ezmsg.simbiophys.dnss.lfp import DNSSLFPProducer, DNSSLFPSettings - transformer = DNSSLFPTransformer(DNSSLFPSettings(n_ch=8)) - input_msg = self._create_counter_input(n_time=600) - result = transformer(input_msg) + n_time = 600 + producer = DNSSLFPProducer(DNSSLFPSettings(n_time=n_time, n_ch=8)) + clock_tick = self._create_clock_tick(n_time=n_time) + result = producer(clock_tick) # All columns should be identical for ch in range(1, result.data.shape[1]): np.testing.assert_allclose(result.data[:, 0], result.data[:, ch]) - def test_transformer_continuity(self): + def test_producer_continuity(self): """Multiple calls produce continuous data.""" - from ezmsg.simbiophys.dnss.lfp import DNSSLFPSettings, DNSSLFPTransformer + from ezmsg.simbiophys.dnss.lfp import DNSSLFPProducer, DNSSLFPSettings - transformer = DNSSLFPTransformer(DNSSLFPSettings(n_ch=4)) + n_time = 600 + producer = DNSSLFPProducer(DNSSLFPSettings(n_time=n_time, n_ch=4)) # Get multiple chunks results = [] - counter = 0 for i in range(5): - input_msg = self._create_counter_input(n_time=600, offset=counter / DEFAULT_FS, counter_start=counter) - results.append(transformer(input_msg)) - counter += 600 + offset = i * n_time / DEFAULT_FS + clock_tick = self._create_clock_tick(n_time=n_time, offset=offset) + results.append(producer(clock_tick)) # Concatenate first channel combined = np.concatenate([r.data[:, 0] for r in results]) @@ -302,17 +299,18 @@ def test_transformer_continuity(self): np.testing.assert_allclose(combined, expected) - def test_transformer_different_patterns(self): - """Transformer works with different patterns.""" - from ezmsg.simbiophys.dnss.lfp import DNSSLFPSettings, DNSSLFPTransformer + def test_producer_different_patterns(self): + """Producer works with different patterns.""" + from ezmsg.simbiophys.dnss.lfp import DNSSLFPProducer, DNSSLFPSettings - spike_transformer = DNSSLFPTransformer(DNSSLFPSettings(n_ch=4, pattern="spike")) - other_transformer = DNSSLFPTransformer(DNSSLFPSettings(n_ch=4, pattern="other")) + n_time = 600 + spike_producer = DNSSLFPProducer(DNSSLFPSettings(n_time=n_time, n_ch=4, pattern="spike")) + other_producer = DNSSLFPProducer(DNSSLFPSettings(n_time=n_time, n_ch=4, pattern="other")) - input_msg = self._create_counter_input(n_time=600) + clock_tick = self._create_clock_tick(n_time=n_time) - spike_result = spike_transformer(input_msg) - other_result = other_transformer(input_msg) + spike_result = spike_producer(clock_tick) + other_result = other_producer(clock_tick) # Both should produce valid output assert spike_result.data.shape[0] > 0 diff --git a/tests/unit/test_dnss_spike.py b/tests/unit/test_dnss_spike.py index ad4bb9e..49bfdc0 100644 --- a/tests/unit/test_dnss_spike.py +++ b/tests/unit/test_dnss_spike.py @@ -508,25 +508,20 @@ def test_different_n_chans(self): assert np.sum(burst_mask) == N_BURST_SPIKES * n_chans -class TestDNSSSpikeTransformer: - """Tests for DNSSSpikeTransformer.""" - - def _create_counter_input(self, n_time: int, offset: float = 0.0, counter_start: int = 0) -> "AxisArray": - """Create a counter AxisArray input.""" - data = np.arange(counter_start, counter_start + n_time) - return AxisArray( - data=data, - dims=["time"], - axes={"time": AxisArray.TimeAxis(fs=FS, offset=offset)}, - ) - - def test_transformer_sync_call(self): - """Test synchronous transformer via __call__.""" - from ezmsg.simbiophys.dnss.spike import DNSSSpikeSettings, DNSSSpikeTransformer - - transformer = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=4)) - input_msg = self._create_counter_input(n_time=600) - result = transformer(input_msg) +class TestDNSSSpikeProducer: + """Tests for DNSSSpikeProducer.""" + + def _create_clock_tick(self, n_time: int, offset: float = 0.0) -> "AxisArray.LinearAxis": + """Create a clock tick (LinearAxis).""" + return AxisArray.LinearAxis(gain=1.0 / FS, offset=offset) + + def test_producer_sync_call(self): + """Test synchronous producer via __call__.""" + from ezmsg.simbiophys.dnss.spike import DNSSSpikeProducer, DNSSSpikeSettings + + producer = DNSSSpikeProducer(DNSSSpikeSettings(n_time=600, n_ch=4)) + clock_tick = self._create_clock_tick(n_time=600) + result = producer(clock_tick) assert result is not None assert result.data.shape[1] == 4 # n_ch @@ -534,137 +529,120 @@ def test_transformer_sync_call(self): assert "time" in result.dims assert "ch" in result.dims - def test_transformer_output_is_sparse(self): - """Transformer output data is sparse.COO.""" + def test_producer_output_is_sparse(self): + """Producer output data is sparse.COO.""" import sparse - from ezmsg.simbiophys.dnss.spike import DNSSSpikeSettings, DNSSSpikeTransformer + from ezmsg.simbiophys.dnss.spike import DNSSSpikeProducer, DNSSSpikeSettings - transformer = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=4)) - input_msg = self._create_counter_input(n_time=600) - result = transformer(input_msg) + producer = DNSSSpikeProducer(DNSSSpikeSettings(n_time=600, n_ch=4)) + clock_tick = self._create_clock_tick(n_time=600) + result = producer(clock_tick) assert isinstance(result.data, sparse.COO) assert result.data.ndim == 2 - def test_transformer_spike_values(self): + def test_producer_spike_values(self): """Spike values are waveform IDs (1, 2, or 3).""" - from ezmsg.simbiophys.dnss.spike import DNSSSpikeSettings, DNSSSpikeTransformer + from ezmsg.simbiophys.dnss.spike import DNSSSpikeProducer, DNSSSpikeSettings - transformer = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=4)) + n_time = 600 + producer = DNSSSpikeProducer(DNSSSpikeSettings(n_time=n_time, n_ch=4)) # Collect multiple chunks to ensure we get spikes all_data = [] - counter = 0 - for _ in range(10): - input_msg = self._create_counter_input(n_time=600, offset=counter / FS, counter_start=counter) - result = transformer(input_msg) + for i in range(10): + offset = i * n_time / FS + clock_tick = self._create_clock_tick(n_time=n_time, offset=offset) + result = producer(clock_tick) if result.data.nnz > 0: all_data.extend(result.data.data.tolist()) - counter += 600 # All non-zero values should be 1, 2, or 3 assert len(all_data) > 0 assert set(all_data).issubset({1, 2, 3}) - def test_transformer_continuity(self): + def test_producer_continuity(self): """Multiple calls produce continuous spike pattern.""" - from ezmsg.simbiophys.dnss.spike import DNSSSpikeSettings, DNSSSpikeTransformer + from ezmsg.simbiophys.dnss.spike import DNSSSpikeProducer, DNSSSpikeSettings - transformer = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=4, mode="ideal")) + n_time = 600 + producer = DNSSSpikeProducer(DNSSSpikeSettings(n_time=n_time, n_ch=4, mode="ideal")) # Get multiple chunks all_coords = [] all_waveforms = [] - counter = 0 - for _ in range(20): - input_msg = self._create_counter_input(n_time=600, offset=counter / FS, counter_start=counter) - result = transformer(input_msg) + total_samples = 0 + for i in range(20): + offset = i * n_time / FS + clock_tick = self._create_clock_tick(n_time=n_time, offset=offset) + result = producer(clock_tick) if result.data.nnz > 0: coords = result.data.coords - all_coords.append(coords[0] + counter) # Adjust sample indices + all_coords.append(coords[0] + i * n_time) # Adjust sample indices all_waveforms.extend(result.data.data.tolist()) - counter += 600 + total_samples += n_time # Compare with generator output gen = spike_event_generator(mode="ideal", n_chans=4) next(gen) - expected_coords, expected_waveforms = gen.send(counter) + expected_coords, expected_waveforms = gen.send(total_samples) if len(all_coords) > 0: - transformer_samples = np.concatenate(all_coords) - np.testing.assert_array_equal(np.sort(transformer_samples), np.sort(expected_coords[0])) + producer_samples = np.concatenate(all_coords) + np.testing.assert_array_equal(np.sort(producer_samples), np.sort(expected_coords[0])) - def test_transformer_different_n_chans(self): - """Transformer works with different channel counts.""" - from ezmsg.simbiophys.dnss.spike import DNSSSpikeSettings, DNSSSpikeTransformer + def test_producer_different_n_chans(self): + """Producer works with different channel counts.""" + from ezmsg.simbiophys.dnss.spike import DNSSSpikeProducer, DNSSSpikeSettings + n_time = 600 for n_ch in [4, 8, 32, 256]: - transformer = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=n_ch)) - input_msg = self._create_counter_input(n_time=600) - result = transformer(input_msg) + producer = DNSSSpikeProducer(DNSSSpikeSettings(n_time=n_time, n_ch=n_ch)) + clock_tick = self._create_clock_tick(n_time=n_time) + result = producer(clock_tick) assert result.data.shape[1] == n_ch - def test_transformer_hdmi_vs_ideal_mode(self): + def test_producer_hdmi_vs_ideal_mode(self): """HDMI and ideal modes produce different spike patterns.""" - from ezmsg.simbiophys.dnss.spike import DNSSSpikeSettings, DNSSSpikeTransformer + from ezmsg.simbiophys.dnss.spike import DNSSSpikeProducer, DNSSSpikeSettings - hdmi_transformer = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=4, mode="hdmi")) - # ideal_transformer = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=4, mode="ideal")) + n_time = 600 + hdmi_producer = DNSSSpikeProducer(DNSSSpikeSettings(n_time=n_time, n_ch=4, mode="hdmi")) # Collect enough data to see differences hdmi_coords = [] ideal_coords = [] - counter = 0 - for _ in range(50): - input_msg = self._create_counter_input(n_time=600, offset=counter / FS, counter_start=counter) - hdmi_result = hdmi_transformer(input_msg) + for i in range(50): + offset = i * n_time / FS + clock_tick = self._create_clock_tick(n_time=n_time, offset=offset) + hdmi_result = hdmi_producer(clock_tick) - # Reset transformer for ideal (need fresh input) - ideal_transformer_fresh = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=4, mode="ideal")) - ideal_input = self._create_counter_input(n_time=600, offset=counter / FS, counter_start=counter) - ideal_result = ideal_transformer_fresh(ideal_input) + # Reset producer for ideal (need fresh input) + ideal_producer_fresh = DNSSSpikeProducer(DNSSSpikeSettings(n_time=n_time, n_ch=4, mode="ideal")) + ideal_result = ideal_producer_fresh(clock_tick) if hdmi_result.data.nnz > 0: hdmi_coords.append(hdmi_result.data.coords.copy()) if ideal_result.data.nnz > 0: ideal_coords.append(ideal_result.data.coords.copy()) - counter += 600 # Both should produce spikes assert len(hdmi_coords) > 0 assert len(ideal_coords) > 0 - def test_transformer_empty_chunks_handled(self): - """Transformer handles chunks with no spikes correctly.""" - from ezmsg.simbiophys.dnss.spike import DNSSSpikeSettings, DNSSSpikeTransformer + def test_producer_empty_chunks_handled(self): + """Producer handles chunks with no spikes correctly.""" + from ezmsg.simbiophys.dnss.spike import DNSSSpikeProducer, DNSSSpikeSettings - transformer = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=4)) + n_time = 600 + producer = DNSSSpikeProducer(DNSSSpikeSettings(n_time=n_time, n_ch=4)) # Even if no spikes, should return valid sparse array - counter = 0 - for _ in range(10): - input_msg = self._create_counter_input(n_time=600, offset=counter / FS, counter_start=counter) - result = transformer(input_msg) + for i in range(10): + offset = i * n_time / FS + clock_tick = self._create_clock_tick(n_time=n_time, offset=offset) + result = producer(clock_tick) assert result.data.shape[0] >= 0 assert result.data.shape[1] == 4 - counter += 600 - - def test_transformer_rejects_wrong_sample_rate(self): - """Transformer raises error for sample rates other than 30kHz.""" - from ezmsg.simbiophys.dnss.spike import DNSSSpikeSettings, DNSSSpikeTransformer - - transformer = DNSSSpikeTransformer(DNSSSpikeSettings(n_ch=4)) - - # Create input with wrong sample rate (e.g., 1000 Hz) - wrong_fs = 1000 - data = np.arange(100) - input_msg = AxisArray( - data=data, - dims=["time"], - axes={"time": AxisArray.TimeAxis(fs=wrong_fs, offset=0.0)}, - ) - - with pytest.raises(ValueError, match="requires fs=30000"): - transformer(input_msg) diff --git a/tests/unit/test_oscillator.py b/tests/unit/test_oscillator.py index 0a39552..a6ffd28 100644 --- a/tests/unit/test_oscillator.py +++ b/tests/unit/test_oscillator.py @@ -1,70 +1,63 @@ """Unit tests for ezmsg.simbiophys.oscillator module.""" import numpy as np +import pytest from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.simbiophys import SinGeneratorSettings, SinTransformer +from ezmsg.simbiophys import SinGeneratorSettings, SinProducer -def test_sin_transformer(freq: float = 1.0, amp: float = 1.0, phase: float = 0.0): - """Test SinTransformer via __call__.""" +def test_sin_generator_basic(freq: float = 1.0, amp: float = 1.0, phase: float = 0.0): + """Test SinProducer via __call__.""" n_ch = 1 srate = max(4.0 * freq, 1000.0) sim_dur = 30.0 n_samples = int(srate * sim_dur) n_msgs = min(n_samples, 10) - - # Create input messages with counter data (integer sample counts) - messages = [] - counter = 0 samples_per_msg = n_samples // n_msgs - for i in range(n_msgs): - n = samples_per_msg if i < n_msgs - 1 else n_samples - counter - sample_indices = np.arange(counter, counter + n) - _time_axis = AxisArray.TimeAxis(fs=srate, offset=counter / srate) - messages.append(AxisArray(sample_indices, dims=["time"], axes={"time": _time_axis})) - counter += n def f_test(t): return amp * np.sin(2 * np.pi * freq * t + phase) - # Create transformer - transformer = SinTransformer(SinGeneratorSettings(n_ch=n_ch, freq=freq, amp=amp, phase=phase)) + # Create producer with clock-driven settings + producer = SinProducer( + SinGeneratorSettings(fs=srate, n_time=samples_per_msg, n_ch=n_ch, freq=freq, amp=amp, phase=phase) + ) - # Process messages + # Process clock ticks results = [] - for msg in messages: - res = transformer(msg) + for i in range(n_msgs): + offset = i * samples_per_msg / srate + clock_tick = AxisArray.TimeAxis(fs=srate, offset=offset) + res = producer(clock_tick) + # Check output shape - assert res.data.shape == (len(msg.data), n_ch) - # Check values - t = msg.data / srate - expected = f_test(t)[:, np.newaxis] - assert np.allclose(res.data, expected) + expected_n = samples_per_msg if i < n_msgs - 1 else n_samples - i * samples_per_msg + # Note: With fixed n_time, all chunks have the same size + assert res.data.shape == (expected_n, n_ch) results.append(res) # Verify concatenated output concat_ax_arr = AxisArray.concatenate(*results, dim="time") - assert np.allclose(concat_ax_arr.data, f_test(np.arange(n_samples) / srate)[:, np.newaxis]) + total_samples = concat_ax_arr.data.shape[0] + t = np.arange(total_samples) / srate + expected = f_test(t)[:, np.newaxis] + np.testing.assert_allclose(concat_ax_arr.data, expected, rtol=1e-10) -def test_sin_transformer_multi_channel(): - """Test SinTransformer with multiple channels.""" +def test_sin_generator_multi_channel(): + """Test SinProducer with multiple channels.""" n_ch = 4 freq = 10.0 srate = 1000.0 n_samples = 100 - # Create input with counter data - sample_indices = np.arange(n_samples) - _time_axis = AxisArray.TimeAxis(fs=srate, offset=0.0) - msg = AxisArray(sample_indices, dims=["time"], axes={"time": _time_axis}) + # Create producer + producer = SinProducer(SinGeneratorSettings(fs=srate, n_time=n_samples, n_ch=n_ch, freq=freq)) - # Create transformer - transformer = SinTransformer(SinGeneratorSettings(n_ch=n_ch, freq=freq)) - - # Process - result = transformer(msg) + # Process single clock tick + clock_tick = AxisArray.TimeAxis(fs=srate, offset=0.0) + result = producer(clock_tick) # Check output shape assert result.data.shape == (n_samples, n_ch) @@ -75,8 +68,8 @@ def test_sin_transformer_multi_channel(): np.testing.assert_allclose(result.data[:, 0], result.data[:, ch]) -def test_sin_transformer_per_channel_freq(): - """Test SinTransformer with per-channel frequencies.""" +def test_sin_generator_per_channel_freq(): + """Test SinProducer with per-channel frequencies.""" n_ch = 3 freqs = [5.0, 10.0, 20.0] amp = 1.0 @@ -84,29 +77,27 @@ def test_sin_transformer_per_channel_freq(): srate = 1000.0 n_samples = 100 - # Create input with counter data - sample_indices = np.arange(n_samples) - _time_axis = AxisArray.TimeAxis(fs=srate, offset=0.0) - msg = AxisArray(sample_indices, dims=["time"], axes={"time": _time_axis}) - - # Create transformer with per-channel freqs - transformer = SinTransformer(SinGeneratorSettings(n_ch=n_ch, freq=freqs, amp=amp, phase=phase)) + # Create producer with per-channel freqs + producer = SinProducer( + SinGeneratorSettings(fs=srate, n_time=n_samples, n_ch=n_ch, freq=freqs, amp=amp, phase=phase) + ) # Process - result = transformer(msg) + clock_tick = AxisArray.TimeAxis(fs=srate, offset=0.0) + result = producer(clock_tick) # Check output shape assert result.data.shape == (n_samples, n_ch) # Verify each channel has correct frequency - t = sample_indices / srate + t = np.arange(n_samples) / srate for ch, freq in enumerate(freqs): expected = amp * np.sin(2 * np.pi * freq * t + phase) np.testing.assert_allclose(result.data[:, ch], expected, rtol=1e-10) -def test_sin_transformer_per_channel_all_params(): - """Test SinTransformer with per-channel freq, amp, and phase.""" +def test_sin_generator_per_channel_all_params(): + """Test SinProducer with per-channel freq, amp, and phase.""" n_ch = 4 freqs = [5.0, 10.0, 15.0, 20.0] amps = [1.0, 2.0, 0.5, 1.5] @@ -114,29 +105,27 @@ def test_sin_transformer_per_channel_all_params(): srate = 1000.0 n_samples = 200 - # Create input with counter data - sample_indices = np.arange(n_samples) - _time_axis = AxisArray.TimeAxis(fs=srate, offset=0.0) - msg = AxisArray(sample_indices, dims=["time"], axes={"time": _time_axis}) - - # Create transformer with all per-channel params - transformer = SinTransformer(SinGeneratorSettings(n_ch=n_ch, freq=freqs, amp=amps, phase=phases)) + # Create producer with all per-channel params + producer = SinProducer( + SinGeneratorSettings(fs=srate, n_time=n_samples, n_ch=n_ch, freq=freqs, amp=amps, phase=phases) + ) # Process - result = transformer(msg) + clock_tick = AxisArray.TimeAxis(fs=srate, offset=0.0) + result = producer(clock_tick) # Check output shape assert result.data.shape == (n_samples, n_ch) # Verify each channel (use atol for values near zero) - t = sample_indices / srate + t = np.arange(n_samples) / srate for ch in range(n_ch): expected = amps[ch] * np.sin(2 * np.pi * freqs[ch] * t + phases[ch]) np.testing.assert_allclose(result.data[:, ch], expected, rtol=1e-10, atol=1e-14) -def test_sin_transformer_mixed_scalar_array(): - """Test SinTransformer with mixed scalar and array params.""" +def test_sin_generator_mixed_scalar_array(): + """Test SinProducer with mixed scalar and array params.""" n_ch = 3 freqs = [5.0, 10.0, 20.0] # per-channel amp = 2.0 # scalar - same for all channels @@ -144,50 +133,43 @@ def test_sin_transformer_mixed_scalar_array(): srate = 1000.0 n_samples = 100 - # Create input with counter data - sample_indices = np.arange(n_samples) - _time_axis = AxisArray.TimeAxis(fs=srate, offset=0.0) - msg = AxisArray(sample_indices, dims=["time"], axes={"time": _time_axis}) - - # Create transformer - transformer = SinTransformer(SinGeneratorSettings(n_ch=n_ch, freq=freqs, amp=amp, phase=phase)) + # Create producer + producer = SinProducer( + SinGeneratorSettings(fs=srate, n_time=n_samples, n_ch=n_ch, freq=freqs, amp=amp, phase=phase) + ) # Process - result = transformer(msg) + clock_tick = AxisArray.TimeAxis(fs=srate, offset=0.0) + result = producer(clock_tick) # Check output shape assert result.data.shape == (n_samples, n_ch) # Verify each channel - t = sample_indices / srate + t = np.arange(n_samples) / srate for ch, freq in enumerate(freqs): expected = amp * np.sin(2 * np.pi * freq * t + phase) np.testing.assert_allclose(result.data[:, ch], expected, rtol=1e-10) -def test_sin_transformer_array_length_mismatch(): - """Test SinTransformer raises error when array length doesn't match n_ch.""" - import pytest - +def test_sin_generator_array_length_mismatch(): + """Test SinProducer raises error when array length doesn't match n_ch.""" n_ch = 4 freqs = [5.0, 10.0, 20.0] # length 3, but n_ch is 4 srate = 1000.0 + n_samples = 100 - # Create input - sample_indices = np.arange(100) - _time_axis = AxisArray.TimeAxis(fs=srate, offset=0.0) - msg = AxisArray(sample_indices, dims=["time"], axes={"time": _time_axis}) - - # Create transformer - should not raise here - transformer = SinTransformer(SinGeneratorSettings(n_ch=n_ch, freq=freqs)) + # Create producer - should not raise here + producer = SinProducer(SinGeneratorSettings(fs=srate, n_time=n_samples, n_ch=n_ch, freq=freqs)) - # Should raise ValueError when processing + # Should raise ValueError when processing (during _reset_state) + clock_tick = AxisArray.TimeAxis(fs=srate, offset=0.0) with pytest.raises(ValueError, match="freq has length 3 but n_ch is 4"): - transformer(msg) + producer(clock_tick) -def test_sin_transformer_numpy_array_input(): - """Test SinTransformer accepts numpy arrays for per-channel params.""" +def test_sin_generator_numpy_array_input(): + """Test SinProducer accepts numpy arrays for per-channel params.""" n_ch = 3 freqs = np.array([5.0, 10.0, 20.0]) amps = np.array([1.0, 2.0, 0.5]) @@ -195,22 +177,77 @@ def test_sin_transformer_numpy_array_input(): srate = 1000.0 n_samples = 100 - # Create input with counter data - sample_indices = np.arange(n_samples) - _time_axis = AxisArray.TimeAxis(fs=srate, offset=0.0) - msg = AxisArray(sample_indices, dims=["time"], axes={"time": _time_axis}) - - # Create transformer with numpy arrays - transformer = SinTransformer(SinGeneratorSettings(n_ch=n_ch, freq=freqs, amp=amps, phase=phases)) + # Create producer with numpy arrays + producer = SinProducer( + SinGeneratorSettings(fs=srate, n_time=n_samples, n_ch=n_ch, freq=freqs, amp=amps, phase=phases) + ) # Process - result = transformer(msg) + clock_tick = AxisArray.TimeAxis(fs=srate, offset=0.0) + result = producer(clock_tick) # Check output shape assert result.data.shape == (n_samples, n_ch) # Verify each channel (use atol for values near zero) - t = sample_indices / srate + t = np.arange(n_samples) / srate for ch in range(n_ch): expected = amps[ch] * np.sin(2 * np.pi * freqs[ch] * t + phases[ch]) np.testing.assert_allclose(result.data[:, ch], expected, rtol=1e-10, atol=1e-14) + + +def test_sin_generator_continuity_across_chunks(): + """Test that sine wave is continuous across multiple clock ticks.""" + n_ch = 1 + freq = 10.0 + srate = 1000.0 + n_samples_per_chunk = 50 + n_chunks = 4 + + # Create producer + producer = SinProducer(SinGeneratorSettings(fs=srate, n_time=n_samples_per_chunk, n_ch=n_ch, freq=freq)) + + # Process multiple clock ticks + results = [] + for i in range(n_chunks): + offset = i * n_samples_per_chunk / srate + clock_tick = AxisArray.TimeAxis(fs=srate, offset=offset) + results.append(producer(clock_tick)) + + # Concatenate and verify continuity + concat = AxisArray.concatenate(*results, dim="time") + total_samples = n_samples_per_chunk * n_chunks + t = np.arange(total_samples) / srate + expected = np.sin(2 * np.pi * freq * t)[:, np.newaxis] + np.testing.assert_allclose(concat.data, expected, rtol=1e-10) + + +def test_sin_generator_variable_chunk_mode(): + """Test SinProducer with variable chunk sizes (n_time=None).""" + n_ch = 1 + freq = 10.0 + srate = 1000.0 + + # Create producer without fixed n_time + producer = SinProducer(SinGeneratorSettings(fs=srate, n_time=None, n_ch=n_ch, freq=freq)) + + # Clock ticks with different gains -> different chunk sizes + clock_ticks = [ + AxisArray.LinearAxis(gain=0.1, offset=0.0), # 100 samples + AxisArray.LinearAxis(gain=0.05, offset=0.1), # 50 samples + AxisArray.LinearAxis(gain=0.2, offset=0.15), # 200 samples + ] + + results = [] + for tick in clock_ticks: + results.append(producer(tick)) + + assert results[0].data.shape == (100, n_ch) + assert results[1].data.shape == (50, n_ch) + assert results[2].data.shape == (200, n_ch) + + # Verify continuity + concat = AxisArray.concatenate(*results, dim="time") + t = np.arange(350) / srate + expected = np.sin(2 * np.pi * freq * t)[:, np.newaxis] + np.testing.assert_allclose(concat.data, expected, rtol=1e-10)