Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f99d79bfe0
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if not torch.is_complex(hidden_state): | ||
| hidden_state = torch.complex(hidden_state, torch.zeros_like(hidden_state)) | ||
|
|
||
| a_bar, b_bar = self._discretize() | ||
| _, _, c = self._complex_params() | ||
|
|
||
| # Output: y_t = MLP(h_t) + Linear(x_t) | ||
| output = self.output_network(next_hidden_state) | ||
| feedthrough_output = self.feedthrough(x) | ||
| final_output = output + feedthrough_output | ||
| # Explicit recurrence for RL inference. | ||
| x_complex = torch.complex(x, torch.zeros_like(x)) | ||
| next_hidden_state = hidden_state * a_bar + x_complex @ b_bar.T |
There was a problem hiding this comment.
Avoid bf16/AMP crash when creating complex inputs
This forward pass constructs complex tensors via torch.complex(...) on the raw x and hidden_state. In mixed precision (bf16 autocast is common on Ampere/TPU), torch.complex does not accept bf16 inputs and will raise at runtime, so the model will crash as soon as it sees bf16 activations. This is a regression from the prior real‑valued MLP path which worked under bf16. Consider upcasting to float32 before torch.complex (e.g., x.float()) or disabling autocast for the complex math to keep AMP training functional.
Useful? React with 👍 / 👎.
Motivation
torch.autograd.grad(..., create_graph=True)for MAML outer-loop differentiation.SSMclass name and the exact API (__init__(self, state_dim, input_dim, output_dim, hidden_dim, device),forward(self, x, hidden_state) -> (output, next_hidden_state),init_hidden(self, batch_size)) and support explicit recurrence for RL timesteps.Description
a_real,a_imag,B_real,B_imag,C_real,C_imag,D, and a learnablelog_dtfor the step size._discretizewith elementwise operations soĀandB̄are computed in PyTorch and remain differentiable throughΔandA.forwardfor a single timestep asnext_hidden = hidden * a_bar + x_complex @ b_bar.T, and projected complex outputs to real withy_real = (next_hidden @ C.T).real + D(x)so the public output is real-valued.save/load, keptStateSpaceModel = SSMalias, and added inline comments indicating where discretization happens and that gradients flow through these computations.Testing
Codex Task