From 382986d7a255bcd4e75bef5d1f34bde882313bb1 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Sun, 27 Apr 2025 19:22:23 +0000 Subject: [PATCH 1/6] Applying liger Kernals to Qwen VL --- apps/Castor/modules/text_encoder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/apps/Castor/modules/text_encoder.py b/apps/Castor/modules/text_encoder.py index a952c9a..155baaa 100644 --- a/apps/Castor/modules/text_encoder.py +++ b/apps/Castor/modules/text_encoder.py @@ -16,6 +16,7 @@ GemmaTokenizerFast, UMT5EncoderModel, ) +from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl logger = logging.getLogger() from typing import Optional @@ -93,6 +94,7 @@ def __init__(self, args: TextEncoderArgs): "Qwen/Qwen2.5-VL-3B-Instruct" if args.model_path == "" else args.model_path, torch_dtype=self.dtype, ).cuda() + apply_liger_kernel_to_qwen2_5_vl(self.model) if args.layers_to_use is not None: self.model.layers = self.model.layers[: args.layers_to_use] self.model.eval() From 84d2693f8aa5071f0065511e7c52258685d0d46d Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Sun, 27 Apr 2025 19:58:34 +0000 Subject: [PATCH 2/6] Added torch compile regions over VAE --- apps/Castor/modules/vae.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/apps/Castor/modules/vae.py b/apps/Castor/modules/vae.py index 0410a20..33fde08 100644 --- a/apps/Castor/modules/vae.py +++ b/apps/Castor/modules/vae.py @@ -191,6 +191,7 @@ def __init__(self, args: VideoVAEArgs): self.scale = 0.3611 self.shift = 0.1159 + @torch.compile @torch.no_grad() def encode(self, x: torch.Tensor) -> torch.Tensor: x = x.to(device=self.vae.device, dtype=self.vae.dtype) @@ -199,6 +200,7 @@ def encode(self, x: torch.Tensor) -> torch.Tensor: ) * self.scale return x + @torch.compile @torch.no_grad() def decode(self, x: torch.Tensor) -> torch.Tensor: x = x.to(device=self.vae.device, dtype=self.vae.dtype) @@ -207,7 +209,8 @@ def decode(self, x: torch.Tensor) -> torch.Tensor: # Use the VAE's decode method and get the sample decoded = self.vae.decode(x).sample return decoded - + + @torch.compile @torch.no_grad() def forward(self, x=torch.Tensor): x = self.encode(x) From 7b612de321d59ce5fe412fd76783a9bf8af371a2 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Sun, 27 Apr 2025 20:11:53 +0000 Subject: [PATCH 3/6] Added torch compile region over Qwen lm --- apps/Castor/modules/text_encoder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/apps/Castor/modules/text_encoder.py b/apps/Castor/modules/text_encoder.py index 155baaa..96f75e2 100644 --- a/apps/Castor/modules/text_encoder.py +++ b/apps/Castor/modules/text_encoder.py @@ -97,6 +97,7 @@ def __init__(self, args: TextEncoderArgs): apply_liger_kernel_to_qwen2_5_vl(self.model) if args.layers_to_use is not None: self.model.layers = self.model.layers[: args.layers_to_use] + self.model = torch.compile(self.model) self.model.eval() self.model.requires_grad_(False) self.tokenizer = AutoTokenizer.from_pretrained( From fd5de9444e699fe3fff51b045c90cce2794d2828 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Mon, 28 Apr 2025 07:42:15 +0000 Subject: [PATCH 4/6] Adds proper padding of text condition --- apps/Castor/modules/text_encoder.py | 3 + apps/Castor/modules/transformer.py | 119 +++++++++++++++------------- requirements.txt | 4 +- 3 files changed, 71 insertions(+), 55 deletions(-) diff --git a/apps/Castor/modules/text_encoder.py b/apps/Castor/modules/text_encoder.py index 96f75e2..c8f1e12 100644 --- a/apps/Castor/modules/text_encoder.py +++ b/apps/Castor/modules/text_encoder.py @@ -130,6 +130,9 @@ def _convert_caption_to_messages(self, caption: str) -> str: ) def __call__(self, batch: dict[str:any]) -> Tuple[torch.Tensor, torch.Tensor]: + """ + returns last_hidden_state and attention_mask, right padded + """ assert "caption" in batch if isinstance(batch["caption"][0], tuple): batch["caption"] = [x[0] for x in batch["caption"]] diff --git a/apps/Castor/modules/transformer.py b/apps/Castor/modules/transformer.py index 244557e..21818f6 100644 --- a/apps/Castor/modules/transformer.py +++ b/apps/Castor/modules/transformer.py @@ -19,6 +19,8 @@ logger = logging.getLogger() +def nearest_multiple_of_8(x): + return ((x + 7) // 8) * 8 @dataclass class TransformerArgs(BaseTransformerArgs): @@ -289,6 +291,7 @@ def patchify_and_embed_image( max_seq_len = max( [cond_l[i] + (H_list[i] // pH) * (W_list[i] // pW) for i in range(bsz)] ) + max_seq_len = nearest_multiple_of_8(max_seq_len) x_new = torch.zeros(bsz, max_seq_len, self.dim, dtype=x[0].dtype).to( x[0].device ) @@ -331,10 +334,16 @@ def patchify_and_embed_image( return x_new, x_mask, cond_l, (H_list, W_list), freqs_cis else: B, C, H, W = x.size() - target_dtype = x.dtype - cond_l = condition.size(1) + condition = condition.to(x.dtype) + assert H % pH == 0, f"H ({H}) should be divisible by pH ({pH})" + assert W % pW == 0, f"W ({W}) should be divisible by pW ({pW})" - # 1. Patchify and embed image tensor + + # 1. Get actual condition lengths from mask + cond_l = condition_mask.sum(dim=1, dtype=torch.int32) # Shape: [B] + max_cond_l = cond_l.max(dim=0)[0].item() + + # 2. Patchify and embed image tensor x_img_patch = ( x.view(B, C, H // pH, pH, W // pW, pW) .permute(0, 2, 4, 3, 5, 1) @@ -342,36 +351,44 @@ def patchify_and_embed_image( ) x_img_embed = self.img_embed(x_img_patch) # Assume img_embed preserves or outputs target_dtype x_img_embed = x_img_embed.flatten(1, 2) # Shape: [B, L_img, D] + L_img = x_img_embed.size(1) - # Ensure image embedding has the target dtype - x_img_embed = x_img_embed.to(target_dtype) - - # 2. Prepare condition tensor - # Cast condition to the target dtype *before* concatenation - condition = condition.to(target_dtype) # Shape: [B, L_cond, D] + # 4. Initialize combined tensors + max_seq_len = nearest_multiple_of_8(max_cond_l + L_img) + x_combined = torch.zeros(B, max_seq_len, self.dim, dtype=x.dtype, device=x.device) + x_mask = torch.zeros(B, max_seq_len, dtype=torch.bool, device=x.device) + freqs_cis = torch.zeros( + ( + B, + max_seq_len, + ) + + (self.rope_embeddings_conditions.freqs_cis.shape[-3:]), + dtype=x.dtype, + ).to(x.device) - # 3. Concatenate condition and image embeddings - x_combined = torch.cat([condition, x_img_embed], dim=1) # Shape: [B, L_cond + L_img, D] + # 5. Precompute RoPE embeddings + freqs_cis_cond_all = self.rope_embeddings_conditions.freqs_cis[:max_cond_l].to(device=x.device, dtype=x.dtype) + freqs_cis_img_all = self.rope_embeddings_image.freqs_cis[: H // pH, : W // pW].flatten(0, 1).to(device=x.device, dtype=x.dtype) # Shape [L_img, ..., D/2, 2] - # 4. Create attention mask - x_mask_img = torch.ones(B, (H // pH) * (W // pW), dtype=torch.bool, device=x.device) - x_mask = torch.cat([condition_mask, x_mask_img], dim=1) # Shape: [B, L_cond + L_img] + # 6. Populate tensors respecting actual condition lengths + for i in range(B): + # Place condition and image embeddings + x_combined[i, :cond_l[i]] = condition[i, :cond_l[i]] + x_combined[i, cond_l[i] : cond_l[i] + L_img] = x_img_embed[i] - # 5. Prepare RoPE embeddings - # Ensure RoPE embeddings are on the correct device and have the target dtype - freqs_cis_cond = self.rope_embeddings_conditions.freqs_cis[:cond_l].to(device=x.device, dtype=target_dtype) - freqs_cis_img = self.rope_embeddings_image.freqs_cis[: H // pH, : W // pW].flatten(0, 1).to(device=x.device, dtype=target_dtype) + # Create mask + x_mask[i, : cond_l[i] + L_img] = True - # Concatenate RoPE embeddings - freqs_cis = torch.cat([freqs_cis_cond, freqs_cis_img], dim=0) # Shape: [L_cond + L_img, ..., D/2, 2] - # Expand for batch dimension - freqs_cis = freqs_cis.unsqueeze(0).expand(B, -1, *([-1] * (freqs_cis.dim() - 1))) # Shape: [B, L_cond + L_img, ..., D/2, 2] + # Create RoPE embeddings + freqs_cis[i, :cond_l[i]] = freqs_cis_cond_all[:cond_l[i]] + freqs_cis[i, cond_l[i] : cond_l[i] + L_img] = freqs_cis_img_all # Image RoPE is the same for all in batch here + # Return list of actual lengths for consistency return ( x_combined, x_mask, - cond_l, - (H, W), + cond_l.tolist(), + (H, W), # Return single H, W tuple freqs_cis, ) @@ -417,7 +434,7 @@ def forward( return output def unpatchify_image( - self, x: torch.Tensor, cond_l: Union[List[int], int], img_size: Tuple[int, int] + self, x: torch.Tensor, cond_l: List[int], img_size: Union[Tuple[int, int], Tuple[List[int], List[int]]] ) -> Union[torch.Tensor, List[torch.Tensor]]: """ Convert patched image features back to the original latent format. @@ -444,36 +461,30 @@ def unpatchify_image( out_x_list.append(_x) return out_x_list else: - # Handle the case where cond_l is an integer, not a list - if isinstance(cond_l, list): - # If cond_l is unexpectedly a list here, take the first element or max? - # Assuming it should be consistent with the input type logic. - # If input was Tensor, cond_l should be int. If list, list. - # This path assumes input was Tensor, so cond_l should be int. - # If it's a list, maybe log a warning or error. - # For now, assume it's an int if we reach here. - if len(cond_l) > 0: - max_cond_l = cond_l[0] # Or max(cond_l) if that makes sense - # logger.warning("cond_l was a list in unpatchify_image tensor path.") - else: - max_cond_l = 0 # Handle empty list case - else: - max_cond_l = cond_l # It's already an int - + # Handle non-dynamic (tensor) case + H, W = img_size + L_img = (H // pH) * (W // pW) B = x.size(0) - L = (H // pH) * (W // pW) - # Ensure slicing indices are correct - img_features = x[:, max_cond_l : max_cond_l + L] - - # Ensure the view dimensions match the extracted features - # B, L_img, D_out_patch = img_features.shape - # D_out_patch should be pH * pW * self.out_channels - # L_img should be (H // pH) * (W // pW) - img_features = img_features.view( - B, H // pH, W // pW, pH, pW, self.out_channels - ) - img_features = img_features.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3) - return img_features + + # Initialize output tensor + # Note: Need to ensure self.out_channels is correctly defined/accessed + out_features = torch.zeros(B, self.out_channels, H, W, dtype=x.dtype, device=x.device) + + for i in range(B): + l_cond = cond_l[i] + + # Extract the image part for this batch item + img_features_i = x[i, l_cond : l_cond + L_img] # Shape [L_img, D_out_patch] + + # Reshape back to image format + # D_out_patch = pH * pW * self.out_channels must hold + img_features_i = img_features_i.view( + H // pH, W // pW, pH, pW, self.out_channels + ) + img_features_i = img_features_i.permute( 5, 1, 3, 2, 4).flatten(3, 4).flatten(2, 3) # [C_out, H, W] + out_features[i] = img_features_i + + return out_features def reset_parameters(self, init_std=None): # Either use fixed base std or sqrt model dim diff --git a/requirements.txt b/requirements.txt index eb4007b..43f58ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,4 +24,6 @@ boto3 python-dotenv ijson flash-attn -transformers \ No newline at end of file +transformers +liger-kernel +timm \ No newline at end of file From f1382beefc43a168cbc9353bd2fc92df9e204a80 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Mon, 28 Apr 2025 23:28:19 +0000 Subject: [PATCH 5/6] Added Qwen 7B --- .../train_bucket_256_Castor_flux_qwen_dynamic.yaml | 10 +++++----- apps/Castor/modules/text_encoder.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_dynamic.yaml b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_dynamic.yaml index cee5fec..bbe7ca0 100644 --- a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_dynamic.yaml +++ b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_dynamic.yaml @@ -70,12 +70,12 @@ model: enable_tiling: false enable_slicing: false text_encoder: - config_name: "Qwen/Qwen2.5-VL-3B-Instruct" + config_name: "Qwen/Qwen2.5-VL-7B-Instruct" dtype: "bf16" text_seqlen: 256 - model_path: "/mnt/pollux/checkpoints/Qwen2.5-VL-3B-Instruct" - # qwen has 36 layers, we can use the first 3/4 layers - layers_to_use: 30 + model_path: "/mnt/pollux/checkpoints/Qwen2.5-VL-7B-Instruct" + # qwen3.5B has 36 and 7B has 28 layers, we can use the first 3/4 layers + layers_to_use: 21 data: - stage: stage-1 id: 1 @@ -84,7 +84,7 @@ data: source: mongodb image_size: 256 condition_image_size: 256 - max_ratio: 2.0 + max_ratio: 1.0 partition_key: 'partition_key' retries: 3 extract_field: diff --git a/apps/Castor/modules/text_encoder.py b/apps/Castor/modules/text_encoder.py index c8f1e12..2d56f66 100644 --- a/apps/Castor/modules/text_encoder.py +++ b/apps/Castor/modules/text_encoder.py @@ -241,7 +241,7 @@ def __call__(self, batch: dict[str:any]) -> Tuple[torch.Tensor, torch.Tensor]: def create_text_encoder(args: TextEncoderArgs) -> BaseTextEncoder: if args.config_name == "ViT-B/32": return CLIP(args) - elif args.config_name == "Qwen/Qwen2.5-VL-3B-Instruct": + elif args.config_name == "Qwen/Qwen2.5-VL-3B-Instruct" or args.config_name == "Qwen/Qwen2.5-VL-7B-Instruct": return Qwen2_5_VL(args) elif args.config_name == "Gemma2_2B_it": return Gemma2_2B_it(args) From d4b7bd54cd12e68c603d9322b40212edf97f3cb7 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Tue, 29 Apr 2025 01:57:50 +0000 Subject: [PATCH 6/6] Qwen 7b changes --- .../configs/train_bucket_256_Castor_flux_qwen_dynamic.yaml | 2 +- apps/main/train.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_dynamic.yaml b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_dynamic.yaml index bbe7ca0..304c555 100644 --- a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_dynamic.yaml +++ b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_dynamic.yaml @@ -56,7 +56,7 @@ model: gen_seqlen: 32 condition_seqlen: 256 norm_eps: 1e-5 - condition_dim: 2048 + condition_dim: 3584 # 2048 - 3b qwen, 3584 - 7b qwen qk_norm: false liger_rms_norm: true liger_ffn: true diff --git a/apps/main/train.py b/apps/main/train.py index fb75c28..6308d58 100644 --- a/apps/main/train.py +++ b/apps/main/train.py @@ -583,4 +583,5 @@ def main(): if __name__ == "__main__": + torch.set_float32_matmul_precision('high') main()