From dd8cc72b4823f47987b65f80971d0fce859dbb85 Mon Sep 17 00:00:00 2001 From: sippycoder Date: Wed, 9 Apr 2025 09:12:22 +0000 Subject: [PATCH 1/2] changes as of 090425 --- apps/Castor/configs/eval.yaml | 10 +- ...ucket_256_Castor_hunyuan_qwen_dynamic.yaml | 117 ++++++++++++++ apps/Castor/generate.py | 3 +- apps/Castor/modules/component.py | 151 ++++++++++-------- apps/Castor/modules/text_encoder.py | 4 + apps/Castor/modules/transformer.py | 15 +- apps/main/utils/mongodb_data_load.py | 2 +- 7 files changed, 222 insertions(+), 80 deletions(-) create mode 100644 apps/Castor/configs/train_bucket_256_Castor_hunyuan_qwen_dynamic.yaml diff --git a/apps/Castor/configs/eval.yaml b/apps/Castor/configs/eval.yaml index 246d43d6..2ad0cbe4 100644 --- a/apps/Castor/configs/eval.yaml +++ b/apps/Castor/configs/eval.yaml @@ -1,18 +1,18 @@ # python -m apps.Castor.generate config=apps/Castor/configs/eval.yaml name: "debug_evals" stage: eval -ckpt_dir: "/mnt/pollux/checkpoints/chandan/qwen2_5_vl_flux_dynamic/checkpoints/0000005000/" -dump_dir: "/mnt/pollux/checkpoints/chandan/qwen2_5_vl_flux_dynamic/MJHQ/0000015000" +ckpt_dir: "/mnt/pollux/checkpoints/chandan/qwen2_5_vl_hunyuan_dynamic/checkpoints/0000007500" +dump_dir: "/mnt/pollux/checkpoints/chandan/qwen2_5_vl_hunyuan_dynamic/MJHQ/0000007500" generator: guidance_scale: 6.5 dtype: bf16 - resolution: 256 + resolution: 512 show_progress: False inference_steps: 50 vae_scale_factor: 8.0 tvae: - vae_type: flux - pretrained_model_name_or_path: '/mnt/pollux/checkpoints/FLUX.1-dev/vae' + vae_type: hunyuan + pretrained_model_name_or_path: '/mnt/pollux/checkpoints/HunyuanVideo/vae' enable_tiling: false enable_slicing: false data: diff --git a/apps/Castor/configs/train_bucket_256_Castor_hunyuan_qwen_dynamic.yaml b/apps/Castor/configs/train_bucket_256_Castor_hunyuan_qwen_dynamic.yaml new file mode 100644 index 00000000..bc28fbc8 --- /dev/null +++ b/apps/Castor/configs/train_bucket_256_Castor_hunyuan_qwen_dynamic.yaml @@ -0,0 +1,117 @@ +# torchrun --standalone --nnodes 1 --nproc-per-node 8 -m apps.Castor.train config=apps/Castor/configs/train_bucket_256_Castor_image_cache_json.yaml + +# Set up single experiment +version: v1.0 +# From now on, start to align train, data, and model setting with train stage (just finish refactor for dara) +train_stage: stage-1 # options: preliminary, pretraining, posttraining; aligned with data setting +name: qwen2_5_vl_hunyuan_dynamic #used for local dump and wandb log +output_dir: /mnt/pollux/checkpoints/chandan +dump_dir: '' # No need now +steps: 500000 +seed: 777 +optim: + lr: 1e-4 + warmup: 4000 + lr_min_ratio: 1.5e-5 + clip: 1.0 + weight_decay: 0.01 + +distributed: + gpus: 0,1,2,3,4,5,6,7 + fsdp_type: full_shard + dp_shard: 8 + dp_replicate: 1 + compile: false + model_dtype: bf16 # options: `fb8` is only supported by H100 + matmul_allow_tf32: false + selective_activation_checkpointing: true + tp_size: 1 + compile_cache_size_limit: 64 + +model: + scheduler: + num_train_timesteps: 1000 + base_image_seq_len: 256 + base_shift: 0.5 + max_image_seq_len: 4096 + max_shift: 1.15 + shift: 1.0 # need consider 3.0 or 1.0 + weighting_scheme: 'logit_normal' + logit_mean: 0.0 + logit_std: 1.0 + mode_scale: 1.29 + use_dynamic_shifting: true + diffusion_model: + dim: 2048 + ffn_dim_multiplier: 1.5 + multiple_of: 256 + n_heads: 32 + n_kv_heads: 8 + n_layers: 24 + time_step_dim: 2048 + patch_size: 2 + in_channels: 16 + out_channels: 16 + tmb_size: 256 + gen_seqlen: 32 + condition_seqlen: 256 + norm_eps: 1e-5 + condition_dim: 2048 + qk_norm: false + text_cfg_ratio: 0.1 + with_vae: true + vae_args: + vae_type: "hunyuan" + pretrained_model_name_or_path: '/mnt/pollux/checkpoints/HunyuanVideo/vae' + enable_tiling: false + enable_slicing: false + text_encoder: + config_name: "Qwen/Qwen2.5-VL-3B-Instruct" + dtype: "bf16" + text_seqlen: 256 + +data: + - stage: stage-1 + id: 1 + data_name: bucket-256-2 + task: text_to_image + source: mongodb + image_size: 256 + condition_image_size: 256 + max_ratio: 2.0 + partition_key: 'partition_key' + retries: 3 + extract_field: + "media_path": "image" + use: true + root_dir: "/mnt/pollux/mongo_db_cache_train" + dataloader: + prefetch_factor: 2 + batch_size: 48 + num_workers: 8 + seed: 1024 + shuffle: True + pin_memory: True + drop_last: False + +profiling: + run: false + +checkpoint: + dump: + every: 2500 + keep: 0 # Don't remove the ckpt + eval: + every: 5000 + keep: 0 # Don't remove the ckpt + +logging: + freq: 100 + wandb: + project: Pollux + entity: metauto + name: '' + +env: + ENABLE_INTRA_NODE_COMM: '0' # '0' for local machine (otherwise errors happen); '1' for slurmn (need test) + NCCL_DEBUG: 'ERROR' \ No newline at end of file diff --git a/apps/Castor/generate.py b/apps/Castor/generate.py index a8241c39..dbd07fe6 100644 --- a/apps/Castor/generate.py +++ b/apps/Castor/generate.py @@ -115,11 +115,12 @@ def forward(self, context: Dict[str, Any]) -> torch.Tensor: latent_model_input = torch.cat([latent] * 2) timestep = t.expand(latent_model_input.shape[0]) noise_pred = self.model.diffusion_transformer( - x=latent_model_input, + x=[latent for latent in latent_model_input], time_steps=timestep, condition=context, condition_mask=context_mask, ) + noise_pred = torch.stack(noise_pred) noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * ( noise_pred_text - noise_pred_uncond diff --git a/apps/Castor/modules/component.py b/apps/Castor/modules/component.py index b5a692db..1d31a7d7 100644 --- a/apps/Castor/modules/component.py +++ b/apps/Castor/modules/component.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from enum import Enum import math +from types import SimpleNamespace from typing import Optional, Union, Tuple import torch @@ -13,6 +14,7 @@ _mask_mod_signature, create_block_mask, ) +from liger_kernel.transformers import LigerSwiGLUMLP, LigerRMSNorm from flash_attn import flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa import warnings @@ -139,35 +141,6 @@ def precompute_2d_freqs_cls( return freqs_cis -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int): - """ - Reshape frequency tensor for broadcasting it with another tensor. - - This function reshapes the frequency tensor to have the same shape as the target tensor 'x' - for the purpose of broadcasting the frequency tensor during element-wise operations. - - Args: - freqs_cis (torch.Tensor): Frequency tensor to be reshaped. - x (torch.Tensor): Target tensor for broadcasting compatibility. - seq_dim (int): Sequence dimension index. - - Returns: - torch.Tensor: Reshaped frequency tensor. - """ - ndim = x.ndim - assert 0 <= seq_dim < ndim - assert freqs_cis.shape == ( - x.shape[seq_dim], - x.shape[-3], - 2, - 2, - ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}" - shape = [ - d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2]) - ] + [2, 2] - return freqs_cis.view(*shape) - - def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, @@ -176,9 +149,8 @@ def apply_rotary_emb( ) -> Tuple[torch.Tensor, torch.Tensor]: xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 - freqs_cis = reshape_for_broadcast( - freqs_cis, xq_, seq_dim - ).float() # S D/2 2 2 -> 1 S 1 D/2 2 2 + # B S D/2 2 2 -> B S 1 D/2 2 2 + freqs_cis = freqs_cis.unsqueeze(seq_dim) xq_out = (xq_ * freqs_cis).sum(5).flatten(3) xk_out = (xk_ * freqs_cis).sum(5).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) @@ -332,18 +304,15 @@ class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x: torch.Tensor): - return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) + # casting gemma: where everything is cast to fp32, then computed, then cast back to the original dtype. + # casting llama: where only the inverse RMS is computed on fp32. + self.rms_norm = LigerRMSNorm(dim, init_fn="ones", eps=self.eps, casting_mode="llama") def forward(self, x: torch.Tensor): - x = probe.log_stats(x, "resid") - output = self._norm(x.float()) - return (output * self.weight.float()).type_as(x) + return self.rms_norm(x) def reset_parameters(self): - torch.nn.init.ones_(self.weight) # type: ignore + torch.nn.init.ones_(self.rms_norm.weight) class Attention(nn.Module): @@ -400,7 +369,7 @@ def forward( self, x: torch.Tensor, x_mask: torch.Tensor, - freq_cis: torch.Tensor, + freqs_cis: torch.Tensor, tok_idx: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, AttentionBias, str]] = None, attn_impl: str = "sdpa", @@ -420,7 +389,7 @@ def forward( xq = self.q_norm(xq) xk = self.k_norm(xk) - xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len]) + xq, xk = apply_rotary_emb(xq, xk, 2, freqs_cis[:, 0:seq_len]) # This condition helps us be easily compatible # with inference by adding a pluggable KVCache @@ -637,7 +606,7 @@ def forward( xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) xq = self.q_norm(xq) xk = self.k_norm(xk) - xq, xk = apply_rotary_emb(xq, xk, 1, freqs_cis[0:seqlen]) + xq, xk = apply_rotary_emb(xq, xk, 2, freqs_cis[:, 0:seqlen]) xq, xk = xq.to(dtype), xk.to(dtype) softmax_scale = math.sqrt(1 / self.head_dim) @@ -693,6 +662,70 @@ def forward( return self.wo(output) +# class FeedForward(nn.Module): +# def __init__( +# self, +# dim: int, +# hidden_dim: int, +# multiple_of: int, +# ffn_dim_multiplier: Optional[float], +# mp_size: int = 1, +# ): +# super().__init__() + +# hidden_dim = int(2 * hidden_dim / 3) +# if ffn_dim_multiplier is not None: +# hidden_dim = int(ffn_dim_multiplier * hidden_dim) +# hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) +# assert hidden_dim % mp_size == 0 + +# self.dim = dim +# self.hidden_dim = hidden_dim + +# self.w1 = nn.Linear( +# dim, +# hidden_dim, +# bias=False, +# ) +# self.w3 = nn.Linear( +# dim, +# hidden_dim, +# bias=False, +# ) +# self.w2 = nn.Linear( +# hidden_dim, +# dim, +# bias=False, +# ) + +# def forward(self, x: torch.Tensor) -> torch.Tensor: +# # B S D +# x1 = self.w1(x.view_as(x)) +# x3 = self.w3(x.view_as(x)) +# output = self.w2(F.silu(x1) * x3) +# return output + +# def reset_parameters(self, init_std=None, factor=1.0): +# in_init_std = init_std or (self.dim ** (-0.5)) +# out_init_std = init_std or (self.hidden_dim ** (-0.5)) +# out_init_std = out_init_std / factor +# for w in [self.w1, self.w3]: +# nn.init.trunc_normal_( +# w.weight, +# mean=0.0, +# std=in_init_std, +# a=-3 * in_init_std, +# b=3 * in_init_std, +# ) +# nn.init.trunc_normal_( +# self.w2.weight, +# mean=0.0, +# std=out_init_std, +# a=-3 * out_init_std, +# b=3 * out_init_std, +# ) + + class FeedForward(nn.Module): def __init__( self, @@ -713,34 +746,22 @@ def __init__( self.dim = dim self.hidden_dim = hidden_dim - self.w1 = nn.Linear( - dim, - hidden_dim, - bias=False, - ) - self.w3 = nn.Linear( - dim, - hidden_dim, - bias=False, - ) - self.w2 = nn.Linear( - hidden_dim, - dim, - bias=False, + config = SimpleNamespace( + hidden_size=dim, + intermediate_size=hidden_dim, + hidden_act="silu", ) + self.swiglu = LigerSwiGLUMLP(config) def forward(self, x: torch.Tensor) -> torch.Tensor: # B S D - x1 = self.w1(x.view_as(x)) - x3 = self.w3(x.view_as(x)) - output = self.w2(F.silu(x1) * x3) - return output + return self.swiglu(x) def reset_parameters(self, init_std=None, factor=1.0): in_init_std = init_std or (self.dim ** (-0.5)) out_init_std = init_std or (self.hidden_dim ** (-0.5)) out_init_std = out_init_std / factor - for w in [self.w1, self.w3]: + for w in [self.swiglu.gate_proj, self.swiglu.up_proj]: nn.init.trunc_normal_( w.weight, mean=0.0, @@ -749,7 +770,7 @@ def reset_parameters(self, init_std=None, factor=1.0): b=3 * in_init_std, ) nn.init.trunc_normal_( - self.w2.weight, + self.swiglu.down_proj.weight, mean=0.0, std=out_init_std, a=-3 * out_init_std, @@ -791,7 +812,7 @@ def __init__(self, args: BaseTransformerArgs): def forward( self, x: torch.Tensor, - freq_cis: torch.Tensor, + freqs_cis: torch.Tensor, tok_idx: Optional[torch.Tensor] = None, mask: Optional[Union[BlockMask, AttentionBias, str]] = None, attn_impl: str = "sdpa", @@ -799,7 +820,7 @@ def forward( h = x + self.attention( self.attention_norm(x), - freq_cis, + freqs_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl, diff --git a/apps/Castor/modules/text_encoder.py b/apps/Castor/modules/text_encoder.py index a3617014..601a1d2e 100644 --- a/apps/Castor/modules/text_encoder.py +++ b/apps/Castor/modules/text_encoder.py @@ -67,6 +67,10 @@ def __call__(self, batch: dict[str:any]) -> Tuple[torch.Tensor, torch.Tensor]: class Qwen2_5_VL(BaseTextEncoder): def __init__(self, args: TextEncoderArgs): super().__init__(args) + + from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl + apply_liger_kernel_to_qwen2_5_vl() + self.model = AutoModel.from_pretrained( "Qwen/Qwen2.5-VL-3B-Instruct", torch_dtype=self.dtype, diff --git a/apps/Castor/modules/transformer.py b/apps/Castor/modules/transformer.py index 3b0f1238..5e14d57f 100644 --- a/apps/Castor/modules/transformer.py +++ b/apps/Castor/modules/transformer.py @@ -242,15 +242,13 @@ def patchify_and_embed_image( use_dynamic_res = isinstance(x, list) if use_dynamic_res: cond_l = condition_mask.sum(dim=1, dtype=torch.int32).tolist() - max_cond_l = max(cond_l) bsz = len(x) H_list = [x[i].size(1) for i in range(bsz)] W_list = [x[i].size(2) for i in range(bsz)] - H_max = max(H_list) - W_max = max(W_list) - max_seq_len = max_cond_l + (H_max // pH) * (W_max // pW) + max_seq_len = max([cond_l[i] + (H_list[i] // pH) * (W_list[i] // pW) for i in range(bsz)]) x_new = torch.zeros(bsz, max_seq_len, self.dim, dtype=x[0].dtype).to(x[0].device) x_mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool).to(x[0].device) + freqs_cis = torch.zeros((bsz, max_seq_len,) + (self.rope_embeddings_conditions.freqs_cis.shape[-3:]), dtype=x[0].dtype).to(x[0].device) for i in range(bsz): _x = x[i] C, H, W = x[i].size() @@ -263,10 +261,11 @@ def patchify_and_embed_image( x_new[i, :cond_l[i]] = condition[i, :cond_l[i]] # TODO: assumes condition is right padded! x_new[i, cond_l[i]:cond_l[i] + (H // pH) * (W // pW)] = _x x_mask[i, :cond_l[i] + (H // pH) * (W // pW)] = True - # rope embeddings - freqs_cis_cond = self.rope_embeddings_conditions.freqs_cis[:max_cond_l].to(x[0].device) - freqs_cis_img = self.rope_embeddings_image.freqs_cis[: H_max // pH, : W_max // pW].flatten(0, 1) - freqs_cis = torch.cat([freqs_cis_cond, freqs_cis_img], dim=0) + + # rope embeddings + freqs_cis[i, :cond_l[i]] = self.rope_embeddings_conditions.freqs_cis[:cond_l[i]].to(x[0].device) + freqs_cis[i, cond_l[i]:cond_l[i] + (H // pH) * (W // pW)] = self.rope_embeddings_image.freqs_cis[: H // pH, : W // pW].flatten(0, 1).to(x[0].device) + return x_new, x_mask, cond_l, (H_list, W_list), freqs_cis else: B, C, H, W = x.size() diff --git a/apps/main/utils/mongodb_data_load.py b/apps/main/utils/mongodb_data_load.py index 3eb3b88d..aaa872c7 100644 --- a/apps/main/utils/mongodb_data_load.py +++ b/apps/main/utils/mongodb_data_load.py @@ -128,7 +128,7 @@ def set_local_partition(self): if partition_key % self.num_shards == self.shard_idx: data.append(item) # Note: used for debugging - # if len(data) > 10000: + # if len(data) > 2000000: # break self.data = pd.DataFrame(data).reset_index() end_time = time.time() # Record the end time From 3f92ef67ed4ee402c0905cfb145df455cf58e149 Mon Sep 17 00:00:00 2001 From: sippycoder Date: Thu, 10 Apr 2025 08:19:45 +0000 Subject: [PATCH 2/2] clean up --- apps/Castor/modules/component.py | 64 -------------------------------- 1 file changed, 64 deletions(-) diff --git a/apps/Castor/modules/component.py b/apps/Castor/modules/component.py index 1d31a7d7..bafc495a 100644 --- a/apps/Castor/modules/component.py +++ b/apps/Castor/modules/component.py @@ -662,70 +662,6 @@ def forward( return self.wo(output) -# class FeedForward(nn.Module): -# def __init__( -# self, -# dim: int, -# hidden_dim: int, -# multiple_of: int, -# ffn_dim_multiplier: Optional[float], -# mp_size: int = 1, -# ): -# super().__init__() - -# hidden_dim = int(2 * hidden_dim / 3) -# if ffn_dim_multiplier is not None: -# hidden_dim = int(ffn_dim_multiplier * hidden_dim) -# hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) -# assert hidden_dim % mp_size == 0 - -# self.dim = dim -# self.hidden_dim = hidden_dim - -# self.w1 = nn.Linear( -# dim, -# hidden_dim, -# bias=False, -# ) -# self.w3 = nn.Linear( -# dim, -# hidden_dim, -# bias=False, -# ) -# self.w2 = nn.Linear( -# hidden_dim, -# dim, -# bias=False, -# ) - -# def forward(self, x: torch.Tensor) -> torch.Tensor: -# # B S D -# x1 = self.w1(x.view_as(x)) -# x3 = self.w3(x.view_as(x)) -# output = self.w2(F.silu(x1) * x3) -# return output - -# def reset_parameters(self, init_std=None, factor=1.0): -# in_init_std = init_std or (self.dim ** (-0.5)) -# out_init_std = init_std or (self.hidden_dim ** (-0.5)) -# out_init_std = out_init_std / factor -# for w in [self.w1, self.w3]: -# nn.init.trunc_normal_( -# w.weight, -# mean=0.0, -# std=in_init_std, -# a=-3 * in_init_std, -# b=3 * in_init_std, -# ) -# nn.init.trunc_normal_( -# self.w2.weight, -# mean=0.0, -# std=out_init_std, -# a=-3 * out_init_std, -# b=3 * out_init_std, -# ) - - class FeedForward(nn.Module): def __init__( self,