Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions apps/Castor/configs/mup_v0.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# torchrun --standalone --nnodes 1 --nproc-per-node 8 -m apps.Castor.train config=apps/Castor/configs/aws_256_Castor_flux_qwen_fixed_siglip2.yaml

# Set up single experiment
version: v1.0
# From now on, start to align train, data, and model setting with train stage (just finish refactor for dara)
train_stage: stage-1 # options: preliminary, pretraining, posttraining; aligned with data setting
name: mup_dim_144_bs_128_lr_5e-4 #used for local dump and wandb log
output_dir: /mnt/pollux/checkpoints/mup/lr_sweep
dump_dir: '' # No need now
steps: 500000
seed: 777
optim:
lr: 5e-4
warmup: 4000
lr_min_ratio: 1.5e-5
clip: 1.0
weight_decay: 0.01
mup: true

distributed:
gpus: "0,1,2,3"
fsdp_type: full_shard
dp_shard: 4
dp_replicate: 1
compile: false
model_dtype: bf16 # options: `fb8` is only supported by H100
matmul_allow_tf32: false
selective_activation_checkpointing: false
tp_size: 1
compile_cache_size_limit: 64

model:
scheduler:
num_train_timesteps: 1000
base_image_seq_len: 256
base_shift: 0.5
max_image_seq_len: 4096
max_shift: 1.15
shift: 1.0 # need consider 3.0 or 1.0
weighting_scheme: 'logit_normal'
logit_mean: 0.0
logit_std: 1.0
mode_scale: 1.29
use_dynamic_shifting: true

diffusion_model:
dim: &hidden_dim 256
ffn_dim_multiplier: 1.5
multiple_of: 256
n_heads: 32
n_kv_heads: 8
n_layers: 24
attention_window: [-1, -1] # [-1, -1] for full attention
align_layer: 8
time_step_dim: *hidden_dim
patch_size: 2
in_channels: 16
out_channels: 16
tmb_size: 256
gen_seqlen: 32
condition_seqlen: 256
norm_eps: 1e-5
condition_dim: 3584
qk_norm: false
liger_rms_norm: true
liger_ffn: true
liger_rotary_emb: false
shared_adaLN: true
unpadded: true
use_fp8_ffn: false
fp8_ffn_skip_layers: [] # falls back to liger_ffn
with_vae: true
vae_args:
vae_type: "flux"
pretrained_model_name_or_path: '/mnt/pollux/checkpoints/FLUX.1-dev/vae'
enable_tiling: false
enable_slicing: false

vision_encoder_alignment: false
vision_encoder_args:
encoder_name: "siglip2"
weight_path: "/mnt/pollux/checkpoints/siglip2-base-patch16-naflex"
projection_hidden_dim: *hidden_dim

text_cfg_ratio: 0.1
text_encoder:
config_name: "Qwen/Qwen2.5-VL-7B-Instruct"
dtype: "bf16"
text_seqlen: 256
model_path: "/mnt/pollux/checkpoints/Qwen2.5-VL-7B-Instruct"
relative_depth: 0.75

# save_base_shapes: "base_shapes.pkl"
load_base_shapes: "base_shapes.pkl"

data:
- stage: stage-1
id: 1
data_name: bucket-256-2
task: text_to_image
source: mongodb
image_size: 256
condition_image_size: 256
max_ratio: 1.0
partition_key: 'partition_key'
retries: 3
extract_field:
"media_path": "image"
use: true
root_dir: "/mnt/pollux/mongo_db_cache_train"
root_dir_type: "json" # options: "json", "parquet"
# https://worldmodeldata-prod.s3.us-east-2.amazonaws.com
base_url: s3://worldmodeldata-prod
dataloader:
prefetch_factor: 2
batch_size: 32
num_workers: 8
seed: 1024
shuffle: True
pin_memory: True
drop_last: False

profiling:
run: false

checkpoint:
dump:
every: 10000
keep: 0 # Don't remove the ckpt
eval:
every: 5000
keep: 0 # Don't remove the ckpt

logging:
freq: 100
wandb:
project: ablations
entity: metauto
name: ''

env:
ENABLE_INTRA_NODE_COMM: '0' # '0' for local machine (otherwise errors happen); '1' for slurmn (need test)
NCCL_DEBUG: 'ERROR'
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ optim:
lr_min_ratio: 1.5e-5
clip: 1.0
weight_decay: 0.01
mup: true

distributed:
gpus: "0,1,2,3,4,5,6,7"
Expand Down Expand Up @@ -89,6 +90,8 @@ model:
model_path: "/mnt/pollux/checkpoints/Qwen2.5-VL-7B-Instruct"
relative_depth: 0.75

save_base_shapes: "base_shapes.pkl"

data:
- stage: stage-1
id: 1
Expand Down
4 changes: 3 additions & 1 deletion apps/Castor/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from apps.Castor.generate import (GeneratorArgs, LatentGenerator,
load_consolidated_model)
from apps.Castor.model import Castor, ModelArgs
from apps.Castor.modules.text_encoder import create_text_encoder
from apps.Castor.modules.vae import (BaseLatentVideoVAE, VideoVAEArgs,
create_vae)
from apps.main.data import AutoDataLoader, DataArgs
Expand Down Expand Up @@ -90,7 +91,8 @@ def launch_eval(cfg: EvalArgs):
logger.info("Model loaded")
model.eval()
tvae = create_vae(cfg.generator.tvae)
generator = LatentGenerator(cfg.generator, model, tvae).cuda()
text_encoder = create_text_encoder(cfg.generator.text_encoder)
generator = LatentGenerator(cfg.generator, model, tvae, text_encoder).cuda()
active_data = [d for d in cfg.data if d.stage == cfg.stage and d.use]
data_loader_factory = AutoDataLoader(
shard_id=global_rank,
Expand Down
130 changes: 130 additions & 0 deletions apps/Castor/eval.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#!/bin/bash

#SBATCH --job-name=castor_eval
#SBATCH --output=slurm_logs/slurm-%x-%j.out # %x for job name, %j for job ID
#SBATCH --error=slurm_logs/slurm-%x-%j.err
# User will specify --nodes and --partition on sbatch command line
# e.g., sbatch --nodes=2 --partition=my_partition eval.sbatch

#SBATCH --ntasks-per-node=1 # We run one torchrun launcher per node
#SBATCH --gpus-per-node=8 # Each torchrun launcher will manage 8 processes, one per GPU

# --- Project and Log Directories ---
PROJECT_DIR=${PROJECT_DIR:-"/fsx/ubuntu/workspace/repo/Pollux"}
LOG_DIR=${LOG_DIR:-"/fsx/checkpoints/ablations/logs"}

echo "Changing directory to Project Directory: ${PROJECT_DIR}"
cd "${PROJECT_DIR}" || { echo "Failed to cd into ${PROJECT_DIR}"; exit 1; }
echo "Current working directory: $(pwd)"

# --- User defined ENVs for AWS Hyperpod ---
export NCCL_PROTO="Simple"
export FI_PROVIDER="efa"
export FI_EFA_USE_DEVICE_RDMA="1"
export FI_EFA_USE_HUGE_PAGE="0"
export FI_EFA_SET_CUDA_SYNC_MEMOPS="0"
export NCCL_SOCKET_IFNAME="^docker,lo,veth,eth"
export LD_PRELOAD="/usr/local/cuda-12.8/lib/libnccl.so"

# --- Conda environment ---
CONDA_ENV_NAME="pollux"

CONDA_PATH=${CONDA_PATH:-"/fsx/ubuntu/miniconda3"}
export PATH="$CONDA_PATH/bin:$PATH"
source $CONDA_PATH/etc/profile.d/conda.sh

echo "Attempting to activate conda environment: ${CONDA_ENV_NAME}"
_CONDA_ROOT=$(conda info --base 2>/dev/null)

if [ -z "${_CONDA_ROOT}" ]; then
echo "Error: conda command not found or conda base not determined."
echo "Please ensure conda is installed and initialized."
exit 1
fi

conda activate "${CONDA_ENV_NAME}"
if [ $? -ne 0 ]; then
echo "Error: Failed to activate conda environment: ${CONDA_ENV_NAME}"
echo "Please ensure the environment exists and conda is correctly set up."
exit 1
fi
echo "Conda environment ${CONDA_ENV_NAME} activated successfully."
echo "Python executable: $(which python)"
echo "PYTHONPATH: $PYTHONPATH"

# --- PyTorch distributed setup ---
# Determine Master Address and Port from Slurm
export PytorchMASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export PytorchMASTER_PORT=29500 # Default port

echo "--- Slurm Job Information ---"
echo "SLURM_JOB_ID: ${SLURM_JOB_ID}"
echo "SLURM_JOB_NODELIST: ${SLURM_JOB_NODELIST}"
echo "SLURM_NNODES: ${SLURM_NNODES}"
echo "SLURM_NTASKS_PER_NODE: ${SLURM_NTASKS_PER_NODE}"
echo "SLURM_SUBMIT_DIR: ${SLURM_SUBMIT_DIR}"
echo "PytorchMASTER_ADDR: ${PytorchMASTER_ADDR}"
echo "PytorchMASTER_PORT: ${PytorchMASTER_PORT}"
echo "--- End Slurm Job Information ---"


AUTO_RESUME=""
if [ -d "/opt/sagemaker_cluster" ]; then
echo "Detected Hyperpod cluster.. enabling --auto-resume=1"
AUTO_RESUME="--auto-resume=1"
fi

TORCHRUN_CMD="torchrun"

# TORCHRUN_ARGS:
# torchrun will use the PytorchMASTER_ADDR and PytorchMASTER_PORT for rendezvous.
# nnodes and node_rank are typically auto-detected by torchrun from Slurm environment variables.
declare -a TORCHRUN_ARGS=(
"--nnodes=${SLURM_NNODES}"
"--nproc_per_node=1"
"--rdzv_backend=c10d"
"--rdzv_endpoint=${PytorchMASTER_ADDR}:${PytorchMASTER_PORT}"
"--log_dir=${LOG_DIR}/torchrun_logs/job_${SLURM_JOB_ID}_node_${SLURM_NODEID}" # Per-node torchrun logs
)

# Training script module and its arguments
declare -a TRAIN_SCRIPT_ARGS=(
"-m"
"apps.Castor.eval"
)
declare -a TRAINING_ARGS=(
"config=apps/Castor/configs/eval.yaml"
)

echo "--- srun command execution ---"
echo "Starting evaluation with ${SLURM_NNODES} nodes."
echo "Host where sbatch script is running: $(hostname)"
echo "User: $(whoami)"
echo "Current working directory: $(pwd)"

# The srun command structure requested by user.
# The -l flag labels srun output lines with the task number.
# srun will launch this command once per node (due to --ntasks-per-node=1).

echo "TORCHRUN_CMD: ${TORCHRUN_CMD}"
echo "TORCHRUN_ARGS: ${TORCHRUN_ARGS[*]}"
echo "TRAIN_SCRIPT_ARGS: ${TRAIN_SCRIPT_ARGS[*]}"
echo "TRAINING_ARGS: ${TRAINING_ARGS[*]}"

# Ensure all necessary variables are exported for srun tasks
export PATH FI_PROVIDER FI_EFA_USE_DEVICE_RDMA FI_EFA_USE_HUGE_PAGE FI_EFA_SET_CUDA_SYNC_MEMOPS NCCL_PROTO NCCL_SOCKET_IFNAME LD_PRELOAD

srun ${AUTO_RESUME} \
"${TORCHRUN_CMD}" \
"${TORCHRUN_ARGS[@]}" \
"${TRAIN_SCRIPT_ARGS[@]}" \
"${TRAINING_ARGS[@]}"

EXIT_CODE=$?
echo "srun command finished with exit code ${EXIT_CODE}."

if [ ${EXIT_CODE} -ne 0 ]; then
echo "Evaluation job failed. Please check logs in slurm-${SLURM_JOB_NAME}-${SLURM_JOB_ID}.out/err and any application specific logs."
fi

exit ${EXIT_CODE}
27 changes: 17 additions & 10 deletions apps/Castor/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import numpy as np
import torch
from apps.Castor.model import Castor, ModelArgs
from apps.Castor.modules.text_encoder import TextEncoderArgs, create_text_encoder
from apps.Castor.modules.vae import (BaseLatentVideoVAE, VideoVAEArgs,
create_vae)
from apps.text_encoder.text_encoder import TextEncoder
from lingua.args import dataclass_from_dict
from lingua.checkpoint import (CONSOLIDATE_FOLDER, CONSOLIDATE_NAME,
consolidate_checkpoints)
Expand All @@ -35,6 +37,7 @@ class GeneratorArgs:
inference_steps: int = 25
vae_scale_factor: float = 8.0
tvae: VideoVAEArgs = field(default_factory=VideoVAEArgs)
text_encoder: TextEncoderArgs = field(default_factory=TextEncoderArgs)


class LatentGenerator(nn.Module):
Expand All @@ -43,9 +46,11 @@ def __init__(
cfg: GeneratorArgs,
model: nn.Module,
tvae: BaseLatentVideoVAE,
text_encoder: TextEncoder,
):
super().__init__()
self.model = model

self.vae_scale_factor = cfg.vae_scale_factor
self.resolution = int(cfg.resolution // self.vae_scale_factor)
self.cond_resolution = int(cfg.cond_resolution // self.vae_scale_factor)
Expand All @@ -58,6 +63,7 @@ def __init__(
self.scheduler = model.scheduler.scheduler
self.num_inference_steps = cfg.inference_steps
self.tvae = tvae
self.text_encoder = text_encoder

def prepare_latent(self, context, device):
bsz = len(context["caption"])
Expand Down Expand Up @@ -92,7 +98,7 @@ def forward(self, context: Dict[str, Any]) -> torch.Tensor:
mu=mu,
)
latent = self.prepare_latent(context, device=cur_device)
pos_conditional_signal, pos_conditional_mask = self.model.text_encoder(context)
pos_conditional_signal, pos_conditional_mask = self.text_encoder(context)
negative_conditional_signal = (
self.model.diffusion_transformer.negative_token.repeat(
pos_conditional_signal.size(0), pos_conditional_signal.size(1), 1
Expand Down Expand Up @@ -228,28 +234,29 @@ def main():
cfg.ckpt_dir, model_cls=Castor, model_args_cls=ModelArgs
)
tvae = create_vae(gen_cfg.tvae)
generator = LatentGenerator(gen_cfg, pollux, tvae).cuda()
text_encoder = create_text_encoder(gen_cfg.text_encoder)
generator = LatentGenerator(gen_cfg, pollux, tvae, text_encoder).cuda()
print("Model loaded successfully")
context = {
"caption": [
# Short, simple descriptions
"A red rose in full bloom against a black background.",
"A happy young man sitting on a piece of cloud, reading a book.",
"A sleeping cat curled up on a windowsill.",
"Fresh snow falling in a forest.",
"Hot air balloons floating in a clear blue sky.",
"A happy young man sitting on a kitchen table chair, reading a book.",
"A sleeping cat curled up on a windowsill, looking at the sky.",
"Children playing in a forest.",
"Hot air balloons floating in a clear blue sky, with a river in the background.",

# Medium length, more detailed
"A cozy coffee shop interior with vintage furniture, warm lighting, and the aroma of freshly ground beans wafting through the air.",
"An ancient temple hidden in a misty mountain valley, its weathered stone walls covered in flowering vines.",
"A bustling night market in Tokyo, neon signs reflecting off wet streets as people hurry past food stalls.",
"A sea turtle glides gracefully through crystal-clear turquoise water above a school of small fish, with sunlight reflecting off the surface.",
"A petri dish with a bamboo forest growing within it that has tiny red pandas running around.",
"A bamboo forest growing within it that has tiny red pandas running around.",

# Technical/scientific
"A lone astronaut floats in space, gazing at a swirling black hole surrounded by vibrant landscapes, rivers and clouds below.",
"Microscopic view of CRISPR gene editing in action, with precisely rendered molecular structures.",
"A topographical map of an alien planet's surface, complete with elevation data and geological formations.",
"A lone astronaut floats in space, rivers and clouds below.",
"Image of a ribosome, a molecular machine that synthesizes proteins.",
"A topographical map of earth's surface, complete with elevation data and geological formations.",

# Artistic/abstract
"An impressionist painting of music made visible, with colorful swirls representing different instruments in an orchestra.",
Expand Down
Loading