From b88ef69f0094016e10a0a15c9b106d6684d288d3 Mon Sep 17 00:00:00 2001 From: sippycoder Date: Sun, 25 May 2025 18:36:30 +0000 Subject: [PATCH 01/13] mup parameterization and training --- ...et_256_Castor_flux_qwen_fixed_siglip2.yaml | 4 +- apps/Castor/model.py | 23 ++- apps/Castor/modules/component.py | 189 ++++++------------ apps/Castor/modules/transformer.py | 53 ++--- lingua/optim.py | 3 +- requirements.txt | 1 + 6 files changed, 98 insertions(+), 175 deletions(-) diff --git a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml index 62d312e0..651a2d78 100644 --- a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml +++ b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml @@ -4,8 +4,8 @@ version: v1.0 # From now on, start to align train, data, and model setting with train stage (just finish refactor for dara) train_stage: stage-1 # options: preliminary, pretraining, posttraining; aligned with data setting -name: unpadded3 #used for local dump and wandb log -output_dir: /mnt/pollux/checkpoints/aj +name: mup_test #used for local dump and wandb log +output_dir: /mnt/pollux/checkpoints/ablations dump_dir: '' # No need now steps: 500000 seed: 777 diff --git a/apps/Castor/model.py b/apps/Castor/model.py index 2db24626..16a6b761 100644 --- a/apps/Castor/model.py +++ b/apps/Castor/model.py @@ -15,6 +15,8 @@ from .modules.vae import VideoVAEArgs, create_vae from .modules.vision_encoder import VisionEncoderArgs, create_vision_encoder +from .modules.component import layer_init_kaiming_normal + logger = logging.getLogger() @@ -45,19 +47,26 @@ def __init__(self, input_dim: int, hidden_dim: int, encoder_dim: int): super(AlignmentProjection, self).__init__() self.proj = nn.Sequential( - nn.Sequential( - nn.Linear(input_dim, hidden_dim), - nn.SiLU(), - nn.Linear(hidden_dim, hidden_dim), - nn.SiLU(), - nn.Linear(hidden_dim, encoder_dim), - ) + nn.Linear(input_dim, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, encoder_dim), ) + + self.proj.reset_parameters() def forward(self, x): x = self.proj(x) return x + def reset_parameters(self): + layer_init_kaiming_normal(self.proj[0]) + layer_init_kaiming_normal(self.proj[2]) + nn.init.constant_(self.proj[4].weight, 0.) # initialize output weights by zero. + if self.proj[4].bias is not None: + nn.init.constant_(self.proj[4].bias, 0.) + class Castor(nn.Module): VERSION: str = "v1.0" diff --git a/apps/Castor/modules/component.py b/apps/Castor/modules/component.py index b3f55f5c..7ec8082a 100644 --- a/apps/Castor/modules/component.py +++ b/apps/Castor/modules/component.py @@ -23,11 +23,10 @@ flex_attention_comp = torch.compile(flex_attention) -class InitStdFactor(Enum): - DISABLED = "disabled" # Init std is divided by 1.0 - GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers) - CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) - DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096 +def layer_init_kaiming_normal(x): + nn.init.kaiming_normal_(x.weight, a=1, mode='fan_in') + if x.bias is not None: + nn.init.constant_(x.bias, 0.) @dataclass @@ -482,26 +481,11 @@ def forward( return output - def reset_parameters(self, init_std=None, factor=1.0): - init_std = init_std or (self.dim ** (-0.5)) - - for w in [self.wq, self.wk, self.wv]: - nn.init.trunc_normal_( - w.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - - nn.init.trunc_normal_( - self.wo.weight, - mean=0.0, - std=init_std / factor, - a=-3 * init_std, - b=3 * init_std, - ) - + def reset_parameters(self): + layer_init_kaiming_normal(self.wq) + layer_init_kaiming_normal(self.wk) + layer_init_kaiming_normal(self.wv) + layer_init_kaiming_normal(self.wo) if isinstance(self.q_norm, RMSNorm): self.q_norm.reset_parameters() if isinstance(self.k_norm, RMSNorm): @@ -540,44 +524,46 @@ def __init__( self.liger_rotary_emb = liger_rotary_emb self.liger_rms_norm = liger_rms_norm self.window_size = window_size + self.qk_norm = qk_norm self.wq = nn.Linear( dim, n_heads * self.head_dim, bias=False, - ) + ) # mup self.wk = nn.Linear( dim, n_kv_heads * self.head_dim, bias=False, - ) + ) # mup self.wv = nn.Linear( dim, n_kv_heads * self.head_dim, bias=False, - ) - nn.init.xavier_uniform_(self.wq.weight) - nn.init.xavier_uniform_(self.wk.weight) - nn.init.xavier_uniform_(self.wv.weight) + ) # mup self.wo = nn.Linear( n_heads * self.head_dim, dim, bias=False, - ) - nn.init.xavier_uniform_(self.wo.weight) + ) # mup - if qk_norm: + if self.qk_norm: self.q_norm = RMSNorm(self.head_dim, liger_rms_norm=liger_rms_norm) self.k_norm = RMSNorm(self.head_dim, liger_rms_norm=liger_rms_norm) else: self.q_norm = self.k_norm = nn.Identity() + + self.reset_parameters() def reset_parameters(self, *args, **kwargs): - nn.init.xavier_uniform_(self.wq.weight) - nn.init.xavier_uniform_(self.wk.weight) - nn.init.xavier_uniform_(self.wv.weight) - nn.init.xavier_uniform_(self.wo.weight) + layer_init_kaiming_normal(self.wq) + layer_init_kaiming_normal(self.wk) + layer_init_kaiming_normal(self.wv) + layer_init_kaiming_normal(self.wo) + if self.q_norm: + self.q_norm.reset_parameters() + self.k_norm.reset_parameters() # copied from huggingface modeling_llama.py def _upad_input( @@ -780,24 +766,24 @@ def __init__( hidden_size=dim, intermediate_size=hidden_dim, hidden_act="silu", - ) + ) # mup self.ffn = LigerSwiGLUMLP(config) else: self.w1 = nn.Linear( dim, hidden_dim, bias=False, - ) + ) # mup self.w3 = nn.Linear( dim, hidden_dim, bias=False, - ) + ) # mup self.w2 = nn.Linear( hidden_dim, dim, bias=False, - ) + ) # mup def forward(self, x: torch.Tensor) -> torch.Tensor: # B S D @@ -809,45 +795,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = self.w2(F.silu(x1) * x3) return output - def reset_parameters(self, init_std=None, factor=1.0): - in_init_std = init_std or (self.dim ** (-0.5)) - out_init_std = init_std or (self.hidden_dim ** (-0.5)) - out_init_std = out_init_std / factor + def reset_parameters(self): if self.liger_ffn: - # Initialize LigerSwiGLUMLP parameters - # gate_proj and up_proj correspond to w1 and w3 - for w in [self.ffn.gate_proj, self.ffn.up_proj]: - nn.init.trunc_normal_( - w.weight, - mean=0.0, - std=in_init_std, - a=-3 * in_init_std, - b=3 * in_init_std, - ) - # down_proj corresponds to w2 - nn.init.trunc_normal_( - self.ffn.down_proj.weight, - mean=0.0, - std=out_init_std, - a=-3 * out_init_std, - b=3 * out_init_std, - ) + layer_init_kaiming_normal(self.ffn.gate_proj) + layer_init_kaiming_normal(self.ffn.up_proj) + layer_init_kaiming_normal(self.ffn.down_proj) else: - for w in [self.w1, self.w3]: - nn.init.trunc_normal_( - w.weight, - mean=0.0, - std=in_init_std, - a=-3 * in_init_std, - b=3 * in_init_std, - ) - nn.init.trunc_normal_( - self.w2.weight, - mean=0.0, - std=out_init_std, - a=-3 * out_init_std, - b=3 * out_init_std, - ) + layer_init_kaiming_normal(self.w1) + layer_init_kaiming_normal(self.w3) + layer_init_kaiming_normal(self.w2) class TransformerBlock(nn.Module): @@ -901,11 +857,11 @@ def forward( out = h + self.feed_forward(self.ffn_norm(h)) return out - def init_weights(self, init_std=None, factor=1.0): - self.attention.reset_parameters(init_std, factor) + def init_weights(self): + self.attention.reset_parameters() self.attention_norm.reset_parameters() - self.feed_forward.reset_parameters(init_std, factor) + self.feed_forward.reset_parameters() self.ffn_norm.reset_parameters() @@ -1014,16 +970,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = self.w1(F.silu(x)) return output - def reset_parameters(self, init_std=None, factor=1.0): - init_std = init_std or (self.in_dim ** (-0.5)) - init_std = init_std / factor - nn.init.trunc_normal_( - self.w1.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) + def reset_parameters(self): + layer_init_kaiming_normal(self.w1) class TimestepEmbedder(nn.Module): @@ -1033,10 +981,14 @@ class TimestepEmbedder(nn.Module): def __init__(self, hidden_size: int, time_embedding_size: int = 256): super().__init__() - self.w1 = nn.Linear( - time_embedding_size, - hidden_size, - bias=True, + self.mlp = nn.Sequential( + nn.Linear( + time_embedding_size, + hidden_size, + bias=True, + ), # mup: input weights + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), # mup: hidden weights ) self.w2 = nn.Linear( hidden_size, @@ -1046,6 +998,8 @@ def __init__(self, hidden_size: int, time_embedding_size: int = 256): self.hidden_size = hidden_size self.time_embedding_size = time_embedding_size + self.reset_parameters() + @staticmethod def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000): # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py @@ -1065,30 +1019,12 @@ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000): def forward(self, t: torch.Tensor) -> torch.Tensor: t_freq = self.timestep_embedding(t, self.time_embedding_size) - t_emb = self.w1(t_freq.to(self.w1.weight.dtype)) - t_emb = self.w2(F.silu(t_emb)) + t_emb = self.mlp(t_freq.to(self.w1.weight.dtype)) return t_emb - def reset_parameters(self, init_std=None, factor=1.0): - in_init_std = init_std or (self.time_embedding_size ** (-0.5)) - out_init_std = init_std or (self.hidden_size ** (-0.5)) - out_init_std = out_init_std / factor - nn.init.trunc_normal_( - self.w1.weight, - mean=0.0, - std=in_init_std, - a=-3 * in_init_std, - b=3 * in_init_std, - ) - nn.init.trunc_normal_( - self.w2.weight, - mean=0.0, - std=out_init_std, - a=-3 * out_init_std, - b=3 * out_init_std, - ) - nn.init.normal_(self.w1.bias, std=0.02) - nn.init.normal_(self.w2.bias, std=0.02) + def reset_parameters(self): + layer_init_kaiming_normal(self.mlp[0]) + layer_init_kaiming_normal(self.mlp[2]) class ImageEmbedder(nn.Module): @@ -1102,21 +1038,12 @@ def __init__(self, in_dim, out_dim): in_features=in_dim, out_features=out_dim, bias=True, - ) + ) # mup: input weights self.in_dim = in_dim def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w1(x) - def reset_parameters(self, init_std=None, factor=1.0): - init_std = init_std or (self.in_dim ** (-0.5)) - init_std = init_std / factor - nn.init.trunc_normal_( - self.w1.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - - nn.init.normal_(self.w1.bias, std=0.02) + def reset_parameters(self): + layer_init_kaiming_normal(self.w1) + \ No newline at end of file diff --git a/apps/Castor/modules/transformer.py b/apps/Castor/modules/transformer.py index c3b44dfd..3932eb9f 100644 --- a/apps/Castor/modules/transformer.py +++ b/apps/Castor/modules/transformer.py @@ -13,12 +13,14 @@ from torch.nn.attention.flex_attention import BlockMask from .component import (AdaLN, BaseTransformerArgs, FeedForward, - FlashAttention, ImageEmbedder, InitStdFactor, RMSNorm, + FlashAttention, ImageEmbedder, RMSNorm, RotaryEmbedding1D, RotaryEmbedding2D, TimestepEmbedder, - create_causal_mask, modulate_and_gate, nearest_multiple_of_8, modulate_and_gate_unpadded) + create_causal_mask, layer_init_kaiming_normal, modulate_and_gate, nearest_multiple_of_8, modulate_and_gate_unpadded) from flash_attn.bert_padding import unpad_input, pad_input import copy +from mup import MuReadout + logger = logging.getLogger() @@ -150,10 +152,10 @@ def forward( return h - def init_weights(self, init_std=None, factor=1.0): - self.attention.reset_parameters(init_std, factor) + def init_weights(self): + self.attention.reset_parameters() self.attention_norm.reset_parameters() - self.feed_forward.reset_parameters(init_std, factor) + self.feed_forward.reset_parameters() self.ffn_norm.reset_parameters() if not self.shared_adaLN: self.adaLN_modulation.reset_parameters() @@ -171,8 +173,6 @@ class BaseDiffusionTransformer(nn.Module): def __init__(self, args: TransformerArgs): super().__init__() self.dim = args.dim - self.init_base_std = args.init_base_std - self.init_std_factor = InitStdFactor(args.init_std_factor) self.gen_seqlen = args.gen_seqlen self.layers = nn.ModuleList() self.shared_adaLN = args.shared_adaLN @@ -223,15 +223,8 @@ def reset_parameters(self): def init_weights(self, pre_trained_path: Optional[str] = None): self.reset_parameters() - for depth, layer in enumerate(self.layers): - factor = { - InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, - InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, - InitStdFactor.DIM_RATIO: self.dim / 4096, - InitStdFactor.DISABLED: 1.0, - }[self.init_std_factor] - - layer.init_weights(self.init_base_std, factor) + for layer in self.layers: + layer.init_weights() if pre_trained_path: assert os.path.exists(pre_trained_path) ckpt_state_dict = torch.load(pre_trained_path, map_location="cpu") @@ -274,10 +267,10 @@ def __init__(self, args: TransformerArgs): in_dim=self.patch_size * self.patch_size * args.in_channels, out_dim=args.dim, ) - self.img_output = nn.Linear( + self.img_output = MuReadout( args.dim, self.patch_size * self.patch_size * args.out_channels, - bias=False, + bias=True, ) self.rope_embeddings_image = RotaryEmbedding2D( theta=args.rope_theta, @@ -559,30 +552,22 @@ def unpatchify_image( img_features = img_features.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3) return img_features - def reset_parameters(self, init_std=None): + def reset_parameters(self): # Either use fixed base std or sqrt model dim super().reset_parameters() self.rope_embeddings_image.reset_parameters() self.rope_embeddings_conditions.reset_parameters() - init_std = init_std or (self.dim ** (-0.5)) self.norm.reset_parameters() self.cond_norm.reset_parameters() self.tmb_embed.reset_parameters() self.img_embed.reset_parameters() - nn.init.trunc_normal_( - self.img_output.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - nn.init.trunc_normal_( - self.cond_proj.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) + + nn.init.constant_(self.img_output.weight, 0.) # initialize output weights by zero. + if self.img_output.bias is not None: + nn.init.constant_(self.img_output.bias, 0.) + + layer_init_kaiming_normal(self.cond_proj) + nn.init.normal_(self.negative_token, std=0.02) if self.shared_adaLN: self.adaLN_modulation.reset_parameters() diff --git a/lingua/optim.py b/lingua/optim.py index 1c7ba9ba..b24a6e75 100644 --- a/lingua/optim.py +++ b/lingua/optim.py @@ -7,6 +7,7 @@ import logging from torch import nn from torch.optim import AdamW, lr_scheduler +from mup import MuAdamW logger = logging.getLogger() @@ -145,7 +146,7 @@ def build_lr_fn(args: OptimArgs, n_steps: int): def build_optimizer(model: nn.Module, args: OptimArgs, n_steps: int): logger.info("Starting build of optimizer...") - optimizer = AdamW( + optimizer = MuAdamW( (param for param in model.parameters() if param.requires_grad), lr=args.lr, betas=(args.beta1, args.beta2), diff --git a/requirements.txt b/requirements.txt index 4b25c428..9bb589a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,3 +27,4 @@ transformers liger-kernel timm open-clip-torch +mup \ No newline at end of file From 9aa75e5435ed77de412be67a0a1e7d3c177b24ea Mon Sep 17 00:00:00 2001 From: sippycoder Date: Sun, 25 May 2025 18:37:45 +0000 Subject: [PATCH 02/13] mup comment --- apps/Castor/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/apps/Castor/model.py b/apps/Castor/model.py index 16a6b761..2f0adf0f 100644 --- a/apps/Castor/model.py +++ b/apps/Castor/model.py @@ -47,11 +47,11 @@ def __init__(self, input_dim: int, hidden_dim: int, encoder_dim: int): super(AlignmentProjection, self).__init__() self.proj = nn.Sequential( - nn.Linear(input_dim, hidden_dim), + nn.Linear(input_dim, hidden_dim), # mup nn.SiLU(), - nn.Linear(hidden_dim, hidden_dim), + nn.Linear(hidden_dim, hidden_dim), # mup nn.SiLU(), - nn.Linear(hidden_dim, encoder_dim), + nn.Linear(hidden_dim, encoder_dim), # mup ) self.proj.reset_parameters() From 3cec050c22248ff30fc1ab576b24a27e6c4b2f62 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Sun, 25 May 2025 18:52:35 +0000 Subject: [PATCH 03/13] Minor fix --- apps/Castor/model.py | 2 +- apps/Castor/modules/component.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/Castor/model.py b/apps/Castor/model.py index 2f0adf0f..7c969b74 100644 --- a/apps/Castor/model.py +++ b/apps/Castor/model.py @@ -54,7 +54,7 @@ def __init__(self, input_dim: int, hidden_dim: int, encoder_dim: int): nn.Linear(hidden_dim, encoder_dim), # mup ) - self.proj.reset_parameters() + self.reset_parameters() def forward(self, x): x = self.proj(x) diff --git a/apps/Castor/modules/component.py b/apps/Castor/modules/component.py index 7ec8082a..18d45015 100644 --- a/apps/Castor/modules/component.py +++ b/apps/Castor/modules/component.py @@ -561,7 +561,7 @@ def reset_parameters(self, *args, **kwargs): layer_init_kaiming_normal(self.wk) layer_init_kaiming_normal(self.wv) layer_init_kaiming_normal(self.wo) - if self.q_norm: + if self.qk_norm: self.q_norm.reset_parameters() self.k_norm.reset_parameters() From d47d86443c12497891637cfaabdfbb9b5c08203c Mon Sep 17 00:00:00 2001 From: sippycoder Date: Sun, 25 May 2025 19:43:54 +0000 Subject: [PATCH 04/13] mup: set base shapes --- apps/Castor/train.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/apps/Castor/train.py b/apps/Castor/train.py index 779ddad2..7e255998 100644 --- a/apps/Castor/train.py +++ b/apps/Castor/train.py @@ -1,6 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +import copy import gc import logging import os @@ -54,6 +55,8 @@ from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler +from mup import set_base_shapes + logger = logging.getLogger() @@ -227,6 +230,23 @@ def every_n_steps(train_state, freq, acc_step=None, acc_freq=None): return test +def mup_set_base_shapes(model, args): + base_args = copy.deepcopy(args.model) + base_args.diffusion_model.dim = 288 + base_args.diffusion_model.n_heads = 4 + base_model = Castor(base_args) + + delta_args = copy.deepcopy(args.model) + delta_args.diffusion_model.dim = 360 + delta_args.diffusion_model.n_heads = 5 + delta_model = Castor(delta_args) + + set_base_shapes(model, base_model, delta=delta_model) + + del base_model, delta_model + gc.collect() + + def train(args: TrainArgs): with ExitStack() as context_stack: validate_train_args( @@ -261,6 +281,8 @@ def train(args: TrainArgs): model = Castor(args.model) logger.info("Model is built !") + mup_set_base_shapes(model, args.model) + model_param_count = get_num_params(model) flops_meter = FlopsMeter(args.model, model) From d7d63b0d16554f2a46c86952f3a36ccbb6f502da Mon Sep 17 00:00:00 2001 From: sippycoder Date: Sun, 25 May 2025 19:44:10 +0000 Subject: [PATCH 05/13] output layer correction --- apps/Castor/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/apps/Castor/model.py b/apps/Castor/model.py index 7c969b74..53c2da56 100644 --- a/apps/Castor/model.py +++ b/apps/Castor/model.py @@ -16,6 +16,7 @@ from .modules.vision_encoder import VisionEncoderArgs, create_vision_encoder from .modules.component import layer_init_kaiming_normal +from mup import MuReadout logger = logging.getLogger() @@ -51,7 +52,7 @@ def __init__(self, input_dim: int, hidden_dim: int, encoder_dim: int): nn.SiLU(), nn.Linear(hidden_dim, hidden_dim), # mup nn.SiLU(), - nn.Linear(hidden_dim, encoder_dim), # mup + MuReadout(hidden_dim, encoder_dim), # mup ) self.reset_parameters() From 1668c2cb2ab7ce2d45ae093578a55b64e88e23d9 Mon Sep 17 00:00:00 2001 From: sippycoder Date: Sun, 25 May 2025 19:47:00 +0000 Subject: [PATCH 06/13] bug fix --- apps/Castor/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/Castor/train.py b/apps/Castor/train.py index 7e255998..df9d9a71 100644 --- a/apps/Castor/train.py +++ b/apps/Castor/train.py @@ -231,12 +231,12 @@ def every_n_steps(train_state, freq, acc_step=None, acc_freq=None): def mup_set_base_shapes(model, args): - base_args = copy.deepcopy(args.model) + base_args = copy.deepcopy(args) base_args.diffusion_model.dim = 288 base_args.diffusion_model.n_heads = 4 base_model = Castor(base_args) - delta_args = copy.deepcopy(args.model) + delta_args = copy.deepcopy(args) delta_args.diffusion_model.dim = 360 delta_args.diffusion_model.n_heads = 5 delta_model = Castor(delta_args) From c654fe9694cf088203f7af50cf1a39a0bebd9a95 Mon Sep 17 00:00:00 2001 From: sippycoder Date: Sun, 25 May 2025 19:47:48 +0000 Subject: [PATCH 07/13] for debugging --- apps/main/utils/mongodb_data_load.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/main/utils/mongodb_data_load.py b/apps/main/utils/mongodb_data_load.py index 2f29a869..16be9b18 100644 --- a/apps/main/utils/mongodb_data_load.py +++ b/apps/main/utils/mongodb_data_load.py @@ -132,8 +132,8 @@ def set_local_partition(self): if partition_key % self.num_shards == self.shard_idx: data.append(item) # # Note: used for debugging - # if len(data) > 10000: - # break + if len(data) > 1400000: + break self.data = pd.DataFrame(data).reset_index() end_time = time.time() # Record the end time From df727fc41c92cca614d1eb399a210eb725e00d55 Mon Sep 17 00:00:00 2001 From: sippycoder Date: Sun, 25 May 2025 19:49:56 +0000 Subject: [PATCH 08/13] delta fix --- apps/Castor/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/apps/Castor/train.py b/apps/Castor/train.py index df9d9a71..0708a018 100644 --- a/apps/Castor/train.py +++ b/apps/Castor/train.py @@ -234,11 +234,13 @@ def mup_set_base_shapes(model, args): base_args = copy.deepcopy(args) base_args.diffusion_model.dim = 288 base_args.diffusion_model.n_heads = 4 + base_args.diffusion_model.n_kv_heads = 1 base_model = Castor(base_args) delta_args = copy.deepcopy(args) delta_args.diffusion_model.dim = 360 delta_args.diffusion_model.n_heads = 5 + delta_args.diffusion_model.n_kv_heads = 1 delta_model = Castor(delta_args) set_base_shapes(model, base_model, delta=delta_model) From c6be27026e042b4ea0fc0e5991af59ad8dfdc2dc Mon Sep 17 00:00:00 2001 From: sippycoder Date: Sun, 25 May 2025 19:56:25 +0000 Subject: [PATCH 09/13] mup fix --- apps/Castor/model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/apps/Castor/model.py b/apps/Castor/model.py index 53c2da56..85fd2d62 100644 --- a/apps/Castor/model.py +++ b/apps/Castor/model.py @@ -52,21 +52,21 @@ def __init__(self, input_dim: int, hidden_dim: int, encoder_dim: int): nn.SiLU(), nn.Linear(hidden_dim, hidden_dim), # mup nn.SiLU(), - MuReadout(hidden_dim, encoder_dim), # mup ) + self.output = MuReadout(hidden_dim, encoder_dim), # mup self.reset_parameters() def forward(self, x): x = self.proj(x) - return x + return self.output(x) def reset_parameters(self): layer_init_kaiming_normal(self.proj[0]) layer_init_kaiming_normal(self.proj[2]) - nn.init.constant_(self.proj[4].weight, 0.) # initialize output weights by zero. - if self.proj[4].bias is not None: - nn.init.constant_(self.proj[4].bias, 0.) + nn.init.constant_(self.output.weight, 0.) # initialize output weights by zero. + if self.output.bias is not None: + nn.init.constant_(self.output.bias, 0.) class Castor(nn.Module): From 98afe7810614acc4a98fe2044f061f54cef6b743 Mon Sep 17 00:00:00 2001 From: sippycoder Date: Sun, 25 May 2025 20:03:22 +0000 Subject: [PATCH 10/13] bug fix --- apps/Castor/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/Castor/model.py b/apps/Castor/model.py index 85fd2d62..d7eb8432 100644 --- a/apps/Castor/model.py +++ b/apps/Castor/model.py @@ -53,7 +53,7 @@ def __init__(self, input_dim: int, hidden_dim: int, encoder_dim: int): nn.Linear(hidden_dim, hidden_dim), # mup nn.SiLU(), ) - self.output = MuReadout(hidden_dim, encoder_dim), # mup + self.output = MuReadout(hidden_dim, encoder_dim) # mup self.reset_parameters() From b7c1b02589bc091a51fc3454a709e8915846880b Mon Sep 17 00:00:00 2001 From: sippycoder Date: Mon, 26 May 2025 01:09:03 +0000 Subject: [PATCH 11/13] mup compatibility fix --- apps/Castor/model.py | 14 +++++++------- apps/Castor/modules/transformer.py | 7 ++++--- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/apps/Castor/model.py b/apps/Castor/model.py index d7eb8432..a074ea79 100644 --- a/apps/Castor/model.py +++ b/apps/Castor/model.py @@ -48,25 +48,25 @@ def __init__(self, input_dim: int, hidden_dim: int, encoder_dim: int): super(AlignmentProjection, self).__init__() self.proj = nn.Sequential( - nn.Linear(input_dim, hidden_dim), # mup + MuReadout(input_dim, hidden_dim), # mup nn.SiLU(), nn.Linear(hidden_dim, hidden_dim), # mup nn.SiLU(), + nn.Linear(hidden_dim, encoder_dim), # mup ) - self.output = MuReadout(hidden_dim, encoder_dim) # mup self.reset_parameters() def forward(self, x): x = self.proj(x) - return self.output(x) + return x def reset_parameters(self): - layer_init_kaiming_normal(self.proj[0]) + # MuReadout has its own initialization layer_init_kaiming_normal(self.proj[2]) - nn.init.constant_(self.output.weight, 0.) # initialize output weights by zero. - if self.output.bias is not None: - nn.init.constant_(self.output.bias, 0.) + nn.init.constant_(self.proj[4].weight, 0.) # initialize output weights by zero. + if self.proj[4].bias is not None: + nn.init.constant_(self.proj[4].bias, 0.) class Castor(nn.Module): diff --git a/apps/Castor/modules/transformer.py b/apps/Castor/modules/transformer.py index 3932eb9f..9dc8594a 100644 --- a/apps/Castor/modules/transformer.py +++ b/apps/Castor/modules/transformer.py @@ -562,9 +562,10 @@ def reset_parameters(self): self.tmb_embed.reset_parameters() self.img_embed.reset_parameters() - nn.init.constant_(self.img_output.weight, 0.) # initialize output weights by zero. - if self.img_output.bias is not None: - nn.init.constant_(self.img_output.bias, 0.) + # muReadout has its own initialization + # nn.init.constant_(self.img_output.weight, 0.) # initialize output weights by zero. + # if self.img_output.bias is not None: + # nn.init.constant_(self.img_output.bias, 0.) layer_init_kaiming_normal(self.cond_proj) From 4ab46264ea96ed98b125fe6f6cb129d7bf8643ee Mon Sep 17 00:00:00 2001 From: sippycoder Date: Mon, 26 May 2025 01:29:37 +0000 Subject: [PATCH 12/13] mup compatibility fixes --- apps/Castor/modules/component.py | 8 +-- apps/Castor/train.py | 4 +- test_mup_fix.py | 87 ++++++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 6 deletions(-) create mode 100644 test_mup_fix.py diff --git a/apps/Castor/modules/component.py b/apps/Castor/modules/component.py index 18d45015..95f206f5 100644 --- a/apps/Castor/modules/component.py +++ b/apps/Castor/modules/component.py @@ -15,6 +15,7 @@ from xformers.ops import AttentionBias, fmha from liger_kernel.transformers import LigerSwiGLUMLP, LigerRMSNorm, liger_rotary_pos_emb from types import SimpleNamespace +from mup import MuReadout # fa3 from flash_attn_interface import flash_attn_varlen_func @@ -531,12 +532,12 @@ def __init__( n_heads * self.head_dim, bias=False, ) # mup - self.wk = nn.Linear( + self.wk = MuReadout( dim, n_kv_heads * self.head_dim, bias=False, ) # mup - self.wv = nn.Linear( + self.wv = MuReadout( dim, n_kv_heads * self.head_dim, bias=False, @@ -558,8 +559,7 @@ def __init__( def reset_parameters(self, *args, **kwargs): layer_init_kaiming_normal(self.wq) - layer_init_kaiming_normal(self.wk) - layer_init_kaiming_normal(self.wv) + # MuReadout layers have their own initialization layer_init_kaiming_normal(self.wo) if self.qk_norm: self.q_norm.reset_parameters() diff --git a/apps/Castor/train.py b/apps/Castor/train.py index 0708a018..84971294 100644 --- a/apps/Castor/train.py +++ b/apps/Castor/train.py @@ -234,13 +234,13 @@ def mup_set_base_shapes(model, args): base_args = copy.deepcopy(args) base_args.diffusion_model.dim = 288 base_args.diffusion_model.n_heads = 4 - base_args.diffusion_model.n_kv_heads = 1 + base_args.diffusion_model.n_kv_heads = 2 # Scale this too base_model = Castor(base_args) delta_args = copy.deepcopy(args) delta_args.diffusion_model.dim = 360 delta_args.diffusion_model.n_heads = 5 - delta_args.diffusion_model.n_kv_heads = 1 + delta_args.diffusion_model.n_kv_heads = 3 # Scale this too delta_model = Castor(delta_args) set_base_shapes(model, base_model, delta=delta_model) diff --git a/test_mup_fix.py b/test_mup_fix.py new file mode 100644 index 00000000..51c016a4 --- /dev/null +++ b/test_mup_fix.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +import sys +import os +sys.path.append('/home/akito/Pollux') + +import torch +import copy +import gc +from apps.Castor.model import Castor, ModelArgs, build_2B_Castor +from apps.Castor.modules.transformer import TransformerArgs +from mup import set_base_shapes + +def test_mup_setup(): + print("Testing MuP setup...") + + # Build the main model + model_args = build_2B_Castor().args + model = Castor(model_args) + + print(f"Main model config:") + print(f" dim: {model_args.diffusion_model.dim}") + print(f" n_heads: {model_args.diffusion_model.n_heads}") + print(f" n_kv_heads: {model_args.diffusion_model.n_kv_heads}") + + # Create base model + base_args = copy.deepcopy(model_args) + base_args.diffusion_model.dim = 288 + base_args.diffusion_model.n_heads = 4 + base_args.diffusion_model.n_kv_heads = 2 + base_model = Castor(base_args) + + print(f"Base model config:") + print(f" dim: {base_args.diffusion_model.dim}") + print(f" n_heads: {base_args.diffusion_model.n_heads}") + print(f" n_kv_heads: {base_args.diffusion_model.n_kv_heads}") + + # Create delta model + delta_args = copy.deepcopy(model_args) + delta_args.diffusion_model.dim = 360 + delta_args.diffusion_model.n_heads = 5 + delta_args.diffusion_model.n_kv_heads = 3 + delta_model = Castor(delta_args) + + print(f"Delta model config:") + print(f" dim: {delta_args.diffusion_model.dim}") + print(f" n_heads: {delta_args.diffusion_model.n_heads}") + print(f" n_kv_heads: {delta_args.diffusion_model.n_kv_heads}") + + # Test MuP setup + try: + print("\nSetting base shapes...") + set_base_shapes(model, base_model, delta=delta_model) + print("āœ… MuP setup successful!") + + # Test a forward pass + print("\nTesting forward pass...") + model.eval() + + # Create dummy batch + batch = { + 'latent_code': torch.randn(2, 16, 32, 32), # B, C, H, W + 'text_embedding': torch.randn(2, 128, 512), # B, seq_len, dim + 'attention_mask': torch.ones(2, 128, dtype=torch.bool) + } + + with torch.no_grad(): + output = model(batch) + print(f"āœ… Forward pass successful! Loss: {output.loss.item():.4f}") + + except Exception as e: + print(f"āŒ Error: {e}") + return False + + finally: + # Cleanup + del base_model, delta_model + gc.collect() + + return True + +if __name__ == "__main__": + success = test_mup_setup() + if success: + print("\nšŸŽ‰ All tests passed! The MuP fix is working correctly.") + else: + print("\nšŸ’„ Tests failed. There may still be issues with the MuP setup.") \ No newline at end of file From 279225012dc800a511640ca6fbe31c06c2bed1c0 Mon Sep 17 00:00:00 2001 From: sippycoder Date: Mon, 26 May 2025 01:59:23 +0000 Subject: [PATCH 13/13] mup fixes --- apps/Castor/train.py | 2 +- test_mup_fix.py | 20 +++++++++++++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/apps/Castor/train.py b/apps/Castor/train.py index 84971294..7ef482ae 100644 --- a/apps/Castor/train.py +++ b/apps/Castor/train.py @@ -239,7 +239,7 @@ def mup_set_base_shapes(model, args): delta_args = copy.deepcopy(args) delta_args.diffusion_model.dim = 360 - delta_args.diffusion_model.n_heads = 5 + delta_args.diffusion_model.n_heads = 6 delta_args.diffusion_model.n_kv_heads = 3 # Scale this too delta_model = Castor(delta_args) diff --git a/test_mup_fix.py b/test_mup_fix.py index 51c016a4..ce37ceac 100644 --- a/test_mup_fix.py +++ b/test_mup_fix.py @@ -15,7 +15,9 @@ def test_mup_setup(): print("Testing MuP setup...") # Build the main model + print("Building main model...") model_args = build_2B_Castor().args + print(f"Main model args: {model_args}") model = Castor(model_args) print(f"Main model config:") @@ -23,7 +25,7 @@ def test_mup_setup(): print(f" n_heads: {model_args.diffusion_model.n_heads}") print(f" n_kv_heads: {model_args.diffusion_model.n_kv_heads}") - # Create base model + # Create base model (matching train.py) base_args = copy.deepcopy(model_args) base_args.diffusion_model.dim = 288 base_args.diffusion_model.n_heads = 4 @@ -35,10 +37,10 @@ def test_mup_setup(): print(f" n_heads: {base_args.diffusion_model.n_heads}") print(f" n_kv_heads: {base_args.diffusion_model.n_kv_heads}") - # Create delta model + # Create delta model (matching train.py) delta_args = copy.deepcopy(model_args) delta_args.diffusion_model.dim = 360 - delta_args.diffusion_model.n_heads = 5 + delta_args.diffusion_model.n_heads = 6 # Updated to match your change delta_args.diffusion_model.n_kv_heads = 3 delta_model = Castor(delta_args) @@ -47,6 +49,14 @@ def test_mup_setup(): print(f" n_heads: {delta_args.diffusion_model.n_heads}") print(f" n_kv_heads: {delta_args.diffusion_model.n_kv_heads}") + # Debug: Check parameters before MuP setup + print("\nChecking parameters that might need MuP treatment...") + for name, param in model.named_parameters(): + if param.shape == torch.Size([2048, 2048]): + print(f"Found 2048x2048 parameter: {name}") + if len(param.shape) == 2 and (2048 in param.shape): + print(f"Parameter with 2048 dimension: {name} - shape: {param.shape}") + # Test MuP setup try: print("\nSetting base shapes...") @@ -70,6 +80,10 @@ def test_mup_setup(): except Exception as e: print(f"āŒ Error: {e}") + print("\nDebugging: Checking which parameters don't have infshape...") + for name, param in model.named_parameters(): + if not hasattr(param, 'infshape'): + print(f"Missing infshape: {name} - shape: {param.shape}") return False finally: