diff --git a/dion/normuon.py b/dion/normuon.py index e2a2479..ee8f362 100644 --- a/dion/normuon.py +++ b/dion/normuon.py @@ -1,6 +1,7 @@ import math import torch import torch.distributed as dist +from collections import defaultdict from itertools import chain from torch import Tensor from torch.distributed import ProcessGroup @@ -200,14 +201,69 @@ def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict: state["variance_neuron"] = torch.zeros_like(param[..., 0:1]) return state + def _get_shard_info(self, param: Tensor, group: dict): + """Determine sharding info for a parameter. Returns (is_batch_sharded, is_matrix_sharded, sharded_tensor_dim).""" + is_batch_sharded = False + is_matrix_sharded = False + sharded_tensor_dim = None + + if not isinstance(param, DTensor): + return is_batch_sharded, is_matrix_sharded, sharded_tensor_dim + + if not isinstance(self._distributed_mesh, DeviceMesh): + raise RuntimeError( + "Must create optimizer with DeviceMesh if using DTensor parameters." + ) + + shard_placements = [ + (i, p) + for i, p in enumerate(param.placements) + if p.is_shard() and param.device_mesh.size(i) > 1 + ] + + if not group["flatten"]: + matrix_dims = {param.ndim - 1, param.ndim - 2} + is_batch_sharded = any( + p.dim not in matrix_dims for _, p in shard_placements + ) + shard_placements = [ + (i, p) for i, p in shard_placements if p.dim in matrix_dims + ] + + if any(p.dim == param.ndim - 1 for _, p in shard_placements): + raise NotImplementedError( + "NorMuon currently does not support parameters sharded along the last dimension. " + "Please avoid shards at dim -1." + ) + + if len(shard_placements) == 1: + is_matrix_sharded = True + sharded_mesh_dim = shard_placements[0][0] + sharded_tensor_dim = shard_placements[0][1].dim + + if ( + param.device_mesh.get_group(sharded_mesh_dim) + != self._process_group + ): + raise RuntimeError( + f"Got DTensor sharded over mesh dimension {sharded_mesh_dim} different from the optimizer's device mesh. " + f"DTensor has mesh: {param.device_mesh}, placements: {param.placements}, but optimizer was created with mesh: {self._distributed_mesh}." + ) + elif len(shard_placements) > 1: + raise NotImplementedError( + "NorMuon does not support parameters with multiple sharded dimensions." + ) + + return is_batch_sharded, is_matrix_sharded, sharded_tensor_dim + def _create_normuon_tasks( self, param_groups: List[dict], algo_name: str = "normuon", ) -> Generator["AsyncTask", None, None]: """ - Helper function to create batches of NorMuon matrices and generate - AsyncTask objects so we can process multiple batches concurrently. + Mega-batched NorMuon task creation: groups ALL same-shape parameters + into a single task to minimize communication rounds and kernel launches. """ for group in param_groups: assert group["algorithm"] == algo_name @@ -236,106 +292,39 @@ def _create_normuon_tasks( cautious_wd=group["cautious_wd"], ) - # Create batches of parameters of size self._world_size - for params in create_param_batches( - group_params, batch_size=self._world_size - ): + # Group parameters by shape, sharding, and dtype for mega-batching + shape_groups: dict[tuple, list] = defaultdict(list) + for p in group_params: + sharding = p.placements if isinstance(p, DTensor) else None + shape_groups[(p.shape, sharding, p.dtype)].append(p) + + for (_shape, _sharding, _dtype), params in shape_groups.items(): gradients = [p.grad for p in params] states = [self._get_or_initialize_state(p, algo_name) for p in params] momentums = [s["momentum"] for s in states] variances_neuron = [s["variance_neuron"] for s in states] - # Get sharding state for DTensor - is_batch_sharded = False - is_matrix_sharded = False - sharded_mesh_dim = None - sharded_tensor_dim = None - - if isinstance(params[0], DTensor): - if not isinstance(self._distributed_mesh, DeviceMesh): - raise RuntimeError( - "Must create optimizer with DeviceMesh if using DTensor parameters." - ) - - # Find the sharded placement and get its mesh and tensor dimensions - # Skip any Shard() placements on size-1 mesh dimension = Replicate() - shard_placements = [ - (i, p) - for i, p in enumerate(params[0].placements) - if p.is_shard() and params[0].device_mesh.size(i) > 1 - ] - - # If we don't flatten 3D matrices, we can ignore shard placements along batch dimensions - # Only keep placements that shard one of the two matrix dimensions - if not group["flatten"]: - matrix_dims = {params[0].ndim - 1, params[0].ndim - 2} - is_batch_sharded = any( - p.dim not in matrix_dims for _, p in shard_placements - ) - shard_placements = [ - (i, p) for i, p in shard_placements if p.dim in matrix_dims - ] - - # We currently do not support tensors sharded along the last dimension because NorMuon - # normalization later assumes a full trailing axis when computing means. - if any(p.dim == params[0].ndim - 1 for _, p in shard_placements): - raise NotImplementedError( - "NorMuon currently does not support parameters sharded along the last dimension. " - "Please avoid shards at dim -1." - ) - - # Check that we have no more than 1 sharded matrix dimension - # Note that non-flattened 3D tensors can have additional sharded batch dimensions - # Flattened 3D tensors are limited to one sharded dimension out of all dimensions - if len(shard_placements) == 1: - is_matrix_sharded = True - sharded_mesh_dim = shard_placements[0][0] - sharded_tensor_dim = shard_placements[0][1].dim - elif len(shard_placements) > 1: - raise NotImplementedError( - "NorMuon does not support parameters with multiple sharded dimensions." - ) - - # Check that the sharded mesh dimension matches optimizer's device mesh - if ( - sharded_mesh_dim is not None - and params[0].device_mesh.get_group(sharded_mesh_dim) - != self._process_group - ): - raise RuntimeError( - f"Got DTensor sharded over mesh dimension {sharded_mesh_dim} different from the optimizer's device mesh. " - f"DTensor has mesh: {params[0].device_mesh}, placements: {params[0].placements}, but optimizer was created with mesh: {self._distributed_mesh}." - ) - - # Special case for 3D tensors sharded along batch dimension - # As long as matrix dimensions are not sharded, each device will have whole matrices - # Each device already has different matrices of the batch, so we can't parallelize further + is_batch_sharded, is_matrix_sharded, sharded_tensor_dim = ( + self._get_shard_info(params[0], group) + ) + + megabatch_args = normuon_update_args if is_batch_sharded and not is_matrix_sharded: - for x, g, m, v in zip( - params, gradients, momentums, variances_neuron - ): - yield AsyncTask( - normuon_update_batch_async( - X=[x], - G=[g], - M=[m], - V=[v], - shard_dim=None, # No sharded matrix dim - **normuon_update_args, - ) - ) - # Otherwise, we parallelize the Muon update across devices - else: - yield AsyncTask( - normuon_update_batch_async( - X=pad_batch(params, self._world_size), - G=pad_batch(gradients, self._world_size), - M=pad_batch(momentums, self._world_size), - V=pad_batch(variances_neuron, self._world_size), - shard_dim=sharded_tensor_dim, - **normuon_update_args, - ) + # Batch-sharded 3D tensors already contain whole local matrices, + # so we can stack them and run the local megabatch path without + # cross-rank communication. + megabatch_args = {**normuon_update_args, "process_group": None} + + yield AsyncTask( + normuon_update_megabatch_async( + X=params, + G=gradients, + M=momentums, + V=variances_neuron, + shard_dim=sharded_tensor_dim, + **megabatch_args, ) + ) def _create_lion_tasks( self, @@ -574,6 +563,210 @@ def normuon_update_batch_async( ) +def normuon_update_megabatch_async( + X: List[Tensor], # All same-shape params (may be more than world_size) + G: List[Tensor], + M: List[Tensor], + V: List[Tensor], + lr: Tensor, + momentum: Tensor, + muon_beta2: Tensor, + weight_decay: Tensor, + epsilon: Tensor, + nesterov: bool, + flatten: bool, + adjust_lr: Optional[str], + device_rank: int, + world_size: int, + shard_dim: Optional[int] = None, + process_group: Optional[ProcessGroup] = None, + newton_schulz_func: Optional[Callable] = None, + cautious_wd: bool = False, +) -> Generator[None, None, None]: + """ + Mega-batched NorMuon update: processes ALL same-shape parameters in one + communication round instead of world_size-sized batches. This reduces + the number of all-to-all/all-gather rounds from O(N/world_size) to O(1) + per shape group, and enables batched Newton-Schulz on stacked 3D tensors. + """ + N = len(X) + assert N == len(G) == len(M) == len(V) + + # Update momentum and compute inputs for orthogonalization + U = muon_update_pre_orthogonalize( + G=to_local(G), + M=to_local(M), + momentum=momentum, + nesterov=nesterov, + ) + + if shard_dim is not None and process_group is not None: + # --- Mega-batched sharded FSDP2 path --- + # Pad N to be divisible by world_size + pad_n = (world_size - N % world_size) % world_size + U_work = U + [torch.zeros_like(U[0])] * pad_n if pad_n > 0 else U + N_total = len(U_work) + per_rank = N_total // world_size + + # Stack shards for each target rank into a single tensor + # input_chunks[r] = stacked local shards for matrices assigned to rank r + input_chunks = [ + torch.stack(U_work[r * per_rank : (r + 1) * per_rank]) + for r in range(world_size) + ] # each: [per_rank, *shard_shape] + + output_chunks = [torch.empty_like(c) for c in input_chunks] + work = dist.all_to_all( + output_chunks, input_chunks, group=process_group, async_op=True + ) + yield + work.wait() + + # Cat shards from all ranks along shard_dim (+1 for batch dim) to form full matrices + full_matrices = torch.cat( + output_chunks, dim=shard_dim + 1 + ) # [per_rank, *full_shape] + + # Batched Newton-Schulz on stacked 3D tensor + full_matrices = muon_update_newton_schulz( + full_matrices, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) + + # Split back into shards and all-to-all back + split_chunks = [ + s.contiguous() + for s in torch.tensor_split(full_matrices, world_size, dim=shard_dim + 1) + ] + + recv_chunks = [torch.empty_like(c) for c in split_chunks] + work = dist.all_to_all( + recv_chunks, split_chunks, group=process_group, async_op=True + ) + yield + work.wait() + + # Unstack back to list in original order + U = [recv_chunks[r][i] for r in range(world_size) for i in range(per_rank)] + U = U[:N] # Remove padding + + elif N > 1 and process_group is not None: + # --- Mega-batched non-sharded path --- + # Each GPU orthogonalizes N/world_size matrices instead of 1 + pad_n = (world_size - N % world_size) % world_size + U_work = U + [torch.zeros_like(U[0])] * pad_n if pad_n > 0 else U + N_total = len(U_work) + per_rank = N_total // world_size + + # This GPU processes its assigned chunk + start = device_rank * per_rank + my_matrices = torch.stack(U_work[start : start + per_rank]) + + my_matrices = muon_update_newton_schulz( + my_matrices, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) + + # All-gather: each rank broadcasts its processed chunk + all_chunks = [torch.empty_like(my_matrices) for _ in range(world_size)] + work = dist.all_gather( + all_chunks, my_matrices.contiguous(), group=process_group, async_op=True + ) + yield + work.wait() + + # Unstack back to list in original order + U = [all_chunks[r][i] for r in range(world_size) for i in range(per_rank)] + U = U[:N] + + elif N == 1: + U[0] = muon_update_newton_schulz( + U[0], + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) + else: + # N > 1 but no process_group (single GPU): batch Newton-Schulz + stacked = torch.stack(U) + stacked = muon_update_newton_schulz( + stacked, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) + U = [stacked[i] for i in range(N)] + + # NorMuon normalization using stacked tensors for fewer kernel launches + V_local = to_local(V) + U_stacked = torch.stack(U) + V_stacked = torch.stack(V_local) + U_stacked, V_stacked = normuon_normalization_stacked(U_stacked, V_stacked, muon_beta2) + # Write back updated V (stack creates a copy, so copy back to original buffers) + for i in range(N): + V_local[i].copy_(V_stacked[i]) + U = [U_stacked[i] for i in range(N)] + + # Compute scaled learning rate + if adjust_lr is None: + adjusted_lr = lr + elif adjust_lr == "spectral_norm": + adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape, flatten=flatten) + elif adjust_lr == "rms_norm": + adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape, flatten=flatten) + else: + raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") + + # Update model parameters + muon_update_post_orthogonalize( + X=to_local(X), + U=U, + base_lr=lr, + adjusted_lr=adjusted_lr, + weight_decay=weight_decay, + cautious_wd=cautious_wd, + ) + + +@torch.compile(fullgraph=True) +def normuon_normalization_stacked( + U: Tensor, # [N, rows, cols] + V: Tensor, # [N, rows, 1] (variance neuron buffer) + muon_beta2: Tensor, +) -> Tuple[Tensor, Tensor]: + """ + NorMuon normalization on stacked 3D tensors for minimal kernel launches. + Equivalent to normuon_normalization but operates on a single stacked tensor + instead of a list, reducing per-element kernel overhead. + Returns (normalized_U, updated_V). + """ + V_dtype = V.dtype + U = U.to(dtype=V_dtype) + + # Frobenius norm per matrix: [N, 1, 1] + norm_U = U.norm(p=2, dim=(-2, -1), keepdim=True) + + # Neuron-wise variance: mean of squares along last dim → [N, rows, 1] + neuron_norms = (U * U).mean(dim=-1, keepdim=True) + + # Update variance buffer (EMA) + V = torch.lerp(V, neuron_norms, 1 - muon_beta2) + + # Normalize + denom = V.sqrt() + 1e-8 + normalized_U = U / denom + + # Rescale to preserve Frobenius norm + norm_U_new = normalized_U.norm(p=2, dim=(-2, -1), keepdim=True).clamp(min=1e-8) + normalized_U = normalized_U * (norm_U / norm_U_new) + + return normalized_U, V + + @torch.compile(fullgraph=True) def normuon_normalization( U: List[Tensor],