Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions apps/Castor/configs/eval.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# python -m apps.Castor.generate config=apps/Castor/configs/eval.yaml
name: "debug_evals"
stage: eval
ckpt_dir: "/mnt/pollux/checkpoints/aj/unpadded3/checkpoints/0000080000/"
dump_dir: "/mnt/pollux/checkpoints/aj/unpadded3/MJHQ/0000080000"
ckpt_dir: "/mnt/pollux/checkpoints/aj/neg_token2/checkpoints/0000080000/"
dump_dir: "/mnt/pollux/checkpoints/aj/neg_token2/MJHQ/0000080000"
generator:
guidance_scale: 6.5
dtype: bf16
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
version: v1.0
# From now on, start to align train, data, and model setting with train stage (just finish refactor for dara)
train_stage: stage-1 # options: preliminary, pretraining, posttraining; aligned with data setting
name: fp8_ffn_unpadded #used for local dump and wandb log
output_dir: /mnt/pollux/checkpoints/aj/fp8_ffn_unpadded
name: neg_token2 #used for local dump and wandb log
output_dir: /mnt/pollux/checkpoints/aj/
dump_dir: '' # No need now
steps: 500000
seed: 777
Expand Down Expand Up @@ -68,6 +68,7 @@ model:
unpadded: true
use_fp8_ffn: true
fp8_ffn_skip_layers: [] # falls back to liger_ffn
n_unconditional_tokens: 64
with_vae: true
vae_args:
vae_type: "flux"
Expand Down
17 changes: 9 additions & 8 deletions apps/Castor/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,15 @@ def forward(self, context: Dict[str, Any]) -> torch.Tensor:
)
latent = self.prepare_latent(context, device=cur_device)
pos_conditional_signal, pos_conditional_mask = self.model.text_encoder(context)
negative_conditional_signal = (
self.model.diffusion_transformer.negative_token.repeat(
pos_conditional_signal.size(0), pos_conditional_signal.size(1), 1
)
)
negative_conditional_mask = torch.ones_like(
pos_conditional_mask, dtype=pos_conditional_mask.dtype
)

n_unconditional_tokens = self.model.diffusion_transformer.n_unconditional_tokens
negative_conditional_signal = torch.zeros_like(pos_conditional_signal)
negative_conditional_signal[
:, :n_unconditional_tokens, :
] = self.model.diffusion_transformer.negative_token
negative_conditional_mask = torch.zeros_like(pos_conditional_mask)
negative_conditional_mask[:, :n_unconditional_tokens] = 1

context = torch.cat(
[
pos_conditional_signal,
Expand Down
70 changes: 61 additions & 9 deletions apps/Castor/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import random
import math
from dataclasses import dataclass, field
from typing import List, Optional, Tuple

Expand Down Expand Up @@ -92,7 +93,55 @@ def __init__(self, args: ModelArgs):
self.scheduler = RectifiedFlow(args.scheduler)
self.text_cfg_ratio = args.text_cfg_ratio

def forward(self, batch: dict[str:any], flops_meter= None) -> dict[str:any]:
@torch.compile
@torch.no_grad()
def sample_timesteps(
self, vae_latent: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sample a timestep `t`, create the corresponding noisy latent `z_t`, and
compute the velocity target `v_objective`.

Parameters
----------
vae_latent : torch.Tensor
Clean latent tensor with shape (B, C, H, W).

Returns
-------
z_t : torch.Tensor
Noisy latent at timestep `t`, same shape and dtype as `vae_latent`.
t : torch.Tensor
Sampled timestep in the continuous range [0, 1] with shape (B,).
v_objective : torch.Tensor
Velocity target (`x₀ − noise`) used for training.
"""

assert isinstance(vae_latent, torch.Tensor), "vae_latent must be a tensor, implement dynamic resolution"

B, _, H, W = vae_latent.shape
device, dtype = vae_latent.device, vae_latent.dtype

# --- sample timesteps --------------------------------------------------
image_token_size = H * W
alpha = 2 * math.sqrt(image_token_size / (64 * 64))

z = torch.randn(B, device=device, dtype=torch.float32)
logistic = torch.sigmoid(z)
lognormal = logistic * alpha / (1 + (alpha - 1) * logistic)

do_uniform = torch.rand(B, device=device) < 0.1
t = torch.where(do_uniform, torch.rand(B, device=device), lognormal).to(dtype)

# --- add noise ---------------------------------------------------------
noise = torch.randn_like(vae_latent)
t_reshaped = t.view(B, 1, 1, 1)
z_t = vae_latent * (1 - t_reshaped) + noise * t_reshaped

v_objective = noise - vae_latent
return z_t, t, v_objective

def forward(self, batch: dict[str:any], flops_meter= None, is_training=False) -> dict[str:any]:
if hasattr(self, "compressor"):
batch["latent_code"] = self.compressor.extract_latents(batch, flops_meter)

Expand All @@ -104,16 +153,18 @@ def forward(self, batch: dict[str:any], flops_meter= None) -> dict[str:any]:

conditional_signal, conditional_mask = batch["text_embedding"], batch["attention_mask"]

if random.random() <= self.text_cfg_ratio:
conditional_signal = self.diffusion_transformer.negative_token.repeat(
conditional_signal.size(0), conditional_signal.size(1), 1
)
conditional_mask = torch.ones_like(
conditional_mask, dtype=conditional_signal.dtype
)
if is_training:
cfg_mask = (torch.rand(conditional_signal.shape[0], device=conditional_signal.device) < self.text_cfg_ratio)
if cfg_mask.any():
n_unconditional_tokens = self.diffusion_transformer.n_unconditional_tokens
conditional_signal[cfg_mask, :n_unconditional_tokens, :] = self.diffusion_transformer.negative_token
# For unconditional samples, the attention mask should be [1, 1, ..., 0, ...].
conditional_mask[cfg_mask, :n_unconditional_tokens] = 1
conditional_mask[cfg_mask, n_unconditional_tokens:] = 0

latent_code = batch["latent_code"]
noised_x, t, target = self.scheduler.sample_noised_input(latent_code)
noised_x, t, target = self.sample_timesteps(latent_code)

output = self.diffusion_transformer(
x=noised_x,
time_steps=t,
Expand Down Expand Up @@ -235,6 +286,7 @@ def build_2B_Castor():
condition_seqlen=128,
norm_eps=1e-5,
condition_dim=512,
n_unconditional_tokens=64,
)
vae_args = VideoVAEArgs(
pretrained_model_name_or_path="/mnt/pollux/checkpoints/HunyuanVideo/vae",
Expand Down
4 changes: 3 additions & 1 deletion apps/Castor/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class TransformerArgs(BaseTransformerArgs):
tmb_size: int = 256
condition_seqlen: int = 1000
gen_seqlen: int = 1000
n_unconditional_tokens: int = 64
pre_trained_path: Optional[str] = None
qk_norm: bool = True
shared_adaLN: bool = False
Expand Down Expand Up @@ -305,6 +306,7 @@ def __init__(self, args: TransformerArgs):
self.out_channels = args.out_channels
self.in_channels = args.in_channels
self.unpadded = args.unpadded
self.n_unconditional_tokens = args.n_unconditional_tokens
self.tmb_embed = TimestepEmbedder(
hidden_size=args.time_step_dim, time_embedding_size=args.tmb_size
)
Expand Down Expand Up @@ -335,7 +337,7 @@ def __init__(self, args: TransformerArgs):
self.cond_norm = RMSNorm(
args.dim, eps=args.norm_eps, liger_rms_norm=args.liger_rms_norm
)
self.negative_token = nn.Parameter(torch.zeros(1, 1, args.condition_dim))
self.negative_token = nn.Parameter(torch.zeros(1, self.n_unconditional_tokens, args.condition_dim))
self.cond_proj = nn.Linear(
in_features=args.condition_dim,
out_features=args.dim,
Expand Down
27 changes: 17 additions & 10 deletions apps/Castor/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,26 @@ def train(args: TrainArgs):

torch.manual_seed(args.seed)
logger.info("Building model")

fp8_recipe = None
if args.model.diffusion_model.use_fp8_ffn:
logger.info("FP8 is enabled. Defining FP8 recipe.")
# Example recipe, adjust as needed
all_gpus = torch.distributed.new_group(backend="nccl")
fp8_format = Format.HYBRID # Or Format.E4M3
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")

# with te.fp8_model_init(enabled=True, recipe=fp8_recipe):
model = Castor(args.model)
logger.info("Model is built !")
logger.info(model)

with te.fp8_autocast(enabled=args.model.diffusion_model.use_fp8_ffn, fp8_recipe=fp8_recipe):
for module in model.modules():
if isinstance(module, te.LayerNormMLP):
module.init_fp8_metadata(2)


model_param_count = get_num_params(model)
flops_meter = FlopsMeter(args.model, model)

Expand Down Expand Up @@ -319,14 +334,6 @@ def train(args: TrainArgs):
checkpoint.load(model, optimizer, train_state, world_mesh)
# Either load from latest checkpoint or start from scratch

fp8_recipe = None
if args.model.diffusion_model.use_fp8_ffn:
logger.info("FP8 is enabled. Defining FP8 recipe.")
# Example recipe, adjust as needed
all_gpus = torch.distributed.new_group(backend="nccl")
fp8_format = Format.HYBRID # Or Format.E4M3
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
# You might want to make recipe parameters configurable via TrainArgs

gc.disable()

Expand Down Expand Up @@ -413,7 +420,7 @@ def train(args: TrainArgs):
start_timer.record()

with te.fp8_autocast(enabled=args.model.diffusion_model.use_fp8_ffn, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
outputs = model(batch, flops_meter)
outputs = model(batch, flops_meter, is_training=True)
# We scale loss with grad_acc_steps so the gradient is the same
# regardless of grad_acc_steps
loss = outputs.loss / args.grad_acc_steps
Expand Down Expand Up @@ -612,4 +619,4 @@ def main():

if __name__ == "__main__":
torch.set_float32_matmul_precision('high')
main()
main()
13 changes: 7 additions & 6 deletions apps/Castor/utils/flop_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,13 @@ def __init__(
self.cond_num_heads = model.text_encoder.model.config.num_attention_heads
self.cond_headdim = self.cond_dim // self.cond_num_heads

self.vision_params = get_num_params(model.vision_encoder.model)
self.vision_num_layers = model.vision_encoder.model.config.num_hidden_layers
self.vision_num_heads = model.vision_encoder.model.config.num_attention_heads
self.vision_dim = model.vision_encoder.model.config.hidden_size
self.vision_patch_size = model.vision_encoder.model.config.patch_size
self.vision_headdim = self.vision_dim // self.vision_num_heads
if hasattr(model, "vision_encoder"):
self.vision_params = get_num_params(model.vision_encoder.model)
self.vision_num_layers = model.vision_encoder.model.config.num_hidden_layers
self.vision_num_heads = model.vision_encoder.model.config.num_attention_heads
self.vision_dim = model.vision_encoder.model.config.hidden_size
self.vision_patch_size = model.vision_encoder.model.config.patch_size
self.vision_headdim = self.vision_dim // self.vision_num_heads

self.vae_params = get_num_params(model.compressor.vae)

Expand Down
4 changes: 2 additions & 2 deletions apps/main/utils/mongodb_data_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def set_local_partition(self):
data.extend(df.to_dict(orient="records"))

# # Note: used for debugging
# if len(data) > 10000:
# break
if len(data) > int(2e6):
break
else:
raise ValueError(f"Invalid Root Directory Type. Set root_dir_type to 'json' or 'parquet'")

Expand Down