Skip to content

Implement S4D structured state space model#3

Closed
sunghunkwag wants to merge 2 commits intomainfrom
codex/refactor-core/ssm.py-for-structured-ssm-implementation
Closed

Implement S4D structured state space model#3
sunghunkwag wants to merge 2 commits intomainfrom
codex/refactor-core/ssm.py-for-structured-ssm-implementation

Conversation

@sunghunkwag
Copy link
Copy Markdown
Owner

Motivation

  • Replace the previous MLP-based placeholder with a mathematically grounded diagonal SSM (S4D) implementation to provide structured, stable recurrent dynamics.
  • Ensure the implementation is pure PyTorch and compatible with higher-order gradients required by torch.autograd.grad(..., create_graph=True) for MAML outer-loop differentiation.
  • Keep the original SSM class name and the exact API (__init__(self, state_dim, input_dim, output_dim, hidden_dim, device), forward(self, x, hidden_state) -> (output, next_hidden_state), init_hidden(self, batch_size)) and support explicit recurrence for RL timesteps.

Description

  • Replaced the MLP-based state model with a diagonal complex-valued S4D parameterization exposing a_real, a_imag, B_real, B_imag, C_real, C_imag, D, and a learnable log_dt for the step size.
  • Implemented bilinear (Tustin) continuous-to-discrete conversion in _discretize with elementwise operations so and are computed in PyTorch and remain differentiable through Δ and A.
  • Implemented explicit recurrence in forward for a single timestep as next_hidden = hidden * a_bar + x_complex @ b_bar.T, and projected complex outputs to real with y_real = (next_hidden @ C.T).real + D(x) so the public output is real-valued.
  • Preserved checkpointing via save/load, kept StateSpaceModel = SSM alias, and added inline comments indicating where discretization happens and that gradients flow through these computations.

Testing

  • No automated tests were executed as part of this change.

Codex Task

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f99d79bfe0

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread core/ssm.py Outdated
Comment on lines +117 to +125
if not torch.is_complex(hidden_state):
hidden_state = torch.complex(hidden_state, torch.zeros_like(hidden_state))

a_bar, b_bar = self._discretize()
_, _, c = self._complex_params()

# Output: y_t = MLP(h_t) + Linear(x_t)
output = self.output_network(next_hidden_state)
feedthrough_output = self.feedthrough(x)
final_output = output + feedthrough_output
# Explicit recurrence for RL inference.
x_complex = torch.complex(x, torch.zeros_like(x))
next_hidden_state = hidden_state * a_bar + x_complex @ b_bar.T
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid bf16/AMP crash when creating complex inputs

This forward pass constructs complex tensors via torch.complex(...) on the raw x and hidden_state. In mixed precision (bf16 autocast is common on Ampere/TPU), torch.complex does not accept bf16 inputs and will raise at runtime, so the model will crash as soon as it sees bf16 activations. This is a regression from the prior real‑valued MLP path which worked under bf16. Consider upcasting to float32 before torch.complex (e.g., x.float()) or disabling autocast for the complex math to keep AMP training functional.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant