Skip to content

feat: Basic MOLT Support#22

Open
Ky-Ng wants to merge 4 commits into
masterfrom
feat/MOLT-basic
Open

feat: Basic MOLT Support#22
Ky-Ng wants to merge 4 commits into
masterfrom
feat/MOLT-basic

Conversation

@Ky-Ng
Copy link
Copy Markdown
Collaborator

@Ky-Ng Ky-Ng commented Apr 29, 2026

Summary

This PR ports the MOLT branch onto the Master Branch.

See wandb training runs comparison between MOLT branch and PR feat/MOLT-basic branch.

Visual Verification

image

What lands:

  • crosslayer_transcoder/model/molt.py — new Molt(nn.Module). Stack of low-rank (U_i V_i) transforms gated by a learned linear + non-linearity.
  • crosslayer_transcoder/model/clt_lightning.py — new MoltModule(CrossLayerTranscoderModule) Lightning wrapper with its own training_step (MSE + tanh-weighted L1 sparsity). The parent __init__ was widened to Union[CrossLayerTranscoder, Molt] and the last_active buffer / encoder.n_layers == decoder.n_layers assertion are now inside an isinstance(self.model, Molt) branch.
  • crosslayer_transcoder/model/jumprelu.pyJumpReLU produces a 2-D theta (1, d_features) when n_layers == 1 so MoLT can run single-layer.
  • crosslayer_transcoder/model/__init__.py — re-export Molt.
  • crosslayer_transcoder/data/datamodule.py — bonus fix: guard teardown with self.data_loader is not None (commit 39de8aa on the branch).
  • config/molt*.yaml — eight Lightning CLI configs (molt, molt-long, molt-5090, four molt-5090_20M_tokens_* sparsity-sweep configs, molt-5090_50M_tokens_0_00015). All class_path entries point at crosslayer_transcoder.*.

Out of scope (follow-up PRs)

  • MoltSerializableModule so save_pretrained works. Right now MoLT runs are checkpointed via EndOfTrainingCheckpointCallback; there is no folded-weights export.
  • Parameterize the layer index in MoltModule.training_step — currently hardcoded to layer = 8.
  • Wire the compute_dead_features config flags through MoltModule — flags are inert for MoLT today.
  • Cleanup pass on clt_lightning.py — branch deleted ~180 lines of memory-debug code; defensible but separate.
  • Branch-only artifacts (stale config/21k_sweep/, config/batch_topk*.yaml, config/debug-*, config/jumprelu/, config/topk/, config/staging*.yaml, config/new_architecture_test/, shell scripts, lifetime_active.pt, snapshot.pickle, poetry.lock, setuptools build revert) are intentionally excluded.

Tests run

1. Import check

python -c "from crosslayer_transcoder.model import Molt; print(Molt)"
python -c "from crosslayer_transcoder.model.clt_lightning import MoltModule; print(MoltModule)"

What: confirms the new symbol is exported and MoltModule resolves through the rewritten clt_lightning.py.
Success criteria: both commands print the class object and exit 0.
Result: ✅ both passed.

2. CPU forward smoke test (scratch_molt_smoke.py from PR-Instructions.md Step 9)

How to run:

uv venv --python 3.12 .venv
source .venv/bin/activate
uv pip install -e .
python scratch_molt_smoke.py

What: builds Molt(d_acts=64, N=4, ranks=[8, 4]) with JumpReLU(n_layers=1, d_features=12), initializes standardizers from a fake (B=4, 2, 12, 64) batch, runs one forward pass, asserts shapes — exercises the new n_layers=1 branch in JumpReLU, the low-rank transforms, and standardizer wiring.
Success criteria: prints OK; gate.shape == (B, n_features), recons.shape == (B, d_acts); no NaN/inf.
Result: ✅ printed gate torch.Size([4, 12]) recons_norm torch.Size([4, 64]) recons torch.Size([4, 64]) then OK.

3. Existing test suite

source .venv/bin/activate
python -m pytest tests/

What: runs tests/test_dead_features.py, tests/test_deployment_policy.py, tests/test_load_from_pretrained.py, tests/test_process_monitor.py, tests/test_save_to_pretrained.py, tests/test_standardization_folding.py, tests/test_text_dataset.py, tests/test_topk.py — covers standardizer fold/save round-trip, dead-feature accounting, top-k variants, dataloader policy.
Success criteria: 100% of previously-passing tests still pass; no regressions from the JumpReLU shape change or the wider Lightning model type.
Result: ✅ 207 passed, 10 warnings in 136.13s.

4. GPU smoke (NOT run in this PR)

python -m crosslayer_transcoder.main fit \
  --config config/molt-5090_50M_tokens_0_00015.yaml \
  --trainer.max_steps=10
# repeat with --trainer.precision=16-mixed

What: 10 end-to-end training steps (data gen → standardizer init → MoltModule forward / sparsity loss / backprop), once in fp32 and once in 16-mixed.
Success criteria: both runs complete 10 steps with finite loss; model/d_latents, model/n_features, metrics/dead_features show up in wandb.

Test plan

  • Imports resolve (Molt, MoltModule)
  • CPU forward smoke prints OK
  • pytest tests/ — 207 passed
  • 10-step GPU fit on config/molt-5090_50M_tokens_0_00015.yaml (fp32)
  • 10-step GPU fit with --trainer.precision=16-mixed

KyleNg2868 and others added 4 commits April 29, 2026 07:05
Adds the Molt model class, MoltModule Lightning wrapper, and the
necessary plumbing to support a single-layer MoLT alongside the existing
CrossLayerTranscoder. Includes:

- crosslayer_transcoder/model/molt.py: new Molt nn.Module
- model/__init__.py: export Molt
- model/jumprelu.py: allow n_layers=1 to produce a 2-D theta parameter
- model/clt_lightning.py: import Molt, widen model type to
  Union[CrossLayerTranscoder, Molt], wrap the encoder/decoder
  assertions and last_active buffer in an isinstance check, and append
  the MoltModule subclass with its own training_step
- data/datamodule.py: guard self.data_loader teardown with is not None

Known limitations (follow-ups):
- MoltModule.training_step is hardcoded to layer 8
- compute_dead_features config flags are inert for MoLT
- Molt does not yet inherit from SerializableModule / save_pretrained

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Eight YAML configs for the Lightning CLI:
- config/molt.yaml, config/molt-long.yaml: baseline
- config/molt-5090.yaml: tuned for a 5090 with 31 GB /dev/shm
- config/molt-5090_20M_tokens_*.yaml: sparsity sweep at 20M tokens
- config/molt-5090_50M_tokens_0_00015.yaml: 50M-token run

class_path entries point at the master package (crosslayer_transcoder.*).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
tests/test_molt_smoke.py covers three cases:
- test_molt_cpu_forward: builds a tiny Molt, runs one forward pass,
  checks shapes and finiteness — no GPU, no Lightning, no dataset
- test_molt_gpu_fp32_train_step: forward + backward + Adam step on
  synthetic activations on cuda; asserts loss and params remain finite
- test_molt_gpu_amp_train_step: same, inside torch.amp.autocast(float16)
  with a GradScaler — mirrors Lightning's precision="16-mixed"

Both GPU tests are guarded with skipif(not cuda.is_available()), so they
silently skip on CPU CI runners while still exercising mixed-precision
locally.

Verified locally on RTX 5090: 3/3 tests pass; full suite is 210 passed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants