Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions emerging_optimizers/soap/soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

__all__ = [
"SOAP",
"StackedSoap",
"precondition",
"init_kronecker_factors",
"update_kronecker_factors",
Expand Down Expand Up @@ -584,3 +585,133 @@ def _clip_update_rms_in_place(u: torch.Tensor, max_rms: float, eps: float = 1e-7
scale = (max_rms / (rms + eps)).clamp(max=1.0)
# in‐place scale
u.mul_(scale)


def _stack_2d(x: torch.Tensor) -> torch.Tensor:
"""Flattens a 2D or 3D tensor to 2D, merging the batch dim into the smaller matrix edge.

A 2D tensor is returned unchanged. A 3D tensor ``(b, m, n)`` is merged into the smaller of its two
matrix edges: ``(m, b * n)`` when ``n <= m``, otherwise ``(b * m, n)``.

Args:
x: A 2D matrix ``(m, n)`` or a 3D batched matrix ``(b, m, n)``.

Returns:
The 2D stacking of ``x``.
"""
if x.ndim == 2:
return x
b, m, n = x.shape
if n <= m:
# -> (m, b*n): move the batch next to the smaller edge, then merge.
out = x.permute(1, 0, 2).reshape(m, b * n)
else:
# -> (b*m, n): contiguous merge into rows.
out = x.reshape(b * m, n)
return out.contiguous()


def _unstack(u: torch.Tensor, shape: torch.Size) -> torch.Tensor:
"""Inverse of :func:`_stack_2d`, restoring the original ``shape``."""
if len(shape) == 2:
return u
b, m, n = shape
if n <= m:
return u.reshape(m, b, n).permute(1, 0, 2).reshape(shape)
return u.reshape(shape)


@registry.register_optimizer("stacked_soap")
class StackedSoap(SOAP):
"""Limited-memory SOAP for batched / 3D parameters via transient 2D stacking.

Optimizes the real parameters directly: ``self.param_groups``, ``self.state``, and gradients are all
keyed by the user's parameters, so learning-rate schedulers, gradient clipping, and ``state_dict``
behave exactly as for plain :class:`SOAP`. Each 3D parameter is flattened to 2D by merging its batch
dim into the smaller matrix edge (see :func:`_stack_2d`) only for the duration of :meth:`step`: the
parameter's ``data`` and ``grad`` are swapped to their 2D views, the inherited SOAP step runs, and the
2D update is unstacked back into the original storage. Because the swap happens before the inherited
step, its lazy state initialization sizes the optimizer state to the stacked 2D shape automatically.

Stacking on the smaller edge keeps both Kronecker factors small (the larger edge becomes a single
shared factor) while reusing the full, unmodified SOAP machinery (KL-Shampoo + QR eigenbasis). The
stacking is a storage-sharing view except for the permute branch (``q <= p``), which allocates one
transient 2D buffer per step. A plain 2D parameter is stacked as itself, so this is exactly stock SOAP.

SOAP is configured with the fixed settings appropriate for this use: decoupled weight decay, no
Nesterov, bias correction on, the QR eigenbasis path with 1 power-iteration step, KL-Shampoo on, and
the default matmul precision.

Args:
params: Iterable of 2D or 3D parameters to optimize or dicts defining parameter groups.
lr: The learning rate.
betas: Inner Adam betas ``(b1, b2)``.
shampoo_beta: Beta for the kronecker factor moving average.
eps: Inner Adam epsilon.
weight_decay: Decoupled weight decay coefficient.
"""

def __init__(
self,
params: ParamsT,
lr: float,
betas: tuple[float, float] = (0.9, 0.95),
shampoo_beta: float = 0.95,
eps: float = 1e-8,
weight_decay: float = 0.01,
) -> None:
super().__init__(
params,
lr,
betas=betas,
shampoo_beta=shampoo_beta,
eps=eps,
weight_decay=weight_decay,
weight_decay_method="decoupled",
nesterov=False,
correct_bias=True,
use_eigh=False,
power_iter_steps=1,
use_kl_shampoo=True,
)

if TYPE_CHECKING:

@overload
def step(self, closure: None = ...) -> None: ...

@overload
def step(self, closure: Callable[[], float]) -> float: ...

@torch.no_grad() # type: ignore[misc]
@override
def step(self, closure: Callable[[], float] | None = None) -> float | None:
if closure is not None:
raise ValueError("closure is not supported")

# Swap each parameter's data/grad to their 2D stacking, run the inherited SOAP step on the 2D
# views (state is keyed by the real parameter and sized for the stacked shape), then unstack the
# update back into the original storage. The restore runs in a finally so that an exception inside
# super().step() (e.g. OOM, a NaN check) cannot leave parameters stuck in their 2D stacked shape.
saved: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = []
try:
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue # pragma: no cover
data, grad = p.data, p.grad
saved.append((p, data, grad))
p.data = _stack_2d(data)
p.grad = _stack_2d(grad)

super().step()
finally:
for p, data, grad in saved:
stacked = p.data
p.data = data
p.grad = grad
# Copy back only when stacking allocated an independent buffer (permute branch); the view
# branches already wrote the update through to the original storage.
if stacked.data_ptr() != data.data_ptr():
data.copy_(_unstack(stacked, data.shape))
return None
79 changes: 79 additions & 0 deletions examples/stacked_soap_grouped_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os


os.environ["NVTE_GROUPED_LINEAR_SINGLE_PARAM"] = "1"

import torch
import transformer_engine.pytorch as te
from absl import app, flags

from emerging_optimizers.soap.soap import StackedSoap


FLAGS = flags.FLAGS

flags.DEFINE_integer("num_experts", 8, "Number of experts (grouped GEMMs).")
flags.DEFINE_integer("in_features", 512, "Input feature dimension per expert.")
flags.DEFINE_integer("out_features", 1024, "Output feature dimension per expert.")
flags.DEFINE_integer("tokens", 256, "Total number of tokens routed across experts.")
flags.DEFINE_integer("steps", 5, "Number of optimization steps.")
flags.DEFINE_float("lr", 1e-3, "Learning rate.")


def main(argv: list[str]) -> None:
"""Build a Transformer Engine GroupedLinear with a single 3D weight and train it with StackedSoap."""
del argv
if not torch.cuda.is_available():
raise RuntimeError("This example requires a CUDA device (Transformer Engine is GPU-only).")

device = torch.device("cuda")
dtype = torch.bfloat16

grouped_linear = te.GroupedLinear(
FLAGS.num_experts,
FLAGS.in_features,
FLAGS.out_features,
bias=False,
single_grouped_weight=True,
params_dtype=dtype,
device=device,
)

weight = grouped_linear.weight
print(f"Single expert weight tensor: shape={tuple(weight.shape)}, dtype={weight.dtype}")

optimizer = StackedSoap(grouped_linear.parameters(), lr=FLAGS.lr, weight_decay=0.0)

# MoE routing: split `tokens` rows across experts; m_splits must sum to the token count.
base = FLAGS.tokens // FLAGS.num_experts
m_splits = [base] * FLAGS.num_experts
m_splits[-1] += FLAGS.tokens - sum(m_splits)

for step in range(FLAGS.steps):
x = torch.randn(FLAGS.tokens, FLAGS.in_features, device=device, dtype=dtype, requires_grad=True)
out = grouped_linear(x, m_splits)
loss = out.float().square().mean()

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"step {step}: loss={loss.item():.6f}")


if __name__ == "__main__":
app.run(main)
97 changes: 96 additions & 1 deletion tests/test_soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from absl.testing import absltest, parameterized

from emerging_optimizers.soap import REKLS, SOAP, soap
from emerging_optimizers.soap.soap import _clip_update_rms_in_place
from emerging_optimizers.soap.soap import StackedSoap, _clip_update_rms_in_place, _stack_2d, _unstack


flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on")
Expand Down Expand Up @@ -578,5 +578,100 @@ def test_eigenbasis_matches_reference(self, shape: tuple, num_steps: int):
self.assertEqual(test_state["step"], ref_state["step"])


class StackedSoapTest(parameterized.TestCase):
def setUp(self):
self.device = FLAGS.device

@parameterized.product(shape=[(8, 5), (4, 6, 3), (4, 3, 6)])
def test_smoke(self, shape) -> None:
p = torch.nn.Parameter(torch.randn(shape, device=self.device))
opt = StackedSoap([p], lr=1e-2, weight_decay=0.01)
for _ in range(3):
p.grad = torch.randn_like(p)
opt.step()
self.assertTrue(torch.isfinite(p).all())

@parameterized.product(shape=[(8, 5), (4, 6, 3), (4, 3, 6), (4, 5, 5)])
def test_stack_unstack_shapes_and_roundtrip(self, shape) -> None:
x = torch.randn(shape, device=self.device)

if x.ndim == 2:
expected_2d = shape
else:
b, m, n = shape
expected_2d = (m, b * n) if n <= m else (b * m, n)

stacked = _stack_2d(x)
self.assertEqual(stacked.shape, torch.Size(expected_2d))

restored = _unstack(stacked, x.shape)
self.assertEqual(restored.shape, x.shape)
assert_equal(restored, x)

@parameterized.product(shape=[(8, 5), (16, 16), (5, 7)])
def test_2d_input_7steps_matches_vanilla_soap(self, shape) -> None:
x = torch.randn(shape, device=self.device)
p_stacked = torch.nn.Parameter(x.clone())
p_ref = torch.nn.Parameter(x.clone())

opt_stacked = StackedSoap([p_stacked], lr=1e-2, weight_decay=0.01)
opt_ref = SOAP(
[p_ref],
1e-2,
weight_decay=0.01,
weight_decay_method="decoupled",
nesterov=False,
correct_bias=True,
use_eigh=False,
power_iter_steps=1,
use_kl_shampoo=True,
)

for _ in range(7):
grad = torch.randn(shape, device=self.device)
p_stacked.grad = grad.clone()
p_ref.grad = grad.clone()
opt_stacked.step()
opt_ref.step()
assert_equal(
p_stacked.detach(),
p_ref.detach(),
msg=lambda m: f"StackedSoap must match stock SOAP exactly on 2D params.\n\n{m}",
)

@parameterized.product(shape=[(4, 6, 3), (4, 3, 6)])
def test_3d_input_5steps_matches_vanilla_soap(self, shape) -> None:
"""StackedSoap on a 3D param must match vanilla SOAP run on the manually stacked 2D param."""
x = torch.randn(shape, device=self.device)
p_stacked = torch.nn.Parameter(x.clone())
# Reference is vanilla SOAP on the 2D stacking of the same parameter.
p_ref = torch.nn.Parameter(_stack_2d(x).clone())

opt_stacked = StackedSoap([p_stacked], lr=1e-2, weight_decay=0.01)
opt_ref = SOAP(
[p_ref],
1e-2,
weight_decay=0.01,
weight_decay_method="decoupled",
nesterov=False,
correct_bias=True,
use_eigh=False,
power_iter_steps=1,
use_kl_shampoo=True,
)

for _ in range(5):
grad = torch.randn(shape, device=self.device)
p_stacked.grad = grad.clone()
p_ref.grad = _stack_2d(grad)
opt_stacked.step()
opt_ref.step()
assert_equal(
_stack_2d(p_stacked.detach()),
p_ref.detach(),
msg=lambda m: f"StackedSoap on a 3D param must match vanilla SOAP on its 2D stacking.\n\n{m}",
)


if __name__ == "__main__":
absltest.main()
Loading