Skip to content

Xiao elo#8

Closed
xiaol827 wants to merge 6 commits into
mainfrom
xiao_elo
Closed

Xiao elo#8
xiaol827 wants to merge 6 commits into
mainfrom
xiao_elo

Conversation

@xiaol827

Copy link
Copy Markdown
Collaborator

add LOs trained with ELO

xiaol827 and others added 5 commits June 9, 2026 17:23
Port the inference-time forward pass of the CELO2 / ELO-CELO2 learned
optimizers from the original JAX/optax implementation into pure PyTorch,
as drop-in pylo optimizers.

- pylo/models/CELO2_MLP.py: CELO2MLP, the split-input per-parameter MLP
  backbone (14 split first-layer weights + dense layers, HF Hub mixin).
- pylo/optim/CELO2_naive.py: CELO2_naive optimizer — momentum / RMS /
  factored Adafactor accumulators, CELO2 feature stack, Newton-Schulz
  orthogonalization for 2D+ params, AdamW for 1D params over the shared
  accumulators, and a warmup + cosine LR schedule. Loads its meta-model
  from HuggingFace (default DiamondXL/celo2); a local converted checkpoint
  or an explicit network take precedence.
- pylo/optim/ELO_CELO2_naive.py: ELO_CELO2_naive — at inference the ELO
  expert mechanism is disabled, so it reduces to the CELO2 forward with
  the ELO default hyper-parameters (weight_decay=0.1, clip_grad=True);
  default meta-model DiamondXL/elo-celo2.
- scripts/convert_celo2_checkpoint.py: convert a JAX/Haiku theta
  checkpoint into a CELO2MLP state_dict.
- tests/test_celo2.py: step/update, higher-rank param, state-dict resume,
  and a JAX numerical-alignment test (2D update matches the reference to
  ~3e-6; auto-skips when the JAX source is unavailable).
- Register the new classes in the pylo / pylo.optim / pylo.models inits.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Port the inference-time forward pass of the ELO learned optimizer from
the original JAX implementation (ELO_AdafacMLPLOpt) into pure PyTorch.
At inference the ELO expert mechanism is disabled, so the update reduces
to the Adafactor-MLP forward — identical features and meta-model
(MetaMLP, 39 inputs / 2 outputs) to AdafacLO. ELO differs only in using
raw accumulator decays, a warmup-then-constant (optionally cosine) LR
schedule, and the update rule p -= lr * (dir*exp(mag*exp_mult) + wd*p).

- pylo/optim/ELO_naive.py: ELO_naive optimizer (reuses MetaMLP and the
  AdafacLO feature helpers); default meta-model DiamondXL/elo.
- scripts/convert_elo_checkpoint.py: convert a JAX/Haiku ELO theta into a
  MetaMLP state_dict (transposes the dense weights).
- tests/test_elo.py: step/update, state-dict resume, and a JAX
  numerical-alignment test (matches the reference to ~1.5e-8;
  auto-skips when the JAX source is unavailable).
- Register ELO_naive / ELO in the pylo and pylo.optim inits.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ests

End-to-end comparison against the real JAX optimizers revealed that the
standalone Celo2LOpt drives its LR schedule through an optax chain whose
step count starts at 0 (so the first update uses schedule(0)), whereas
ELO_Celo2LOpt evaluates the schedule at iteration+1. CELO2_naive was
1-indexed (matching ELO-CELO2), leaving a ~1.8e-3 warmup-phase
discrepancy vs Celo2LOpt.

- Add a per-class LR-schedule offset: CELO2_naive is 0-indexed (matches
  Celo2LOpt), ELO_CELO2_naive overrides to 1-indexed (matches
  ELO_Celo2LOpt). AdamW bias correction stays 1-indexed in both.
- Add test_jax_end_to_end_alignment: drives the real Celo2LOpt /
  ELO_Celo2LOpt over a multi-step trajectory with a 2D weight + 1D bias,
  nonzero weight decay and enabled gradient clipping, exercising the
  full step() (1D AdamW, schedule, weight decay, global-norm clipping).
  Both match the reference to ~6e-8 (was only the 2D core verified before).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Keep the existing VeLO_CUDA Quick Start intact and append an ELO-CELO2
example plus an "ELO series" entry with the arXiv link. Pure additions
(no content removed from the xiao_elo README).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Merge the CUDA path from the ELO-torch line into xiao_elo (additive; the
existing naive optimizers and VeLO/AdafacLO kernels are untouched):

- pylo/csrc/celo2_kernel.cu: fused feature-construction + split-input MLP
  forward kernel for CELO2 / ELO-CELO2
- pylo/optim/CELO2_cuda.py, ELO_CELO2_cuda.py: Python wrappers
- tests/test_celo2_cuda.py: CUDA-vs-naive numerical alignment tests
- setup.py: register the celo2_cuda_kernel CUDAExtension
- pylo/optim/__init__.py: override CELO2 / ELO_CELO2 to the CUDA variants
  when the extension is available, falling back to naive otherwise
- .gitignore: ignore *.pickle LO checkpoints

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@xiaol827 xiaol827 requested a review from Pauljanson002 June 18, 2026 21:27
@xiaol827 xiaol827 closed this Jun 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant