From 624e66f1b3aff058a404018ef7f0885858776a10 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Sun, 8 Jun 2025 06:19:50 +0000 Subject: [PATCH 1/5] Added constant negative token --- apps/Castor/configs/eval.yaml | 4 ++-- ...ucket_256_Castor_flux_qwen_fixed_siglip2.yaml | 4 ++-- apps/Castor/generate.py | 13 +++++-------- apps/Castor/model.py | 16 ++++++++-------- apps/Castor/train.py | 2 +- apps/main/utils/mongodb_data_load.py | 4 ++-- 6 files changed, 20 insertions(+), 23 deletions(-) diff --git a/apps/Castor/configs/eval.yaml b/apps/Castor/configs/eval.yaml index 8961228..67c50e6 100644 --- a/apps/Castor/configs/eval.yaml +++ b/apps/Castor/configs/eval.yaml @@ -1,8 +1,8 @@ # python -m apps.Castor.generate config=apps/Castor/configs/eval.yaml name: "debug_evals" stage: eval -ckpt_dir: "/mnt/pollux/checkpoints/aj/unpadded3/checkpoints/0000080000/" -dump_dir: "/mnt/pollux/checkpoints/aj/unpadded3/MJHQ/0000080000" +ckpt_dir: "/mnt/pollux/checkpoints/aj/neg_token/checkpoints/0000080000/" +dump_dir: "/mnt/pollux/checkpoints/aj/neg_token/MJHQ/0000080000" generator: guidance_scale: 6.5 dtype: bf16 diff --git a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml index 5b110a0..55bed60 100644 --- a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml +++ b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml @@ -4,8 +4,8 @@ version: v1.0 # From now on, start to align train, data, and model setting with train stage (just finish refactor for dara) train_stage: stage-1 # options: preliminary, pretraining, posttraining; aligned with data setting -name: fp8_ffn_unpadded #used for local dump and wandb log -output_dir: /mnt/pollux/checkpoints/aj/fp8_ffn_unpadded +name: neg_token #used for local dump and wandb log +output_dir: /mnt/pollux/checkpoints/aj/ dump_dir: '' # No need now steps: 500000 seed: 777 diff --git a/apps/Castor/generate.py b/apps/Castor/generate.py index 0ed6dd7..61f939e 100644 --- a/apps/Castor/generate.py +++ b/apps/Castor/generate.py @@ -93,14 +93,11 @@ def forward(self, context: Dict[str, Any]) -> torch.Tensor: ) latent = self.prepare_latent(context, device=cur_device) pos_conditional_signal, pos_conditional_mask = self.model.text_encoder(context) - negative_conditional_signal = ( - self.model.diffusion_transformer.negative_token.repeat( - pos_conditional_signal.size(0), pos_conditional_signal.size(1), 1 - ) - ) - negative_conditional_mask = torch.ones_like( - pos_conditional_mask, dtype=pos_conditional_mask.dtype - ) + negative_conditional_signal = torch.zeros_like(pos_conditional_signal) + negative_conditional_signal[:, 0, :] = self.model.diffusion_transformer.negative_token + negative_conditional_mask = torch.zeros_like(pos_conditional_mask) + negative_conditional_mask[:, 0] = 1 + context = torch.cat( [ pos_conditional_signal, diff --git a/apps/Castor/model.py b/apps/Castor/model.py index 2db2462..54f2adf 100644 --- a/apps/Castor/model.py +++ b/apps/Castor/model.py @@ -92,7 +92,7 @@ def __init__(self, args: ModelArgs): self.scheduler = RectifiedFlow(args.scheduler) self.text_cfg_ratio = args.text_cfg_ratio - def forward(self, batch: dict[str:any], flops_meter= None) -> dict[str:any]: + def forward(self, batch: dict[str:any], flops_meter= None, is_training=False) -> dict[str:any]: if hasattr(self, "compressor"): batch["latent_code"] = self.compressor.extract_latents(batch, flops_meter) @@ -104,13 +104,13 @@ def forward(self, batch: dict[str:any], flops_meter= None) -> dict[str:any]: conditional_signal, conditional_mask = batch["text_embedding"], batch["attention_mask"] - if random.random() <= self.text_cfg_ratio: - conditional_signal = self.diffusion_transformer.negative_token.repeat( - conditional_signal.size(0), conditional_signal.size(1), 1 - ) - conditional_mask = torch.ones_like( - conditional_mask, dtype=conditional_signal.dtype - ) + if is_training: + cfg_mask = (torch.rand(conditional_signal.shape[0], device=conditional_signal.device) < self.text_cfg_ratio) + if cfg_mask.any(): + conditional_signal[cfg_mask, 0:1, :] = self.diffusion_transformer.negative_token + # For unconditional samples, the attention mask should be [1, 0, 0, ...]. + conditional_mask[cfg_mask, 0] = 1 + conditional_mask[cfg_mask, 1:] = 0 latent_code = batch["latent_code"] noised_x, t, target = self.scheduler.sample_noised_input(latent_code) diff --git a/apps/Castor/train.py b/apps/Castor/train.py index 8e889d3..e3dcf68 100644 --- a/apps/Castor/train.py +++ b/apps/Castor/train.py @@ -413,7 +413,7 @@ def train(args: TrainArgs): start_timer.record() with te.fp8_autocast(enabled=args.model.diffusion_model.use_fp8_ffn, fp8_recipe=fp8_recipe, fp8_group=all_gpus): - outputs = model(batch, flops_meter) + outputs = model(batch, flops_meter, is_training=True) # We scale loss with grad_acc_steps so the gradient is the same # regardless of grad_acc_steps loss = outputs.loss / args.grad_acc_steps diff --git a/apps/main/utils/mongodb_data_load.py b/apps/main/utils/mongodb_data_load.py index ef48a4a..db6f337 100644 --- a/apps/main/utils/mongodb_data_load.py +++ b/apps/main/utils/mongodb_data_load.py @@ -148,8 +148,8 @@ def set_local_partition(self): data.extend(df.to_dict(orient="records")) # # Note: used for debugging - # if len(data) > 10000: - # break + if len(data) > int(2e6): + break else: raise ValueError(f"Invalid Root Directory Type. Set root_dir_type to 'json' or 'parquet'") From 1c98325c4d61694f6c8b37bb7c14d423a5798583 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Sun, 8 Jun 2025 20:23:48 +0000 Subject: [PATCH 2/5] intermmediate commit --- apps/Castor/configs/eval.yaml | 4 ++-- ...n_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml | 3 ++- apps/Castor/generate.py | 8 ++++++-- apps/Castor/model.py | 10 ++++++---- apps/Castor/modules/transformer.py | 4 +++- apps/Castor/train.py | 13 +++++++++---- 6 files changed, 28 insertions(+), 14 deletions(-) diff --git a/apps/Castor/configs/eval.yaml b/apps/Castor/configs/eval.yaml index 67c50e6..82d797d 100644 --- a/apps/Castor/configs/eval.yaml +++ b/apps/Castor/configs/eval.yaml @@ -1,8 +1,8 @@ # python -m apps.Castor.generate config=apps/Castor/configs/eval.yaml name: "debug_evals" stage: eval -ckpt_dir: "/mnt/pollux/checkpoints/aj/neg_token/checkpoints/0000080000/" -dump_dir: "/mnt/pollux/checkpoints/aj/neg_token/MJHQ/0000080000" +ckpt_dir: "/mnt/pollux/checkpoints/aj/neg_token2/checkpoints/0000080000/" +dump_dir: "/mnt/pollux/checkpoints/aj/neg_token2/MJHQ/0000080000" generator: guidance_scale: 6.5 dtype: bf16 diff --git a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml index 55bed60..1b2601a 100644 --- a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml +++ b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml @@ -4,7 +4,7 @@ version: v1.0 # From now on, start to align train, data, and model setting with train stage (just finish refactor for dara) train_stage: stage-1 # options: preliminary, pretraining, posttraining; aligned with data setting -name: neg_token #used for local dump and wandb log +name: neg_token2 #used for local dump and wandb log output_dir: /mnt/pollux/checkpoints/aj/ dump_dir: '' # No need now steps: 500000 @@ -68,6 +68,7 @@ model: unpadded: true use_fp8_ffn: true fp8_ffn_skip_layers: [] # falls back to liger_ffn + n_unconditional_tokens: 64 with_vae: true vae_args: vae_type: "flux" diff --git a/apps/Castor/generate.py b/apps/Castor/generate.py index 61f939e..0203f3a 100644 --- a/apps/Castor/generate.py +++ b/apps/Castor/generate.py @@ -93,10 +93,14 @@ def forward(self, context: Dict[str, Any]) -> torch.Tensor: ) latent = self.prepare_latent(context, device=cur_device) pos_conditional_signal, pos_conditional_mask = self.model.text_encoder(context) + + n_unconditional_tokens = self.model.diffusion_transformer.n_unconditional_tokens negative_conditional_signal = torch.zeros_like(pos_conditional_signal) - negative_conditional_signal[:, 0, :] = self.model.diffusion_transformer.negative_token + negative_conditional_signal[ + :, :n_unconditional_tokens, : + ] = self.model.diffusion_transformer.negative_token negative_conditional_mask = torch.zeros_like(pos_conditional_mask) - negative_conditional_mask[:, 0] = 1 + negative_conditional_mask[:, :n_unconditional_tokens] = 1 context = torch.cat( [ diff --git a/apps/Castor/model.py b/apps/Castor/model.py index 54f2adf..8ed7392 100644 --- a/apps/Castor/model.py +++ b/apps/Castor/model.py @@ -107,10 +107,11 @@ def forward(self, batch: dict[str:any], flops_meter= None, is_training=False) -> if is_training: cfg_mask = (torch.rand(conditional_signal.shape[0], device=conditional_signal.device) < self.text_cfg_ratio) if cfg_mask.any(): - conditional_signal[cfg_mask, 0:1, :] = self.diffusion_transformer.negative_token - # For unconditional samples, the attention mask should be [1, 0, 0, ...]. - conditional_mask[cfg_mask, 0] = 1 - conditional_mask[cfg_mask, 1:] = 0 + n_unconditional_tokens = self.diffusion_transformer.n_unconditional_tokens + conditional_signal[cfg_mask, :n_unconditional_tokens, :] = self.diffusion_transformer.negative_token + # For unconditional samples, the attention mask should be [1, 1, ..., 0, ...]. + conditional_mask[cfg_mask, :n_unconditional_tokens] = 1 + conditional_mask[cfg_mask, n_unconditional_tokens:] = 0 latent_code = batch["latent_code"] noised_x, t, target = self.scheduler.sample_noised_input(latent_code) @@ -235,6 +236,7 @@ def build_2B_Castor(): condition_seqlen=128, norm_eps=1e-5, condition_dim=512, + n_unconditional_tokens=64, ) vae_args = VideoVAEArgs( pretrained_model_name_or_path="/mnt/pollux/checkpoints/HunyuanVideo/vae", diff --git a/apps/Castor/modules/transformer.py b/apps/Castor/modules/transformer.py index 4a1ca92..3d1a8f1 100644 --- a/apps/Castor/modules/transformer.py +++ b/apps/Castor/modules/transformer.py @@ -46,6 +46,7 @@ class TransformerArgs(BaseTransformerArgs): tmb_size: int = 256 condition_seqlen: int = 1000 gen_seqlen: int = 1000 + n_unconditional_tokens: int = 64 pre_trained_path: Optional[str] = None qk_norm: bool = True shared_adaLN: bool = False @@ -305,6 +306,7 @@ def __init__(self, args: TransformerArgs): self.out_channels = args.out_channels self.in_channels = args.in_channels self.unpadded = args.unpadded + self.n_unconditional_tokens = args.n_unconditional_tokens self.tmb_embed = TimestepEmbedder( hidden_size=args.time_step_dim, time_embedding_size=args.tmb_size ) @@ -335,7 +337,7 @@ def __init__(self, args: TransformerArgs): self.cond_norm = RMSNorm( args.dim, eps=args.norm_eps, liger_rms_norm=args.liger_rms_norm ) - self.negative_token = nn.Parameter(torch.zeros(1, 1, args.condition_dim)) + self.negative_token = nn.Parameter(torch.zeros(1, self.n_unconditional_tokens, args.condition_dim)) self.cond_proj = nn.Linear( in_features=args.condition_dim, out_features=args.dim, diff --git a/apps/Castor/train.py b/apps/Castor/train.py index e3dcf68..7a53276 100644 --- a/apps/Castor/train.py +++ b/apps/Castor/train.py @@ -315,10 +315,6 @@ def train(args: TrainArgs): scheduler=scheduler, ) - checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint) - checkpoint.load(model, optimizer, train_state, world_mesh) - # Either load from latest checkpoint or start from scratch - fp8_recipe = None if args.model.diffusion_model.use_fp8_ffn: logger.info("FP8 is enabled. Defining FP8 recipe.") @@ -327,6 +323,15 @@ def train(args: TrainArgs): fp8_format = Format.HYBRID # Or Format.E4M3 fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") # You might want to make recipe parameters configurable via TrainArgs + te.fp8_global_state.FP8GlobalStateManager.set_fp8_enabled(True) + te.fp8_global_state.FP8GlobalStateManager.set_fp8_recipe( + fp8_recipe + ) + + checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint) + checkpoint.load(model, optimizer, train_state, world_mesh) + # Either load from latest checkpoint or start from scratch + gc.disable() From 98739a205ceca51c65a6878b2066f45844e85cda Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Mon, 16 Jun 2025 06:53:57 +0000 Subject: [PATCH 3/5] f-lite scheduler --- apps/Castor/model.py | 49 ++++++++++++++++++++++++++++++++- apps/Castor/train.py | 30 ++++++++++---------- apps/Castor/utils/flop_meter.py | 13 +++++---- 3 files changed, 71 insertions(+), 21 deletions(-) diff --git a/apps/Castor/model.py b/apps/Castor/model.py index 8ed7392..90ede7d 100644 --- a/apps/Castor/model.py +++ b/apps/Castor/model.py @@ -2,6 +2,7 @@ import logging import random +import math from dataclasses import dataclass, field from typing import List, Optional, Tuple @@ -92,6 +93,51 @@ def __init__(self, args: ModelArgs): self.scheduler = RectifiedFlow(args.scheduler) self.text_cfg_ratio = args.text_cfg_ratio + @torch.compile + @torch.no_grad() + def sample_timesteps( + self, vae_latent: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sample a timestep `t`, create the corresponding noisy latent `z_t`, and + compute the velocity target `v_objective`. + + Parameters + ---------- + vae_latent : torch.Tensor + Clean latent tensor with shape (B, C, H, W). + + Returns + ------- + z_t : torch.Tensor + Noisy latent at timestep `t`, same shape and dtype as `vae_latent`. + t : torch.Tensor + Sampled timestep in the continuous range [0, 1] with shape (B,). + v_objective : torch.Tensor + Velocity target (`x₀ − noise`) used for training. + """ + B, _, H, W = vae_latent.shape + device, dtype = vae_latent.device, vae_latent.dtype + + # --- sample timesteps -------------------------------------------------- + image_token_size = H * W + alpha = 2 * math.sqrt(image_token_size / (64 * 64)) + + z = torch.randn(B, device=device, dtype=torch.float32) + logistic = torch.sigmoid(z) + lognormal = logistic * alpha / (1 + (alpha - 1) * logistic) + + do_uniform = torch.rand(B, device=device) < 0.1 + t = torch.where(do_uniform, torch.rand(B, device=device), lognormal).to(dtype) + + # --- add noise --------------------------------------------------------- + noise = torch.randn_like(vae_latent) + t_reshaped = t.view(B, 1, 1, 1) + z_t = vae_latent * (1 - t_reshaped) + noise * t_reshaped + + v_objective = vae_latent - noise + return z_t, t, v_objective + def forward(self, batch: dict[str:any], flops_meter= None, is_training=False) -> dict[str:any]: if hasattr(self, "compressor"): batch["latent_code"] = self.compressor.extract_latents(batch, flops_meter) @@ -114,7 +160,8 @@ def forward(self, batch: dict[str:any], flops_meter= None, is_training=False) -> conditional_mask[cfg_mask, n_unconditional_tokens:] = 0 latent_code = batch["latent_code"] - noised_x, t, target = self.scheduler.sample_noised_input(latent_code) + noised_x, t, target = self.sample_timesteps(latent_code) + output = self.diffusion_transformer( x=noised_x, time_steps=t, diff --git a/apps/Castor/train.py b/apps/Castor/train.py index 7a53276..8d85504 100644 --- a/apps/Castor/train.py +++ b/apps/Castor/train.py @@ -260,11 +260,26 @@ def train(args: TrainArgs): torch.manual_seed(args.seed) logger.info("Building model") + + fp8_recipe = None + if args.model.diffusion_model.use_fp8_ffn: + logger.info("FP8 is enabled. Defining FP8 recipe.") + # Example recipe, adjust as needed + all_gpus = torch.distributed.new_group(backend="nccl") + fp8_format = Format.HYBRID # Or Format.E4M3 + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") + # with te.fp8_model_init(enabled=True, recipe=fp8_recipe): model = Castor(args.model) logger.info("Model is built !") logger.info(model) + with te.fp8_autocast(enabled=args.model.diffusion_model.use_fp8_ffn, fp8_recipe=fp8_recipe): + for module in model.modules(): + if isinstance(module, te.LayerNormMLP): + module.init_fp8_metadata(2) + + model_param_count = get_num_params(model) flops_meter = FlopsMeter(args.model, model) @@ -315,19 +330,6 @@ def train(args: TrainArgs): scheduler=scheduler, ) - fp8_recipe = None - if args.model.diffusion_model.use_fp8_ffn: - logger.info("FP8 is enabled. Defining FP8 recipe.") - # Example recipe, adjust as needed - all_gpus = torch.distributed.new_group(backend="nccl") - fp8_format = Format.HYBRID # Or Format.E4M3 - fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") - # You might want to make recipe parameters configurable via TrainArgs - te.fp8_global_state.FP8GlobalStateManager.set_fp8_enabled(True) - te.fp8_global_state.FP8GlobalStateManager.set_fp8_recipe( - fp8_recipe - ) - checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint) checkpoint.load(model, optimizer, train_state, world_mesh) # Either load from latest checkpoint or start from scratch @@ -617,4 +619,4 @@ def main(): if __name__ == "__main__": torch.set_float32_matmul_precision('high') - main() + main() \ No newline at end of file diff --git a/apps/Castor/utils/flop_meter.py b/apps/Castor/utils/flop_meter.py index 95dbc81..ff971c1 100644 --- a/apps/Castor/utils/flop_meter.py +++ b/apps/Castor/utils/flop_meter.py @@ -116,12 +116,13 @@ def __init__( self.cond_num_heads = model.text_encoder.model.config.num_attention_heads self.cond_headdim = self.cond_dim // self.cond_num_heads - self.vision_params = get_num_params(model.vision_encoder.model) - self.vision_num_layers = model.vision_encoder.model.config.num_hidden_layers - self.vision_num_heads = model.vision_encoder.model.config.num_attention_heads - self.vision_dim = model.vision_encoder.model.config.hidden_size - self.vision_patch_size = model.vision_encoder.model.config.patch_size - self.vision_headdim = self.vision_dim // self.vision_num_heads + if hasattr(model, "vision_encoder"): + self.vision_params = get_num_params(model.vision_encoder.model) + self.vision_num_layers = model.vision_encoder.model.config.num_hidden_layers + self.vision_num_heads = model.vision_encoder.model.config.num_attention_heads + self.vision_dim = model.vision_encoder.model.config.hidden_size + self.vision_patch_size = model.vision_encoder.model.config.patch_size + self.vision_headdim = self.vision_dim // self.vision_num_heads self.vae_params = get_num_params(model.compressor.vae) From d1f0fd7cb1b67cd844b39810fcc44e9d5e34160e Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Mon, 16 Jun 2025 07:09:43 +0000 Subject: [PATCH 4/5] Minor fix --- apps/Castor/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/Castor/model.py b/apps/Castor/model.py index 90ede7d..a93e2f2 100644 --- a/apps/Castor/model.py +++ b/apps/Castor/model.py @@ -135,7 +135,7 @@ def sample_timesteps( t_reshaped = t.view(B, 1, 1, 1) z_t = vae_latent * (1 - t_reshaped) + noise * t_reshaped - v_objective = vae_latent - noise + v_objective = noise - vae_latent return z_t, t, v_objective def forward(self, batch: dict[str:any], flops_meter= None, is_training=False) -> dict[str:any]: From 1eeba6d44acc5bad192ef7cd4d7a9ec8af765877 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Mon, 16 Jun 2025 07:12:35 +0000 Subject: [PATCH 5/5] Added assert to crash on dynamic resolution --- apps/Castor/model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/apps/Castor/model.py b/apps/Castor/model.py index a93e2f2..2449f93 100644 --- a/apps/Castor/model.py +++ b/apps/Castor/model.py @@ -116,6 +116,9 @@ def sample_timesteps( v_objective : torch.Tensor Velocity target (`x₀ − noise`) used for training. """ + + assert isinstance(vae_latent, torch.Tensor), "vae_latent must be a tensor, implement dynamic resolution" + B, _, H, W = vae_latent.shape device, dtype = vae_latent.device, vae_latent.dtype