Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 74 additions & 3 deletions dion/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,20 @@ class Muon(Optimizer):
use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided.
newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization.
Signature is `func(input: Tensor, epsilon: float) -> Tensor`.
skip_update_prob: SkipUpdate survival probability p ∈ (0, 1].
At each step, each parameter matrix is independently kept with probability p and
skipped (zeroed out) with probability (1-p). Surviving updates are rescaled by
1/p to keep the update unbiased in expectation. Moment buffers always update
densely regardless of skip. None (default) disables SkipUpdate.
See: https://arxiv.org/abs/2602.15322
magma_tau: Magma temperature τ > 0. When set, enables Magma mode which replaces the
fixed 1/p rescaling with an adaptive EMA scale driven by momentum-gradient alignment:
ẽ_t = sigmoid(cossim(momentum_before, grad) / τ)
s_t = 0.9 * s_{t-1} + 0.1 * ẽ_t
The surviving update is scaled by s_t instead of 1/p. This is intentionally biased
(no 1/s_t correction) as unbiased variants were found to be unstable.
Requires skip_update_prob to also be set. None (default) uses plain SkipUpdate scaling.
See: https://arxiv.org/abs/2602.15322

Muon optimizer algorithm by Keller Jordan: https://kellerjordan.github.io/posts/muon/
FSDP2 Muon uses all-to-all communications: https://www.essential.ai/blog/infra
Expand All @@ -65,6 +79,8 @@ def __init__(
flatten: bool = False,
use_triton: bool = False,
newton_schulz_func: Optional[Callable] = None,
skip_update_prob: Optional[float] = None,
magma_tau: Optional[float] = None,
):
# Check hyperparameters
if lr < 0.0:
Expand All @@ -77,6 +93,13 @@ def __init__(
raise ValueError(
f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None."
)
# SkipUpdate / Magma: validate parameters
if skip_update_prob is not None and not (0.0 < skip_update_prob <= 1.0):
raise ValueError(f"skip_update_prob must be in (0, 1], got {skip_update_prob}")
if magma_tau is not None and magma_tau <= 0.0:
raise ValueError(f"magma_tau must be > 0, got {magma_tau}")
if magma_tau is not None and skip_update_prob is None:
raise ValueError("magma_tau requires skip_update_prob to be set")

# Default arguments for each param group
defaults = dict(
Expand All @@ -92,6 +115,8 @@ def __init__(
nesterov=nesterov,
flatten=flatten,
adjust_lr=adjust_lr,
skip_update_prob=skip_update_prob, # SkipUpdate: survival prob (None = disabled)
magma_tau=magma_tau, # Magma: temperature for adaptive scaling (None = disabled)
)
super().__init__(params, defaults)

Expand Down Expand Up @@ -180,6 +205,9 @@ def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict:
state["momentum"] = torch.zeros_like(param)
if algo == "adamw":
state["variance"] = torch.zeros_like(param)
if algo == "muon":
# Magma: per-param EMA scale, init=0.5 (neutral alignment)
state["magma_scale"] = torch.tensor(0.5, device=param.device, dtype=param.dtype)
return state

def _create_muon_tasks(
Expand Down Expand Up @@ -215,6 +243,8 @@ def _create_muon_tasks(
process_group=self._process_group,
newton_schulz_func=self._newton_schulz_func,
cautious_wd=group["cautious_wd"],
skip_update_prob=group["skip_update_prob"], # SkipUpdate: survival probability
magma_tau=group["magma_tau"], # Magma: temperature (None = plain SkipUpdate)
)

# Create batches of parameters of size self._world_size
Expand All @@ -224,6 +254,7 @@ def _create_muon_tasks(
gradients = [p.grad for p in params]
states = [self._get_or_initialize_state(p, algo_name) for p in params]
momentums = [s["momentum"] for s in states]
magma_scales = [s["magma_scale"] for s in states] # Magma EMA scale per param

# Get sharding state for DTensor
is_batch_sharded = False
Expand Down Expand Up @@ -283,12 +314,13 @@ def _create_muon_tasks(
# As long as matrix dimensions are not sharded, each device will have whole matrices
# Each device already has different matrices of the batch, so we can't parallelize further
if is_batch_sharded and not is_matrix_sharded:
for x, g, m in zip(params, gradients, momentums):
for x, g, m, s in zip(params, gradients, momentums, magma_scales):
yield AsyncTask(
muon_update_batch_async(
X=[x],
G=[g],
M=[m],
S=[s],
shard_dim=None, # No sharded matrix dim
**muon_update_args,
)
Expand All @@ -300,6 +332,7 @@ def _create_muon_tasks(
X=pad_batch(params, self._world_size),
G=pad_batch(gradients, self._world_size),
M=pad_batch(momentums, self._world_size),
S=pad_batch(magma_scales, self._world_size),
shard_dim=sharded_tensor_dim,
**muon_update_args,
)
Expand Down Expand Up @@ -394,6 +427,7 @@ def muon_update_batch_async(
X: List[Tensor], # Model weights (modified in place)
G: List[Tensor], # Gradient
M: List[Tensor], # Momentum buffer (modified in place)
S: List[Tensor], # Magma EMA scale buffer, scalar per param (modified in place)
lr: Tensor, # Learning rate (scalar tensor)
momentum: Tensor, # Momentum factor (scalar tensor)
weight_decay: Tensor, # Weight decay (scalar tensor)
Expand All @@ -407,6 +441,8 @@ def muon_update_batch_async(
process_group: Optional[ProcessGroup] = None,
newton_schulz_func: Optional[Callable] = None,
cautious_wd: bool = False,
skip_update_prob: Optional[float] = None, # SkipUpdate: survival probability (None = disabled)
magma_tau: Optional[float] = None, # Magma: temperature for adaptive scaling (None = disabled)
) -> Generator[None, None, None]:
"""
Batched version of Muon update. Batch size should be equal to number of GPUs.
Expand All @@ -417,10 +453,17 @@ def muon_update_batch_async(
assert len(X) == len(G)
assert len(X) == len(M)

# Magma: snapshot momentum before it's updated, for cosine similarity with current grad.
# muon_update_pre_orthogonalize updates M in-place, so we must clone beforehand.
G_local = to_local(G)
M_local = to_local(M)
if magma_tau is not None:
M_before = [m.clone() for m in M_local]

# Update momentum and compute the inputs for orthogonalization
U = muon_update_pre_orthogonalize(
G=to_local(G),
M=to_local(M),
G=G_local,
M=M_local,
momentum=momentum,
nesterov=nesterov,
)
Expand Down Expand Up @@ -510,6 +553,34 @@ def muon_update_batch_async(
epsilon=epsilon,
)

# SkipUpdate / Magma: stochastic block masking per parameter matrix.
# Moments always update densely (above); only the final update direction is masked.
# Reference: "On Surprising Effectiveness of Masking Updates in Adaptive Optimizers".
if skip_update_prob is not None and skip_update_prob < 1.0:
# muon_update_newton_schulz returns a Tensor, not a list; U may be a list already
U = list(U) if not isinstance(U, list) else U
S_local = to_local(S)

for i in range(len(U)):
# Sample one Bernoulli scalar per parameter block (not per element)
keep = torch.bernoulli(torch.tensor(skip_update_prob, device=U[i].device))

if magma_tau is not None:
# Magma: adaptive scale via momentum-gradient cosine similarity.
# ẽ_t = sigmoid(cossim(μ_t_before, g_t) / τ)
# s_t = 0.9 * s_{t-1} + 0.1 * ẽ_t (EMA, updated in-place)
mu = M_before[i].flatten().float()
g = G_local[i].flatten().float()
cos = torch.dot(mu, g) / (mu.norm() * g.norm() + 1e-8)
e_tilde = torch.sigmoid(cos / magma_tau)
S_local[i].mul_(0.9).add_(e_tilde * 0.1) # EMA update in-place
scale = S_local[i]
else:
# Plain SkipUpdate: fixed unbiasing rescale of 1/p
scale = 1.0 / skip_update_prob

U[i] = U[i] * (keep * scale) # zero-out or scale entire matrix

# Compute scaled learning rate
# Do this before to_local(X) because we use the full tensor shape, not the shard shape
if adjust_lr is None:
Expand Down
79 changes: 75 additions & 4 deletions dion/normuon.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,21 @@ class NorMuon(Optimizer):
use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided.
newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization.
Signature is `func(input: Tensor, epsilon: float) -> Tensor`.
skip_update_prob: SkipUpdate survival probability p ∈ (0, 1].
At each step, each parameter matrix is independently kept with probability p and
skipped (zeroed out) with probability (1-p). Surviving updates are rescaled by
1/p to keep the update unbiased in expectation. Moment buffers (momentum and
variance_neuron) always update densely regardless of skip.
None (default) disables SkipUpdate.
See: https://arxiv.org/abs/2602.15322
magma_tau: Magma temperature τ > 0. When set, enables Magma mode which replaces the
fixed 1/p rescaling with an adaptive EMA scale driven by momentum-gradient alignment:
ẽ_t = sigmoid(cossim(momentum_before, grad) / τ)
s_t = 0.9 * s_{t-1} + 0.1 * ẽ_t
The surviving update is scaled by s_t instead of 1/p. This is intentionally biased
(no 1/s_t correction) as unbiased variants were found to be unstable.
Requires skip_update_prob to also be set. None (default) uses plain SkipUpdate scaling.
See: https://arxiv.org/abs/2602.15322

Muon optimizer algorithm by Keller Jordan: https://kellerjordan.github.io/posts/muon/
FSDP2 Muon uses all-to-all communications: https://www.essential.ai/blog/infra
Expand All @@ -78,6 +93,8 @@ def __init__(
flatten: bool = False,
use_triton: bool = False,
newton_schulz_func: Optional[Callable] = None,
skip_update_prob: Optional[float] = None,
magma_tau: Optional[float] = None,
):
# Check hyperparameters
if lr < 0.0:
Expand All @@ -92,6 +109,13 @@ def __init__(
raise ValueError(
f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None."
)
# SkipUpdate / Magma: validate parameters
if skip_update_prob is not None and not (0.0 < skip_update_prob <= 1.0):
raise ValueError(f"skip_update_prob must be in (0, 1], got {skip_update_prob}")
if magma_tau is not None and magma_tau <= 0.0:
raise ValueError(f"magma_tau must be > 0, got {magma_tau}")
if magma_tau is not None and skip_update_prob is None:
raise ValueError("magma_tau requires skip_update_prob to be set")

# Default arguments for each param group
defaults = dict(
Expand All @@ -108,6 +132,8 @@ def __init__(
nesterov=nesterov,
flatten=flatten,
adjust_lr=adjust_lr,
skip_update_prob=skip_update_prob, # SkipUpdate: survival prob (None = disabled)
magma_tau=magma_tau, # Magma: temperature for adaptive scaling (None = disabled)
)
super().__init__(params, defaults)

Expand Down Expand Up @@ -198,6 +224,8 @@ def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict:
state["variance"] = torch.zeros_like(param)
if algo == "normuon":
state["variance_neuron"] = torch.zeros_like(param[..., 0:1])
# Magma: per-param EMA scale, init=0.5 (neutral alignment)
state["magma_scale"] = torch.tensor(0.5, device=param.device, dtype=param.dtype)
return state

def _create_normuon_tasks(
Expand Down Expand Up @@ -234,6 +262,8 @@ def _create_normuon_tasks(
process_group=self._process_group,
newton_schulz_func=self._newton_schulz_func,
cautious_wd=group["cautious_wd"],
skip_update_prob=group["skip_update_prob"], # SkipUpdate: survival probability
magma_tau=group["magma_tau"], # Magma: temperature (None = plain SkipUpdate)
)

# Create batches of parameters of size self._world_size
Expand All @@ -244,6 +274,7 @@ def _create_normuon_tasks(
states = [self._get_or_initialize_state(p, algo_name) for p in params]
momentums = [s["momentum"] for s in states]
variances_neuron = [s["variance_neuron"] for s in states]
magma_scales = [s["magma_scale"] for s in states] # Magma EMA scale per param

# Get sharding state for DTensor
is_batch_sharded = False
Expand Down Expand Up @@ -311,15 +342,16 @@ def _create_normuon_tasks(
# As long as matrix dimensions are not sharded, each device will have whole matrices
# Each device already has different matrices of the batch, so we can't parallelize further
if is_batch_sharded and not is_matrix_sharded:
for x, g, m, v in zip(
params, gradients, momentums, variances_neuron
for x, g, m, v, s in zip(
params, gradients, momentums, variances_neuron, magma_scales
):
yield AsyncTask(
normuon_update_batch_async(
X=[x],
G=[g],
M=[m],
V=[v],
S=[s],
shard_dim=None, # No sharded matrix dim
**normuon_update_args,
)
Expand All @@ -332,6 +364,7 @@ def _create_normuon_tasks(
G=pad_batch(gradients, self._world_size),
M=pad_batch(momentums, self._world_size),
V=pad_batch(variances_neuron, self._world_size),
S=pad_batch(magma_scales, self._world_size),
shard_dim=sharded_tensor_dim,
**normuon_update_args,
)
Expand Down Expand Up @@ -427,6 +460,7 @@ def normuon_update_batch_async(
G: List[Tensor], # Gradient
M: List[Tensor], # Momentum buffer (modified in place)
V: List[Tensor], # Variance neuron buffer (modified in place)
S: List[Tensor], # Magma EMA scale buffer, scalar per param (modified in place)
lr: Tensor, # Learning rate (scalar tensor)
momentum: Tensor, # Momentum factor (scalar tensor)
muon_beta2: Tensor, # Muon beta2 for normalization
Expand All @@ -441,6 +475,8 @@ def normuon_update_batch_async(
process_group: Optional[ProcessGroup] = None,
newton_schulz_func: Optional[Callable] = None,
cautious_wd: bool = False,
skip_update_prob: Optional[float] = None, # SkipUpdate: survival probability (None = disabled)
magma_tau: Optional[float] = None, # Magma: temperature for adaptive scaling (None = disabled)
) -> Generator[None, None, None]:
"""
Batched version of Muon update. Batch size should be equal to number of GPUs.
Expand All @@ -451,10 +487,17 @@ def normuon_update_batch_async(
assert len(X) == len(G)
assert len(X) == len(M)

# Magma: snapshot momentum before it's updated, for cosine similarity with current grad.
# muon_update_pre_orthogonalize updates M in-place, so we must clone beforehand.
G_local = to_local(G)
M_local = to_local(M)
if magma_tau is not None:
M_before = [m.clone() for m in M_local]

# Update momentum and compute the inputs for orthogonalization
U = muon_update_pre_orthogonalize(
G=to_local(G),
M=to_local(M),
G=G_local,
M=M_local,
momentum=momentum,
nesterov=nesterov,
)
Expand Down Expand Up @@ -552,6 +595,34 @@ def normuon_update_batch_async(
muon_beta2=muon_beta2,
)

# SkipUpdate / Magma: stochastic block masking per parameter matrix.
# Moments always update densely (above); only the final update direction is masked.
# Reference: "On Surprising Effectiveness of Masking Updates in Adaptive Optimizers".
if skip_update_prob is not None and skip_update_prob < 1.0:
# normuon_normalization may return a tuple under torch.compile; convert to list
U = list(U)
S_local = to_local(S)

for i in range(len(U)):
# Sample one Bernoulli scalar per parameter block (not per element)
keep = torch.bernoulli(torch.tensor(skip_update_prob, device=U[i].device))

if magma_tau is not None:
# Magma: adaptive scale via momentum-gradient cosine similarity.
# ẽ_t = sigmoid(cossim(μ_t_before, g_t) / τ)
# s_t = 0.9 * s_{t-1} + 0.1 * ẽ_t (EMA, updated in-place)
mu = M_before[i].flatten().float()
g = G_local[i].flatten().float()
cos = torch.dot(mu, g) / (mu.norm() * g.norm() + 1e-8)
e_tilde = torch.sigmoid(cos / magma_tau)
S_local[i].mul_(0.9).add_(e_tilde * 0.1) # EMA update in-place
scale = S_local[i]
else:
# Plain SkipUpdate: fixed unbiasing rescale of 1/p
scale = 1.0 / skip_update_prob

U[i] = U[i] * (keep * scale) # zero-out or scale entire matrix

# Compute scaled learning rate
# Do this before to_local(X) because we use the full tensor shape, not the shard shape
if adjust_lr is None:
Expand Down