diff --git a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml index 62d312e..651a2d7 100644 --- a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml +++ b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml @@ -4,8 +4,8 @@ version: v1.0 # From now on, start to align train, data, and model setting with train stage (just finish refactor for dara) train_stage: stage-1 # options: preliminary, pretraining, posttraining; aligned with data setting -name: unpadded3 #used for local dump and wandb log -output_dir: /mnt/pollux/checkpoints/aj +name: mup_test #used for local dump and wandb log +output_dir: /mnt/pollux/checkpoints/ablations dump_dir: '' # No need now steps: 500000 seed: 777 diff --git a/apps/Castor/model.py b/apps/Castor/model.py index 2db2462..a074ea7 100644 --- a/apps/Castor/model.py +++ b/apps/Castor/model.py @@ -15,6 +15,9 @@ from .modules.vae import VideoVAEArgs, create_vae from .modules.vision_encoder import VisionEncoderArgs, create_vision_encoder +from .modules.component import layer_init_kaiming_normal +from mup import MuReadout + logger = logging.getLogger() @@ -45,19 +48,26 @@ def __init__(self, input_dim: int, hidden_dim: int, encoder_dim: int): super(AlignmentProjection, self).__init__() self.proj = nn.Sequential( - nn.Sequential( - nn.Linear(input_dim, hidden_dim), - nn.SiLU(), - nn.Linear(hidden_dim, hidden_dim), - nn.SiLU(), - nn.Linear(hidden_dim, encoder_dim), - ) + MuReadout(input_dim, hidden_dim), # mup + nn.SiLU(), + nn.Linear(hidden_dim, hidden_dim), # mup + nn.SiLU(), + nn.Linear(hidden_dim, encoder_dim), # mup ) + + self.reset_parameters() def forward(self, x): x = self.proj(x) return x + def reset_parameters(self): + # MuReadout has its own initialization + layer_init_kaiming_normal(self.proj[2]) + nn.init.constant_(self.proj[4].weight, 0.) # initialize output weights by zero. + if self.proj[4].bias is not None: + nn.init.constant_(self.proj[4].bias, 0.) + class Castor(nn.Module): VERSION: str = "v1.0" diff --git a/apps/Castor/modules/component.py b/apps/Castor/modules/component.py index b3f55f5..95f206f 100644 --- a/apps/Castor/modules/component.py +++ b/apps/Castor/modules/component.py @@ -15,6 +15,7 @@ from xformers.ops import AttentionBias, fmha from liger_kernel.transformers import LigerSwiGLUMLP, LigerRMSNorm, liger_rotary_pos_emb from types import SimpleNamespace +from mup import MuReadout # fa3 from flash_attn_interface import flash_attn_varlen_func @@ -23,11 +24,10 @@ flex_attention_comp = torch.compile(flex_attention) -class InitStdFactor(Enum): - DISABLED = "disabled" # Init std is divided by 1.0 - GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers) - CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) - DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096 +def layer_init_kaiming_normal(x): + nn.init.kaiming_normal_(x.weight, a=1, mode='fan_in') + if x.bias is not None: + nn.init.constant_(x.bias, 0.) @dataclass @@ -482,26 +482,11 @@ def forward( return output - def reset_parameters(self, init_std=None, factor=1.0): - init_std = init_std or (self.dim ** (-0.5)) - - for w in [self.wq, self.wk, self.wv]: - nn.init.trunc_normal_( - w.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - - nn.init.trunc_normal_( - self.wo.weight, - mean=0.0, - std=init_std / factor, - a=-3 * init_std, - b=3 * init_std, - ) - + def reset_parameters(self): + layer_init_kaiming_normal(self.wq) + layer_init_kaiming_normal(self.wk) + layer_init_kaiming_normal(self.wv) + layer_init_kaiming_normal(self.wo) if isinstance(self.q_norm, RMSNorm): self.q_norm.reset_parameters() if isinstance(self.k_norm, RMSNorm): @@ -540,44 +525,45 @@ def __init__( self.liger_rotary_emb = liger_rotary_emb self.liger_rms_norm = liger_rms_norm self.window_size = window_size + self.qk_norm = qk_norm self.wq = nn.Linear( dim, n_heads * self.head_dim, bias=False, - ) - self.wk = nn.Linear( + ) # mup + self.wk = MuReadout( dim, n_kv_heads * self.head_dim, bias=False, - ) - self.wv = nn.Linear( + ) # mup + self.wv = MuReadout( dim, n_kv_heads * self.head_dim, bias=False, - ) - nn.init.xavier_uniform_(self.wq.weight) - nn.init.xavier_uniform_(self.wk.weight) - nn.init.xavier_uniform_(self.wv.weight) + ) # mup self.wo = nn.Linear( n_heads * self.head_dim, dim, bias=False, - ) - nn.init.xavier_uniform_(self.wo.weight) + ) # mup - if qk_norm: + if self.qk_norm: self.q_norm = RMSNorm(self.head_dim, liger_rms_norm=liger_rms_norm) self.k_norm = RMSNorm(self.head_dim, liger_rms_norm=liger_rms_norm) else: self.q_norm = self.k_norm = nn.Identity() + + self.reset_parameters() def reset_parameters(self, *args, **kwargs): - nn.init.xavier_uniform_(self.wq.weight) - nn.init.xavier_uniform_(self.wk.weight) - nn.init.xavier_uniform_(self.wv.weight) - nn.init.xavier_uniform_(self.wo.weight) + layer_init_kaiming_normal(self.wq) + # MuReadout layers have their own initialization + layer_init_kaiming_normal(self.wo) + if self.qk_norm: + self.q_norm.reset_parameters() + self.k_norm.reset_parameters() # copied from huggingface modeling_llama.py def _upad_input( @@ -780,24 +766,24 @@ def __init__( hidden_size=dim, intermediate_size=hidden_dim, hidden_act="silu", - ) + ) # mup self.ffn = LigerSwiGLUMLP(config) else: self.w1 = nn.Linear( dim, hidden_dim, bias=False, - ) + ) # mup self.w3 = nn.Linear( dim, hidden_dim, bias=False, - ) + ) # mup self.w2 = nn.Linear( hidden_dim, dim, bias=False, - ) + ) # mup def forward(self, x: torch.Tensor) -> torch.Tensor: # B S D @@ -809,45 +795,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = self.w2(F.silu(x1) * x3) return output - def reset_parameters(self, init_std=None, factor=1.0): - in_init_std = init_std or (self.dim ** (-0.5)) - out_init_std = init_std or (self.hidden_dim ** (-0.5)) - out_init_std = out_init_std / factor + def reset_parameters(self): if self.liger_ffn: - # Initialize LigerSwiGLUMLP parameters - # gate_proj and up_proj correspond to w1 and w3 - for w in [self.ffn.gate_proj, self.ffn.up_proj]: - nn.init.trunc_normal_( - w.weight, - mean=0.0, - std=in_init_std, - a=-3 * in_init_std, - b=3 * in_init_std, - ) - # down_proj corresponds to w2 - nn.init.trunc_normal_( - self.ffn.down_proj.weight, - mean=0.0, - std=out_init_std, - a=-3 * out_init_std, - b=3 * out_init_std, - ) + layer_init_kaiming_normal(self.ffn.gate_proj) + layer_init_kaiming_normal(self.ffn.up_proj) + layer_init_kaiming_normal(self.ffn.down_proj) else: - for w in [self.w1, self.w3]: - nn.init.trunc_normal_( - w.weight, - mean=0.0, - std=in_init_std, - a=-3 * in_init_std, - b=3 * in_init_std, - ) - nn.init.trunc_normal_( - self.w2.weight, - mean=0.0, - std=out_init_std, - a=-3 * out_init_std, - b=3 * out_init_std, - ) + layer_init_kaiming_normal(self.w1) + layer_init_kaiming_normal(self.w3) + layer_init_kaiming_normal(self.w2) class TransformerBlock(nn.Module): @@ -901,11 +857,11 @@ def forward( out = h + self.feed_forward(self.ffn_norm(h)) return out - def init_weights(self, init_std=None, factor=1.0): - self.attention.reset_parameters(init_std, factor) + def init_weights(self): + self.attention.reset_parameters() self.attention_norm.reset_parameters() - self.feed_forward.reset_parameters(init_std, factor) + self.feed_forward.reset_parameters() self.ffn_norm.reset_parameters() @@ -1014,16 +970,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = self.w1(F.silu(x)) return output - def reset_parameters(self, init_std=None, factor=1.0): - init_std = init_std or (self.in_dim ** (-0.5)) - init_std = init_std / factor - nn.init.trunc_normal_( - self.w1.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) + def reset_parameters(self): + layer_init_kaiming_normal(self.w1) class TimestepEmbedder(nn.Module): @@ -1033,10 +981,14 @@ class TimestepEmbedder(nn.Module): def __init__(self, hidden_size: int, time_embedding_size: int = 256): super().__init__() - self.w1 = nn.Linear( - time_embedding_size, - hidden_size, - bias=True, + self.mlp = nn.Sequential( + nn.Linear( + time_embedding_size, + hidden_size, + bias=True, + ), # mup: input weights + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), # mup: hidden weights ) self.w2 = nn.Linear( hidden_size, @@ -1046,6 +998,8 @@ def __init__(self, hidden_size: int, time_embedding_size: int = 256): self.hidden_size = hidden_size self.time_embedding_size = time_embedding_size + self.reset_parameters() + @staticmethod def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000): # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py @@ -1065,30 +1019,12 @@ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000): def forward(self, t: torch.Tensor) -> torch.Tensor: t_freq = self.timestep_embedding(t, self.time_embedding_size) - t_emb = self.w1(t_freq.to(self.w1.weight.dtype)) - t_emb = self.w2(F.silu(t_emb)) + t_emb = self.mlp(t_freq.to(self.w1.weight.dtype)) return t_emb - def reset_parameters(self, init_std=None, factor=1.0): - in_init_std = init_std or (self.time_embedding_size ** (-0.5)) - out_init_std = init_std or (self.hidden_size ** (-0.5)) - out_init_std = out_init_std / factor - nn.init.trunc_normal_( - self.w1.weight, - mean=0.0, - std=in_init_std, - a=-3 * in_init_std, - b=3 * in_init_std, - ) - nn.init.trunc_normal_( - self.w2.weight, - mean=0.0, - std=out_init_std, - a=-3 * out_init_std, - b=3 * out_init_std, - ) - nn.init.normal_(self.w1.bias, std=0.02) - nn.init.normal_(self.w2.bias, std=0.02) + def reset_parameters(self): + layer_init_kaiming_normal(self.mlp[0]) + layer_init_kaiming_normal(self.mlp[2]) class ImageEmbedder(nn.Module): @@ -1102,21 +1038,12 @@ def __init__(self, in_dim, out_dim): in_features=in_dim, out_features=out_dim, bias=True, - ) + ) # mup: input weights self.in_dim = in_dim def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w1(x) - def reset_parameters(self, init_std=None, factor=1.0): - init_std = init_std or (self.in_dim ** (-0.5)) - init_std = init_std / factor - nn.init.trunc_normal_( - self.w1.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - - nn.init.normal_(self.w1.bias, std=0.02) + def reset_parameters(self): + layer_init_kaiming_normal(self.w1) + \ No newline at end of file diff --git a/apps/Castor/modules/transformer.py b/apps/Castor/modules/transformer.py index c3b44df..9dc8594 100644 --- a/apps/Castor/modules/transformer.py +++ b/apps/Castor/modules/transformer.py @@ -13,12 +13,14 @@ from torch.nn.attention.flex_attention import BlockMask from .component import (AdaLN, BaseTransformerArgs, FeedForward, - FlashAttention, ImageEmbedder, InitStdFactor, RMSNorm, + FlashAttention, ImageEmbedder, RMSNorm, RotaryEmbedding1D, RotaryEmbedding2D, TimestepEmbedder, - create_causal_mask, modulate_and_gate, nearest_multiple_of_8, modulate_and_gate_unpadded) + create_causal_mask, layer_init_kaiming_normal, modulate_and_gate, nearest_multiple_of_8, modulate_and_gate_unpadded) from flash_attn.bert_padding import unpad_input, pad_input import copy +from mup import MuReadout + logger = logging.getLogger() @@ -150,10 +152,10 @@ def forward( return h - def init_weights(self, init_std=None, factor=1.0): - self.attention.reset_parameters(init_std, factor) + def init_weights(self): + self.attention.reset_parameters() self.attention_norm.reset_parameters() - self.feed_forward.reset_parameters(init_std, factor) + self.feed_forward.reset_parameters() self.ffn_norm.reset_parameters() if not self.shared_adaLN: self.adaLN_modulation.reset_parameters() @@ -171,8 +173,6 @@ class BaseDiffusionTransformer(nn.Module): def __init__(self, args: TransformerArgs): super().__init__() self.dim = args.dim - self.init_base_std = args.init_base_std - self.init_std_factor = InitStdFactor(args.init_std_factor) self.gen_seqlen = args.gen_seqlen self.layers = nn.ModuleList() self.shared_adaLN = args.shared_adaLN @@ -223,15 +223,8 @@ def reset_parameters(self): def init_weights(self, pre_trained_path: Optional[str] = None): self.reset_parameters() - for depth, layer in enumerate(self.layers): - factor = { - InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, - InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, - InitStdFactor.DIM_RATIO: self.dim / 4096, - InitStdFactor.DISABLED: 1.0, - }[self.init_std_factor] - - layer.init_weights(self.init_base_std, factor) + for layer in self.layers: + layer.init_weights() if pre_trained_path: assert os.path.exists(pre_trained_path) ckpt_state_dict = torch.load(pre_trained_path, map_location="cpu") @@ -274,10 +267,10 @@ def __init__(self, args: TransformerArgs): in_dim=self.patch_size * self.patch_size * args.in_channels, out_dim=args.dim, ) - self.img_output = nn.Linear( + self.img_output = MuReadout( args.dim, self.patch_size * self.patch_size * args.out_channels, - bias=False, + bias=True, ) self.rope_embeddings_image = RotaryEmbedding2D( theta=args.rope_theta, @@ -559,30 +552,23 @@ def unpatchify_image( img_features = img_features.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3) return img_features - def reset_parameters(self, init_std=None): + def reset_parameters(self): # Either use fixed base std or sqrt model dim super().reset_parameters() self.rope_embeddings_image.reset_parameters() self.rope_embeddings_conditions.reset_parameters() - init_std = init_std or (self.dim ** (-0.5)) self.norm.reset_parameters() self.cond_norm.reset_parameters() self.tmb_embed.reset_parameters() self.img_embed.reset_parameters() - nn.init.trunc_normal_( - self.img_output.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - nn.init.trunc_normal_( - self.cond_proj.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) + + # muReadout has its own initialization + # nn.init.constant_(self.img_output.weight, 0.) # initialize output weights by zero. + # if self.img_output.bias is not None: + # nn.init.constant_(self.img_output.bias, 0.) + + layer_init_kaiming_normal(self.cond_proj) + nn.init.normal_(self.negative_token, std=0.02) if self.shared_adaLN: self.adaLN_modulation.reset_parameters() diff --git a/apps/Castor/train.py b/apps/Castor/train.py index 779ddad..7ef482a 100644 --- a/apps/Castor/train.py +++ b/apps/Castor/train.py @@ -1,6 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +import copy import gc import logging import os @@ -54,6 +55,8 @@ from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler +from mup import set_base_shapes + logger = logging.getLogger() @@ -227,6 +230,25 @@ def every_n_steps(train_state, freq, acc_step=None, acc_freq=None): return test +def mup_set_base_shapes(model, args): + base_args = copy.deepcopy(args) + base_args.diffusion_model.dim = 288 + base_args.diffusion_model.n_heads = 4 + base_args.diffusion_model.n_kv_heads = 2 # Scale this too + base_model = Castor(base_args) + + delta_args = copy.deepcopy(args) + delta_args.diffusion_model.dim = 360 + delta_args.diffusion_model.n_heads = 6 + delta_args.diffusion_model.n_kv_heads = 3 # Scale this too + delta_model = Castor(delta_args) + + set_base_shapes(model, base_model, delta=delta_model) + + del base_model, delta_model + gc.collect() + + def train(args: TrainArgs): with ExitStack() as context_stack: validate_train_args( @@ -261,6 +283,8 @@ def train(args: TrainArgs): model = Castor(args.model) logger.info("Model is built !") + mup_set_base_shapes(model, args.model) + model_param_count = get_num_params(model) flops_meter = FlopsMeter(args.model, model) diff --git a/apps/main/utils/mongodb_data_load.py b/apps/main/utils/mongodb_data_load.py index 2f29a86..16be9b1 100644 --- a/apps/main/utils/mongodb_data_load.py +++ b/apps/main/utils/mongodb_data_load.py @@ -132,8 +132,8 @@ def set_local_partition(self): if partition_key % self.num_shards == self.shard_idx: data.append(item) # # Note: used for debugging - # if len(data) > 10000: - # break + if len(data) > 1400000: + break self.data = pd.DataFrame(data).reset_index() end_time = time.time() # Record the end time diff --git a/lingua/optim.py b/lingua/optim.py index 1c7ba9b..b24a6e7 100644 --- a/lingua/optim.py +++ b/lingua/optim.py @@ -7,6 +7,7 @@ import logging from torch import nn from torch.optim import AdamW, lr_scheduler +from mup import MuAdamW logger = logging.getLogger() @@ -145,7 +146,7 @@ def build_lr_fn(args: OptimArgs, n_steps: int): def build_optimizer(model: nn.Module, args: OptimArgs, n_steps: int): logger.info("Starting build of optimizer...") - optimizer = AdamW( + optimizer = MuAdamW( (param for param in model.parameters() if param.requires_grad), lr=args.lr, betas=(args.beta1, args.beta2), diff --git a/requirements.txt b/requirements.txt index 4b25c42..9bb589a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,3 +27,4 @@ transformers liger-kernel timm open-clip-torch +mup \ No newline at end of file diff --git a/test_mup_fix.py b/test_mup_fix.py new file mode 100644 index 0000000..ce37cea --- /dev/null +++ b/test_mup_fix.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 + +import sys +import os +sys.path.append('/home/akito/Pollux') + +import torch +import copy +import gc +from apps.Castor.model import Castor, ModelArgs, build_2B_Castor +from apps.Castor.modules.transformer import TransformerArgs +from mup import set_base_shapes + +def test_mup_setup(): + print("Testing MuP setup...") + + # Build the main model + print("Building main model...") + model_args = build_2B_Castor().args + print(f"Main model args: {model_args}") + model = Castor(model_args) + + print(f"Main model config:") + print(f" dim: {model_args.diffusion_model.dim}") + print(f" n_heads: {model_args.diffusion_model.n_heads}") + print(f" n_kv_heads: {model_args.diffusion_model.n_kv_heads}") + + # Create base model (matching train.py) + base_args = copy.deepcopy(model_args) + base_args.diffusion_model.dim = 288 + base_args.diffusion_model.n_heads = 4 + base_args.diffusion_model.n_kv_heads = 2 + base_model = Castor(base_args) + + print(f"Base model config:") + print(f" dim: {base_args.diffusion_model.dim}") + print(f" n_heads: {base_args.diffusion_model.n_heads}") + print(f" n_kv_heads: {base_args.diffusion_model.n_kv_heads}") + + # Create delta model (matching train.py) + delta_args = copy.deepcopy(model_args) + delta_args.diffusion_model.dim = 360 + delta_args.diffusion_model.n_heads = 6 # Updated to match your change + delta_args.diffusion_model.n_kv_heads = 3 + delta_model = Castor(delta_args) + + print(f"Delta model config:") + print(f" dim: {delta_args.diffusion_model.dim}") + print(f" n_heads: {delta_args.diffusion_model.n_heads}") + print(f" n_kv_heads: {delta_args.diffusion_model.n_kv_heads}") + + # Debug: Check parameters before MuP setup + print("\nChecking parameters that might need MuP treatment...") + for name, param in model.named_parameters(): + if param.shape == torch.Size([2048, 2048]): + print(f"Found 2048x2048 parameter: {name}") + if len(param.shape) == 2 and (2048 in param.shape): + print(f"Parameter with 2048 dimension: {name} - shape: {param.shape}") + + # Test MuP setup + try: + print("\nSetting base shapes...") + set_base_shapes(model, base_model, delta=delta_model) + print("āœ… MuP setup successful!") + + # Test a forward pass + print("\nTesting forward pass...") + model.eval() + + # Create dummy batch + batch = { + 'latent_code': torch.randn(2, 16, 32, 32), # B, C, H, W + 'text_embedding': torch.randn(2, 128, 512), # B, seq_len, dim + 'attention_mask': torch.ones(2, 128, dtype=torch.bool) + } + + with torch.no_grad(): + output = model(batch) + print(f"āœ… Forward pass successful! Loss: {output.loss.item():.4f}") + + except Exception as e: + print(f"āŒ Error: {e}") + print("\nDebugging: Checking which parameters don't have infshape...") + for name, param in model.named_parameters(): + if not hasattr(param, 'infshape'): + print(f"Missing infshape: {name} - shape: {param.shape}") + return False + + finally: + # Cleanup + del base_model, delta_model + gc.collect() + + return True + +if __name__ == "__main__": + success = test_mup_setup() + if success: + print("\nšŸŽ‰ All tests passed! The MuP fix is working correctly.") + else: + print("\nšŸ’„ Tests failed. There may still be issues with the MuP setup.") \ No newline at end of file