Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion dion/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Callable, Generator, List, Optional, Tuple, Union

from .newton_schulz_triton import newton_schulz_triton
from .polar_express import polar_express, polar_express_triton
from .opt_utils import (
AsyncRuntime,
AsyncTask,
Expand Down Expand Up @@ -43,6 +44,7 @@ class Muon(Optimizer):
True: Tensors with 3+ dimensions are flattened to 2D. Use this for convolutional layers.
False: Tensors are not flattened. 3D+ tensors are treated as batches of 2D matrices.
use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided.
use_polar_express: Whether to use Polar Express instead of Newton-Schulz. Ignored if custom function is provided.
newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization.
Signature is `func(input: Tensor, epsilon: float) -> Tensor`.

Expand All @@ -64,6 +66,7 @@ def __init__(
adjust_lr: Optional[str] = "spectral_norm",
flatten: bool = False,
use_triton: bool = False,
use_polar_express: bool = False,
newton_schulz_func: Optional[Callable] = None,
):
# Check hyperparameters
Expand Down Expand Up @@ -118,13 +121,17 @@ def __init__(
)
self._distributed_mesh = distributed_mesh

# Newton-Schulz configuration
# Orthogonalization function configuration
if newton_schulz_func is not None:
if not callable(newton_schulz_func):
raise TypeError(
f"newton_schulz_func must be a callable function, got {type(newton_schulz_func)}"
)
self._newton_schulz_func = newton_schulz_func
elif use_polar_express and use_triton:
self._newton_schulz_func = polar_express_triton
elif use_polar_express:
self._newton_schulz_func = polar_express
elif use_triton:
self._newton_schulz_func = newton_schulz_triton
else:
Expand Down
9 changes: 8 additions & 1 deletion dion/normuon.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Callable, Generator, List, Optional, Tuple, Union

from .newton_schulz_triton import newton_schulz_triton
from .polar_express import polar_express, polar_express_triton
from .opt_utils import (
AsyncRuntime,
AsyncTask,
Expand Down Expand Up @@ -54,6 +55,7 @@ class NorMuon(Optimizer):
True: Tensors with 3+ dimensions are flattened to 2D. Use this for convolutional layers.
False: Tensors are not flattened. 3D+ tensors are treated as batches of 2D matrices.
use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided.
use_polar_express: Whether to use Polar Express instead of Newton-Schulz. Ignored if custom function is provided.
newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization.
Signature is `func(input: Tensor, epsilon: float) -> Tensor`.

Expand All @@ -77,6 +79,7 @@ def __init__(
adjust_lr: Optional[str] = "rms_norm",
flatten: bool = False,
use_triton: bool = False,
use_polar_express: bool = False,
newton_schulz_func: Optional[Callable] = None,
):
# Check hyperparameters
Expand Down Expand Up @@ -134,13 +137,17 @@ def __init__(
)
self._distributed_mesh = distributed_mesh

# Newton-Schulz configuration
# Orthogonalization function configuration
if newton_schulz_func is not None:
if not callable(newton_schulz_func):
raise TypeError(
f"newton_schulz_func must be a callable function, got {type(newton_schulz_func)}"
)
self._newton_schulz_func = newton_schulz_func
elif use_polar_express and use_triton:
self._newton_schulz_func = polar_express_triton
elif use_polar_express:
self._newton_schulz_func = polar_express
elif use_triton:
self._newton_schulz_func = newton_schulz_triton
else:
Expand Down
85 changes: 85 additions & 0 deletions dion/polar_express.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import torch
from torch import Tensor

from .newton_schulz_triton import ns_line_1, ns_line_2

# Polar Express coefficients (computed for num_iters=5, safety_factor=2e-2, cushion=2)
# From https://arxiv.org/pdf/2505.16932
# Matches the battle-tested KellerJordan/karpathy implementations.
# Safety factor of 1.02 is baked into all but the last polynomial.
_POLAR_EXPRESS_COEFFS = [
(8.156554524902461, -22.48329292557795, 15.878769915207462),
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]


@torch.compile(dynamic=False, fullgraph=True)
def polar_express(G: Tensor, epsilon: float = 1e-6) -> Tensor:
"""
Polar Express orthogonalization (pure PyTorch, torch.compile'd).

Polar Express: https://arxiv.org/pdf/2505.16932
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.

Signature matches zeropower_via_newtonschulz5: func(input, epsilon) -> Tensor.
"""
assert G.ndim >= 2
X = G.bfloat16()

is_tall = G.size(-2) > G.size(-1)

# Ensure spectral norm is at most 1, with safety factor matching coefficients
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + epsilon)

if is_tall:
# Tall: use X.mT @ X (small cols x cols) + right-multiply
for a, b, c in _POLAR_EXPRESS_COEFFS:
A = X.mT @ X
B = b * A + c * (A @ A)
X = a * X + X @ B
else:
# Wide: use X @ X.mT (small rows x rows) + left-multiply
for a, b, c in _POLAR_EXPRESS_COEFFS:
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X

return X


@torch.compile(dynamic=False, fullgraph=True)
def polar_express_triton(G: Tensor, epsilon: float = 1e-6) -> Tensor:
"""
Polar Express orthogonalization using Triton kernels that exploit
the symmetry of A = X @ X.mT and B = c*(A@A) + b*A, computing only
half the blocks and mirroring across the diagonal.

Signature matches zeropower_via_newtonschulz5: func(input, epsilon) -> Tensor.
"""
X = G.to(dtype=torch.bfloat16)
if G.size(-2) > G.size(-1):
X = X.mT

# Ensure spectral norm is at most 1, with safety factor matching coefficients
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + epsilon)

# Pre-allocate buffers for the symmetric intermediates and output
X = X.contiguous()
A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype)
B = torch.empty_like(A)
C = torch.empty_like(X)

line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm

for a, b, c in _POLAR_EXPRESS_COEFFS:
ns_line_1(X, out=A) # A = X @ X.mT (symmetric, half-compute)
ns_line_2(A, alpha=c, beta=b, out=B) # B = c*(A@A) + b*A (symmetric, half-compute)
line_3(X, B, X, beta=a, out=C) # C = a*X + B@X
X, C = C, X

if G.size(-2) > G.size(-1):
X = X.mT
return X