diff --git a/dion/muon.py b/dion/muon.py index 92f444d..ad64722 100644 --- a/dion/muon.py +++ b/dion/muon.py @@ -45,6 +45,20 @@ class Muon(Optimizer): use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided. newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization. Signature is `func(input: Tensor, epsilon: float) -> Tensor`. + skip_update_prob: SkipUpdate survival probability p ∈ (0, 1]. + At each step, each parameter matrix is independently kept with probability p and + skipped (zeroed out) with probability (1-p). Surviving updates are rescaled by + 1/p to keep the update unbiased in expectation. Moment buffers always update + densely regardless of skip. None (default) disables SkipUpdate. + See: https://arxiv.org/abs/2602.15322 + magma_tau: Magma temperature τ > 0. When set, enables Magma mode which replaces the + fixed 1/p rescaling with an adaptive EMA scale driven by momentum-gradient alignment: + ẽ_t = sigmoid(cossim(momentum_before, grad) / τ) + s_t = 0.9 * s_{t-1} + 0.1 * ẽ_t + The surviving update is scaled by s_t instead of 1/p. This is intentionally biased + (no 1/s_t correction) as unbiased variants were found to be unstable. + Requires skip_update_prob to also be set. None (default) uses plain SkipUpdate scaling. + See: https://arxiv.org/abs/2602.15322 Muon optimizer algorithm by Keller Jordan: https://kellerjordan.github.io/posts/muon/ FSDP2 Muon uses all-to-all communications: https://www.essential.ai/blog/infra @@ -65,6 +79,8 @@ def __init__( flatten: bool = False, use_triton: bool = False, newton_schulz_func: Optional[Callable] = None, + skip_update_prob: Optional[float] = None, + magma_tau: Optional[float] = None, ): # Check hyperparameters if lr < 0.0: @@ -77,6 +93,13 @@ def __init__( raise ValueError( f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None." ) + # SkipUpdate / Magma: validate parameters + if skip_update_prob is not None and not (0.0 < skip_update_prob <= 1.0): + raise ValueError(f"skip_update_prob must be in (0, 1], got {skip_update_prob}") + if magma_tau is not None and magma_tau <= 0.0: + raise ValueError(f"magma_tau must be > 0, got {magma_tau}") + if magma_tau is not None and skip_update_prob is None: + raise ValueError("magma_tau requires skip_update_prob to be set") # Default arguments for each param group defaults = dict( @@ -92,6 +115,8 @@ def __init__( nesterov=nesterov, flatten=flatten, adjust_lr=adjust_lr, + skip_update_prob=skip_update_prob, # SkipUpdate: survival prob (None = disabled) + magma_tau=magma_tau, # Magma: temperature for adaptive scaling (None = disabled) ) super().__init__(params, defaults) @@ -180,6 +205,9 @@ def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict: state["momentum"] = torch.zeros_like(param) if algo == "adamw": state["variance"] = torch.zeros_like(param) + if algo == "muon": + # Magma: per-param EMA scale, init=0.5 (neutral alignment) + state["magma_scale"] = torch.tensor(0.5, device=param.device, dtype=param.dtype) return state def _create_muon_tasks( @@ -215,6 +243,8 @@ def _create_muon_tasks( process_group=self._process_group, newton_schulz_func=self._newton_schulz_func, cautious_wd=group["cautious_wd"], + skip_update_prob=group["skip_update_prob"], # SkipUpdate: survival probability + magma_tau=group["magma_tau"], # Magma: temperature (None = plain SkipUpdate) ) # Create batches of parameters of size self._world_size @@ -224,6 +254,7 @@ def _create_muon_tasks( gradients = [p.grad for p in params] states = [self._get_or_initialize_state(p, algo_name) for p in params] momentums = [s["momentum"] for s in states] + magma_scales = [s["magma_scale"] for s in states] # Magma EMA scale per param # Get sharding state for DTensor is_batch_sharded = False @@ -283,12 +314,13 @@ def _create_muon_tasks( # As long as matrix dimensions are not sharded, each device will have whole matrices # Each device already has different matrices of the batch, so we can't parallelize further if is_batch_sharded and not is_matrix_sharded: - for x, g, m in zip(params, gradients, momentums): + for x, g, m, s in zip(params, gradients, momentums, magma_scales): yield AsyncTask( muon_update_batch_async( X=[x], G=[g], M=[m], + S=[s], shard_dim=None, # No sharded matrix dim **muon_update_args, ) @@ -300,6 +332,7 @@ def _create_muon_tasks( X=pad_batch(params, self._world_size), G=pad_batch(gradients, self._world_size), M=pad_batch(momentums, self._world_size), + S=pad_batch(magma_scales, self._world_size), shard_dim=sharded_tensor_dim, **muon_update_args, ) @@ -394,6 +427,7 @@ def muon_update_batch_async( X: List[Tensor], # Model weights (modified in place) G: List[Tensor], # Gradient M: List[Tensor], # Momentum buffer (modified in place) + S: List[Tensor], # Magma EMA scale buffer, scalar per param (modified in place) lr: Tensor, # Learning rate (scalar tensor) momentum: Tensor, # Momentum factor (scalar tensor) weight_decay: Tensor, # Weight decay (scalar tensor) @@ -407,6 +441,8 @@ def muon_update_batch_async( process_group: Optional[ProcessGroup] = None, newton_schulz_func: Optional[Callable] = None, cautious_wd: bool = False, + skip_update_prob: Optional[float] = None, # SkipUpdate: survival probability (None = disabled) + magma_tau: Optional[float] = None, # Magma: temperature for adaptive scaling (None = disabled) ) -> Generator[None, None, None]: """ Batched version of Muon update. Batch size should be equal to number of GPUs. @@ -417,10 +453,17 @@ def muon_update_batch_async( assert len(X) == len(G) assert len(X) == len(M) + # Magma: snapshot momentum before it's updated, for cosine similarity with current grad. + # muon_update_pre_orthogonalize updates M in-place, so we must clone beforehand. + G_local = to_local(G) + M_local = to_local(M) + if magma_tau is not None: + M_before = [m.clone() for m in M_local] + # Update momentum and compute the inputs for orthogonalization U = muon_update_pre_orthogonalize( - G=to_local(G), - M=to_local(M), + G=G_local, + M=M_local, momentum=momentum, nesterov=nesterov, ) @@ -510,6 +553,34 @@ def muon_update_batch_async( epsilon=epsilon, ) + # SkipUpdate / Magma: stochastic block masking per parameter matrix. + # Moments always update densely (above); only the final update direction is masked. + # Reference: "On Surprising Effectiveness of Masking Updates in Adaptive Optimizers". + if skip_update_prob is not None and skip_update_prob < 1.0: + # muon_update_newton_schulz returns a Tensor, not a list; U may be a list already + U = list(U) if not isinstance(U, list) else U + S_local = to_local(S) + + for i in range(len(U)): + # Sample one Bernoulli scalar per parameter block (not per element) + keep = torch.bernoulli(torch.tensor(skip_update_prob, device=U[i].device)) + + if magma_tau is not None: + # Magma: adaptive scale via momentum-gradient cosine similarity. + # ẽ_t = sigmoid(cossim(μ_t_before, g_t) / τ) + # s_t = 0.9 * s_{t-1} + 0.1 * ẽ_t (EMA, updated in-place) + mu = M_before[i].flatten().float() + g = G_local[i].flatten().float() + cos = torch.dot(mu, g) / (mu.norm() * g.norm() + 1e-8) + e_tilde = torch.sigmoid(cos / magma_tau) + S_local[i].mul_(0.9).add_(e_tilde * 0.1) # EMA update in-place + scale = S_local[i] + else: + # Plain SkipUpdate: fixed unbiasing rescale of 1/p + scale = 1.0 / skip_update_prob + + U[i] = U[i] * (keep * scale) # zero-out or scale entire matrix + # Compute scaled learning rate # Do this before to_local(X) because we use the full tensor shape, not the shard shape if adjust_lr is None: diff --git a/dion/normuon.py b/dion/normuon.py index e2a2479..28bec7c 100644 --- a/dion/normuon.py +++ b/dion/normuon.py @@ -56,6 +56,21 @@ class NorMuon(Optimizer): use_triton: Whether to use Triton kernel for Newton-Schulz. Ignored if custom function is provided. newton_schulz_func: Use a custom Newton-Schulz function for orthogonalization. Signature is `func(input: Tensor, epsilon: float) -> Tensor`. + skip_update_prob: SkipUpdate survival probability p ∈ (0, 1]. + At each step, each parameter matrix is independently kept with probability p and + skipped (zeroed out) with probability (1-p). Surviving updates are rescaled by + 1/p to keep the update unbiased in expectation. Moment buffers (momentum and + variance_neuron) always update densely regardless of skip. + None (default) disables SkipUpdate. + See: https://arxiv.org/abs/2602.15322 + magma_tau: Magma temperature τ > 0. When set, enables Magma mode which replaces the + fixed 1/p rescaling with an adaptive EMA scale driven by momentum-gradient alignment: + ẽ_t = sigmoid(cossim(momentum_before, grad) / τ) + s_t = 0.9 * s_{t-1} + 0.1 * ẽ_t + The surviving update is scaled by s_t instead of 1/p. This is intentionally biased + (no 1/s_t correction) as unbiased variants were found to be unstable. + Requires skip_update_prob to also be set. None (default) uses plain SkipUpdate scaling. + See: https://arxiv.org/abs/2602.15322 Muon optimizer algorithm by Keller Jordan: https://kellerjordan.github.io/posts/muon/ FSDP2 Muon uses all-to-all communications: https://www.essential.ai/blog/infra @@ -78,6 +93,8 @@ def __init__( flatten: bool = False, use_triton: bool = False, newton_schulz_func: Optional[Callable] = None, + skip_update_prob: Optional[float] = None, + magma_tau: Optional[float] = None, ): # Check hyperparameters if lr < 0.0: @@ -92,6 +109,13 @@ def __init__( raise ValueError( f"Invalid adjust_lr value: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None." ) + # SkipUpdate / Magma: validate parameters + if skip_update_prob is not None and not (0.0 < skip_update_prob <= 1.0): + raise ValueError(f"skip_update_prob must be in (0, 1], got {skip_update_prob}") + if magma_tau is not None and magma_tau <= 0.0: + raise ValueError(f"magma_tau must be > 0, got {magma_tau}") + if magma_tau is not None and skip_update_prob is None: + raise ValueError("magma_tau requires skip_update_prob to be set") # Default arguments for each param group defaults = dict( @@ -108,6 +132,8 @@ def __init__( nesterov=nesterov, flatten=flatten, adjust_lr=adjust_lr, + skip_update_prob=skip_update_prob, # SkipUpdate: survival prob (None = disabled) + magma_tau=magma_tau, # Magma: temperature for adaptive scaling (None = disabled) ) super().__init__(params, defaults) @@ -198,6 +224,8 @@ def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict: state["variance"] = torch.zeros_like(param) if algo == "normuon": state["variance_neuron"] = torch.zeros_like(param[..., 0:1]) + # Magma: per-param EMA scale, init=0.5 (neutral alignment) + state["magma_scale"] = torch.tensor(0.5, device=param.device, dtype=param.dtype) return state def _create_normuon_tasks( @@ -234,6 +262,8 @@ def _create_normuon_tasks( process_group=self._process_group, newton_schulz_func=self._newton_schulz_func, cautious_wd=group["cautious_wd"], + skip_update_prob=group["skip_update_prob"], # SkipUpdate: survival probability + magma_tau=group["magma_tau"], # Magma: temperature (None = plain SkipUpdate) ) # Create batches of parameters of size self._world_size @@ -244,6 +274,7 @@ def _create_normuon_tasks( states = [self._get_or_initialize_state(p, algo_name) for p in params] momentums = [s["momentum"] for s in states] variances_neuron = [s["variance_neuron"] for s in states] + magma_scales = [s["magma_scale"] for s in states] # Magma EMA scale per param # Get sharding state for DTensor is_batch_sharded = False @@ -311,8 +342,8 @@ def _create_normuon_tasks( # As long as matrix dimensions are not sharded, each device will have whole matrices # Each device already has different matrices of the batch, so we can't parallelize further if is_batch_sharded and not is_matrix_sharded: - for x, g, m, v in zip( - params, gradients, momentums, variances_neuron + for x, g, m, v, s in zip( + params, gradients, momentums, variances_neuron, magma_scales ): yield AsyncTask( normuon_update_batch_async( @@ -320,6 +351,7 @@ def _create_normuon_tasks( G=[g], M=[m], V=[v], + S=[s], shard_dim=None, # No sharded matrix dim **normuon_update_args, ) @@ -332,6 +364,7 @@ def _create_normuon_tasks( G=pad_batch(gradients, self._world_size), M=pad_batch(momentums, self._world_size), V=pad_batch(variances_neuron, self._world_size), + S=pad_batch(magma_scales, self._world_size), shard_dim=sharded_tensor_dim, **normuon_update_args, ) @@ -427,6 +460,7 @@ def normuon_update_batch_async( G: List[Tensor], # Gradient M: List[Tensor], # Momentum buffer (modified in place) V: List[Tensor], # Variance neuron buffer (modified in place) + S: List[Tensor], # Magma EMA scale buffer, scalar per param (modified in place) lr: Tensor, # Learning rate (scalar tensor) momentum: Tensor, # Momentum factor (scalar tensor) muon_beta2: Tensor, # Muon beta2 for normalization @@ -441,6 +475,8 @@ def normuon_update_batch_async( process_group: Optional[ProcessGroup] = None, newton_schulz_func: Optional[Callable] = None, cautious_wd: bool = False, + skip_update_prob: Optional[float] = None, # SkipUpdate: survival probability (None = disabled) + magma_tau: Optional[float] = None, # Magma: temperature for adaptive scaling (None = disabled) ) -> Generator[None, None, None]: """ Batched version of Muon update. Batch size should be equal to number of GPUs. @@ -451,10 +487,17 @@ def normuon_update_batch_async( assert len(X) == len(G) assert len(X) == len(M) + # Magma: snapshot momentum before it's updated, for cosine similarity with current grad. + # muon_update_pre_orthogonalize updates M in-place, so we must clone beforehand. + G_local = to_local(G) + M_local = to_local(M) + if magma_tau is not None: + M_before = [m.clone() for m in M_local] + # Update momentum and compute the inputs for orthogonalization U = muon_update_pre_orthogonalize( - G=to_local(G), - M=to_local(M), + G=G_local, + M=M_local, momentum=momentum, nesterov=nesterov, ) @@ -552,6 +595,34 @@ def normuon_update_batch_async( muon_beta2=muon_beta2, ) + # SkipUpdate / Magma: stochastic block masking per parameter matrix. + # Moments always update densely (above); only the final update direction is masked. + # Reference: "On Surprising Effectiveness of Masking Updates in Adaptive Optimizers". + if skip_update_prob is not None and skip_update_prob < 1.0: + # normuon_normalization may return a tuple under torch.compile; convert to list + U = list(U) + S_local = to_local(S) + + for i in range(len(U)): + # Sample one Bernoulli scalar per parameter block (not per element) + keep = torch.bernoulli(torch.tensor(skip_update_prob, device=U[i].device)) + + if magma_tau is not None: + # Magma: adaptive scale via momentum-gradient cosine similarity. + # ẽ_t = sigmoid(cossim(μ_t_before, g_t) / τ) + # s_t = 0.9 * s_{t-1} + 0.1 * ẽ_t (EMA, updated in-place) + mu = M_before[i].flatten().float() + g = G_local[i].flatten().float() + cos = torch.dot(mu, g) / (mu.norm() * g.norm() + 1e-8) + e_tilde = torch.sigmoid(cos / magma_tau) + S_local[i].mul_(0.9).add_(e_tilde * 0.1) # EMA update in-place + scale = S_local[i] + else: + # Plain SkipUpdate: fixed unbiasing rescale of 1/p + scale = 1.0 / skip_update_prob + + U[i] = U[i] * (keep * scale) # zero-out or scale entire matrix + # Compute scaled learning rate # Do this before to_local(X) because we use the full tensor shape, not the shard shape if adjust_lr is None: