From cf1d30b2465bbe40ac77b45b59632ed58cd82b0f Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Mon, 26 May 2025 08:51:05 +0000 Subject: [PATCH 1/7] removes nn.Paramter dependency --- ...ain_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml | 4 ++-- apps/Castor/generate.py | 10 +++------- apps/Castor/model.py | 10 ++-------- apps/Castor/modules/text_encoder.py | 2 +- apps/Castor/modules/transformer.py | 11 ++--------- lingua/checkpoint.py | 2 +- 6 files changed, 11 insertions(+), 28 deletions(-) diff --git a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml index 62d312e..28837de 100644 --- a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml +++ b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml @@ -4,8 +4,8 @@ version: v1.0 # From now on, start to align train, data, and model setting with train stage (just finish refactor for dara) train_stage: stage-1 # options: preliminary, pretraining, posttraining; aligned with data setting -name: unpadded3 #used for local dump and wandb log -output_dir: /mnt/pollux/checkpoints/aj +name: simplify #used for local dump and wandb log +output_dir: /mnt/pollux/checkpoints/ablations dump_dir: '' # No need now steps: 500000 seed: 777 diff --git a/apps/Castor/generate.py b/apps/Castor/generate.py index 0ed6dd7..f6db513 100644 --- a/apps/Castor/generate.py +++ b/apps/Castor/generate.py @@ -93,13 +93,9 @@ def forward(self, context: Dict[str, Any]) -> torch.Tensor: ) latent = self.prepare_latent(context, device=cur_device) pos_conditional_signal, pos_conditional_mask = self.model.text_encoder(context) - negative_conditional_signal = ( - self.model.diffusion_transformer.negative_token.repeat( - pos_conditional_signal.size(0), pos_conditional_signal.size(1), 1 - ) - ) - negative_conditional_mask = torch.ones_like( - pos_conditional_mask, dtype=pos_conditional_mask.dtype + negative_conditional_signal, negative_conditional_mask = self.model.text_encoder( + # empty context + {"caption": ["" for _ in context["caption"]]} ) context = torch.cat( [ diff --git a/apps/Castor/model.py b/apps/Castor/model.py index 2db2462..e15f54e 100644 --- a/apps/Castor/model.py +++ b/apps/Castor/model.py @@ -100,18 +100,12 @@ def forward(self, batch: dict[str:any], flops_meter= None) -> dict[str:any]: batch["vision_encoder_target"] = self.vision_encoder.extract_image_representations(batch, flops_meter) if "text_embedding" not in batch: + if random.random() <= self.text_cfg_ratio: + batch["caption"] = ["" for _ in batch["caption"]] batch["text_embedding"], batch["attention_mask"] = self.text_encoder(batch, flops_meter) conditional_signal, conditional_mask = batch["text_embedding"], batch["attention_mask"] - if random.random() <= self.text_cfg_ratio: - conditional_signal = self.diffusion_transformer.negative_token.repeat( - conditional_signal.size(0), conditional_signal.size(1), 1 - ) - conditional_mask = torch.ones_like( - conditional_mask, dtype=conditional_signal.dtype - ) - latent_code = batch["latent_code"] noised_x, t, target = self.scheduler.sample_noised_input(latent_code) output = self.diffusion_transformer( diff --git a/apps/Castor/modules/text_encoder.py b/apps/Castor/modules/text_encoder.py index e3c408c..c3058c1 100644 --- a/apps/Castor/modules/text_encoder.py +++ b/apps/Castor/modules/text_encoder.py @@ -158,7 +158,7 @@ def __call__(self, batch: dict[str:any], flops_meter= None) -> Tuple[torch.Tenso return_tensors="pt", max_length=self.text_seqlen, truncation=True, - ).to(device=self.model.device, dtype=self.dtype) + ).to(device=self.model.device) if flops_meter is not None: flops_meter.log_text_encoder_flops(inputs['input_ids'].shape) diff --git a/apps/Castor/modules/transformer.py b/apps/Castor/modules/transformer.py index c3b44df..f620ad4 100644 --- a/apps/Castor/modules/transformer.py +++ b/apps/Castor/modules/transformer.py @@ -94,11 +94,6 @@ def __init__(self, args: TransformerArgs): in_dim=args.time_step_dim, out_dim=4 * args.dim, ) - else: - self.register_parameter( - 'modulation', - nn.Parameter(torch.randn(1, 4, args.dim) / args.dim**0.5) - ) def forward( @@ -119,9 +114,9 @@ def forward( modulation_signal ).chunk(4, dim=1) elif self.unpadded: - scale_msa, gate_msa, scale_mlp, gate_mlp = (self.modulation + modulation_values)[batch_indices].unbind(dim=1) + scale_msa, gate_msa, scale_mlp, gate_mlp = modulation_values[batch_indices].unbind(dim=1) else: - scale_msa, gate_msa, scale_mlp, gate_mlp = (self.modulation + modulation_values).unbind(dim=1) + scale_msa, gate_msa, scale_mlp, gate_mlp = modulation_values.unbind(dim=1) h = x + self.modulate_and_gate( @@ -293,7 +288,6 @@ def __init__(self, args: TransformerArgs): self.dim = args.dim self.norm = RMSNorm(args.dim, eps=args.norm_eps, liger_rms_norm=args.liger_rms_norm) self.cond_norm = RMSNorm(args.dim, eps=args.norm_eps, liger_rms_norm=args.liger_rms_norm) - self.negative_token = nn.Parameter(torch.zeros(1, 1, args.condition_dim)) self.cond_proj = nn.Linear( in_features=args.condition_dim, out_features=args.dim, @@ -583,6 +577,5 @@ def reset_parameters(self, init_std=None): a=-3 * init_std, b=3 * init_std, ) - nn.init.normal_(self.negative_token, std=0.02) if self.shared_adaLN: self.adaLN_modulation.reset_parameters() diff --git a/lingua/checkpoint.py b/lingua/checkpoint.py index c3b780d..1494d5e 100644 --- a/lingua/checkpoint.py +++ b/lingua/checkpoint.py @@ -305,7 +305,7 @@ def load( # If none of those are available don't do anything if path is None: # If no checkpoints exist do nothing - logger.info("No checkpoints found ! Init train state from sratch...") + logger.info("No checkpoints found ! Init train state from scratch...") return # Only load train state if it's provided, the files exist and we're not loading from init path From 73dab2d67e897057238ad80101ea51267e1ba7fc Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Mon, 26 May 2025 09:42:30 +0000 Subject: [PATCH 2/7] simple code update --- apps/Castor/generate.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/apps/Castor/generate.py b/apps/Castor/generate.py index f6db513..f51b852 100644 --- a/apps/Castor/generate.py +++ b/apps/Castor/generate.py @@ -92,23 +92,8 @@ def forward(self, context: Dict[str, Any]) -> torch.Tensor: mu=mu, ) latent = self.prepare_latent(context, device=cur_device) - pos_conditional_signal, pos_conditional_mask = self.model.text_encoder(context) - negative_conditional_signal, negative_conditional_mask = self.model.text_encoder( - # empty context - {"caption": ["" for _ in context["caption"]]} - ) - context = torch.cat( - [ - pos_conditional_signal, - negative_conditional_signal, - ] - ) - context_mask = torch.cat( - [ - pos_conditional_mask, - negative_conditional_mask, - ] - ) + context['caption'] = context['caption'] + ["" for _ in context['caption']] + context, context_mask = self.model.text_encoder(context) for i, t in enumerate(timesteps): latent_model_input = torch.cat([latent] * 2) timestep = t.expand(latent_model_input.shape[0]) From fcfe38437d0bf9b4033854d461daa251d691408a Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Mon, 26 May 2025 18:06:24 +0000 Subject: [PATCH 3/7] support empty condition for CFG --- ...et_256_Castor_flux_qwen_fixed_siglip2.yaml | 6 ++--- apps/Castor/modules/text_encoder.py | 25 +++++++++++++++++-- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml index 28837de..099047f 100644 --- a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml +++ b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml @@ -4,7 +4,7 @@ version: v1.0 # From now on, start to align train, data, and model setting with train stage (just finish refactor for dara) train_stage: stage-1 # options: preliminary, pretraining, posttraining; aligned with data setting -name: simplify #used for local dump and wandb log +name: simplify2 #used for local dump and wandb log output_dir: /mnt/pollux/checkpoints/ablations dump_dir: '' # No need now steps: 500000 @@ -57,7 +57,7 @@ model: out_channels: 16 tmb_size: 256 gen_seqlen: 32 - condition_seqlen: 256 + condition_seqlen: 512 norm_eps: 1e-5 condition_dim: 3584 qk_norm: false @@ -83,7 +83,7 @@ model: text_encoder: config_name: "Qwen/Qwen2.5-VL-7B-Instruct" dtype: "bf16" - text_seqlen: 256 + text_seqlen: 512 model_path: "/mnt/pollux/checkpoints/Qwen2.5-VL-7B-Instruct" relative_depth: 0.75 diff --git a/apps/Castor/modules/text_encoder.py b/apps/Castor/modules/text_encoder.py index c3058c1..e0739b6 100644 --- a/apps/Castor/modules/text_encoder.py +++ b/apps/Castor/modules/text_encoder.py @@ -20,7 +20,6 @@ from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl logger = logging.getLogger() -from typing import Optional @dataclass @@ -95,6 +94,17 @@ def __init__(self, args: TextEncoderArgs): self.init_model(args, model_path) self.init_tokenizer(model_path) self.init_processor(model_path) + + self.system_prompts = [ + """You are NucleusT2I, a text-to-image assistant, trained to generate high-quality images from user prompts.""", + """You are NucleusT2I, a text-to-image assistant. Your task is to generate the highest-quality image possible based solely on the user's description, even if it is vague or incomplete. + +- Do not ask the user for clarification or provide suggestions. +- If the user's input is ambiguous, make reasonable creative assumptions and proceed. +- Only modify or refine the image when the user provides additional instructions in follow-up messages. +- Always maintain context from previous turns to support iterative improvements, but never initiate clarification. +""" + ] def init_model(self, args: TextEncoderArgs, model_path: str): config = AutoConfig.from_pretrained(model_path) @@ -125,10 +135,12 @@ def dim(self) -> int: return self.model.config.hidden_size def _convert_caption_to_messages(self, caption: str) -> str: + if not caption: + return "" messages = [ { "role": "system", - "content": "You are an assistant designed to generate high-quality images based on user prompts.", + "content": self.system_prompts[0], }, { "role": "user", @@ -152,6 +164,15 @@ def __call__(self, batch: dict[str:any], flops_meter= None) -> Tuple[torch.Tenso self._convert_caption_to_messages(caption) for caption in batch["caption"] ] + + if all(msg == "" for msg in messages): + B = len(batch["caption"]) + context = torch.zeros( + B, 0, self.dim(), dtype=self.dtype, device=self.model.device) + context_mask = torch.zeros( + B, 0, dtype=torch.bool, device=self.model.device) + return context, context_mask + inputs = self.processor( text=messages, padding=True, From 13507135a36f74c8202fda95824b0da6f445e7f0 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Mon, 26 May 2025 18:29:24 +0000 Subject: [PATCH 4/7] reintroduce modulation parameter --- apps/Castor/modules/transformer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/apps/Castor/modules/transformer.py b/apps/Castor/modules/transformer.py index f620ad4..46c6177 100644 --- a/apps/Castor/modules/transformer.py +++ b/apps/Castor/modules/transformer.py @@ -94,6 +94,11 @@ def __init__(self, args: TransformerArgs): in_dim=args.time_step_dim, out_dim=4 * args.dim, ) + else: + self.register_parameter( + 'modulation', + nn.Parameter(torch.randn(1, 4, args.dim) / args.dim**0.5) + ) def forward( @@ -114,9 +119,9 @@ def forward( modulation_signal ).chunk(4, dim=1) elif self.unpadded: - scale_msa, gate_msa, scale_mlp, gate_mlp = modulation_values[batch_indices].unbind(dim=1) + scale_msa, gate_msa, scale_mlp, gate_mlp = (self.modulation + modulation_values)[batch_indices].unbind(dim=1) else: - scale_msa, gate_msa, scale_mlp, gate_mlp = modulation_values.unbind(dim=1) + scale_msa, gate_msa, scale_mlp, gate_mlp = (self.modulation + modulation_values).unbind(dim=1) h = x + self.modulate_and_gate( From ebeb8b169cd131676e251fbe802fbfceaeb28f47 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Mon, 26 May 2025 18:31:43 +0000 Subject: [PATCH 5/7] training run config --- .../train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml index 099047f..a5ed3db 100644 --- a/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml +++ b/apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml @@ -4,7 +4,7 @@ version: v1.0 # From now on, start to align train, data, and model setting with train stage (just finish refactor for dara) train_stage: stage-1 # options: preliminary, pretraining, posttraining; aligned with data setting -name: simplify2 #used for local dump and wandb log +name: simplify3 #used for local dump and wandb log output_dir: /mnt/pollux/checkpoints/ablations dump_dir: '' # No need now steps: 500000 From b94cca445b6b4ee88a3b9999b4e2365e6b3bcc26 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Tue, 27 May 2025 02:26:11 +0000 Subject: [PATCH 6/7] in-batch dropping for captions --- apps/Castor/model.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/apps/Castor/model.py b/apps/Castor/model.py index e15f54e..16c7df6 100644 --- a/apps/Castor/model.py +++ b/apps/Castor/model.py @@ -100,8 +100,11 @@ def forward(self, batch: dict[str:any], flops_meter= None) -> dict[str:any]: batch["vision_encoder_target"] = self.vision_encoder.extract_image_representations(batch, flops_meter) if "text_embedding" not in batch: - if random.random() <= self.text_cfg_ratio: - batch["caption"] = ["" for _ in batch["caption"]] + batch["caption"] = [ + "" if random.random() <= self.text_cfg_ratio + else cap + for cap in batch["caption"] + ] batch["text_embedding"], batch["attention_mask"] = self.text_encoder(batch, flops_meter) conditional_signal, conditional_mask = batch["text_embedding"], batch["attention_mask"] From 5fe1bbad0bfc00a1d2a5a2ea63ab1aba1446a455 Mon Sep 17 00:00:00 2001 From: Murali Nandan Date: Tue, 27 May 2025 02:26:44 +0000 Subject: [PATCH 7/7] resolution-aware timestep sampling --- apps/Castor/modules/schedulers.py | 95 ++++++++++++++++++++++++++----- 1 file changed, 80 insertions(+), 15 deletions(-) diff --git a/apps/Castor/modules/schedulers.py b/apps/Castor/modules/schedulers.py index 74fb56d..503d42e 100644 --- a/apps/Castor/modules/schedulers.py +++ b/apps/Castor/modules/schedulers.py @@ -132,7 +132,9 @@ def create_schedulers(self, args: SchedulerArgs): ) return scheduler - def compute_density_for_timestep_sampling(self, batch_size: int) -> torch.Tensor: + def compute_density_for_timestep_sampling( + self, batch_size: int, image_seq_len: Optional[Union[int, List[int]]] = None + ) -> torch.Tensor: """ Compute the density for sampling the timesteps when doing SD3 training. @@ -140,19 +142,62 @@ def compute_density_for_timestep_sampling(self, batch_size: int) -> torch.Tensor SD3 paper reference: https://arxiv.org/abs/2403.03206v1. """ - if self.weighting_scheme == "logit_normal": - # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal( - mean=self.logit_mean, - std=self.logit_std, - size=(batch_size,), - device="cpu", + # Determine the logit_mean parameter for torch.normal + # It can be a scalar or a tensor of shape (batch_size,) + if image_seq_len is not None and self.scheduler.config.use_dynamic_shifting: + if isinstance(image_seq_len, int): + # Single sequence length for the whole batch + mu = calculate_shift( + image_seq_len=image_seq_len, + base_seq_len=self.scheduler.config.base_image_seq_len, + max_seq_len=self.scheduler.config.max_image_seq_len, + base_shift=self.scheduler.config.base_shift, + max_shift=self.scheduler.config.max_shift, + ) + adjusted_logit_mean = mu * self.scheduler.config.shift + logit_mean_param = torch.full( + (batch_size,), adjusted_logit_mean, device="cpu" + ) + elif isinstance(image_seq_len, list): + # List of sequence lengths, one for each item in the batch + if len(image_seq_len) != batch_size: + raise ValueError( + "If image_seq_len is a list, its length must match batch_size." + ) + + mus = [ + calculate_shift( + image_seq_len=sl, + base_seq_len=self.scheduler.config.base_image_seq_len, + max_seq_len=self.scheduler.config.max_image_seq_len, + base_shift=self.scheduler.config.base_shift, + max_shift=self.scheduler.config.max_shift, + ) + for sl in image_seq_len + ] + # Create a tensor of means, one for each item in the batch + logit_mean_param = torch.tensor( + mus, device="cpu" + ) * self.scheduler.config.shift + else: + raise TypeError( + "image_seq_len must be an int or a list of ints, but got " + f"{type(image_seq_len)}" + ) + else: + # No dynamic shifting or no seq_len provided, use default logit_mean for all + logit_mean_param = torch.full( + (batch_size,), self.logit_mean, device="cpu" ) + + if self.weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(mu * shift, std_dev)$). + u = torch.normal(logit_mean_param, self.logit_std) u = torch.nn.functional.sigmoid(u) elif self.weighting_scheme == "mode": u = torch.rand(size=(batch_size,), device="cpu") u = 1 - u - self.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) - else: + else: # 'uniform' u = torch.rand(size=(batch_size,), device="cpu") return u @@ -167,8 +212,12 @@ def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32): return sigma def sample_noised_input( - self, x: torch.Tensor - ) -> Tuple[torch.tensor, torch.tensor, torch.tensor]: + self, x: Union[torch.Tensor, List[torch.Tensor]] + ) -> Tuple[ + Union[torch.Tensor, List[torch.Tensor]], + torch.Tensor, + Union[torch.Tensor, List[torch.Tensor]], + ]: """ Samples a noisy input given a clean latent x and returns noisy input, timesteps and target. """ @@ -176,31 +225,47 @@ def sample_noised_input( use_dynamic_res = isinstance(x, list) if use_dynamic_res: bsz = len(x) + # Assuming x[i] is shaped [seq_len_i, ...], so seq_len is shape[0] + # Adjust if seq_len is at a different dimension. + image_seq_lens = [ + s.shape[0] for s in x + ] # Get seq_len for each tensor in the list + u = self.compute_density_for_timestep_sampling( batch_size=bsz, + image_seq_len=image_seq_lens, # Pass list of sequence lengths ) indices = (u * self.scheduler.config.num_train_timesteps).long() timesteps = self.scheduler.timesteps[indices].to(device=x[0].device) - sigmas = self.get_sigmas(timesteps, n_dim=x[0].ndim + 1, dtype=x[0].dtype) + # n_dim ensures sigmas[i] can broadcast with x[i] + # If x[i] is [S_i, D], x[0].ndim = 2. sigmas[i] will be (1,) or (1,1) after unsqueeze. + sigmas = self.get_sigmas(timesteps, n_dim=x[0].ndim, dtype=x[0].dtype) noise_model_input_list = [] target_list = [] for i in range(bsz): _noise = torch.randn_like(x[i]) + # sigmas is (bsz, 1, ...), so sigmas[i] is (1, ...) _noisy_model_input = (1.0 - sigmas[i]) * x[i] + sigmas[i] * _noise - _target = _noise - x[i] + _target = _noise - x[i] # Velocity prediction target noise_model_input_list.append(_noisy_model_input) target_list.append(_target) return noise_model_input_list, timesteps, target_list - else: + else: # x is a single torch.Tensor bsz = x.size(0) + # Assuming x is of shape [bsz, seq_len, dim] + # The sequence length for shifting is x.shape[1] + image_seq_len_for_shift = x.shape[1] + noise = torch.randn_like(x) u = self.compute_density_for_timestep_sampling( batch_size=bsz, + image_seq_len=image_seq_len_for_shift, # Pass single sequence length ) indices = (u * self.scheduler.config.num_train_timesteps).long() timesteps = self.scheduler.timesteps[indices].to(device=x.device) + # If x is [bsz, S, D], x.ndim = 3. sigmas will be (bsz, 1, 1) after unsqueeze. sigmas = self.get_sigmas(timesteps, n_dim=x.ndim, dtype=x.dtype) noisy_model_input = (1.0 - sigmas) * x + sigmas * noise - target = noise - x + target = noise - x # Velocity prediction target return noisy_model_input, timesteps, target