Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion apps/Castor/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from apps.Castor.generate import (GeneratorArgs, LatentGenerator,
load_consolidated_model)
from apps.Castor.model import Castor, ModelArgs
from apps.Castor.modules.text_encoder import create_text_encoder
from apps.Castor.modules.vae import (BaseLatentVideoVAE, VideoVAEArgs,
create_vae)
from apps.main.data import AutoDataLoader, DataArgs
Expand Down Expand Up @@ -90,7 +91,8 @@ def launch_eval(cfg: EvalArgs):
logger.info("Model loaded")
model.eval()
tvae = create_vae(cfg.generator.tvae)
generator = LatentGenerator(cfg.generator, model, tvae).cuda()
text_encoder = create_text_encoder(cfg.generator.text_encoder)
generator = LatentGenerator(cfg.generator, model, tvae, text_encoder).cuda()
active_data = [d for d in cfg.data if d.stage == cfg.stage and d.use]
data_loader_factory = AutoDataLoader(
shard_id=global_rank,
Expand Down
130 changes: 130 additions & 0 deletions apps/Castor/eval.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#!/bin/bash

#SBATCH --job-name=castor_eval
#SBATCH --output=slurm_logs/slurm-%x-%j.out # %x for job name, %j for job ID
#SBATCH --error=slurm_logs/slurm-%x-%j.err
# User will specify --nodes and --partition on sbatch command line
# e.g., sbatch --nodes=2 --partition=my_partition eval.sbatch

#SBATCH --ntasks-per-node=1 # We run one torchrun launcher per node
#SBATCH --gpus-per-node=8 # Each torchrun launcher will manage 8 processes, one per GPU

# --- Project and Log Directories ---
PROJECT_DIR=${PROJECT_DIR:-"/fsx/ubuntu/workspace/repo/Pollux"}
LOG_DIR=${LOG_DIR:-"/fsx/checkpoints/ablations/logs"}

echo "Changing directory to Project Directory: ${PROJECT_DIR}"
cd "${PROJECT_DIR}" || { echo "Failed to cd into ${PROJECT_DIR}"; exit 1; }
echo "Current working directory: $(pwd)"

# --- User defined ENVs for AWS Hyperpod ---
export NCCL_PROTO="Simple"
export FI_PROVIDER="efa"
export FI_EFA_USE_DEVICE_RDMA="1"
export FI_EFA_USE_HUGE_PAGE="0"
export FI_EFA_SET_CUDA_SYNC_MEMOPS="0"
export NCCL_SOCKET_IFNAME="^docker,lo,veth,eth"
export LD_PRELOAD="/usr/local/cuda-12.8/lib/libnccl.so"

# --- Conda environment ---
CONDA_ENV_NAME="pollux"

CONDA_PATH=${CONDA_PATH:-"/fsx/ubuntu/miniconda3"}
export PATH="$CONDA_PATH/bin:$PATH"
source $CONDA_PATH/etc/profile.d/conda.sh

echo "Attempting to activate conda environment: ${CONDA_ENV_NAME}"
_CONDA_ROOT=$(conda info --base 2>/dev/null)

if [ -z "${_CONDA_ROOT}" ]; then
echo "Error: conda command not found or conda base not determined."
echo "Please ensure conda is installed and initialized."
exit 1
fi

conda activate "${CONDA_ENV_NAME}"
if [ $? -ne 0 ]; then
echo "Error: Failed to activate conda environment: ${CONDA_ENV_NAME}"
echo "Please ensure the environment exists and conda is correctly set up."
exit 1
fi
echo "Conda environment ${CONDA_ENV_NAME} activated successfully."
echo "Python executable: $(which python)"
echo "PYTHONPATH: $PYTHONPATH"

# --- PyTorch distributed setup ---
# Determine Master Address and Port from Slurm
export PytorchMASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export PytorchMASTER_PORT=29500 # Default port

echo "--- Slurm Job Information ---"
echo "SLURM_JOB_ID: ${SLURM_JOB_ID}"
echo "SLURM_JOB_NODELIST: ${SLURM_JOB_NODELIST}"
echo "SLURM_NNODES: ${SLURM_NNODES}"
echo "SLURM_NTASKS_PER_NODE: ${SLURM_NTASKS_PER_NODE}"
echo "SLURM_SUBMIT_DIR: ${SLURM_SUBMIT_DIR}"
echo "PytorchMASTER_ADDR: ${PytorchMASTER_ADDR}"
echo "PytorchMASTER_PORT: ${PytorchMASTER_PORT}"
echo "--- End Slurm Job Information ---"


AUTO_RESUME=""
if [ -d "/opt/sagemaker_cluster" ]; then
echo "Detected Hyperpod cluster.. enabling --auto-resume=1"
AUTO_RESUME="--auto-resume=1"
fi

TORCHRUN_CMD="torchrun"

# TORCHRUN_ARGS:
# torchrun will use the PytorchMASTER_ADDR and PytorchMASTER_PORT for rendezvous.
# nnodes and node_rank are typically auto-detected by torchrun from Slurm environment variables.
declare -a TORCHRUN_ARGS=(
"--nnodes=${SLURM_NNODES}"
"--nproc_per_node=1"
"--rdzv_backend=c10d"
"--rdzv_endpoint=${PytorchMASTER_ADDR}:${PytorchMASTER_PORT}"
"--log_dir=${LOG_DIR}/torchrun_logs/job_${SLURM_JOB_ID}_node_${SLURM_NODEID}" # Per-node torchrun logs
)

# Training script module and its arguments
declare -a TRAIN_SCRIPT_ARGS=(
"-m"
"apps.Castor.eval"
)
declare -a TRAINING_ARGS=(
"config=apps/Castor/configs/eval.yaml"
)

echo "--- srun command execution ---"
echo "Starting evaluation with ${SLURM_NNODES} nodes."
echo "Host where sbatch script is running: $(hostname)"
echo "User: $(whoami)"
echo "Current working directory: $(pwd)"

# The srun command structure requested by user.
# The -l flag labels srun output lines with the task number.
# srun will launch this command once per node (due to --ntasks-per-node=1).

echo "TORCHRUN_CMD: ${TORCHRUN_CMD}"
echo "TORCHRUN_ARGS: ${TORCHRUN_ARGS[*]}"
echo "TRAIN_SCRIPT_ARGS: ${TRAIN_SCRIPT_ARGS[*]}"
echo "TRAINING_ARGS: ${TRAINING_ARGS[*]}"

# Ensure all necessary variables are exported for srun tasks
export PATH FI_PROVIDER FI_EFA_USE_DEVICE_RDMA FI_EFA_USE_HUGE_PAGE FI_EFA_SET_CUDA_SYNC_MEMOPS NCCL_PROTO NCCL_SOCKET_IFNAME LD_PRELOAD

srun ${AUTO_RESUME} \
"${TORCHRUN_CMD}" \
"${TORCHRUN_ARGS[@]}" \
"${TRAIN_SCRIPT_ARGS[@]}" \
"${TRAINING_ARGS[@]}"

EXIT_CODE=$?
echo "srun command finished with exit code ${EXIT_CODE}."

if [ ${EXIT_CODE} -ne 0 ]; then
echo "Evaluation job failed. Please check logs in slurm-${SLURM_JOB_NAME}-${SLURM_JOB_ID}.out/err and any application specific logs."
fi

exit ${EXIT_CODE}
27 changes: 17 additions & 10 deletions apps/Castor/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import numpy as np
import torch
from apps.Castor.model import Castor, ModelArgs
from apps.Castor.modules.text_encoder import TextEncoderArgs, create_text_encoder
from apps.Castor.modules.vae import (BaseLatentVideoVAE, VideoVAEArgs,
create_vae)
from apps.text_encoder.text_encoder import TextEncoder
from lingua.args import dataclass_from_dict
from lingua.checkpoint import (CONSOLIDATE_FOLDER, CONSOLIDATE_NAME,
consolidate_checkpoints)
Expand All @@ -35,6 +37,7 @@ class GeneratorArgs:
inference_steps: int = 25
vae_scale_factor: float = 8.0
tvae: VideoVAEArgs = field(default_factory=VideoVAEArgs)
text_encoder: TextEncoderArgs = field(default_factory=TextEncoderArgs)


class LatentGenerator(nn.Module):
Expand All @@ -43,9 +46,11 @@ def __init__(
cfg: GeneratorArgs,
model: nn.Module,
tvae: BaseLatentVideoVAE,
text_encoder: TextEncoder,
):
super().__init__()
self.model = model

self.vae_scale_factor = cfg.vae_scale_factor
self.resolution = int(cfg.resolution // self.vae_scale_factor)
self.cond_resolution = int(cfg.cond_resolution // self.vae_scale_factor)
Expand All @@ -58,6 +63,7 @@ def __init__(
self.scheduler = model.scheduler.scheduler
self.num_inference_steps = cfg.inference_steps
self.tvae = tvae
self.text_encoder = text_encoder

def prepare_latent(self, context, device):
bsz = len(context["caption"])
Expand Down Expand Up @@ -92,7 +98,7 @@ def forward(self, context: Dict[str, Any]) -> torch.Tensor:
mu=mu,
)
latent = self.prepare_latent(context, device=cur_device)
pos_conditional_signal, pos_conditional_mask = self.model.text_encoder(context)
pos_conditional_signal, pos_conditional_mask = self.text_encoder(context)
negative_conditional_signal = (
self.model.diffusion_transformer.negative_token.repeat(
pos_conditional_signal.size(0), pos_conditional_signal.size(1), 1
Expand Down Expand Up @@ -228,28 +234,29 @@ def main():
cfg.ckpt_dir, model_cls=Castor, model_args_cls=ModelArgs
)
tvae = create_vae(gen_cfg.tvae)
generator = LatentGenerator(gen_cfg, pollux, tvae).cuda()
text_encoder = create_text_encoder(gen_cfg.text_encoder)
generator = LatentGenerator(gen_cfg, pollux, tvae, text_encoder).cuda()
print("Model loaded successfully")
context = {
"caption": [
# Short, simple descriptions
"A red rose in full bloom against a black background.",
"A happy young man sitting on a piece of cloud, reading a book.",
"A sleeping cat curled up on a windowsill.",
"Fresh snow falling in a forest.",
"Hot air balloons floating in a clear blue sky.",
"A happy young man sitting on a kitchen table chair, reading a book.",
"A sleeping cat curled up on a windowsill, looking at the sky.",
"Children playing in a forest.",
"Hot air balloons floating in a clear blue sky, with a river in the background.",

# Medium length, more detailed
"A cozy coffee shop interior with vintage furniture, warm lighting, and the aroma of freshly ground beans wafting through the air.",
"An ancient temple hidden in a misty mountain valley, its weathered stone walls covered in flowering vines.",
"A bustling night market in Tokyo, neon signs reflecting off wet streets as people hurry past food stalls.",
"A sea turtle glides gracefully through crystal-clear turquoise water above a school of small fish, with sunlight reflecting off the surface.",
"A petri dish with a bamboo forest growing within it that has tiny red pandas running around.",
"A bamboo forest growing within it that has tiny red pandas running around.",

# Technical/scientific
"A lone astronaut floats in space, gazing at a swirling black hole surrounded by vibrant landscapes, rivers and clouds below.",
"Microscopic view of CRISPR gene editing in action, with precisely rendered molecular structures.",
"A topographical map of an alien planet's surface, complete with elevation data and geological formations.",
"A lone astronaut floats in space, rivers and clouds below.",
"Image of a ribosome, a molecular machine that synthesizes proteins.",
"A topographical map of earth's surface, complete with elevation data and geological formations.",

# Artistic/abstract
"An impressionist painting of music made visible, with colorful swirls representing different instruments in an orchestra.",
Expand Down
33 changes: 9 additions & 24 deletions apps/Castor/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class ModelArgs:
diffusion_model: TransformerArgs = field(default_factory=TransformerArgs)
with_vae: bool = False
vae_args: VideoVAEArgs = field(default_factory=VideoVAEArgs)
text_encoder_dim: int = 512
vision_encoder_dim: int = 2048
vision_encoder_alignment: bool = False
vision_encoder_alignment_factor: float = 0.5
vision_encoder_args: VisionEncoderArgs = field(default_factory=VisionEncoderArgs)
Expand Down Expand Up @@ -67,23 +69,15 @@ def __init__(self, args: ModelArgs):
super().__init__()
self.args = args

# VAE
if args.with_vae:
self.compressor = create_vae(args.vae_args)

# Vision encoder
if args.vision_encoder_alignment:
self.vision_encoder = create_vision_encoder(args.vision_encoder_args)
self.vision_encoder_proj = AlignmentProjection(
args.diffusion_model.dim, args.vision_encoder_args.projection_hidden_dim, self.vision_encoder.dim)

# Text encoder
self.text_encoder = create_text_encoder(args.text_encoder)
args.diffusion_model.dim, args.vision_encoder_args.projection_hidden_dim, args.vision_encoder_dim)

if args.diffusion_model.condition_dim != self.text_encoder.dim():
logger.warning(f"Condition dim {args.diffusion_model.condition_dim} does not match text encoder dim {self.text_encoder.dim()}")
logger.warning(f"Using {self.text_encoder.dim()} as condition dim")
args.diffusion_model.condition_dim = self.text_encoder.dim()
if args.diffusion_model.condition_dim != args.text_encoder_dim:
logger.warning(f"Condition dim {args.diffusion_model.condition_dim} does not match text encoder dim {args.text_encoder_dim}")
logger.warning(f"Using {args.text_encoder_dim} as condition dim")
args.diffusion_model.condition_dim = args.text_encoder_dim

# Diffusion transformer
self.diffusion_transformer = DiffusionTransformer(args.diffusion_model)
Expand All @@ -92,16 +86,7 @@ def __init__(self, args: ModelArgs):
self.scheduler = RectifiedFlow(args.scheduler)
self.text_cfg_ratio = args.text_cfg_ratio

def forward(self, batch: dict[str:any], flops_meter= None) -> dict[str:any]:
if hasattr(self, "compressor"):
batch["latent_code"] = self.compressor.extract_latents(batch, flops_meter)

if hasattr(self, "vision_encoder"):
batch["vision_encoder_target"] = self.vision_encoder.extract_image_representations(batch, flops_meter)

if "text_embedding" not in batch:
batch["text_embedding"], batch["attention_mask"] = self.text_encoder(batch, flops_meter)

def forward(self, batch: dict[str:any], flops_meter) -> dict[str:any]:
conditional_signal, conditional_mask = batch["text_embedding"], batch["attention_mask"]

if random.random() <= self.text_cfg_ratio:
Expand All @@ -127,7 +112,7 @@ def forward(self, batch: dict[str:any], flops_meter= None) -> dict[str:any]:
target_loss = self.mse_loss(output.output, batch["target"])

align_loss = None
if hasattr(self, "vision_encoder"):
if self.args.vision_encoder_alignment:
vision_encoder_pred = self.vision_encoder_proj(output.align_hidden_state)
align_loss = self.consine_loss_with_features(
vision_encoder_pred, output.cond_l, output.img_size, batch["vision_encoder_target"])
Expand Down
12 changes: 9 additions & 3 deletions apps/Castor/modules/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, args: TextEncoderArgs):
self.text_seqlen = args.text_seqlen

# TODO: use this to get the dimension of the text encoder for transformer
@property
def dim(self) -> int:
raise NotImplementedError

Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(self, args: TextEncoderArgs):
),
)

@property
def dim(self) -> int:
return self.clip_model.config.hidden_size

Expand Down Expand Up @@ -98,14 +100,15 @@ def __init__(self, args: TextEncoderArgs):

def init_model(self, args: TextEncoderArgs, model_path: str):
config = AutoConfig.from_pretrained(model_path)
config.num_hidden_layers = int(math.ceil(args.relative_depth * config.num_hidden_layers))
config.text_config.num_hidden_layers = int(math.ceil(args.relative_depth * config.text_config.num_hidden_layers))
model = Qwen2_5_VLModel.from_pretrained(
model_path,
config=config,
torch_dtype=self.dtype,
).cuda()
# avoid norm layer in the last layer
model.norm = torch.nn.Identity()
if args.relative_depth < 1.0:
# avoid norm layer in the last layer
model.language_model.norm = torch.nn.Identity()
apply_liger_kernel_to_qwen2_5_vl(model)
model.eval()
model.requires_grad_(False)
Expand All @@ -121,6 +124,7 @@ def init_processor(self, model_path: str):
model_path,
)

@property
def dim(self) -> int:
return self.model.config.hidden_size

Expand Down Expand Up @@ -181,6 +185,7 @@ def __init__(self, args):
args.model_path, subfolder="text_encoder", torch_dtype=self.dtype
).cuda()

@property
def dim(self) -> int:
return self.model.config.hidden_size

Expand Down Expand Up @@ -222,6 +227,7 @@ def __init__(self, args):
args.model_path, torch_dtype=self.dtype
).cuda()

@property
def dim(self) -> int:
return self.model.config.hidden_size

Expand Down
2 changes: 1 addition & 1 deletion apps/Castor/modules/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from diffusers import AutoencoderKL, AutoencoderKLHunyuanVideo
from torch import nn
from cosmos_tokenizer.image_lib import ImageTokenizer

logger = logging.getLogger()
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -225,6 +224,7 @@ def forward(self, x=torch.Tensor):
class COSMOSContinuousVAE(BaseLatentVideoVAE):
def __init__(self, args: VideoVAEArgs):
super().__init__(args)
from cosmos_tokenizer.image_lib import ImageTokenizer
"""
Initialize the encoder and decoder for Continuous VAE.
Checks model type and returns the initialized VAE instance.
Expand Down
Loading