diff --git a/dion/muon.py b/dion/muon.py index 92f444d..a27e1c6 100644 --- a/dion/muon.py +++ b/dion/muon.py @@ -9,6 +9,7 @@ from typing import Callable, Generator, List, Optional, Tuple, Union from .newton_schulz_triton import newton_schulz_triton +from .polar_express import polar_express, polar_express_triton from .opt_utils import ( AsyncRuntime, AsyncTask, @@ -43,6 +44,7 @@ class Muon(Optimizer): True: Tensors with 3+ dimensions are flattened to 2D. Use this for convolutional layers. False: Tensors are not flattened. 3D+ tensors are treated as batches of 2D matrices. use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided. + use_polar_express: Whether to use Polar Express instead of Newton-Schulz. Ignored if custom function is provided. newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization. Signature is `func(input: Tensor, epsilon: float) -> Tensor`. @@ -64,6 +66,7 @@ def __init__( adjust_lr: Optional[str] = "spectral_norm", flatten: bool = False, use_triton: bool = False, + use_polar_express: bool = False, newton_schulz_func: Optional[Callable] = None, ): # Check hyperparameters @@ -118,13 +121,17 @@ def __init__( ) self._distributed_mesh = distributed_mesh - # Newton-Schulz configuration + # Orthogonalization function configuration if newton_schulz_func is not None: if not callable(newton_schulz_func): raise TypeError( f"newton_schulz_func must be a callable function, got {type(newton_schulz_func)}" ) self._newton_schulz_func = newton_schulz_func + elif use_polar_express and use_triton: + self._newton_schulz_func = polar_express_triton + elif use_polar_express: + self._newton_schulz_func = polar_express elif use_triton: self._newton_schulz_func = newton_schulz_triton else: diff --git a/dion/normuon.py b/dion/normuon.py index e2a2479..25675ed 100644 --- a/dion/normuon.py +++ b/dion/normuon.py @@ -9,6 +9,7 @@ from typing import Callable, Generator, List, Optional, Tuple, Union from .newton_schulz_triton import newton_schulz_triton +from .polar_express import polar_express, polar_express_triton from .opt_utils import ( AsyncRuntime, AsyncTask, @@ -54,6 +55,7 @@ class NorMuon(Optimizer): True: Tensors with 3+ dimensions are flattened to 2D. Use this for convolutional layers. False: Tensors are not flattened. 3D+ tensors are treated as batches of 2D matrices. use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided. + use_polar_express: Whether to use Polar Express instead of Newton-Schulz. Ignored if custom function is provided. newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization. Signature is `func(input: Tensor, epsilon: float) -> Tensor`. @@ -77,6 +79,7 @@ def __init__( adjust_lr: Optional[str] = "rms_norm", flatten: bool = False, use_triton: bool = False, + use_polar_express: bool = False, newton_schulz_func: Optional[Callable] = None, ): # Check hyperparameters @@ -134,13 +137,17 @@ def __init__( ) self._distributed_mesh = distributed_mesh - # Newton-Schulz configuration + # Orthogonalization function configuration if newton_schulz_func is not None: if not callable(newton_schulz_func): raise TypeError( f"newton_schulz_func must be a callable function, got {type(newton_schulz_func)}" ) self._newton_schulz_func = newton_schulz_func + elif use_polar_express and use_triton: + self._newton_schulz_func = polar_express_triton + elif use_polar_express: + self._newton_schulz_func = polar_express elif use_triton: self._newton_schulz_func = newton_schulz_triton else: diff --git a/dion/polar_express.py b/dion/polar_express.py new file mode 100644 index 0000000..0f233ed --- /dev/null +++ b/dion/polar_express.py @@ -0,0 +1,85 @@ +import torch +from torch import Tensor + +from .newton_schulz_triton import ns_line_1, ns_line_2 + +# Polar Express coefficients (computed for num_iters=5, safety_factor=2e-2, cushion=2) +# From https://arxiv.org/pdf/2505.16932 +# Matches the battle-tested KellerJordan/karpathy implementations. +# Safety factor of 1.02 is baked into all but the last polynomial. +_POLAR_EXPRESS_COEFFS = [ + (8.156554524902461, -22.48329292557795, 15.878769915207462), + (4.042929935166739, -2.808917465908714, 0.5000178451051316), + (3.8916678022926607, -2.772484153217685, 0.5060648178503393), + (3.285753657755655, -2.3681294933425376, 0.46449024233003106), + (2.3465413258596377, -1.7097828382687081, 0.42323551169305323), +] + + +@torch.compile(dynamic=False, fullgraph=True) +def polar_express(G: Tensor, epsilon: float = 1e-6) -> Tensor: + """ + Polar Express orthogonalization (pure PyTorch, torch.compile'd). + + Polar Express: https://arxiv.org/pdf/2505.16932 + by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. + + Signature matches zeropower_via_newtonschulz5: func(input, epsilon) -> Tensor. + """ + assert G.ndim >= 2 + X = G.bfloat16() + + is_tall = G.size(-2) > G.size(-1) + + # Ensure spectral norm is at most 1, with safety factor matching coefficients + X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + epsilon) + + if is_tall: + # Tall: use X.mT @ X (small cols x cols) + right-multiply + for a, b, c in _POLAR_EXPRESS_COEFFS: + A = X.mT @ X + B = b * A + c * (A @ A) + X = a * X + X @ B + else: + # Wide: use X @ X.mT (small rows x rows) + left-multiply + for a, b, c in _POLAR_EXPRESS_COEFFS: + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + + return X + + +@torch.compile(dynamic=False, fullgraph=True) +def polar_express_triton(G: Tensor, epsilon: float = 1e-6) -> Tensor: + """ + Polar Express orthogonalization using Triton kernels that exploit + the symmetry of A = X @ X.mT and B = c*(A@A) + b*A, computing only + half the blocks and mirroring across the diagonal. + + Signature matches zeropower_via_newtonschulz5: func(input, epsilon) -> Tensor. + """ + X = G.to(dtype=torch.bfloat16) + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1, with safety factor matching coefficients + X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + epsilon) + + # Pre-allocate buffers for the symmetric intermediates and output + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + for a, b, c in _POLAR_EXPRESS_COEFFS: + ns_line_1(X, out=A) # A = X @ X.mT (symmetric, half-compute) + ns_line_2(A, alpha=c, beta=b, out=B) # B = c*(A@A) + b*A (symmetric, half-compute) + line_3(X, B, X, beta=a, out=C) # C = a*X + B@X + X, C = C, X + + if G.size(-2) > G.size(-1): + X = X.mT + return X