diff --git a/apps/Castor/configs/mup_v0.yaml b/apps/Castor/configs/mup_v0.yaml new file mode 100644 index 0000000..f931076 --- /dev/null +++ b/apps/Castor/configs/mup_v0.yaml @@ -0,0 +1,143 @@ +# torchrun --standalone --nnodes 1 --nproc-per-node 8 -m apps.Castor.train config=apps/Castor/configs/aws_256_Castor_flux_qwen_fixed_siglip2.yaml + +# Set up single experiment +version: v1.0 +# From now on, start to align train, data, and model setting with train stage (just finish refactor for dara) +train_stage: stage-1 # options: preliminary, pretraining, posttraining; aligned with data setting +name: mup_dim_144_bs_128_lr_5e-4 #used for local dump and wandb log +output_dir: /mnt/pollux/checkpoints/mup/lr_sweep +dump_dir: '' # No need now +steps: 500000 +seed: 777 +optim: + lr: 5e-4 + warmup: 4000 + lr_min_ratio: 1.5e-5 + clip: 1.0 + weight_decay: 0.01 + mup: true + +distributed: + gpus: "0,1,2,3" + fsdp_type: full_shard + dp_shard: 4 + dp_replicate: 1 + compile: false + model_dtype: bf16 # options: `fb8` is only supported by H100 + matmul_allow_tf32: false + selective_activation_checkpointing: false + tp_size: 1 + compile_cache_size_limit: 64 + +model: + scheduler: + num_train_timesteps: 1000 + base_image_seq_len: 256 + base_shift: 0.5 + max_image_seq_len: 4096 + max_shift: 1.15 + shift: 1.0 # need consider 3.0 or 1.0 + weighting_scheme: 'logit_normal' + logit_mean: 0.0 + logit_std: 1.0 + mode_scale: 1.29 + use_dynamic_shifting: true + + diffusion_model: + dim: &hidden_dim 256 + ffn_dim_multiplier: 1.5 + multiple_of: 256 + n_heads: 32 + n_kv_heads: 8 + n_layers: 24 + attention_window: [-1, -1] # [-1, -1] for full attention + align_layer: 8 + time_step_dim: *hidden_dim + patch_size: 2 + in_channels: 16 + out_channels: 16 + tmb_size: 256 + gen_seqlen: 32 + condition_seqlen: 256 + norm_eps: 1e-5 + condition_dim: 3584 + qk_norm: false + liger_rms_norm: true + liger_ffn: true + liger_rotary_emb: false + shared_adaLN: true + unpadded: true + use_fp8_ffn: false + fp8_ffn_skip_layers: [] # falls back to liger_ffn + with_vae: true + vae_args: + vae_type: "flux" + pretrained_model_name_or_path: '/mnt/pollux/checkpoints/FLUX.1-dev/vae' + enable_tiling: false + enable_slicing: false + + vision_encoder_alignment: false + vision_encoder_args: + encoder_name: "siglip2" + weight_path: "/mnt/pollux/checkpoints/siglip2-base-patch16-naflex" + projection_hidden_dim: *hidden_dim + + text_cfg_ratio: 0.1 + text_encoder: + config_name: "Qwen/Qwen2.5-VL-7B-Instruct" + dtype: "bf16" + text_seqlen: 256 + model_path: "/mnt/pollux/checkpoints/Qwen2.5-VL-7B-Instruct" + relative_depth: 0.75 + +# save_base_shapes: "base_shapes.pkl" +load_base_shapes: "base_shapes.pkl" + +data: + - stage: stage-1 + id: 1 + data_name: bucket-256-2 + task: text_to_image + source: mongodb + image_size: 256 + condition_image_size: 256 + max_ratio: 1.0 + partition_key: 'partition_key' + retries: 3 + extract_field: + "media_path": "image" + use: true + root_dir: "/mnt/pollux/mongo_db_cache_train" + root_dir_type: "json" # options: "json", "parquet" + # https://worldmodeldata-prod.s3.us-east-2.amazonaws.com + base_url: s3://worldmodeldata-prod + dataloader: + prefetch_factor: 2 + batch_size: 32 + num_workers: 8 + seed: 1024 + shuffle: True + pin_memory: True + drop_last: False + +profiling: + run: false + +checkpoint: + dump: + every: 10000 + keep: 0 # Don't remove the ckpt + eval: + every: 5000 + keep: 0 # Don't remove the ckpt + +logging: + freq: 100 + wandb: + project: ablations + entity: metauto + name: '' + +env: + ENABLE_INTRA_NODE_COMM: '0' # '0' for local machine (otherwise errors happen); '1' for slurmn (need test) + NCCL_DEBUG: 'ERROR' diff --git a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml index 5b110a0..5a3ae07 100644 --- a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml +++ b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml @@ -15,6 +15,7 @@ optim: lr_min_ratio: 1.5e-5 clip: 1.0 weight_decay: 0.01 + mup: true distributed: gpus: "0,1,2,3,4,5,6,7" @@ -89,6 +90,8 @@ model: model_path: "/mnt/pollux/checkpoints/Qwen2.5-VL-7B-Instruct" relative_depth: 0.75 +save_base_shapes: "base_shapes.pkl" + data: - stage: stage-1 id: 1 diff --git a/apps/Castor/eval.py b/apps/Castor/eval.py index c2a0b18..ea4d11f 100644 --- a/apps/Castor/eval.py +++ b/apps/Castor/eval.py @@ -17,6 +17,7 @@ from apps.Castor.generate import (GeneratorArgs, LatentGenerator, load_consolidated_model) from apps.Castor.model import Castor, ModelArgs +from apps.Castor.modules.text_encoder import create_text_encoder from apps.Castor.modules.vae import (BaseLatentVideoVAE, VideoVAEArgs, create_vae) from apps.main.data import AutoDataLoader, DataArgs @@ -90,7 +91,8 @@ def launch_eval(cfg: EvalArgs): logger.info("Model loaded") model.eval() tvae = create_vae(cfg.generator.tvae) - generator = LatentGenerator(cfg.generator, model, tvae).cuda() + text_encoder = create_text_encoder(cfg.generator.text_encoder) + generator = LatentGenerator(cfg.generator, model, tvae, text_encoder).cuda() active_data = [d for d in cfg.data if d.stage == cfg.stage and d.use] data_loader_factory = AutoDataLoader( shard_id=global_rank, diff --git a/apps/Castor/eval.sbatch b/apps/Castor/eval.sbatch new file mode 100644 index 0000000..1c53363 --- /dev/null +++ b/apps/Castor/eval.sbatch @@ -0,0 +1,130 @@ +#!/bin/bash + +#SBATCH --job-name=castor_eval +#SBATCH --output=slurm_logs/slurm-%x-%j.out # %x for job name, %j for job ID +#SBATCH --error=slurm_logs/slurm-%x-%j.err +# User will specify --nodes and --partition on sbatch command line +# e.g., sbatch --nodes=2 --partition=my_partition eval.sbatch + +#SBATCH --ntasks-per-node=1 # We run one torchrun launcher per node +#SBATCH --gpus-per-node=8 # Each torchrun launcher will manage 8 processes, one per GPU + +# --- Project and Log Directories --- +PROJECT_DIR=${PROJECT_DIR:-"/fsx/ubuntu/workspace/repo/Pollux"} +LOG_DIR=${LOG_DIR:-"/fsx/checkpoints/ablations/logs"} + +echo "Changing directory to Project Directory: ${PROJECT_DIR}" +cd "${PROJECT_DIR}" || { echo "Failed to cd into ${PROJECT_DIR}"; exit 1; } +echo "Current working directory: $(pwd)" + +# --- User defined ENVs for AWS Hyperpod --- +export NCCL_PROTO="Simple" +export FI_PROVIDER="efa" +export FI_EFA_USE_DEVICE_RDMA="1" +export FI_EFA_USE_HUGE_PAGE="0" +export FI_EFA_SET_CUDA_SYNC_MEMOPS="0" +export NCCL_SOCKET_IFNAME="^docker,lo,veth,eth" +export LD_PRELOAD="/usr/local/cuda-12.8/lib/libnccl.so" + +# --- Conda environment --- +CONDA_ENV_NAME="pollux" + +CONDA_PATH=${CONDA_PATH:-"/fsx/ubuntu/miniconda3"} +export PATH="$CONDA_PATH/bin:$PATH" +source $CONDA_PATH/etc/profile.d/conda.sh + +echo "Attempting to activate conda environment: ${CONDA_ENV_NAME}" +_CONDA_ROOT=$(conda info --base 2>/dev/null) + +if [ -z "${_CONDA_ROOT}" ]; then + echo "Error: conda command not found or conda base not determined." + echo "Please ensure conda is installed and initialized." + exit 1 +fi + +conda activate "${CONDA_ENV_NAME}" +if [ $? -ne 0 ]; then + echo "Error: Failed to activate conda environment: ${CONDA_ENV_NAME}" + echo "Please ensure the environment exists and conda is correctly set up." + exit 1 +fi +echo "Conda environment ${CONDA_ENV_NAME} activated successfully." +echo "Python executable: $(which python)" +echo "PYTHONPATH: $PYTHONPATH" + +# --- PyTorch distributed setup --- +# Determine Master Address and Port from Slurm +export PytorchMASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export PytorchMASTER_PORT=29500 # Default port + +echo "--- Slurm Job Information ---" +echo "SLURM_JOB_ID: ${SLURM_JOB_ID}" +echo "SLURM_JOB_NODELIST: ${SLURM_JOB_NODELIST}" +echo "SLURM_NNODES: ${SLURM_NNODES}" +echo "SLURM_NTASKS_PER_NODE: ${SLURM_NTASKS_PER_NODE}" +echo "SLURM_SUBMIT_DIR: ${SLURM_SUBMIT_DIR}" +echo "PytorchMASTER_ADDR: ${PytorchMASTER_ADDR}" +echo "PytorchMASTER_PORT: ${PytorchMASTER_PORT}" +echo "--- End Slurm Job Information ---" + + +AUTO_RESUME="" +if [ -d "/opt/sagemaker_cluster" ]; then + echo "Detected Hyperpod cluster.. enabling --auto-resume=1" + AUTO_RESUME="--auto-resume=1" +fi + +TORCHRUN_CMD="torchrun" + +# TORCHRUN_ARGS: +# torchrun will use the PytorchMASTER_ADDR and PytorchMASTER_PORT for rendezvous. +# nnodes and node_rank are typically auto-detected by torchrun from Slurm environment variables. +declare -a TORCHRUN_ARGS=( + "--nnodes=${SLURM_NNODES}" + "--nproc_per_node=1" + "--rdzv_backend=c10d" + "--rdzv_endpoint=${PytorchMASTER_ADDR}:${PytorchMASTER_PORT}" + "--log_dir=${LOG_DIR}/torchrun_logs/job_${SLURM_JOB_ID}_node_${SLURM_NODEID}" # Per-node torchrun logs +) + +# Training script module and its arguments +declare -a TRAIN_SCRIPT_ARGS=( + "-m" + "apps.Castor.eval" +) +declare -a TRAINING_ARGS=( + "config=apps/Castor/configs/eval.yaml" +) + +echo "--- srun command execution ---" +echo "Starting evaluation with ${SLURM_NNODES} nodes." +echo "Host where sbatch script is running: $(hostname)" +echo "User: $(whoami)" +echo "Current working directory: $(pwd)" + +# The srun command structure requested by user. +# The -l flag labels srun output lines with the task number. +# srun will launch this command once per node (due to --ntasks-per-node=1). + +echo "TORCHRUN_CMD: ${TORCHRUN_CMD}" +echo "TORCHRUN_ARGS: ${TORCHRUN_ARGS[*]}" +echo "TRAIN_SCRIPT_ARGS: ${TRAIN_SCRIPT_ARGS[*]}" +echo "TRAINING_ARGS: ${TRAINING_ARGS[*]}" + +# Ensure all necessary variables are exported for srun tasks +export PATH FI_PROVIDER FI_EFA_USE_DEVICE_RDMA FI_EFA_USE_HUGE_PAGE FI_EFA_SET_CUDA_SYNC_MEMOPS NCCL_PROTO NCCL_SOCKET_IFNAME LD_PRELOAD + +srun ${AUTO_RESUME} \ + "${TORCHRUN_CMD}" \ + "${TORCHRUN_ARGS[@]}" \ + "${TRAIN_SCRIPT_ARGS[@]}" \ + "${TRAINING_ARGS[@]}" + +EXIT_CODE=$? +echo "srun command finished with exit code ${EXIT_CODE}." + +if [ ${EXIT_CODE} -ne 0 ]; then + echo "Evaluation job failed. Please check logs in slurm-${SLURM_JOB_NAME}-${SLURM_JOB_ID}.out/err and any application specific logs." +fi + +exit ${EXIT_CODE} diff --git a/apps/Castor/generate.py b/apps/Castor/generate.py index 0ed6dd7..2fec9a1 100644 --- a/apps/Castor/generate.py +++ b/apps/Castor/generate.py @@ -7,8 +7,10 @@ import numpy as np import torch from apps.Castor.model import Castor, ModelArgs +from apps.Castor.modules.text_encoder import TextEncoderArgs, create_text_encoder from apps.Castor.modules.vae import (BaseLatentVideoVAE, VideoVAEArgs, create_vae) +from apps.text_encoder.text_encoder import TextEncoder from lingua.args import dataclass_from_dict from lingua.checkpoint import (CONSOLIDATE_FOLDER, CONSOLIDATE_NAME, consolidate_checkpoints) @@ -35,6 +37,7 @@ class GeneratorArgs: inference_steps: int = 25 vae_scale_factor: float = 8.0 tvae: VideoVAEArgs = field(default_factory=VideoVAEArgs) + text_encoder: TextEncoderArgs = field(default_factory=TextEncoderArgs) class LatentGenerator(nn.Module): @@ -43,9 +46,11 @@ def __init__( cfg: GeneratorArgs, model: nn.Module, tvae: BaseLatentVideoVAE, + text_encoder: TextEncoder, ): super().__init__() self.model = model + self.vae_scale_factor = cfg.vae_scale_factor self.resolution = int(cfg.resolution // self.vae_scale_factor) self.cond_resolution = int(cfg.cond_resolution // self.vae_scale_factor) @@ -58,6 +63,7 @@ def __init__( self.scheduler = model.scheduler.scheduler self.num_inference_steps = cfg.inference_steps self.tvae = tvae + self.text_encoder = text_encoder def prepare_latent(self, context, device): bsz = len(context["caption"]) @@ -92,7 +98,7 @@ def forward(self, context: Dict[str, Any]) -> torch.Tensor: mu=mu, ) latent = self.prepare_latent(context, device=cur_device) - pos_conditional_signal, pos_conditional_mask = self.model.text_encoder(context) + pos_conditional_signal, pos_conditional_mask = self.text_encoder(context) negative_conditional_signal = ( self.model.diffusion_transformer.negative_token.repeat( pos_conditional_signal.size(0), pos_conditional_signal.size(1), 1 @@ -228,28 +234,29 @@ def main(): cfg.ckpt_dir, model_cls=Castor, model_args_cls=ModelArgs ) tvae = create_vae(gen_cfg.tvae) - generator = LatentGenerator(gen_cfg, pollux, tvae).cuda() + text_encoder = create_text_encoder(gen_cfg.text_encoder) + generator = LatentGenerator(gen_cfg, pollux, tvae, text_encoder).cuda() print("Model loaded successfully") context = { "caption": [ # Short, simple descriptions "A red rose in full bloom against a black background.", - "A happy young man sitting on a piece of cloud, reading a book.", - "A sleeping cat curled up on a windowsill.", - "Fresh snow falling in a forest.", - "Hot air balloons floating in a clear blue sky.", + "A happy young man sitting on a kitchen table chair, reading a book.", + "A sleeping cat curled up on a windowsill, looking at the sky.", + "Children playing in a forest.", + "Hot air balloons floating in a clear blue sky, with a river in the background.", # Medium length, more detailed "A cozy coffee shop interior with vintage furniture, warm lighting, and the aroma of freshly ground beans wafting through the air.", "An ancient temple hidden in a misty mountain valley, its weathered stone walls covered in flowering vines.", "A bustling night market in Tokyo, neon signs reflecting off wet streets as people hurry past food stalls.", "A sea turtle glides gracefully through crystal-clear turquoise water above a school of small fish, with sunlight reflecting off the surface.", - "A petri dish with a bamboo forest growing within it that has tiny red pandas running around.", + "A bamboo forest growing within it that has tiny red pandas running around.", # Technical/scientific - "A lone astronaut floats in space, gazing at a swirling black hole surrounded by vibrant landscapes, rivers and clouds below.", - "Microscopic view of CRISPR gene editing in action, with precisely rendered molecular structures.", - "A topographical map of an alien planet's surface, complete with elevation data and geological formations.", + "A lone astronaut floats in space, rivers and clouds below.", + "Image of a ribosome, a molecular machine that synthesizes proteins.", + "A topographical map of earth's surface, complete with elevation data and geological formations.", # Artistic/abstract "An impressionist painting of music made visible, with colorful swirls representing different instruments in an orchestra.", diff --git a/apps/Castor/model.py b/apps/Castor/model.py index 2db2462..de519ab 100644 --- a/apps/Castor/model.py +++ b/apps/Castor/model.py @@ -23,6 +23,8 @@ class ModelArgs: diffusion_model: TransformerArgs = field(default_factory=TransformerArgs) with_vae: bool = False vae_args: VideoVAEArgs = field(default_factory=VideoVAEArgs) + text_encoder_dim: int = 512 + vision_encoder_dim: int = 2048 vision_encoder_alignment: bool = False vision_encoder_alignment_factor: float = 0.5 vision_encoder_args: VisionEncoderArgs = field(default_factory=VisionEncoderArgs) @@ -45,19 +47,21 @@ def __init__(self, input_dim: int, hidden_dim: int, encoder_dim: int): super(AlignmentProjection, self).__init__() self.proj = nn.Sequential( - nn.Sequential( - nn.Linear(input_dim, hidden_dim), - nn.SiLU(), - nn.Linear(hidden_dim, hidden_dim), - nn.SiLU(), - nn.Linear(hidden_dim, encoder_dim), - ) + nn.Linear(input_dim, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, encoder_dim), ) def forward(self, x): x = self.proj(x) return x + def init_weights(self): + for w in self.proj.parameters(): + nn.init.xavier_uniform_(w) + class Castor(nn.Module): VERSION: str = "v1.0" @@ -67,23 +71,15 @@ def __init__(self, args: ModelArgs): super().__init__() self.args = args - # VAE - if args.with_vae: - self.compressor = create_vae(args.vae_args) - # Vision encoder if args.vision_encoder_alignment: - self.vision_encoder = create_vision_encoder(args.vision_encoder_args) self.vision_encoder_proj = AlignmentProjection( - args.diffusion_model.dim, args.vision_encoder_args.projection_hidden_dim, self.vision_encoder.dim) - - # Text encoder - self.text_encoder = create_text_encoder(args.text_encoder) + args.diffusion_model.dim, args.vision_encoder_args.projection_hidden_dim, args.vision_encoder_dim) - if args.diffusion_model.condition_dim != self.text_encoder.dim(): - logger.warning(f"Condition dim {args.diffusion_model.condition_dim} does not match text encoder dim {self.text_encoder.dim()}") - logger.warning(f"Using {self.text_encoder.dim()} as condition dim") - args.diffusion_model.condition_dim = self.text_encoder.dim() + if args.diffusion_model.condition_dim != args.text_encoder_dim: + logger.warning(f"Condition dim {args.diffusion_model.condition_dim} does not match text encoder dim {args.text_encoder_dim}") + logger.warning(f"Using {args.text_encoder_dim} as condition dim") + args.diffusion_model.condition_dim = args.text_encoder_dim # Diffusion transformer self.diffusion_transformer = DiffusionTransformer(args.diffusion_model) @@ -92,16 +88,7 @@ def __init__(self, args: ModelArgs): self.scheduler = RectifiedFlow(args.scheduler) self.text_cfg_ratio = args.text_cfg_ratio - def forward(self, batch: dict[str:any], flops_meter= None) -> dict[str:any]: - if hasattr(self, "compressor"): - batch["latent_code"] = self.compressor.extract_latents(batch, flops_meter) - - if hasattr(self, "vision_encoder"): - batch["vision_encoder_target"] = self.vision_encoder.extract_image_representations(batch, flops_meter) - - if "text_embedding" not in batch: - batch["text_embedding"], batch["attention_mask"] = self.text_encoder(batch, flops_meter) - + def forward(self, batch: dict[str:any], flops_meter) -> dict[str:any]: conditional_signal, conditional_mask = batch["text_embedding"], batch["attention_mask"] if random.random() <= self.text_cfg_ratio: @@ -127,7 +114,7 @@ def forward(self, batch: dict[str:any], flops_meter= None) -> dict[str:any]: target_loss = self.mse_loss(output.output, batch["target"]) align_loss = None - if hasattr(self, "vision_encoder"): + if self.args.vision_encoder_alignment: vision_encoder_pred = self.vision_encoder_proj(output.align_hidden_state) align_loss = self.consine_loss_with_features( vision_encoder_pred, output.cond_l, output.img_size, batch["vision_encoder_target"]) @@ -187,6 +174,8 @@ def init_weights(self, args: ModelArgs): pre_trained_state_dict = pre_trained_state_dict["model"] self.load_state_dict(pre_trained_state_dict) else: + if args.vision_encoder_alignment: + self.vision_encoder_proj.init_weights() self.diffusion_transformer.init_weights( pre_trained_path=args.diffusion_model.pre_trained_path ) diff --git a/apps/Castor/modules/component.py b/apps/Castor/modules/component.py index 97bffa2..9d324ec 100644 --- a/apps/Castor/modules/component.py +++ b/apps/Castor/modules/component.py @@ -559,16 +559,12 @@ def __init__( n_kv_heads * self.head_dim, bias=False, ) - nn.init.xavier_uniform_(self.wq.weight) - nn.init.xavier_uniform_(self.wk.weight) - nn.init.xavier_uniform_(self.wv.weight) - + self.wo = nn.Linear( n_heads * self.head_dim, dim, bias=False, ) - nn.init.xavier_uniform_(self.wo.weight) if qk_norm: self.q_norm = RMSNorm(self.head_dim, liger_rms_norm=liger_rms_norm) @@ -577,7 +573,7 @@ def __init__( self.q_norm = self.k_norm = nn.Identity() def reset_parameters(self, *args, **kwargs): - nn.init.xavier_uniform_(self.wq.weight) + nn.init.constant_(self.wq.weight, 0.0) # mup init nn.init.xavier_uniform_(self.wk.weight) nn.init.xavier_uniform_(self.wv.weight) nn.init.xavier_uniform_(self.wo.weight) @@ -854,8 +850,7 @@ def reset_parameters(self, init_std=None, factor=1.0): b=3 * out_init_std, ) elif self.use_fp8_ffn: - # TODO: Initialize fp8 parameters - pass + self.ffn.reset_parameters() else: for w in [self.w1, self.w3]: nn.init.trunc_normal_( diff --git a/apps/Castor/modules/text_encoder.py b/apps/Castor/modules/text_encoder.py index e3c408c..acb5b83 100644 --- a/apps/Castor/modules/text_encoder.py +++ b/apps/Castor/modules/text_encoder.py @@ -38,6 +38,7 @@ def __init__(self, args: TextEncoderArgs): self.text_seqlen = args.text_seqlen # TODO: use this to get the dimension of the text encoder for transformer + @property def dim(self) -> int: raise NotImplementedError @@ -67,6 +68,7 @@ def __init__(self, args: TextEncoderArgs): ), ) + @property def dim(self) -> int: return self.clip_model.config.hidden_size @@ -121,6 +123,7 @@ def init_processor(self, model_path: str): model_path, ) + @property def dim(self) -> int: return self.model.config.hidden_size @@ -181,6 +184,7 @@ def __init__(self, args): args.model_path, subfolder="text_encoder", torch_dtype=self.dtype ).cuda() + @property def dim(self) -> int: return self.model.config.hidden_size @@ -222,6 +226,7 @@ def __init__(self, args): args.model_path, torch_dtype=self.dtype ).cuda() + @property def dim(self) -> int: return self.model.config.hidden_size diff --git a/apps/Castor/modules/transformer.py b/apps/Castor/modules/transformer.py index 4a1ca92..a327df2 100644 --- a/apps/Castor/modules/transformer.py +++ b/apps/Castor/modules/transformer.py @@ -32,6 +32,8 @@ from apps.Castor.utils.pad import pad_flat_tokens_to_multiple, unpad_flat_tokens import copy +from mup import MuReadout + logger = logging.getLogger() @@ -74,6 +76,7 @@ def __init__(self, args: TransformerArgs): assert args.n_heads % self.n_kv_heads == 0 assert args.dim % args.n_heads == 0 + self.dim = args.dim # self.attention = Attention( # dim=args.dim, @@ -192,6 +195,8 @@ def init_weights(self, init_std=None, factor=1.0): self.sandwich_norm.reset_parameters() if not self.shared_adaLN: self.adaLN_modulation.reset_parameters() + else: + nn.init.trunc_normal_(self.modulation, std=1) / self.dim**0.5 @dataclass @@ -312,10 +317,12 @@ def __init__(self, args: TransformerArgs): in_dim=self.patch_size * self.patch_size * args.in_channels, out_dim=args.dim, ) - self.img_output = nn.Linear( + # mup init + self.img_output = MuReadout( args.dim, self.patch_size * self.patch_size * args.out_channels, bias=False, + readout_zero_init=True ) self.rope_embeddings_image = RotaryEmbedding2D( theta=args.rope_theta, @@ -655,13 +662,11 @@ def reset_parameters(self, init_std=None): self.cond_norm.reset_parameters() self.tmb_embed.reset_parameters() self.img_embed.reset_parameters() - nn.init.trunc_normal_( - self.img_output.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) + + nn.init.constant_(self.img_output.weight, 0.) # initialize output weights by zero. + if self.img_output.bias is not None: + nn.init.constant_(self.img_output.bias, 0.) + nn.init.trunc_normal_( self.cond_proj.weight, mean=0.0, diff --git a/apps/Castor/modules/vae.py b/apps/Castor/modules/vae.py index 2b2d817..dbd4649 100644 --- a/apps/Castor/modules/vae.py +++ b/apps/Castor/modules/vae.py @@ -5,7 +5,6 @@ import torch from diffusers import AutoencoderKL, AutoencoderKLHunyuanVideo from torch import nn -from cosmos_tokenizer.image_lib import ImageTokenizer logger = logging.getLogger() logger.setLevel(logging.INFO) @@ -225,6 +224,7 @@ def forward(self, x=torch.Tensor): class COSMOSContinuousVAE(BaseLatentVideoVAE): def __init__(self, args: VideoVAEArgs): super().__init__(args) + from cosmos_tokenizer.image_lib import ImageTokenizer """ Initialize the encoder and decoder for Continuous VAE. Checks model type and returns the initialized VAE instance. diff --git a/apps/Castor/train.py b/apps/Castor/train.py index 8e889d3..87974f7 100644 --- a/apps/Castor/train.py +++ b/apps/Castor/train.py @@ -10,6 +10,10 @@ from omegaconf import OmegaConf from tqdm import tqdm +from apps.Castor.modules.text_encoder import create_text_encoder +from apps.Castor.modules.vae import create_vae +from apps.Castor.modules.vision_encoder import create_vision_encoder + cli_args = OmegaConf.from_cli() file_cfg = OmegaConf.load(cli_args.config) os.environ["CUDA_VISIBLE_DEVICES"] = file_cfg.distributed.gpus @@ -42,7 +46,7 @@ from lingua.distributed import (DistributedArgs, EnvironmentArgs, check_model_value_range, dist_mean_dict, get_device_mesh, get_is_master, get_local_rank, - get_world_size, init_signal_handler, + get_world_size, get_is_slurm_job, init_signal_handler, parallelize_model, requeue_slurm_job, setup_env, setup_torch_distributed) from lingua.logger import init_logger @@ -57,6 +61,8 @@ from transformer_engine.common.recipe import Format, DelayedScaling import transformer_engine.pytorch as te +from mup import get_shapes, set_base_shapes, make_base_shapes + logger = logging.getLogger() @@ -95,6 +101,9 @@ class TrainArgs: async_eval_gpus: Optional[int] = None eval: Optional[Any] = None + save_base_shapes: str = "" + load_base_shapes: str = "" + @dataclass class TrainState(Stateful): @@ -255,21 +264,34 @@ def train(args: TrainArgs): dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank() dp_degree *= world_mesh["dp_shard"].size() - logger.info(f"Running on dp rank : {dp_rank}") - logger.info(f"Running on dp size : {dp_degree}") + logger.info(f"Running on dp rank : {dp_rank}/{dp_degree}") torch.manual_seed(args.seed) + + # build encoders + start_time = time.perf_counter() + compressor = create_vae(args.model.vae_args) + vision_encoder = create_vision_encoder(args.model.vision_encoder_args) + text_encoder = create_text_encoder(args.model.text_encoder) + end_time = time.perf_counter() + logger.info(f"Encoders are built in {end_time - start_time:.2f} seconds") + # set encoder dims + args.model.text_encoder_dim = text_encoder.dim + args.model.vision_encoder_dim = vision_encoder.dim + + # build model + start_time = time.perf_counter() logger.info("Building model") + with torch.device("meta"): + model = Castor(args.model) - model = Castor(args.model) - logger.info("Model is built !") - logger.info(model) + logger.info("Model is built on meta device!") model_param_count = get_num_params(model) - flops_meter = FlopsMeter(args.model, model) + flops_meter = FlopsMeter(args.model, model, text_encoder, vision_encoder, compressor) torch.manual_seed(args.seed) - model.init_weights(args.model) + logger.info("Parallelizing model") model = parallelize_model( model, world_mesh, @@ -279,7 +301,18 @@ def train(args: TrainArgs): tp_parallelize=tp_parallelize, no_recompute_ops=get_no_recompute_ops(), ) - model = model.to(device="cuda") + logger.info("Model is parallelized!") + model.to_empty(device="cuda") + logger.info("Model is moved to cuda!") + model.init_weights(args.model) + logger.info("Model is initialized!") + end_time = time.perf_counter() + logger.info(f"Model is initialized in {end_time - start_time:.2f} seconds") + + if args.load_base_shapes != '': + # set base shapes only after the initialization, otherwise they don't persist + logger.info(f'Loading base shapes from {args.load_base_shapes}') + set_base_shapes(model, args.load_base_shapes) check_model_value_range(model, range=10.0, std=1.0) @@ -346,7 +379,7 @@ def train(args: TrainArgs): max_data_load_time = 0.0 gc.collect() - pb = tqdm(total=args.steps, initial=train_state.step, desc="Training Steps") + pb = tqdm(total=args.steps, initial=train_state.step, desc="Training Steps", disable=get_is_slurm_job()) while train_state.step < args.steps: # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1 @@ -412,11 +445,16 @@ def train(args: TrainArgs): end_timer = torch.cuda.Event(enable_timing=True) start_timer.record() - with te.fp8_autocast(enabled=args.model.diffusion_model.use_fp8_ffn, fp8_recipe=fp8_recipe, fp8_group=all_gpus): - outputs = model(batch, flops_meter) - # We scale loss with grad_acc_steps so the gradient is the same - # regardless of grad_acc_steps - loss = outputs.loss / args.grad_acc_steps + batch["latent_code"] = compressor.extract_latents(batch, flops_meter) + batch["vision_encoder_target"] = vision_encoder.extract_image_representations(batch, flops_meter) + batch["text_embedding"], batch["attention_mask"] = text_encoder(batch, flops_meter) + + # with te.fp8_autocast(enabled=args.model.diffusion_model.use_fp8_ffn, fp8_recipe=fp8_recipe, fp8_group=all_gpus): + outputs = model(batch, flops_meter) + # We scale loss with grad_acc_steps so the gradient is the same + # regardless of grad_acc_steps + loss = outputs.loss / args.grad_acc_steps + # backward on scaled loss to create scaled gradients loss.backward() # For logging we undo that scaling @@ -607,7 +645,26 @@ def main(): cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) cfg = OmegaConf.to_object(cfg) - train(cfg) + + if cfg.save_base_shapes != '': + logger.info(f'Saving base shapes at {cfg.save_base_shapes}') + cfg_base: ModelArgs = deepcopy(cfg.model) + cfg_base.diffusion_model.dim = 288 + cfg_base.diffusion_model.n_heads = 8 + base_model = Castor(cfg_base) + base_shapes = get_shapes(base_model) + + cfg_delta: ModelArgs = deepcopy(cfg.model) + cfg_delta.diffusion_model.dim = 352 + cfg_delta.diffusion_model.n_heads = 16 + delta_model = Castor(cfg_delta) + delta_shapes = get_shapes(delta_model) + + make_base_shapes(base_shapes, delta_shapes, savefile=cfg.save_base_shapes) + logger.info('done and exit') + else: + train(cfg) + if __name__ == "__main__": diff --git a/apps/Castor/train.sbatch b/apps/Castor/train.sbatch new file mode 100644 index 0000000..5d918db --- /dev/null +++ b/apps/Castor/train.sbatch @@ -0,0 +1,130 @@ +#!/bin/bash + +#SBATCH --job-name=castor_training +#SBATCH --output=slurm_logs/slurm-%x-%j.out # %x for job name, %j for job ID +#SBATCH --error=slurm_logs/slurm-%x-%j.err +# User will specify --nodes and --partition on sbatch command line +# e.g., sbatch --nodes=2 --partition=my_partition train.sbatch + +#SBATCH --ntasks-per-node=1 # We run one torchrun launcher per node +#SBATCH --gpus-per-node=8 # Each torchrun launcher will manage 8 processes, one per GPU + +# --- Project and Log Directories --- +PROJECT_DIR=${PROJECT_DIR:-"/fsx/ubuntu/workspace/repo/Pollux"} +LOG_DIR=${LOG_DIR:-"/fsx/checkpoints/ablations/logs"} + +echo "Changing directory to Project Directory: ${PROJECT_DIR}" +cd "${PROJECT_DIR}" || { echo "Failed to cd into ${PROJECT_DIR}"; exit 1; } +echo "Current working directory: $(pwd)" + +# --- User defined ENVs for AWS Hyperpod --- +export NCCL_PROTO="Simple" +export FI_PROVIDER="efa" +export FI_EFA_USE_DEVICE_RDMA="1" +export FI_EFA_USE_HUGE_PAGE="0" +export FI_EFA_SET_CUDA_SYNC_MEMOPS="0" +export NCCL_SOCKET_IFNAME="^docker,lo,veth,eth" +export LD_PRELOAD="/usr/local/cuda-12.8/lib/libnccl.so" + +# --- Conda environment --- +CONDA_ENV_NAME="pollux" + +CONDA_PATH=${CONDA_PATH:-"/fsx/ubuntu/miniconda3"} +export PATH="$CONDA_PATH/bin:$PATH" +source $CONDA_PATH/etc/profile.d/conda.sh + +echo "Attempting to activate conda environment: ${CONDA_ENV_NAME}" +_CONDA_ROOT=$(conda info --base 2>/dev/null) + +if [ -z "${_CONDA_ROOT}" ]; then + echo "Error: conda command not found or conda base not determined." + echo "Please ensure conda is installed and initialized." + exit 1 +fi + +conda activate "${CONDA_ENV_NAME}" +if [ $? -ne 0 ]; then + echo "Error: Failed to activate conda environment: ${CONDA_ENV_NAME}" + echo "Please ensure the environment exists and conda is correctly set up." + exit 1 +fi +echo "Conda environment ${CONDA_ENV_NAME} activated successfully." +echo "Python executable: $(which python)" +echo "PYTHONPATH: $PYTHONPATH" + +# --- PyTorch distributed setup --- +# Determine Master Address and Port from Slurm +export PytorchMASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export PytorchMASTER_PORT=29500 # Default port + +echo "--- Slurm Job Information ---" +echo "SLURM_JOB_ID: ${SLURM_JOB_ID}" +echo "SLURM_JOB_NODELIST: ${SLURM_JOB_NODELIST}" +echo "SLURM_NNODES: ${SLURM_NNODES}" +echo "SLURM_NTASKS_PER_NODE: ${SLURM_NTASKS_PER_NODE}" +echo "SLURM_SUBMIT_DIR: ${SLURM_SUBMIT_DIR}" +echo "PytorchMASTER_ADDR: ${PytorchMASTER_ADDR}" +echo "PytorchMASTER_PORT: ${PytorchMASTER_PORT}" +echo "--- End Slurm Job Information ---" + + +AUTO_RESUME="" +if [ -d "/opt/sagemaker_cluster" ]; then + echo "Detected Hyperpod cluster.. enabling --auto-resume=1" + AUTO_RESUME="--auto-resume=1" +fi + +TORCHRUN_CMD="torchrun" + +# TORCHRUN_ARGS: +# torchrun will use the PytorchMASTER_ADDR and PytorchMASTER_PORT for rendezvous. +# nnodes and node_rank are typically auto-detected by torchrun from Slurm environment variables. +declare -a TORCHRUN_ARGS=( + "--nnodes=${SLURM_NNODES}" + "--nproc_per_node=8" + "--rdzv_backend=c10d" + "--rdzv_endpoint=${PytorchMASTER_ADDR}:${PytorchMASTER_PORT}" + "--log_dir=${LOG_DIR}/torchrun_logs/job_${SLURM_JOB_ID}_node_${SLURM_NODEID}" # Per-node torchrun logs +) + +# Training script module and its arguments +declare -a TRAIN_SCRIPT_ARGS=( + "-m" + "apps.Castor.train" +) +declare -a TRAINING_ARGS=( + "config=apps/Castor/configs/aws_256_Castor_flux_qwen_fixed_siglip2.yaml" +) + +echo "--- srun command execution ---" +echo "Starting training with ${SLURM_NNODES} nodes." +echo "Host where sbatch script is running: $(hostname)" +echo "User: $(whoami)" +echo "Current working directory: $(pwd)" + +# The srun command structure requested by user. +# The -l flag labels srun output lines with the task number. +# srun will launch this command once per node (due to --ntasks-per-node=1). + +echo "TORCHRUN_CMD: ${TORCHRUN_CMD}" +echo "TORCHRUN_ARGS: ${TORCHRUN_ARGS[*]}" +echo "TRAIN_SCRIPT_ARGS: ${TRAIN_SCRIPT_ARGS[*]}" +echo "TRAINING_ARGS: ${TRAINING_ARGS[*]}" + +# Ensure all necessary variables are exported for srun tasks +export PATH FI_PROVIDER FI_EFA_USE_DEVICE_RDMA FI_EFA_USE_HUGE_PAGE FI_EFA_SET_CUDA_SYNC_MEMOPS NCCL_PROTO NCCL_SOCKET_IFNAME LD_PRELOAD + +srun ${AUTO_RESUME} \ + "${TORCHRUN_CMD}" \ + "${TORCHRUN_ARGS[@]}" \ + "${TRAIN_SCRIPT_ARGS[@]}" \ + "${TRAINING_ARGS[@]}" + +EXIT_CODE=$? +echo "srun command finished with exit code ${EXIT_CODE}." + +if [ ${EXIT_CODE} -ne 0 ]; then + echo "Training job failed. Please check logs in slurm-${SLURM_JOB_NAME}-${SLURM_JOB_ID}.out/err and any application specific logs." +fi + +exit ${EXIT_CODE} diff --git a/apps/Castor/utils/flop_meter.py b/apps/Castor/utils/flop_meter.py index 95dbc81..13538fc 100644 --- a/apps/Castor/utils/flop_meter.py +++ b/apps/Castor/utils/flop_meter.py @@ -1,5 +1,7 @@ from apps.Castor.model import Castor +from apps.Castor.modules.text_encoder import BaseTextEncoder from apps.Castor.modules.transformer import TransformerArgs +from apps.Castor.modules.vision_encoder import BaseVisionEncoder from lingua.metrics import get_num_params @@ -100,7 +102,7 @@ def estimate_mfu(self, fwdbwd_per_iter, dt): class FlopsMeter: def __init__( - self, args: TransformerArgs, model: Castor, device="h100", dtype="bf16" + self, args: TransformerArgs, model: Castor, text_encoder, vision_encoder, compressor, device="h100", dtype="bf16" ): self.diffusion_params = get_num_params(model.diffusion_transformer) self.diffusion_num_layers = args.diffusion_model.n_layers @@ -110,20 +112,20 @@ def __init__( ) self.diffusion_dim = args.diffusion_model.dim - self.cond_params = get_num_params(model.text_encoder.model) - self.cond_dim = model.text_encoder.dim() - self.cond_num_layers = len(model.text_encoder.model.layers) - self.cond_num_heads = model.text_encoder.model.config.num_attention_heads + self.cond_params = get_num_params(text_encoder.model) + self.cond_dim = text_encoder.dim + self.cond_num_layers = len(text_encoder.model.layers) + self.cond_num_heads = text_encoder.model.config.num_attention_heads self.cond_headdim = self.cond_dim // self.cond_num_heads - self.vision_params = get_num_params(model.vision_encoder.model) - self.vision_num_layers = model.vision_encoder.model.config.num_hidden_layers - self.vision_num_heads = model.vision_encoder.model.config.num_attention_heads - self.vision_dim = model.vision_encoder.model.config.hidden_size - self.vision_patch_size = model.vision_encoder.model.config.patch_size + self.vision_params = get_num_params(vision_encoder.model) + self.vision_num_layers = vision_encoder.model.config.num_hidden_layers + self.vision_num_heads = vision_encoder.model.config.num_attention_heads + self.vision_dim = vision_encoder.model.config.hidden_size + self.vision_patch_size = vision_encoder.model.config.patch_size self.vision_headdim = self.vision_dim // self.vision_num_heads - self.vae_params = get_num_params(model.compressor.vae) + self.vae_params = get_num_params(compressor.vae) self.diffusion_flops = 0 self.vision_encoder_flops = 0 diff --git a/apps/main/utils/mongodb_data_load.py b/apps/main/utils/mongodb_data_load.py index ef48a4a..d9b0df1 100644 --- a/apps/main/utils/mongodb_data_load.py +++ b/apps/main/utils/mongodb_data_load.py @@ -33,6 +33,8 @@ from urllib3.util.retry import Retry from urllib.parse import urlparse +from lingua.distributed import get_is_slurm_job + logging.getLogger("pymongo").setLevel(logging.WARNING) boto3.set_stream_logger("boto3", level=logging.WARNING) boto3.set_stream_logger("botocore", level=logging.WARNING) @@ -132,23 +134,23 @@ def set_local_partition(self): file_path = os.path.join(self.root_dir, f"{self.collection_name}.json") with open(file_path, "r") as file: - for item in tqdm(ijson.items(file, "item"), desc=f"Loading data to shard {self.shard_idx}"): + for item in tqdm(ijson.items(file, "item"), desc=f"Loading data to shard {self.shard_idx}", disable=get_is_slurm_job()): partition_key = int(item[self.partition_key]) if partition_key % self.num_shards == self.shard_idx: data.append(item) # # Note: used for debugging - # if len(data) > 10000: - # break + if len(data) > 2600000: + break elif self.root_dir_type == "parquet": logging.info(f"Loading data from local parquet files: {self.root_dir}") - parquet_files = glob.glob(os.path.join(self.root_dir, "*.parquet")) - for file in tqdm(parquet_files, desc=f"Loading data to shard {self.shard_idx}"): + parquet_files = glob.glob(os.path.join(self.root_dir, self.collection_name, "**/*.parquet")) + for file in tqdm(parquet_files, desc=f"Loading data to shard {self.shard_idx}", disable=get_is_slurm_job()): df = pd.read_parquet(file) df = df[df[self.partition_key] % self.num_shards == self.shard_idx] data.extend(df.to_dict(orient="records")) # # Note: used for debugging - # if len(data) > 10000: + # if len(data) > 2500000: # break else: raise ValueError(f"Invalid Root Directory Type. Set root_dir_type to 'json' or 'parquet'") @@ -164,7 +166,7 @@ def set_local_partition(self): ) def __len__(self) -> int: - return self.data.index.max() + return len(self.data) def __getitem__(self, idx: int) -> dict[str, Any]: return self.data[idx] @@ -238,13 +240,13 @@ def http_client(self, imageUrl: str) -> tuple[Image.Image, bool]: except (requests.RequestException, IOError) as e: status_code = getattr(locals().get('head_response'), 'status_code', 'N/A') # ensure head_response is accessed safely if isinstance(e, requests.Timeout): - logging.debug(f"Timeout downloading image: {imageUrl}") + logging.error(f"Timeout downloading image: {imageUrl}") elif isinstance(e, requests.HTTPError): - logging.debug(f"HTTP error ({status_code}) for: {imageUrl}") + logging.error(f"HTTP error ({status_code}) for: {imageUrl}") elif isinstance(e, requests.ConnectionError): - logging.debug(f"Connection error for: {imageUrl}") + logging.error(f"Connection error for: {imageUrl}") else: - logging.debug(f"Error processing image {imageUrl}: {str(e)}") + logging.error(f"Error processing image {imageUrl}: {str(e)}") # Fall back to placeholder image return self.place_holder_image, False # Signal failure @@ -274,7 +276,7 @@ def s3_client(self, imageUrl: str) -> tuple[Image.Image, bool]: except Exception as e: # Catching a broad exception. # For production, you might want to catch more specific Boto3 exceptions like ClientError - logging.debug(f"Error downloading image from S3 {imageUrl}: {str(e)}") + logging.error(f"Error downloading image from S3 {imageUrl}: {str(e)}") return self.place_holder_image, False # Signal failure @@ -286,7 +288,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: # for pd data sample = self.data.iloc[idx] # Use iloc for row access in DataFrame return_sample = {} - return_sample["_id"] = str(sample["_id"]) + return_sample["_id"] = str(sample["_id"] if "_id" in sample else sample["id"]) caption = sample["caption"] if isinstance(caption, tuple): caption = caption[0] diff --git a/lingua/optim.py b/lingua/optim.py index 1c7ba9b..ab92a85 100644 --- a/lingua/optim.py +++ b/lingua/optim.py @@ -30,6 +30,8 @@ class OptimArgs: exp_factor: float = 0.5 + mup: bool = False + def lr_linear(step: int, warmup: int, n_steps: int, min_ratio: float) -> float: if step < warmup: @@ -145,14 +147,25 @@ def build_lr_fn(args: OptimArgs, n_steps: int): def build_optimizer(model: nn.Module, args: OptimArgs, n_steps: int): logger.info("Starting build of optimizer...") - optimizer = AdamW( - (param for param in model.parameters() if param.requires_grad), - lr=args.lr, - betas=(args.beta1, args.beta2), - weight_decay=args.weight_decay, - eps=args.epsilon, - fused=True, # Faster optim.step but can throw errors - ) + if args.mup: + from mup.optim import MuAdamW + optimizer = MuAdamW( + (param for param in model.parameters() if param.requires_grad), + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + eps=args.epsilon, + fused=True, # Faster optim.step but can throw errors + ) + else: + optimizer = AdamW( + (param for param in model.parameters() if param.requires_grad), + lr=args.lr, + betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay, + eps=args.epsilon, + fused=True, # Faster optim.step but can throw errors + ) # scheduler lr_fn = build_lr_fn(args, n_steps)