From f99d79bfe07cb51ebb7f0907f895ebdde59d94ee Mon Sep 17 00:00:00 2001 From: Sung hun kwag Date: Tue, 10 Feb 2026 01:04:15 +0900 Subject: [PATCH 1/2] Implement S4D structured state space model --- core/ssm.py | 322 ++++++++++++++++++++-------------------------------- 1 file changed, 125 insertions(+), 197 deletions(-) diff --git a/core/ssm.py b/core/ssm.py index 0122595..067773b 100644 --- a/core/ssm.py +++ b/core/ssm.py @@ -1,108 +1,41 @@ -"""Core Neural State Model: PyTorch Implementation with Explicit State Tracking - -⚠️ IMPORTANT: This is NOT a structured State Space Model (SSM) like S4/Mamba/LRU. - -This module implements a neural network with explicit state representation, -inspired by State Space Model concepts but using standard MLP components. - -Architecture: - h_t = MLP(h_{t-1}) + Linear(x_t) # State transition with residual - y_t = MLP(h_t) + Linear(x_t) # Output with feedthrough - -What this IS: - - Neural network with explicit state tracking - - MLP-based state transitions with residual connections - - Compatible with meta-learning algorithms (MAML) - - Recurrent processing (not parallelizable) - -What this is NOT: - - Structured SSM (no HiPPO, diagonal, or low-rank parameterization) - - Continuous-time dynamics (no discretization) - - FFT-based convolution mode (no parallel processing) - - Sub-quadratic complexity (actual: O(d²) per timestep) - -Complexity: - - Forward pass: O(d²) per timestep (due to MLP layers) - - Similar to GRU/LSTM, not faster - - No convolution mode for parallelization - -Use this if: - - You need explicit state representation for RL - - You want compatibility with standard meta-learning - - You prioritize simplicity over efficiency - -Consider alternatives if: - - You need true sub-quadratic complexity - - You want FFT-based parallel processing - - You require structured SSM guarantees - -Example: - >>> import torch - >>> from core.ssm import SSM - >>> - >>> model = SSM(state_dim=64, input_dim=32, output_dim=16) - >>> x = torch.randn(4, 32) # batch_size=4 - >>> h = model.init_hidden(4) - >>> output, next_h = model(x, h) - >>> print(output.shape, next_h.shape) - torch.Size([4, 16]) torch.Size([4, 64]) +"""Structured State Space Model (S4D) with explicit recurrence. + +This module implements a diagonal SSM (S4D) with bilinear (Tustin) discretization +in pure PyTorch, compatible with higher-order gradients for MAML. """ +from __future__ import annotations + +import os +from typing import Tuple, Optional + import torch import torch.nn as nn -import os -from typing import Tuple, Optional, Dict, Any + class SSM(nn.Module): - """Neural State Model with MLP-based transitions (NOT structured SSM). - - ⚠️ WARNING: Despite the name, this is NOT a structured State Space Model. - This is a neural network with explicit state, using MLP for transitions. - - Architecture: - State transition: h_t = MLP(h_{t-1}) + Linear(x_t) - Output: y_t = MLP(h_t) + Linear(x_t) - - The "SSM" naming is kept for backward compatibility, but this should be - understood as "Stateful Sequential Model" not "State Space Model". - - Args: - state_dim (int): Dimension of the internal hidden state - input_dim (int): Dimension of input features - output_dim (int): Dimension of output features - hidden_dim (int): Hidden layer size in MLP networks (default: 128) - device (str): Device to run on ('cpu' or 'cuda') - - Attributes: - state_transition: MLP network for state updates (A matrix analog) - input_projection: Linear layer for input (B matrix analog) - output_network: MLP network for output (C matrix analog) - feedthrough: Direct input-to-output connection (D matrix analog) - - Methods: - forward(x, hidden_state): Process one timestep - init_hidden(batch_size): Initialize hidden state - save(path): Save model checkpoint - load(path): Load model checkpoint - - Complexity: - Time: O(d²) per timestep (due to Linear layers in MLPs) - Space: O(d²) for parameters - Not parallelizable (recurrent structure) - - Example: - >>> model = SSM(state_dim=128, input_dim=64, output_dim=32) - >>> h = model.init_hidden(batch_size=4) - >>> x = torch.randn(4, 64) - >>> y, next_h = model(x, h) + """Structured State Space Model (S4D) with diagonal dynamics. + + State equation (continuous-time): + h'(t) = A h(t) + B x(t) + Output equation: + y(t) = C h(t) + D x(t) + + Discretization (bilinear/Tustin) with learnable step size Δ: + Ā = (I - Δ/2 A)^{-1} (I + Δ/2 A) + B̄ = (I - Δ/2 A)^{-1} (Δ B) + + Forward uses explicit recurrence per timestep for RL compatibility. """ - def __init__(self, - state_dim: int, - input_dim: int, - output_dim: int, - hidden_dim: int = 128, - device: str = 'cpu'): - super(SSM, self).__init__() + def __init__( + self, + state_dim: int, + input_dim: int, + output_dim: int, + hidden_dim: int = 128, + device: str = "cpu", + ) -> None: + super().__init__() self.state_dim = state_dim self.input_dim = input_dim @@ -110,71 +43,92 @@ def __init__(self, self.output_dim = output_dim self.device = device - # State transition network (A matrix analog) - # Uses MLP instead of structured matrix - self.state_transition = nn.Sequential( - nn.Linear(state_dim, hidden_dim), - nn.ReLU(), - nn.Linear(hidden_dim, state_dim) - ) + # Diagonal continuous-time dynamics A = diag(a), complex-valued. + # Initialize with stable real parts (negative) and small imaginary parts. + a_real = -0.5 * torch.ones(state_dim) + a_imag = 0.1 * torch.randn(state_dim) + self.a_real = nn.Parameter(a_real) + self.a_imag = nn.Parameter(a_imag) - # Input projection network (B matrix analog) - self.input_projection = nn.Linear(input_dim, state_dim) + # Input and output projections in complex space. + self.B_real = nn.Parameter(0.1 * torch.randn(state_dim, input_dim)) + self.B_imag = nn.Parameter(0.1 * torch.randn(state_dim, input_dim)) + self.C_real = nn.Parameter(0.1 * torch.randn(output_dim, state_dim)) + self.C_imag = nn.Parameter(0.1 * torch.randn(output_dim, state_dim)) - # Output network (C matrix analog) - self.output_network = nn.Sequential( - nn.Linear(state_dim, hidden_dim), - nn.ReLU(), - nn.Linear(hidden_dim, output_dim) - ) + # Real feedthrough term D. + self.D = nn.Linear(input_dim, output_dim) - # Direct feedthrough (D matrix analog) - self.feedthrough = nn.Linear(input_dim, output_dim) + # Learnable step size Δ (positive via softplus). + self.log_dt = nn.Parameter(torch.zeros(state_dim)) - # Move model to device self.to(device) + def _complex_params(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + a = torch.complex(self.a_real, self.a_imag) + b = torch.complex(self.B_real, self.B_imag) + c = torch.complex(self.C_real, self.C_imag) + return a, b, c + + def _discretize(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute discrete-time (Ā, B̄) with bilinear/Tustin method. + + This keeps the computation in PyTorch for autograd compatibility, + ensuring higher-order gradients can flow through Δ and A. + """ + a, b, _ = self._complex_params() + dt = torch.nn.functional.softplus(self.log_dt) + + # Diagonal A -> elementwise discretization. + denom = 1.0 - 0.5 * dt * a + a_bar = (1.0 + 0.5 * dt * a) / denom + b_bar = (dt[:, None] * b) / denom[:, None] + return a_bar, b_bar + def init_hidden(self, batch_size: int = 1) -> torch.Tensor: - """Initialize the hidden state to zeros. - + """Initialize the hidden state to zeros (complex). + Args: batch_size: Number of sequences in batch - + Returns: - Zero tensor of shape (batch_size, state_dim) + Zero tensor of shape (batch_size, state_dim) with complex dtype. """ - return torch.zeros(batch_size, self.state_dim, device=self.device) + return torch.zeros( + batch_size, + self.state_dim, + device=self.device, + dtype=torch.complex64, + ) - def forward(self, x: torch.Tensor, hidden_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward pass: process one timestep with explicit state. - - Architecture: - h_t = MLP(h_{t-1}) + Linear(x_t) # State update with residual - y_t = MLP(h_t) + Linear(x_t) # Output with feedthrough + def forward( + self, x: torch.Tensor, hidden_state: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Process a single timestep. Args: x: Input tensor of shape (batch_size, input_dim) - hidden_state: Current hidden state (batch_size, state_dim) + hidden_state: Current hidden state (batch_size, state_dim), complex Returns: - Tuple of: - - output: Output tensor (batch_size, output_dim) - - next_hidden_state: Updated state (batch_size, state_dim) - - Complexity: - O(d²) due to Linear layers in MLPs, where d ~ hidden_dim + output: Real-valued output tensor (batch_size, output_dim) + next_hidden_state: Updated complex state (batch_size, state_dim) """ - # State transition: h_t = MLP(h_{t-1}) + Linear(x_t) - state_update = self.state_transition(hidden_state) - input_update = self.input_projection(x) - next_hidden_state = state_update + input_update + if not torch.is_complex(hidden_state): + hidden_state = torch.complex(hidden_state, torch.zeros_like(hidden_state)) + + a_bar, b_bar = self._discretize() + _, _, c = self._complex_params() - # Output: y_t = MLP(h_t) + Linear(x_t) - output = self.output_network(next_hidden_state) - feedthrough_output = self.feedthrough(x) - final_output = output + feedthrough_output + # Explicit recurrence for RL inference. + x_complex = torch.complex(x, torch.zeros_like(x)) + next_hidden_state = hidden_state * a_bar + x_complex @ b_bar.T - return final_output, next_hidden_state + # Output: take real projection to ensure real-valued outputs. + y_complex = next_hidden_state @ c.T + y_real = y_complex.real + self.D(x) + + return y_real, next_hidden_state def save(self, path: str) -> None: """Save model checkpoint. @@ -182,81 +136,55 @@ def save(self, path: str) -> None: Args: path: Path to save the checkpoint """ - # Create directory if it doesn't exist - os.makedirs(os.path.dirname(path) if os.path.dirname(path) else '.', exist_ok=True) - - # Save state dict and config - torch.save({ - 'state_dict': self.state_dict(), - 'config': { - 'state_dim': self.state_dim, - 'input_dim': self.input_dim, - 'hidden_dim': self.hidden_dim, - 'output_dim': self.output_dim, - 'device': self.device - } - }, path) + os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) + torch.save( + { + "state_dict": self.state_dict(), + "config": { + "state_dim": self.state_dim, + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "output_dim": self.output_dim, + "device": self.device, + }, + }, + path, + ) @staticmethod - def load(path: str, device: Optional[str] = None) -> 'SSM': + def load(path: str, device: Optional[str] = None) -> "SSM": """Load model checkpoint. - + Args: path: Path to checkpoint file device: Override device (default: use saved device) - + Returns: Loaded SSM model """ - checkpoint = torch.load(path, map_location='cpu') - config = checkpoint['config'] + checkpoint = torch.load(path, map_location="cpu") + config = checkpoint["config"] - # Override device if specified if device is not None: - config['device'] = device + config["device"] = device - # Create and load model model = SSM(**config) - model.load_state_dict(checkpoint['state_dict']) - model.to(config['device']) - + model.load_state_dict(checkpoint["state_dict"]) + model.to(config["device"]) return model -# Alias for backward compatibility -# NOTE: This is NOT a true "State Space Model", but a neural network -# with explicit state tracking. The name is kept for compatibility. + StateSpaceModel = SSM if __name__ == "__main__": - # Quick test - print("Testing Neural State Model (SSM)...") - print("Note: This is NOT a structured SSM, but an MLP-based state model.\n") + print("Testing Structured State Space Model (S4D)...") ssm = SSM(state_dim=64, input_dim=32, output_dim=16, hidden_dim=128) - print(f"Created model: state_dim=64, input_dim=32, output_dim=16, hidden_dim=128") - - # Initialize hidden state batch_size = 4 hidden = ssm.init_hidden(batch_size) - print(f"Initial hidden state shape: {hidden.shape}") # Expected: [4, 64] - - # Forward pass - x = torch.randn(batch_size, 32) # input_dim = 32 + x = torch.randn(batch_size, 32) output, next_hidden = ssm(x, hidden) - print(f"Input shape: {x.shape}") # Expected: [4, 32] - print(f"Output shape: {output.shape}") # Expected: [4, 16] - print(f"Next hidden shape: {next_hidden.shape}") # Expected: [4, 64] - - # Save and load test - import tempfile - with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as f: - temp_path = f.name - - ssm.save(temp_path) - print(f"\nSaved model to {temp_path}") - - loaded_ssm = SSM.load(temp_path) - print(f"Loaded model successfully") - os.remove(temp_path) - print("\n✓ Neural State Model test completed successfully!") + print(f"Input shape: {x.shape}") + print(f"Output shape: {output.shape}") + print(f"Next hidden shape: {next_hidden.shape}") From eb454d5c4ac45752eccebcdba0b14cfe86832763 Mon Sep 17 00:00:00 2001 From: Sung hun kwag Date: Tue, 10 Feb 2026 01:19:18 +0900 Subject: [PATCH 2/2] Fix S4D to use real-valued state --- core/ssm.py | 53 +++++++++++++++++------------------------------------ 1 file changed, 17 insertions(+), 36 deletions(-) diff --git a/core/ssm.py b/core/ssm.py index 067773b..0648781 100644 --- a/core/ssm.py +++ b/core/ssm.py @@ -43,18 +43,13 @@ def __init__( self.output_dim = output_dim self.device = device - # Diagonal continuous-time dynamics A = diag(a), complex-valued. - # Initialize with stable real parts (negative) and small imaginary parts. - a_real = -0.5 * torch.ones(state_dim) - a_imag = 0.1 * torch.randn(state_dim) - self.a_real = nn.Parameter(a_real) - self.a_imag = nn.Parameter(a_imag) - - # Input and output projections in complex space. - self.B_real = nn.Parameter(0.1 * torch.randn(state_dim, input_dim)) - self.B_imag = nn.Parameter(0.1 * torch.randn(state_dim, input_dim)) - self.C_real = nn.Parameter(0.1 * torch.randn(output_dim, state_dim)) - self.C_imag = nn.Parameter(0.1 * torch.randn(output_dim, state_dim)) + # Diagonal continuous-time dynamics A = diag(a), real-valued and stable. + a = -0.5 * torch.ones(state_dim) + self.a = nn.Parameter(a) + + # Input and output projections in real space. + self.B = nn.Parameter(0.1 * torch.randn(state_dim, input_dim)) + self.C = nn.Parameter(0.1 * torch.randn(output_dim, state_dim)) # Real feedthrough term D. self.D = nn.Linear(input_dim, output_dim) @@ -64,41 +59,33 @@ def __init__( self.to(device) - def _complex_params(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - a = torch.complex(self.a_real, self.a_imag) - b = torch.complex(self.B_real, self.B_imag) - c = torch.complex(self.C_real, self.C_imag) - return a, b, c - def _discretize(self) -> Tuple[torch.Tensor, torch.Tensor]: """Compute discrete-time (Ā, B̄) with bilinear/Tustin method. This keeps the computation in PyTorch for autograd compatibility, ensuring higher-order gradients can flow through Δ and A. """ - a, b, _ = self._complex_params() dt = torch.nn.functional.softplus(self.log_dt) # Diagonal A -> elementwise discretization. - denom = 1.0 - 0.5 * dt * a - a_bar = (1.0 + 0.5 * dt * a) / denom - b_bar = (dt[:, None] * b) / denom[:, None] + denom = 1.0 - 0.5 * dt * self.a + a_bar = (1.0 + 0.5 * dt * self.a) / denom + b_bar = (dt[:, None] * self.B) / denom[:, None] return a_bar, b_bar def init_hidden(self, batch_size: int = 1) -> torch.Tensor: - """Initialize the hidden state to zeros (complex). + """Initialize the hidden state to zeros. Args: batch_size: Number of sequences in batch Returns: - Zero tensor of shape (batch_size, state_dim) with complex dtype. + Zero tensor of shape (batch_size, state_dim). """ return torch.zeros( batch_size, self.state_dim, device=self.device, - dtype=torch.complex64, ) def forward( @@ -108,25 +95,19 @@ def forward( Args: x: Input tensor of shape (batch_size, input_dim) - hidden_state: Current hidden state (batch_size, state_dim), complex + hidden_state: Current hidden state (batch_size, state_dim) Returns: output: Real-valued output tensor (batch_size, output_dim) - next_hidden_state: Updated complex state (batch_size, state_dim) + next_hidden_state: Updated state (batch_size, state_dim) """ - if not torch.is_complex(hidden_state): - hidden_state = torch.complex(hidden_state, torch.zeros_like(hidden_state)) - a_bar, b_bar = self._discretize() - _, _, c = self._complex_params() # Explicit recurrence for RL inference. - x_complex = torch.complex(x, torch.zeros_like(x)) - next_hidden_state = hidden_state * a_bar + x_complex @ b_bar.T + next_hidden_state = hidden_state * a_bar + x @ b_bar.T - # Output: take real projection to ensure real-valued outputs. - y_complex = next_hidden_state @ c.T - y_real = y_complex.real + self.D(x) + # Output projection stays real-valued for autograd compatibility. + y_real = next_hidden_state @ self.C.T + self.D(x) return y_real, next_hidden_state