Skip to content

kaelvalen/prism

Repository files navigation

PRISM — modality-portable hybrid linear-recurrent backbone

One hybrid SSD + Gated-Delta sequence backbone — no modality-specific architecture — applied with identical hyperparameters to 12-lead ECG (PTB-XL), audio, and sequential images. A from-scratch reference implementation with full numerical-equivalence tests against the torch.associative_scan and FLA Triton kernels.

Python 3.12+ PyTorch 2.5+ License: MIT


What's the claim

Hybrid linear-recurrent backbones (Mamba-2, Gated DeltaNet) win on language, but their design choices are language-specific. PRISM tests whether a single hybrid backbone — with no modality-specific architectural tweaks — matches strong CNN baselines on PTB-XL (the primary, clinically-defensible target), Speech Commands, and sequential CIFAR-10.

The backbone interleaves two complementary mixers (plus optional attention):

  • SSD blocks (Mamba-2 state-space duality) — scalar-per-head decay with per-channel state and input-dependent (selective) Δ/B/C. This is the primitive Mamba-3 builds on, and the default ssm_kind="ssd".
  • Gated Delta Rule blocks — matrix-valued associative memory with data-dependent forget/write gates; targeted recall and overwrite.
  • Sliding-window attention (optional, swa) — for H1-style hybrid ablations, since "some attention" is what carries retrieval in SSM hybrids.

The per-layer role tokens are defined in prism.layer_tokens as ("s4", "delta", "swa"). prism.modules.block.BLOCK_REGISTRY maps each token to a builder function; new mixers can be registered with @register_block("token") without changing the core config or model code.

Status: architecture implemented and tested; benchmark runs in progress. 111 tests pass (equivalence, gradcheck, streaming, property-based). The benchmark numbers require GPU + datasets — the locked experiment matrix, primary metric (macro AUROC), and baseline targets are specified in EXPERIMENTS.md before any runs complete. Paper skeleton at paper/PAPER_DRAFT.md.


Install

Option A: Nix + uv (recommended for GPU development)

If you have the Nix package manager with flakes enabled:

git clone https://github.com/kaelvalen/prism.git
cd prism
nix develop

This gives you a shell with uv, the CUDA toolkit, git, and just. Python dependencies are installed into a local .venv managed by uv (declared in pyproject.toml). The Nix shell does not ship a Nix-built PyTorch; instead it lets uv resolve PyTorch from PyPI against your system NVIDIA driver, which keeps fast-moving packages such as flash-linear-attention practical to track.

Option B: pip

git clone https://github.com/kaelvalen/prism.git
cd prism
pip install -e ".[train,test]"      # CPU-friendly: pure-PyTorch reference paths
pip install -e ".[gpu]"             # optional Triton kernels (FLA, mamba-ssm)

The pure-PyTorch reference paths run everywhere (no Triton/CUDA needed). The gpu extra adds the production kernels; everything falls back gracefully if they are absent or incompatible with the installed PyTorch/CUDA pair.

Quickstart

import torch
from prism import PRISMConfig, ModalityConfig, PRISMForClassification

cfg = PRISMConfig(
    hidden_dim=256, num_heads=8, num_layers=12,   # ~8M params, SSD+Delta 3:1
    ssm_kind="ssd",                               # "ssd" (default) | "s4d_legacy"
    modalities=[
        ModalityConfig(name="ecg",   input_dim=12, num_classes=5),
        ModalityConfig(name="image", input_dim=48, num_classes=10),
    ],
)
model = PRISMForClassification(cfg)

ecg = torch.randn(4, 1000, 12)                    # 12-lead ECG
out = model(ecg, modality="ecg", labels=torch.randint(0, 5, (4,)))
print(out["loss"].item())

The layer mix is fully configurable:

PRISMConfig(num_layers=12, block_pattern="s4,s4,s4,swa, s4,s4,s4,swa, s4,s4,s4,swa")  # H1-style
PRISMConfig(num_layers=4,  force_block_type="delta")                                  # all-delta ablation

Reproducing the paper

DATA_ROOT=./datasets SEEDS="0 1 2" EPOCHS=50 bash scripts/run_benchmarks.sh
python scripts/aggregate_results.py output/benchmarks      # mean ± std
python scripts/bench_throughput.py --device cuda --seq-len 4096

One command per table row; full matrix, datasets, metric (macro-AUROC) and compute budget are in EXPERIMENTS.md.

Inference

After training, use the per-task inference scripts. Both load the best validation checkpoint, run the model on the test fold, and can write a JSON report:

# PTB-XL super-diagnostic (multi-label) or single-label task
python scripts/infer_ecg.py \
  --checkpoint output/ptbxl_superdiag/best.pt \
  --task superdiag \
  --output output/ptbxl_superdiag/test_report.json

# Sequential CIFAR-10
python scripts/infer_image.py \
  --checkpoint output/cifar10/best.pt \
  --output output/cifar10/test_report.json

Reproducibility

The trainer supports the standard levers:

  • --seed N sets Python/NumPy/PyTorch RNGs and, with --deterministic, enables CUBLAS_WORKSPACE_CONFIG for deterministic CUDA ops where possible.
  • Checkpoints contain optimizer, scheduler, RNG state, and global step; resume with --resume path/to/last.pt.
  • last.pt is written every epoch; best.pt is selected by the validation metric (macro-AUROC for ECG tasks, accuracy for image/audio).

Benchmark numbers

Primary metric on PTB-XL is macro one-vs-rest AUROC (Strodthoff et al. 2020), not accuracy. Baseline to match: xresnet1d1010.928 macro AUC on the 5-class super-diagnostic task (within ±0.005 bootstrap CI).

Model params sCIFAR-10 acc PTB-XL super-diag AUC Speech Cmds acc
ResNet1D (xresnet1d101) ~8M 0.928 (lit.)
Small Transformer ~8M TODO TODO TODO
Mamba-2 only (SSD) ~8M TODO TODO TODO
Gated DeltaNet only ~8M TODO TODO TODO
PRISM (SSD + Delta hybrid) ~8M TODO TODO TODO
PRISM legacy (S4D + Delta) ~8M 0.884 acc (prior, single-seed) TODO TODO

The legacy 0.884 sCIFAR number is from the previous S4D backbone, kept only as a historical ablation row; see EXPERIMENTS.md. All new numbers must be mean ± std over ≥3 seeds.


Architecture

Input (any modality)  [B, T, input_dim]
        │  ModalityProjection  Linear(input_dim → hidden_dim)   ← per-modality
        ▼
PRISMBackbone   (block_pattern of s4 / delta / swa)
   s4    → SSDBlock     RMSNorm→Conv→SSD(selective scan)→res ; RMSNorm→SwiGLU→res
   delta → DeltaBlock   RMSNorm→Conv→GatedDeltaRule→res       ; RMSNorm→SwiGLU→res
   swa   → SWABlock     RMSNorm→SlidingWindowAttn(RoPE)→res   ; RMSNorm→SwiGLU→res
        │  mean / last pooling
        ▼  PerModalityHead  LayerNorm → Linear → logits

SSD mixer (prism/modules/ssd.py) — Mamba-2 state-space duality. A is a scalar per head (A=-exp(A_log)), decay a_t = exp(Δ_t·A), and crucially the input is kept per-channel (h_t = a_t h_{t-1} + (Δ_t x_t) ⊗ B_t, y_t = ⟨h_t, C_t⟩ + D x_t). No mean-over-Dₕ collapse — that was the central weakness of the original S4D block, and the most important ablation in the paper (ssm_kind="ssd" vs "s4d_legacy").

Parallel scan (prism/modules/scan.py) — the recurrence is solved with torch.associative_scan (fused HOP) or a vectorized Hillis-Steele fallback; neither uses strided indexed assignment. The original hand-derived Blelloch up/down-sweep is preserved in scan_reference.py for teaching and as an equivalence anchor.

Gated delta rule (prism/modules/delta.py) — backend="reference" is the from-scratch chunked solve (UT transform via solve_triangular); backend="fla" calls FLA's chunk_gated_delta_rule Triton kernel and falls back to the reference if FLA/CUDA are unavailable.

Backends

Component reference (default) production fallback
SSD / S4D scan Hillis-Steele (scan_backend="reference") torch.associative_scan ("auto"/"assoc") + torch.compile automatic to Hillis-Steele if associative_scan unavailable
Gated delta rule pure-PyTorch chunked solve FLA chunk_gated_delta_rule (delta_backend="fla") automatic to reference if FLA missing, not on CUDA, wrong dtype, or Triton fails

tests/test_delta_equivalence.py and tests/test_scan_equivalence.py assert the production backends are numerically equivalent to the references. The FLA case is skipped when the Triton kernel cannot execute on the current PyTorch/CUDA/driver combination.

Testing

# inside nix develop, or with the pip venv activated
pytest                       # 111 passed, 3 skipped (FLA probe skips when Triton unavailable)
ruff check prism tests scripts train.py
ruff format --check prism tests scripts train.py
  • Numerical equivalence — scan backends and delta backends vs sequential/reference ground truth.
  • Gradcheck (float64) — scan, SSD mixer, delta rule.
  • State-passing — one-shot == chunked-with-carried-state (streaming correctness).
  • Regression — seed-locked golden losses.
  • Property-based (hypothesis) — finite outputs/loss/grads across random shapes; CPU determinism.
  • FLA probetests/test_delta_equivalence.py runs the FLA Triton kernel on a tiny tensor before declaring the backend available; if it fails (missing Triton, driver mismatch, PyTorch/FLA ABI incompatibility), the test skips and GatedDeltaRule(backend="fla") falls back to the reference path.

Key config flags

Field / CLI flag Default Description
ssm_kind / --ssm-kind ssd ssd (Mamba-2 selective) or s4d_legacy
block_pattern / --layer-pattern None explicit per-layer tokens s4,delta,swa (overrides interleave)
delta_every 4 DeltaBlock every Nth layer (3:1) when no explicit pattern
delta_backend / --delta-backend reference reference or fla
scan_backend / --scan-backend auto auto / assoc / reference
s4d_init / --s4d-init lin S4D-Lin (A=-½+iπn) or legacy
swa_window / --swa-window 128 sliding-window attention span
compile / --compile off torch.compile the model in the trainer
multilabel / --ecg-multilabel off PTB-XL multi-label targets + BCE + macro-AUROC selection

Repository layout

prism/
├── config.py                 # PRISMConfig (ssm_kind, block_pattern, backends, …)
├── layer_tokens.py           # dependency-free {s4, delta, swa} role tokens
├── model.py                  # projection → backbone → per-modality head
├── inference.py              # shared checkpoint loader for inference scripts
├── modules/
│   ├── ssd.py                # SSDMixer / SSDBlock  (Mamba-2 SSD, per-channel)
│   ├── s4.py                 # legacy S4D-Complex (ablation), parallel_scan wrapper
│   ├── delta.py              # GatedDeltaRule (reference + FLA backends)
│   ├── attention.py          # SlidingWindowAttention / SWABlock (RoPE)
│   ├── scan.py               # associative_scan + Hillis-Steele scan backends
│   ├── scan_reference.py     # preserved hand-derived Blelloch (teaching/equivalence)
│   └── block.py              # BLOCK_REGISTRY + PRISMBlock protocol + build_block dispatch
├── training/                 # Trainer, CLI (train.py), metrics (macro-AUROC), loops
├── baselines/                # ResNet1D, small Transformer
└── data/                     # ecg / image / audio loaders
EXPERIMENTS.md                # locked benchmark matrix + honest gaps
paper/PAPER_DRAFT.md          # 4-page workshop manuscript skeleton
scripts/                      # run_benchmarks.sh, bench_throughput.py, aggregate_results.py,
                              # infer_ecg.py, infer_image.py
tests/                        # pytest suite

Honest scope

This is a modality-portable architecture (same arch + hyperparameters, one training run per modality), not yet a single-set-of-weights joint model — that true "modality-agnostic" result is the follow-up. The architecture is grounded in the 2024–2026 frontier (Mamba-2/3, Gated DeltaNet, FLA); the from-scratch reference implementations and their equivalence tests are the contribution alongside the cross-modal portability study. See EXPERIMENTS.md for what still needs doing before submission.

Citation

@misc{prism2026,
  title  = {PRISM: a modality-portable hybrid linear-recurrent backbone},
  author = {Hakbilen, Mehmet Arda},
  year   = {2026},
  note   = {https://github.com/kaelvalen/prism}
}

Acknowledgments

Builds on ideas and (optionally) kernels from Mamba-2 / Mamba-3 (Dao, Gu et al.), Gated DeltaNet (Yang, Kautz, Hatamizadeh et al.), and flash-linear-attention.

About

Modality-agnostic sequence model: S4D-Complex + Gated Delta Rule backbone for ECG, images, and continuous signals

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors