Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions apps/Castor/configs/eval.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
# python -m apps.Castor.generate config=apps/Castor/configs/eval.yaml
name: "debug_evals"
stage: eval
ckpt_dir: "/mnt/pollux/checkpoints/chandan/qwen2_5_vl_flux_dynamic/checkpoints/0000005000/"
dump_dir: "/mnt/pollux/checkpoints/chandan/qwen2_5_vl_flux_dynamic/MJHQ/0000015000"
ckpt_dir: "/mnt/pollux/checkpoints/chandan/qwen2_5_vl_hunyuan_dynamic/checkpoints/0000007500"
dump_dir: "/mnt/pollux/checkpoints/chandan/qwen2_5_vl_hunyuan_dynamic/MJHQ/0000007500"
generator:
guidance_scale: 6.5
dtype: bf16
resolution: 256
resolution: 512
show_progress: False
inference_steps: 50
vae_scale_factor: 8.0
tvae:
vae_type: flux
pretrained_model_name_or_path: '/mnt/pollux/checkpoints/FLUX.1-dev/vae'
vae_type: hunyuan
pretrained_model_name_or_path: '/mnt/pollux/checkpoints/HunyuanVideo/vae'
enable_tiling: false
enable_slicing: false
data:
Expand Down
117 changes: 117 additions & 0 deletions apps/Castor/configs/train_bucket_256_Castor_hunyuan_qwen_dynamic.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# torchrun --standalone --nnodes 1 --nproc-per-node 8 -m apps.Castor.train config=apps/Castor/configs/train_bucket_256_Castor_image_cache_json.yaml

# Set up single experiment
version: v1.0
# From now on, start to align train, data, and model setting with train stage (just finish refactor for dara)
train_stage: stage-1 # options: preliminary, pretraining, posttraining; aligned with data setting
name: qwen2_5_vl_hunyuan_dynamic #used for local dump and wandb log
output_dir: /mnt/pollux/checkpoints/chandan
dump_dir: '' # No need now
steps: 500000
seed: 777
optim:
lr: 1e-4
warmup: 4000
lr_min_ratio: 1.5e-5
clip: 1.0
weight_decay: 0.01

distributed:
gpus: 0,1,2,3,4,5,6,7
fsdp_type: full_shard
dp_shard: 8
dp_replicate: 1
compile: false
model_dtype: bf16 # options: `fb8` is only supported by H100
matmul_allow_tf32: false
selective_activation_checkpointing: true
tp_size: 1
compile_cache_size_limit: 64

model:
scheduler:
num_train_timesteps: 1000
base_image_seq_len: 256
base_shift: 0.5
max_image_seq_len: 4096
max_shift: 1.15
shift: 1.0 # need consider 3.0 or 1.0
weighting_scheme: 'logit_normal'
logit_mean: 0.0
logit_std: 1.0
mode_scale: 1.29
use_dynamic_shifting: true
diffusion_model:
dim: 2048
ffn_dim_multiplier: 1.5
multiple_of: 256
n_heads: 32
n_kv_heads: 8
n_layers: 24
time_step_dim: 2048
patch_size: 2
in_channels: 16
out_channels: 16
tmb_size: 256
gen_seqlen: 32
condition_seqlen: 256
norm_eps: 1e-5
condition_dim: 2048
qk_norm: false
text_cfg_ratio: 0.1
with_vae: true
vae_args:
vae_type: "hunyuan"
pretrained_model_name_or_path: '/mnt/pollux/checkpoints/HunyuanVideo/vae'
enable_tiling: false
enable_slicing: false
text_encoder:
config_name: "Qwen/Qwen2.5-VL-3B-Instruct"
dtype: "bf16"
text_seqlen: 256

data:
- stage: stage-1
id: 1
data_name: bucket-256-2
task: text_to_image
source: mongodb
image_size: 256
condition_image_size: 256
max_ratio: 2.0
partition_key: 'partition_key'
retries: 3
extract_field:
"media_path": "image"
use: true
root_dir: "/mnt/pollux/mongo_db_cache_train"
dataloader:
prefetch_factor: 2
batch_size: 48
num_workers: 8
seed: 1024
shuffle: True
pin_memory: True
drop_last: False

profiling:
run: false

checkpoint:
dump:
every: 2500
keep: 0 # Don't remove the ckpt
eval:
every: 5000
keep: 0 # Don't remove the ckpt

logging:
freq: 100
wandb:
project: Pollux
entity: metauto
name: ''

env:
ENABLE_INTRA_NODE_COMM: '0' # '0' for local machine (otherwise errors happen); '1' for slurmn (need test)
NCCL_DEBUG: 'ERROR'
3 changes: 2 additions & 1 deletion apps/Castor/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,12 @@ def forward(self, context: Dict[str, Any]) -> torch.Tensor:
latent_model_input = torch.cat([latent] * 2)
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.model.diffusion_transformer(
x=latent_model_input,
x=[latent for latent in latent_model_input],
time_steps=timestep,
condition=context,
condition_mask=context_mask,
)
noise_pred = torch.stack(noise_pred)
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred_text - noise_pred_uncond
Expand Down
87 changes: 22 additions & 65 deletions apps/Castor/modules/component.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from enum import Enum
import math
from types import SimpleNamespace
from typing import Optional, Union, Tuple

import torch
Expand All @@ -13,6 +14,7 @@
_mask_mod_signature,
create_block_mask,
)
from liger_kernel.transformers import LigerSwiGLUMLP, LigerRMSNorm
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
import warnings
Expand Down Expand Up @@ -139,35 +141,6 @@ def precompute_2d_freqs_cls(
return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
"""
Reshape frequency tensor for broadcasting it with another tensor.

This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.

Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
seq_dim (int): Sequence dimension index.

Returns:
torch.Tensor: Reshaped frequency tensor.
"""
ndim = x.ndim
assert 0 <= seq_dim < ndim
assert freqs_cis.shape == (
x.shape[seq_dim],
x.shape[-3],
2,
2,
), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
shape = [
d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
] + [2, 2]
return freqs_cis.view(*shape)


def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
Expand All @@ -176,9 +149,8 @@ def apply_rotary_emb(
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
freqs_cis = reshape_for_broadcast(
freqs_cis, xq_, seq_dim
).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
# B S D/2 2 2 -> B S 1 D/2 2 2
freqs_cis = freqs_cis.unsqueeze(seq_dim)
xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
Expand Down Expand Up @@ -332,18 +304,15 @@ class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x: torch.Tensor):
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
# casting gemma: where everything is cast to fp32, then computed, then cast back to the original dtype.
# casting llama: where only the inverse RMS is computed on fp32.
self.rms_norm = LigerRMSNorm(dim, init_fn="ones", eps=self.eps, casting_mode="llama")

def forward(self, x: torch.Tensor):
x = probe.log_stats(x, "resid")
output = self._norm(x.float())
return (output * self.weight.float()).type_as(x)
return self.rms_norm(x)

def reset_parameters(self):
torch.nn.init.ones_(self.weight) # type: ignore
torch.nn.init.ones_(self.rms_norm.weight)


class Attention(nn.Module):
Expand Down Expand Up @@ -400,7 +369,7 @@ def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
freq_cis: torch.Tensor,
freqs_cis: torch.Tensor,
tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
attn_impl: str = "sdpa",
Expand All @@ -420,7 +389,7 @@ def forward(
xq = self.q_norm(xq)
xk = self.k_norm(xk)

xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len])
xq, xk = apply_rotary_emb(xq, xk, 2, freqs_cis[:, 0:seq_len])

# This condition helps us be easily compatible
# with inference by adding a pluggable KVCache
Expand Down Expand Up @@ -637,7 +606,7 @@ def forward(
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq = self.q_norm(xq)
xk = self.k_norm(xk)
xq, xk = apply_rotary_emb(xq, xk, 1, freqs_cis[0:seqlen])
xq, xk = apply_rotary_emb(xq, xk, 2, freqs_cis[:, 0:seqlen])
xq, xk = xq.to(dtype), xk.to(dtype)

softmax_scale = math.sqrt(1 / self.head_dim)
Expand Down Expand Up @@ -713,34 +682,22 @@ def __init__(
self.dim = dim
self.hidden_dim = hidden_dim

self.w1 = nn.Linear(
dim,
hidden_dim,
bias=False,
)
self.w3 = nn.Linear(
dim,
hidden_dim,
bias=False,
)
self.w2 = nn.Linear(
hidden_dim,
dim,
bias=False,
config = SimpleNamespace(
hidden_size=dim,
intermediate_size=hidden_dim,
hidden_act="silu",
)
self.swiglu = LigerSwiGLUMLP(config)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# B S D
x1 = self.w1(x.view_as(x))
x3 = self.w3(x.view_as(x))
output = self.w2(F.silu(x1) * x3)
return output
return self.swiglu(x)

def reset_parameters(self, init_std=None, factor=1.0):
in_init_std = init_std or (self.dim ** (-0.5))
out_init_std = init_std or (self.hidden_dim ** (-0.5))
out_init_std = out_init_std / factor
for w in [self.w1, self.w3]:
for w in [self.swiglu.gate_proj, self.swiglu.up_proj]:
nn.init.trunc_normal_(
w.weight,
mean=0.0,
Expand All @@ -749,7 +706,7 @@ def reset_parameters(self, init_std=None, factor=1.0):
b=3 * in_init_std,
)
nn.init.trunc_normal_(
self.w2.weight,
self.swiglu.down_proj.weight,
mean=0.0,
std=out_init_std,
a=-3 * out_init_std,
Expand Down Expand Up @@ -791,15 +748,15 @@ def __init__(self, args: BaseTransformerArgs):
def forward(
self,
x: torch.Tensor,
freq_cis: torch.Tensor,
freqs_cis: torch.Tensor,
tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
attn_impl: str = "sdpa",
) -> torch.Tensor:

h = x + self.attention(
self.attention_norm(x),
freq_cis,
freqs_cis,
tok_idx=tok_idx,
mask=mask,
attn_impl=attn_impl,
Expand Down
4 changes: 4 additions & 0 deletions apps/Castor/modules/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def __call__(self, batch: dict[str:any]) -> Tuple[torch.Tensor, torch.Tensor]:
class Qwen2_5_VL(BaseTextEncoder):
def __init__(self, args: TextEncoderArgs):
super().__init__(args)

from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl
apply_liger_kernel_to_qwen2_5_vl()

self.model = AutoModel.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype=self.dtype,
Expand Down
15 changes: 7 additions & 8 deletions apps/Castor/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,13 @@ def patchify_and_embed_image(
use_dynamic_res = isinstance(x, list)
if use_dynamic_res:
cond_l = condition_mask.sum(dim=1, dtype=torch.int32).tolist()
max_cond_l = max(cond_l)
bsz = len(x)
H_list = [x[i].size(1) for i in range(bsz)]
W_list = [x[i].size(2) for i in range(bsz)]
H_max = max(H_list)
W_max = max(W_list)
max_seq_len = max_cond_l + (H_max // pH) * (W_max // pW)
max_seq_len = max([cond_l[i] + (H_list[i] // pH) * (W_list[i] // pW) for i in range(bsz)])
x_new = torch.zeros(bsz, max_seq_len, self.dim, dtype=x[0].dtype).to(x[0].device)
x_mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool).to(x[0].device)
freqs_cis = torch.zeros((bsz, max_seq_len,) + (self.rope_embeddings_conditions.freqs_cis.shape[-3:]), dtype=x[0].dtype).to(x[0].device)
for i in range(bsz):
_x = x[i]
C, H, W = x[i].size()
Expand All @@ -263,10 +261,11 @@ def patchify_and_embed_image(
x_new[i, :cond_l[i]] = condition[i, :cond_l[i]] # TODO: assumes condition is right padded!
x_new[i, cond_l[i]:cond_l[i] + (H // pH) * (W // pW)] = _x
x_mask[i, :cond_l[i] + (H // pH) * (W // pW)] = True
# rope embeddings
freqs_cis_cond = self.rope_embeddings_conditions.freqs_cis[:max_cond_l].to(x[0].device)
freqs_cis_img = self.rope_embeddings_image.freqs_cis[: H_max // pH, : W_max // pW].flatten(0, 1)
freqs_cis = torch.cat([freqs_cis_cond, freqs_cis_img], dim=0)

# rope embeddings
freqs_cis[i, :cond_l[i]] = self.rope_embeddings_conditions.freqs_cis[:cond_l[i]].to(x[0].device)
freqs_cis[i, cond_l[i]:cond_l[i] + (H // pH) * (W // pW)] = self.rope_embeddings_image.freqs_cis[: H // pH, : W // pW].flatten(0, 1).to(x[0].device)

return x_new, x_mask, cond_l, (H_list, W_list), freqs_cis
else:
B, C, H, W = x.size()
Expand Down
2 changes: 1 addition & 1 deletion apps/main/utils/mongodb_data_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def set_local_partition(self):
if partition_key % self.num_shards == self.shard_idx:
data.append(item)
# Note: used for debugging
# if len(data) > 10000:
# if len(data) > 2000000:
# break
self.data = pd.DataFrame(data).reset_index()
end_time = time.time() # Record the end time
Expand Down