From 11076c136bb33df9c2ceff9fe4c833eebbc9e901 Mon Sep 17 00:00:00 2001 From: sippycoder Date: Sat, 31 May 2025 19:43:36 +0000 Subject: [PATCH 1/8] slurm script --- apps/Castor/train.sbatch | 136 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 apps/Castor/train.sbatch diff --git a/apps/Castor/train.sbatch b/apps/Castor/train.sbatch new file mode 100644 index 0000000..3da3ca8 --- /dev/null +++ b/apps/Castor/train.sbatch @@ -0,0 +1,136 @@ +#!/bin/bash + +#SBATCH --job-name=castor_training +#SBATCH --output=slurm-%x-%j.out # %x for job name, %j for job ID +#SBATCH --error=slurm-%x-%j.err +# User will specify --nodes and --partition on sbatch command line +# e.g., sbatch --nodes=2 --partition=my_partition train.sbatch + +#SBATCH --ntasks-per-node=1 # We run one torchrun launcher per node +#SBATCH --gpus-per-node=8 # Each torchrun launcher will manage 8 processes, one per GPU +#SBATCH --cpus-per-task=24 # Allocate 24 CPUs for the torchrun task (and its 8 worker processes) + +# --- Project and Log Directories --- +PROJECT_DIR=${PROJECT_DIR:-"/fsx/ubuntu/workspace/repo/Pollux"} +LOG_DIR=${LOG_DIR:-"/fsx/checkpoints/ablations/logs"} + +echo "Changing directory to Project Directory: ${PROJECT_DIR}" +cd "${PROJECT_DIR}" || { echo "Failed to cd into ${PROJECT_DIR}"; exit 1; } +echo "Current working directory: $(pwd)" + +# --- User defined ENVs for AWS Hyperpod --- +export NCCL_PROTO="Simple" +export FI_PROVIDER="efa" +export FI_EFA_USE_DEVICE_RDMA="1" +export FI_EFA_USE_HUGE_PAGE="0" +export FI_EFA_SET_CUDA_SYNC_MEMOPS="0" +export NCCL_SOCKET_IFNAME="^docker,lo,veth,eth" +export LD_PRELOAD="/usr/local/cuda-12.8/lib/libnccl.so" + +# --- Conda environment --- +CONDA_ENV_NAME="pollux" + +echo "Attempting to activate conda environment: ${CONDA_ENV_NAME}" +_CONDA_ROOT=$(conda info --base 2>/dev/null) + +if [ -z "${_CONDA_ROOT}" ]; then + echo "Error: conda command not found or conda base not determined." + echo "Please ensure conda is installed and initialized." + exit 1 +fi + +# Source conda.sh if not already sourced or conda command not available +if ! command -v conda &> /dev/null || [ -z "$CONDA_SHLVL" ] || [ "$CONDA_SHLVL" -lt 1 ]; then + echo "Sourcing conda from ${_CONDA_ROOT}/etc/profile.d/conda.sh" + # shellcheck source=/dev/null + source "${_CONDA_ROOT}/etc/profile.d/conda.sh" + if [ $? -ne 0 ]; then + echo "Error: Failed to source conda.sh from ${_CONDA_ROOT}/etc/profile.d/conda.sh" + exit 1 + fi +else + echo "Conda appears to be already initialized." +fi + +conda activate "${CONDA_ENV_NAME}" +if [ $? -ne 0 ]; then + echo "Error: Failed to activate conda environment: ${CONDA_ENV_NAME}" + echo "Please ensure the environment exists and conda is correctly set up." + exit 1 +fi +echo "Conda environment ${CONDA_ENV_NAME} activated successfully." +echo "Python executable: $(which python)" +echo "PYTHONPATH: $PYTHONPATH" + +# --- PyTorch distributed setup --- +# Determine Master Address and Port from Slurm +export PytorchMASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export PytorchMASTER_PORT=29500 # Default port + +echo "--- Slurm Job Information ---" +echo "SLURM_JOB_ID: ${SLURM_JOB_ID}" +echo "SLURM_JOB_NODELIST: ${SLURM_JOB_NODELIST}" +echo "SLURM_NNODES: ${SLURM_NNODES}" +echo "SLURM_NTASKS_PER_NODE: ${SLURM_NTASKS_PER_NODE}" +echo "SLURM_SUBMIT_DIR: ${SLURM_SUBMIT_DIR}" +echo "PytorchMASTER_ADDR: ${PytorchMASTER_ADDR}" +echo "PytorchMASTER_PORT: ${PytorchMASTER_PORT}" +echo "--- End Slurm Job Information ---" + + +AUTO_RESUME="" +if [ -d "/opt/sagemaker_cluster" ]; then + echo "Detected Hyperpod cluster.. enabling --auto-resume=1" + AUTO_RESUME="--auto-resume=1" +fi + +TORCHRUN_CMD="torchrun" + +# TORCHRUN_ARGS: +# torchrun will use the PytorchMASTER_ADDR and PytorchMASTER_PORT for rendezvous. +# nnodes and node_rank are typically auto-detected by torchrun from Slurm environment variables. +declare -a TORCHRUN_ARGS=( + "--nproc_per_node=8" + "--rdzv_backend=c10d" + "--rdzv_endpoint=${PytorchMASTER_ADDR}:${PytorchMASTER_PORT}" + "--log_dir=${LOG_DIR}/torchrun_logs/job_${SLURM_JOB_ID}_node_${SLURM_NODEID}" # Per-node torchrun logs +) + +# Training script module and its arguments +TRAIN_SCRIPT_MODULE="-m apps.Castor.train" +declare -a TRAINING_ARGS=( + "config=apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml" +) + +echo "--- srun command execution ---" +echo "Starting training with ${SLURM_NNODES} nodes." +echo "Host where sbatch script is running: $(hostname)" +echo "User: $(whoami)" +echo "Current working directory: $(pwd)" + +# The srun command structure requested by user. +# The -l flag labels srun output lines with the task number. +# srun will launch this command once per node (due to --ntasks-per-node=1). + +echo "TORCHRUN_CMD: ${TORCHRUN_CMD}" +echo "TORCHRUN_ARGS: ${TORCHRUN_ARGS[*]}" +echo "TRAIN_SCRIPT_MODULE: ${TRAIN_SCRIPT_MODULE}" +echo "TRAINING_ARGS: ${TRAINING_ARGS[*]}" + +# Ensure all necessary variables are exported for srun tasks +export PATH FI_PROVIDER FI_EFA_USE_DEVICE_RDMA FI_EFA_USE_HUGE_PAGE FI_EFA_SET_CUDA_SYNC_MEMOPS NCCL_PROTO NCCL_SOCKET_IFNAME LD_PRELOAD + +srun ${AUTO_RESUME} \ + "${TORCHRUN_CMD}" \ + "${TORCHRUN_ARGS[@]}" \ + "${TRAIN_SCRIPT_MODULE}" \ + "${TRAINING_ARGS[@]}" + +EXIT_CODE=$? +echo "srun command finished with exit code ${EXIT_CODE}." + +if [ ${EXIT_CODE} -ne 0 ]; then + echo "Training job failed. Please check logs in slurm-${SLURM_JOB_NAME}-${SLURM_JOB_ID}.out/err and any application specific logs." +fi + +exit ${EXIT_CODE} From bee01ed0d4f571f6ee3f0b7c631da73ba9809b28 Mon Sep 17 00:00:00 2001 From: nmnWithNucleus Date: Sun, 1 Jun 2025 00:18:34 +0000 Subject: [PATCH 2/8] refactor --- apps/Castor/model.py | 33 ++++++--------------- apps/Castor/modules/text_encoder.py | 5 ++++ apps/Castor/modules/vae.py | 2 +- apps/Castor/train.py | 43 ++++++++++++++++++++++------ apps/Castor/utils/flop_meter.py | 24 +++++++++------- apps/main/utils/mongodb_data_load.py | 8 +++--- 6 files changed, 67 insertions(+), 48 deletions(-) diff --git a/apps/Castor/model.py b/apps/Castor/model.py index 2db2462..0dd2b73 100644 --- a/apps/Castor/model.py +++ b/apps/Castor/model.py @@ -23,6 +23,8 @@ class ModelArgs: diffusion_model: TransformerArgs = field(default_factory=TransformerArgs) with_vae: bool = False vae_args: VideoVAEArgs = field(default_factory=VideoVAEArgs) + text_encoder_dim: int = 512 + vision_encoder_dim: int = 2048 vision_encoder_alignment: bool = False vision_encoder_alignment_factor: float = 0.5 vision_encoder_args: VisionEncoderArgs = field(default_factory=VisionEncoderArgs) @@ -67,23 +69,15 @@ def __init__(self, args: ModelArgs): super().__init__() self.args = args - # VAE - if args.with_vae: - self.compressor = create_vae(args.vae_args) - # Vision encoder if args.vision_encoder_alignment: - self.vision_encoder = create_vision_encoder(args.vision_encoder_args) self.vision_encoder_proj = AlignmentProjection( - args.diffusion_model.dim, args.vision_encoder_args.projection_hidden_dim, self.vision_encoder.dim) - - # Text encoder - self.text_encoder = create_text_encoder(args.text_encoder) + args.diffusion_model.dim, args.vision_encoder_args.projection_hidden_dim, args.vision_encoder_dim) - if args.diffusion_model.condition_dim != self.text_encoder.dim(): - logger.warning(f"Condition dim {args.diffusion_model.condition_dim} does not match text encoder dim {self.text_encoder.dim()}") - logger.warning(f"Using {self.text_encoder.dim()} as condition dim") - args.diffusion_model.condition_dim = self.text_encoder.dim() + if args.diffusion_model.condition_dim != args.text_encoder_dim: + logger.warning(f"Condition dim {args.diffusion_model.condition_dim} does not match text encoder dim {args.text_encoder_dim}") + logger.warning(f"Using {args.text_encoder_dim} as condition dim") + args.diffusion_model.condition_dim = args.text_encoder_dim # Diffusion transformer self.diffusion_transformer = DiffusionTransformer(args.diffusion_model) @@ -92,16 +86,7 @@ def __init__(self, args: ModelArgs): self.scheduler = RectifiedFlow(args.scheduler) self.text_cfg_ratio = args.text_cfg_ratio - def forward(self, batch: dict[str:any], flops_meter= None) -> dict[str:any]: - if hasattr(self, "compressor"): - batch["latent_code"] = self.compressor.extract_latents(batch, flops_meter) - - if hasattr(self, "vision_encoder"): - batch["vision_encoder_target"] = self.vision_encoder.extract_image_representations(batch, flops_meter) - - if "text_embedding" not in batch: - batch["text_embedding"], batch["attention_mask"] = self.text_encoder(batch, flops_meter) - + def forward(self, batch: dict[str:any], flops_meter) -> dict[str:any]: conditional_signal, conditional_mask = batch["text_embedding"], batch["attention_mask"] if random.random() <= self.text_cfg_ratio: @@ -127,7 +112,7 @@ def forward(self, batch: dict[str:any], flops_meter= None) -> dict[str:any]: target_loss = self.mse_loss(output.output, batch["target"]) align_loss = None - if hasattr(self, "vision_encoder"): + if self.args.vision_encoder_alignment: vision_encoder_pred = self.vision_encoder_proj(output.align_hidden_state) align_loss = self.consine_loss_with_features( vision_encoder_pred, output.cond_l, output.img_size, batch["vision_encoder_target"]) diff --git a/apps/Castor/modules/text_encoder.py b/apps/Castor/modules/text_encoder.py index e3c408c..acb5b83 100644 --- a/apps/Castor/modules/text_encoder.py +++ b/apps/Castor/modules/text_encoder.py @@ -38,6 +38,7 @@ def __init__(self, args: TextEncoderArgs): self.text_seqlen = args.text_seqlen # TODO: use this to get the dimension of the text encoder for transformer + @property def dim(self) -> int: raise NotImplementedError @@ -67,6 +68,7 @@ def __init__(self, args: TextEncoderArgs): ), ) + @property def dim(self) -> int: return self.clip_model.config.hidden_size @@ -121,6 +123,7 @@ def init_processor(self, model_path: str): model_path, ) + @property def dim(self) -> int: return self.model.config.hidden_size @@ -181,6 +184,7 @@ def __init__(self, args): args.model_path, subfolder="text_encoder", torch_dtype=self.dtype ).cuda() + @property def dim(self) -> int: return self.model.config.hidden_size @@ -222,6 +226,7 @@ def __init__(self, args): args.model_path, torch_dtype=self.dtype ).cuda() + @property def dim(self) -> int: return self.model.config.hidden_size diff --git a/apps/Castor/modules/vae.py b/apps/Castor/modules/vae.py index 2b2d817..dbd4649 100644 --- a/apps/Castor/modules/vae.py +++ b/apps/Castor/modules/vae.py @@ -5,7 +5,6 @@ import torch from diffusers import AutoencoderKL, AutoencoderKLHunyuanVideo from torch import nn -from cosmos_tokenizer.image_lib import ImageTokenizer logger = logging.getLogger() logger.setLevel(logging.INFO) @@ -225,6 +224,7 @@ def forward(self, x=torch.Tensor): class COSMOSContinuousVAE(BaseLatentVideoVAE): def __init__(self, args: VideoVAEArgs): super().__init__(args) + from cosmos_tokenizer.image_lib import ImageTokenizer """ Initialize the encoder and decoder for Continuous VAE. Checks model type and returns the initialized VAE instance. diff --git a/apps/Castor/train.py b/apps/Castor/train.py index 779ddad..de8f37d 100644 --- a/apps/Castor/train.py +++ b/apps/Castor/train.py @@ -10,6 +10,10 @@ from omegaconf import OmegaConf from tqdm import tqdm +from apps.Castor.modules.text_encoder import create_text_encoder +from apps.Castor.modules.vae import create_vae +from apps.Castor.modules.vision_encoder import create_vision_encoder + cli_args = OmegaConf.from_cli() file_cfg = OmegaConf.load(cli_args.config) os.environ["CUDA_VISIBLE_DEVICES"] = file_cfg.distributed.gpus @@ -252,20 +256,33 @@ def train(args: TrainArgs): dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank() dp_degree *= world_mesh["dp_shard"].size() - logger.info(f"Running on dp rank : {dp_rank}") - logger.info(f"Running on dp size : {dp_degree}") + logger.info(f"Running on dp rank : {dp_rank}/{dp_degree}") torch.manual_seed(args.seed) - logger.info("Building model") - model = Castor(args.model) - logger.info("Model is built !") + # build encoders + start_time = time.perf_counter() + compressor = create_vae(args.model.vae_args) + vision_encoder = create_vision_encoder(args.model.vision_encoder_args) + text_encoder = create_text_encoder(args.model.text_encoder) + end_time = time.perf_counter() + logger.info(f"Encoders are built in {end_time - start_time:.2f} seconds") + # set encoder dims + args.model.text_encoder_dim = text_encoder.dim + args.model.vision_encoder_dim = vision_encoder.dim + + # build model + start_time = time.perf_counter() + logger.info("Building model") + with torch.device("meta"): + model = Castor(args.model) + logger.info("Model is built on meta device!") model_param_count = get_num_params(model) - flops_meter = FlopsMeter(args.model, model) + flops_meter = FlopsMeter(args.model, model, text_encoder, vision_encoder, compressor) torch.manual_seed(args.seed) - model.init_weights(args.model) + logger.info("Parallelizing model") model = parallelize_model( model, world_mesh, @@ -275,7 +292,13 @@ def train(args: TrainArgs): tp_parallelize=tp_parallelize, no_recompute_ops=get_no_recompute_ops(), ) - model = model.to(device="cuda") + logger.info("Model is parallelized!") + model.to_empty(device="cuda") + logger.info("Model is moved to cuda!") + model.init_weights(args.model) + logger.info("Model is initialized!") + end_time = time.perf_counter() + logger.info(f"Model is initialized in {end_time - start_time:.2f} seconds") check_model_value_range(model, range=10.0, std=1.0) @@ -399,6 +422,10 @@ def train(args: TrainArgs): end_timer = torch.cuda.Event(enable_timing=True) start_timer.record() + batch["latent_code"] = compressor.extract_latents(batch, flops_meter) + batch["vision_encoder_target"] = vision_encoder.extract_image_representations(batch, flops_meter) + batch["text_embedding"], batch["attention_mask"] = text_encoder(batch, flops_meter) + outputs = model(batch, flops_meter) # We scale loss with grad_acc_steps so the gradient is the same # regardless of grad_acc_steps diff --git a/apps/Castor/utils/flop_meter.py b/apps/Castor/utils/flop_meter.py index 95dbc81..13538fc 100644 --- a/apps/Castor/utils/flop_meter.py +++ b/apps/Castor/utils/flop_meter.py @@ -1,5 +1,7 @@ from apps.Castor.model import Castor +from apps.Castor.modules.text_encoder import BaseTextEncoder from apps.Castor.modules.transformer import TransformerArgs +from apps.Castor.modules.vision_encoder import BaseVisionEncoder from lingua.metrics import get_num_params @@ -100,7 +102,7 @@ def estimate_mfu(self, fwdbwd_per_iter, dt): class FlopsMeter: def __init__( - self, args: TransformerArgs, model: Castor, device="h100", dtype="bf16" + self, args: TransformerArgs, model: Castor, text_encoder, vision_encoder, compressor, device="h100", dtype="bf16" ): self.diffusion_params = get_num_params(model.diffusion_transformer) self.diffusion_num_layers = args.diffusion_model.n_layers @@ -110,20 +112,20 @@ def __init__( ) self.diffusion_dim = args.diffusion_model.dim - self.cond_params = get_num_params(model.text_encoder.model) - self.cond_dim = model.text_encoder.dim() - self.cond_num_layers = len(model.text_encoder.model.layers) - self.cond_num_heads = model.text_encoder.model.config.num_attention_heads + self.cond_params = get_num_params(text_encoder.model) + self.cond_dim = text_encoder.dim + self.cond_num_layers = len(text_encoder.model.layers) + self.cond_num_heads = text_encoder.model.config.num_attention_heads self.cond_headdim = self.cond_dim // self.cond_num_heads - self.vision_params = get_num_params(model.vision_encoder.model) - self.vision_num_layers = model.vision_encoder.model.config.num_hidden_layers - self.vision_num_heads = model.vision_encoder.model.config.num_attention_heads - self.vision_dim = model.vision_encoder.model.config.hidden_size - self.vision_patch_size = model.vision_encoder.model.config.patch_size + self.vision_params = get_num_params(vision_encoder.model) + self.vision_num_layers = vision_encoder.model.config.num_hidden_layers + self.vision_num_heads = vision_encoder.model.config.num_attention_heads + self.vision_dim = vision_encoder.model.config.hidden_size + self.vision_patch_size = vision_encoder.model.config.patch_size self.vision_headdim = self.vision_dim // self.vision_num_heads - self.vae_params = get_num_params(model.compressor.vae) + self.vae_params = get_num_params(compressor.vae) self.diffusion_flops = 0 self.vision_encoder_flops = 0 diff --git a/apps/main/utils/mongodb_data_load.py b/apps/main/utils/mongodb_data_load.py index ef48a4a..2cd99b2 100644 --- a/apps/main/utils/mongodb_data_load.py +++ b/apps/main/utils/mongodb_data_load.py @@ -141,14 +141,14 @@ def set_local_partition(self): # break elif self.root_dir_type == "parquet": logging.info(f"Loading data from local parquet files: {self.root_dir}") - parquet_files = glob.glob(os.path.join(self.root_dir, "*.parquet")) + parquet_files = glob.glob(os.path.join(self.root_dir, self.collection_name, "**/*.parquet")) for file in tqdm(parquet_files, desc=f"Loading data to shard {self.shard_idx}"): df = pd.read_parquet(file) df = df[df[self.partition_key] % self.num_shards == self.shard_idx] data.extend(df.to_dict(orient="records")) # # Note: used for debugging - # if len(data) > 10000: + # if len(data) > 2500000: # break else: raise ValueError(f"Invalid Root Directory Type. Set root_dir_type to 'json' or 'parquet'") @@ -164,7 +164,7 @@ def set_local_partition(self): ) def __len__(self) -> int: - return self.data.index.max() + return len(self.data) def __getitem__(self, idx: int) -> dict[str, Any]: return self.data[idx] @@ -286,7 +286,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: # for pd data sample = self.data.iloc[idx] # Use iloc for row access in DataFrame return_sample = {} - return_sample["_id"] = str(sample["_id"]) + return_sample["_id"] = str(sample["_id"] if "_id" in sample else sample["id"]) caption = sample["caption"] if isinstance(caption, tuple): caption = caption[0] From 4b3224de4c9a442cf6ed840fab980895976c066e Mon Sep 17 00:00:00 2001 From: nmnWithNucleus Date: Sun, 1 Jun 2025 00:20:24 +0000 Subject: [PATCH 3/8] SLURM script --- apps/Castor/train.sbatch | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/apps/Castor/train.sbatch b/apps/Castor/train.sbatch index 3da3ca8..9c66884 100644 --- a/apps/Castor/train.sbatch +++ b/apps/Castor/train.sbatch @@ -8,7 +8,6 @@ #SBATCH --ntasks-per-node=1 # We run one torchrun launcher per node #SBATCH --gpus-per-node=8 # Each torchrun launcher will manage 8 processes, one per GPU -#SBATCH --cpus-per-task=24 # Allocate 24 CPUs for the torchrun task (and its 8 worker processes) # --- Project and Log Directories --- PROJECT_DIR=${PROJECT_DIR:-"/fsx/ubuntu/workspace/repo/Pollux"} @@ -30,6 +29,10 @@ export LD_PRELOAD="/usr/local/cuda-12.8/lib/libnccl.so" # --- Conda environment --- CONDA_ENV_NAME="pollux" +CONDA_PATH=${CONDA_PATH:-"/fsx/ubuntu/miniconda3"} +export PATH="$CONDA_PATH/bin:$PATH" +source $CONDA_PATH/etc/profile.d/conda.sh + echo "Attempting to activate conda environment: ${CONDA_ENV_NAME}" _CONDA_ROOT=$(conda info --base 2>/dev/null) @@ -39,19 +42,6 @@ if [ -z "${_CONDA_ROOT}" ]; then exit 1 fi -# Source conda.sh if not already sourced or conda command not available -if ! command -v conda &> /dev/null || [ -z "$CONDA_SHLVL" ] || [ "$CONDA_SHLVL" -lt 1 ]; then - echo "Sourcing conda from ${_CONDA_ROOT}/etc/profile.d/conda.sh" - # shellcheck source=/dev/null - source "${_CONDA_ROOT}/etc/profile.d/conda.sh" - if [ $? -ne 0 ]; then - echo "Error: Failed to source conda.sh from ${_CONDA_ROOT}/etc/profile.d/conda.sh" - exit 1 - fi -else - echo "Conda appears to be already initialized." -fi - conda activate "${CONDA_ENV_NAME}" if [ $? -ne 0 ]; then echo "Error: Failed to activate conda environment: ${CONDA_ENV_NAME}" @@ -97,9 +87,12 @@ declare -a TORCHRUN_ARGS=( ) # Training script module and its arguments -TRAIN_SCRIPT_MODULE="-m apps.Castor.train" +declare -a TRAIN_SCRIPT_ARGS=( + "-m" + "apps.Castor.train" +) declare -a TRAINING_ARGS=( - "config=apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml" + "config=apps/Castor/configs/aws_256_Castor_flux_qwen_fixed_siglip2.yaml" ) echo "--- srun command execution ---" @@ -114,7 +107,7 @@ echo "Current working directory: $(pwd)" echo "TORCHRUN_CMD: ${TORCHRUN_CMD}" echo "TORCHRUN_ARGS: ${TORCHRUN_ARGS[*]}" -echo "TRAIN_SCRIPT_MODULE: ${TRAIN_SCRIPT_MODULE}" +echo "TRAIN_SCRIPT_ARGS: ${TRAIN_SCRIPT_ARGS[*]}" echo "TRAINING_ARGS: ${TRAINING_ARGS[*]}" # Ensure all necessary variables are exported for srun tasks @@ -123,7 +116,7 @@ export PATH FI_PROVIDER FI_EFA_USE_DEVICE_RDMA FI_EFA_USE_HUGE_PAGE FI_EFA_SET_C srun ${AUTO_RESUME} \ "${TORCHRUN_CMD}" \ "${TORCHRUN_ARGS[@]}" \ - "${TRAIN_SCRIPT_MODULE}" \ + "${TRAIN_SCRIPT_ARGS[@]}" \ "${TRAINING_ARGS[@]}" EXIT_CODE=$? From 6841e4818c57b4fc6e1d2e2174097515b03c326c Mon Sep 17 00:00:00 2001 From: nmnWithNucleus Date: Sun, 1 Jun 2025 01:09:56 +0000 Subject: [PATCH 4/8] disable tqdm for slurm --- apps/Castor/train.py | 4 ++-- apps/main/utils/mongodb_data_load.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/apps/Castor/train.py b/apps/Castor/train.py index de8f37d..52bf595 100644 --- a/apps/Castor/train.py +++ b/apps/Castor/train.py @@ -46,7 +46,7 @@ from lingua.distributed import (DistributedArgs, EnvironmentArgs, check_model_value_range, dist_mean_dict, get_device_mesh, get_is_master, get_local_rank, - get_world_size, init_signal_handler, + get_world_size, get_is_slurm_job, init_signal_handler, parallelize_model, requeue_slurm_job, setup_env, setup_torch_distributed) from lingua.logger import init_logger @@ -356,7 +356,7 @@ def train(args: TrainArgs): max_data_load_time = 0.0 gc.collect() - pb = tqdm(total=args.steps, initial=train_state.step, desc="Training Steps") + pb = tqdm(total=args.steps, initial=train_state.step, desc="Training Steps", disable=get_is_slurm_job()) while train_state.step < args.steps: # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1 diff --git a/apps/main/utils/mongodb_data_load.py b/apps/main/utils/mongodb_data_load.py index 2cd99b2..b5d5b83 100644 --- a/apps/main/utils/mongodb_data_load.py +++ b/apps/main/utils/mongodb_data_load.py @@ -33,6 +33,8 @@ from urllib3.util.retry import Retry from urllib.parse import urlparse +from lingua.distributed import get_is_slurm_job + logging.getLogger("pymongo").setLevel(logging.WARNING) boto3.set_stream_logger("boto3", level=logging.WARNING) boto3.set_stream_logger("botocore", level=logging.WARNING) @@ -132,7 +134,7 @@ def set_local_partition(self): file_path = os.path.join(self.root_dir, f"{self.collection_name}.json") with open(file_path, "r") as file: - for item in tqdm(ijson.items(file, "item"), desc=f"Loading data to shard {self.shard_idx}"): + for item in tqdm(ijson.items(file, "item"), desc=f"Loading data to shard {self.shard_idx}", disable=get_is_slurm_job()): partition_key = int(item[self.partition_key]) if partition_key % self.num_shards == self.shard_idx: data.append(item) @@ -142,7 +144,7 @@ def set_local_partition(self): elif self.root_dir_type == "parquet": logging.info(f"Loading data from local parquet files: {self.root_dir}") parquet_files = glob.glob(os.path.join(self.root_dir, self.collection_name, "**/*.parquet")) - for file in tqdm(parquet_files, desc=f"Loading data to shard {self.shard_idx}"): + for file in tqdm(parquet_files, desc=f"Loading data to shard {self.shard_idx}", disable=get_is_slurm_job()): df = pd.read_parquet(file) df = df[df[self.partition_key] % self.num_shards == self.shard_idx] data.extend(df.to_dict(orient="records")) From ebba68051bf5e32ad54a01a17e9f0c80235b4207 Mon Sep 17 00:00:00 2001 From: nmnWithNucleus Date: Mon, 2 Jun 2025 08:06:03 +0000 Subject: [PATCH 5/8] Adding number of nodes to torchrun args --- apps/Castor/train.sbatch | 1 + 1 file changed, 1 insertion(+) diff --git a/apps/Castor/train.sbatch b/apps/Castor/train.sbatch index 9c66884..58f5198 100644 --- a/apps/Castor/train.sbatch +++ b/apps/Castor/train.sbatch @@ -80,6 +80,7 @@ TORCHRUN_CMD="torchrun" # torchrun will use the PytorchMASTER_ADDR and PytorchMASTER_PORT for rendezvous. # nnodes and node_rank are typically auto-detected by torchrun from Slurm environment variables. declare -a TORCHRUN_ARGS=( + "--nnodes=${SLURM_NNODES}" "--nproc_per_node=8" "--rdzv_backend=c10d" "--rdzv_endpoint=${PytorchMASTER_ADDR}:${PytorchMASTER_PORT}" From 6e5b33719860309a9d46e77657276022ad854310 Mon Sep 17 00:00:00 2001 From: nmnWithNucleus Date: Sun, 8 Jun 2025 20:43:55 +0000 Subject: [PATCH 6/8] support decoupled encoders in eval scripts --- apps/Castor/eval.py | 4 +++- apps/Castor/generate.py | 27 +++++++++++++++++---------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/apps/Castor/eval.py b/apps/Castor/eval.py index c2a0b18..ea4d11f 100644 --- a/apps/Castor/eval.py +++ b/apps/Castor/eval.py @@ -17,6 +17,7 @@ from apps.Castor.generate import (GeneratorArgs, LatentGenerator, load_consolidated_model) from apps.Castor.model import Castor, ModelArgs +from apps.Castor.modules.text_encoder import create_text_encoder from apps.Castor.modules.vae import (BaseLatentVideoVAE, VideoVAEArgs, create_vae) from apps.main.data import AutoDataLoader, DataArgs @@ -90,7 +91,8 @@ def launch_eval(cfg: EvalArgs): logger.info("Model loaded") model.eval() tvae = create_vae(cfg.generator.tvae) - generator = LatentGenerator(cfg.generator, model, tvae).cuda() + text_encoder = create_text_encoder(cfg.generator.text_encoder) + generator = LatentGenerator(cfg.generator, model, tvae, text_encoder).cuda() active_data = [d for d in cfg.data if d.stage == cfg.stage and d.use] data_loader_factory = AutoDataLoader( shard_id=global_rank, diff --git a/apps/Castor/generate.py b/apps/Castor/generate.py index 0ed6dd7..2fec9a1 100644 --- a/apps/Castor/generate.py +++ b/apps/Castor/generate.py @@ -7,8 +7,10 @@ import numpy as np import torch from apps.Castor.model import Castor, ModelArgs +from apps.Castor.modules.text_encoder import TextEncoderArgs, create_text_encoder from apps.Castor.modules.vae import (BaseLatentVideoVAE, VideoVAEArgs, create_vae) +from apps.text_encoder.text_encoder import TextEncoder from lingua.args import dataclass_from_dict from lingua.checkpoint import (CONSOLIDATE_FOLDER, CONSOLIDATE_NAME, consolidate_checkpoints) @@ -35,6 +37,7 @@ class GeneratorArgs: inference_steps: int = 25 vae_scale_factor: float = 8.0 tvae: VideoVAEArgs = field(default_factory=VideoVAEArgs) + text_encoder: TextEncoderArgs = field(default_factory=TextEncoderArgs) class LatentGenerator(nn.Module): @@ -43,9 +46,11 @@ def __init__( cfg: GeneratorArgs, model: nn.Module, tvae: BaseLatentVideoVAE, + text_encoder: TextEncoder, ): super().__init__() self.model = model + self.vae_scale_factor = cfg.vae_scale_factor self.resolution = int(cfg.resolution // self.vae_scale_factor) self.cond_resolution = int(cfg.cond_resolution // self.vae_scale_factor) @@ -58,6 +63,7 @@ def __init__( self.scheduler = model.scheduler.scheduler self.num_inference_steps = cfg.inference_steps self.tvae = tvae + self.text_encoder = text_encoder def prepare_latent(self, context, device): bsz = len(context["caption"]) @@ -92,7 +98,7 @@ def forward(self, context: Dict[str, Any]) -> torch.Tensor: mu=mu, ) latent = self.prepare_latent(context, device=cur_device) - pos_conditional_signal, pos_conditional_mask = self.model.text_encoder(context) + pos_conditional_signal, pos_conditional_mask = self.text_encoder(context) negative_conditional_signal = ( self.model.diffusion_transformer.negative_token.repeat( pos_conditional_signal.size(0), pos_conditional_signal.size(1), 1 @@ -228,28 +234,29 @@ def main(): cfg.ckpt_dir, model_cls=Castor, model_args_cls=ModelArgs ) tvae = create_vae(gen_cfg.tvae) - generator = LatentGenerator(gen_cfg, pollux, tvae).cuda() + text_encoder = create_text_encoder(gen_cfg.text_encoder) + generator = LatentGenerator(gen_cfg, pollux, tvae, text_encoder).cuda() print("Model loaded successfully") context = { "caption": [ # Short, simple descriptions "A red rose in full bloom against a black background.", - "A happy young man sitting on a piece of cloud, reading a book.", - "A sleeping cat curled up on a windowsill.", - "Fresh snow falling in a forest.", - "Hot air balloons floating in a clear blue sky.", + "A happy young man sitting on a kitchen table chair, reading a book.", + "A sleeping cat curled up on a windowsill, looking at the sky.", + "Children playing in a forest.", + "Hot air balloons floating in a clear blue sky, with a river in the background.", # Medium length, more detailed "A cozy coffee shop interior with vintage furniture, warm lighting, and the aroma of freshly ground beans wafting through the air.", "An ancient temple hidden in a misty mountain valley, its weathered stone walls covered in flowering vines.", "A bustling night market in Tokyo, neon signs reflecting off wet streets as people hurry past food stalls.", "A sea turtle glides gracefully through crystal-clear turquoise water above a school of small fish, with sunlight reflecting off the surface.", - "A petri dish with a bamboo forest growing within it that has tiny red pandas running around.", + "A bamboo forest growing within it that has tiny red pandas running around.", # Technical/scientific - "A lone astronaut floats in space, gazing at a swirling black hole surrounded by vibrant landscapes, rivers and clouds below.", - "Microscopic view of CRISPR gene editing in action, with precisely rendered molecular structures.", - "A topographical map of an alien planet's surface, complete with elevation data and geological formations.", + "A lone astronaut floats in space, rivers and clouds below.", + "Image of a ribosome, a molecular machine that synthesizes proteins.", + "A topographical map of earth's surface, complete with elevation data and geological formations.", # Artistic/abstract "An impressionist painting of music made visible, with colorful swirls representing different instruments in an orchestra.", From 2f157f66103fc44074584b69412950bb1dcfa6b6 Mon Sep 17 00:00:00 2001 From: nmnWithNucleus Date: Sun, 8 Jun 2025 20:46:43 +0000 Subject: [PATCH 7/8] slurm scripts v0 --- apps/Castor/eval.sbatch | 130 +++++++++++++++++++++++++++++++++++++++ apps/Castor/train.sbatch | 4 +- 2 files changed, 132 insertions(+), 2 deletions(-) create mode 100644 apps/Castor/eval.sbatch diff --git a/apps/Castor/eval.sbatch b/apps/Castor/eval.sbatch new file mode 100644 index 0000000..1c53363 --- /dev/null +++ b/apps/Castor/eval.sbatch @@ -0,0 +1,130 @@ +#!/bin/bash + +#SBATCH --job-name=castor_eval +#SBATCH --output=slurm_logs/slurm-%x-%j.out # %x for job name, %j for job ID +#SBATCH --error=slurm_logs/slurm-%x-%j.err +# User will specify --nodes and --partition on sbatch command line +# e.g., sbatch --nodes=2 --partition=my_partition eval.sbatch + +#SBATCH --ntasks-per-node=1 # We run one torchrun launcher per node +#SBATCH --gpus-per-node=8 # Each torchrun launcher will manage 8 processes, one per GPU + +# --- Project and Log Directories --- +PROJECT_DIR=${PROJECT_DIR:-"/fsx/ubuntu/workspace/repo/Pollux"} +LOG_DIR=${LOG_DIR:-"/fsx/checkpoints/ablations/logs"} + +echo "Changing directory to Project Directory: ${PROJECT_DIR}" +cd "${PROJECT_DIR}" || { echo "Failed to cd into ${PROJECT_DIR}"; exit 1; } +echo "Current working directory: $(pwd)" + +# --- User defined ENVs for AWS Hyperpod --- +export NCCL_PROTO="Simple" +export FI_PROVIDER="efa" +export FI_EFA_USE_DEVICE_RDMA="1" +export FI_EFA_USE_HUGE_PAGE="0" +export FI_EFA_SET_CUDA_SYNC_MEMOPS="0" +export NCCL_SOCKET_IFNAME="^docker,lo,veth,eth" +export LD_PRELOAD="/usr/local/cuda-12.8/lib/libnccl.so" + +# --- Conda environment --- +CONDA_ENV_NAME="pollux" + +CONDA_PATH=${CONDA_PATH:-"/fsx/ubuntu/miniconda3"} +export PATH="$CONDA_PATH/bin:$PATH" +source $CONDA_PATH/etc/profile.d/conda.sh + +echo "Attempting to activate conda environment: ${CONDA_ENV_NAME}" +_CONDA_ROOT=$(conda info --base 2>/dev/null) + +if [ -z "${_CONDA_ROOT}" ]; then + echo "Error: conda command not found or conda base not determined." + echo "Please ensure conda is installed and initialized." + exit 1 +fi + +conda activate "${CONDA_ENV_NAME}" +if [ $? -ne 0 ]; then + echo "Error: Failed to activate conda environment: ${CONDA_ENV_NAME}" + echo "Please ensure the environment exists and conda is correctly set up." + exit 1 +fi +echo "Conda environment ${CONDA_ENV_NAME} activated successfully." +echo "Python executable: $(which python)" +echo "PYTHONPATH: $PYTHONPATH" + +# --- PyTorch distributed setup --- +# Determine Master Address and Port from Slurm +export PytorchMASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export PytorchMASTER_PORT=29500 # Default port + +echo "--- Slurm Job Information ---" +echo "SLURM_JOB_ID: ${SLURM_JOB_ID}" +echo "SLURM_JOB_NODELIST: ${SLURM_JOB_NODELIST}" +echo "SLURM_NNODES: ${SLURM_NNODES}" +echo "SLURM_NTASKS_PER_NODE: ${SLURM_NTASKS_PER_NODE}" +echo "SLURM_SUBMIT_DIR: ${SLURM_SUBMIT_DIR}" +echo "PytorchMASTER_ADDR: ${PytorchMASTER_ADDR}" +echo "PytorchMASTER_PORT: ${PytorchMASTER_PORT}" +echo "--- End Slurm Job Information ---" + + +AUTO_RESUME="" +if [ -d "/opt/sagemaker_cluster" ]; then + echo "Detected Hyperpod cluster.. enabling --auto-resume=1" + AUTO_RESUME="--auto-resume=1" +fi + +TORCHRUN_CMD="torchrun" + +# TORCHRUN_ARGS: +# torchrun will use the PytorchMASTER_ADDR and PytorchMASTER_PORT for rendezvous. +# nnodes and node_rank are typically auto-detected by torchrun from Slurm environment variables. +declare -a TORCHRUN_ARGS=( + "--nnodes=${SLURM_NNODES}" + "--nproc_per_node=1" + "--rdzv_backend=c10d" + "--rdzv_endpoint=${PytorchMASTER_ADDR}:${PytorchMASTER_PORT}" + "--log_dir=${LOG_DIR}/torchrun_logs/job_${SLURM_JOB_ID}_node_${SLURM_NODEID}" # Per-node torchrun logs +) + +# Training script module and its arguments +declare -a TRAIN_SCRIPT_ARGS=( + "-m" + "apps.Castor.eval" +) +declare -a TRAINING_ARGS=( + "config=apps/Castor/configs/eval.yaml" +) + +echo "--- srun command execution ---" +echo "Starting evaluation with ${SLURM_NNODES} nodes." +echo "Host where sbatch script is running: $(hostname)" +echo "User: $(whoami)" +echo "Current working directory: $(pwd)" + +# The srun command structure requested by user. +# The -l flag labels srun output lines with the task number. +# srun will launch this command once per node (due to --ntasks-per-node=1). + +echo "TORCHRUN_CMD: ${TORCHRUN_CMD}" +echo "TORCHRUN_ARGS: ${TORCHRUN_ARGS[*]}" +echo "TRAIN_SCRIPT_ARGS: ${TRAIN_SCRIPT_ARGS[*]}" +echo "TRAINING_ARGS: ${TRAINING_ARGS[*]}" + +# Ensure all necessary variables are exported for srun tasks +export PATH FI_PROVIDER FI_EFA_USE_DEVICE_RDMA FI_EFA_USE_HUGE_PAGE FI_EFA_SET_CUDA_SYNC_MEMOPS NCCL_PROTO NCCL_SOCKET_IFNAME LD_PRELOAD + +srun ${AUTO_RESUME} \ + "${TORCHRUN_CMD}" \ + "${TORCHRUN_ARGS[@]}" \ + "${TRAIN_SCRIPT_ARGS[@]}" \ + "${TRAINING_ARGS[@]}" + +EXIT_CODE=$? +echo "srun command finished with exit code ${EXIT_CODE}." + +if [ ${EXIT_CODE} -ne 0 ]; then + echo "Evaluation job failed. Please check logs in slurm-${SLURM_JOB_NAME}-${SLURM_JOB_ID}.out/err and any application specific logs." +fi + +exit ${EXIT_CODE} diff --git a/apps/Castor/train.sbatch b/apps/Castor/train.sbatch index 58f5198..5d918db 100644 --- a/apps/Castor/train.sbatch +++ b/apps/Castor/train.sbatch @@ -1,8 +1,8 @@ #!/bin/bash #SBATCH --job-name=castor_training -#SBATCH --output=slurm-%x-%j.out # %x for job name, %j for job ID -#SBATCH --error=slurm-%x-%j.err +#SBATCH --output=slurm_logs/slurm-%x-%j.out # %x for job name, %j for job ID +#SBATCH --error=slurm_logs/slurm-%x-%j.err # User will specify --nodes and --partition on sbatch command line # e.g., sbatch --nodes=2 --partition=my_partition train.sbatch From 3985286b0469cd0f6e9b608cafff7c5c9c75e1b8 Mon Sep 17 00:00:00 2001 From: sippycoder Date: Sun, 15 Jun 2025 03:49:16 +0000 Subject: [PATCH 8/8] fix text encoder --- apps/Castor/modules/text_encoder.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/apps/Castor/modules/text_encoder.py b/apps/Castor/modules/text_encoder.py index acb5b83..0a70c34 100644 --- a/apps/Castor/modules/text_encoder.py +++ b/apps/Castor/modules/text_encoder.py @@ -100,14 +100,15 @@ def __init__(self, args: TextEncoderArgs): def init_model(self, args: TextEncoderArgs, model_path: str): config = AutoConfig.from_pretrained(model_path) - config.num_hidden_layers = int(math.ceil(args.relative_depth * config.num_hidden_layers)) + config.text_config.num_hidden_layers = int(math.ceil(args.relative_depth * config.text_config.num_hidden_layers)) model = Qwen2_5_VLModel.from_pretrained( model_path, config=config, torch_dtype=self.dtype, ).cuda() - # avoid norm layer in the last layer - model.norm = torch.nn.Identity() + if args.relative_depth < 1.0: + # avoid norm layer in the last layer + model.language_model.norm = torch.nn.Identity() apply_liger_kernel_to_qwen2_5_vl(model) model.eval() model.requires_grad_(False)