Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ model:
out_channels: 16
tmb_size: 256
gen_seqlen: 32
condition_seqlen: 256
condition_seqlen: 512
norm_eps: 1e-5
condition_dim: 3584
qk_norm: false
Expand Down Expand Up @@ -85,7 +85,7 @@ model:
text_encoder:
config_name: "Qwen/Qwen2.5-VL-7B-Instruct"
dtype: "bf16"
text_seqlen: 256
text_seqlen: 512
model_path: "/mnt/pollux/checkpoints/Qwen2.5-VL-7B-Instruct"
relative_depth: 0.75

Expand Down
23 changes: 2 additions & 21 deletions apps/Castor/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,27 +92,8 @@ def forward(self, context: Dict[str, Any]) -> torch.Tensor:
mu=mu,
)
latent = self.prepare_latent(context, device=cur_device)
pos_conditional_signal, pos_conditional_mask = self.model.text_encoder(context)
negative_conditional_signal = (
self.model.diffusion_transformer.negative_token.repeat(
pos_conditional_signal.size(0), pos_conditional_signal.size(1), 1
)
)
negative_conditional_mask = torch.ones_like(
pos_conditional_mask, dtype=pos_conditional_mask.dtype
)
context = torch.cat(
[
pos_conditional_signal,
negative_conditional_signal,
]
)
context_mask = torch.cat(
[
pos_conditional_mask,
negative_conditional_mask,
]
)
context['caption'] = context['caption'] + ["" for _ in context['caption']]
context, context_mask = self.model.text_encoder(context)
for i, t in enumerate(timesteps):
latent_model_input = torch.cat([latent] * 2)
timestep = t.expand(latent_model_input.shape[0])
Expand Down
13 changes: 5 additions & 8 deletions apps/Castor/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,15 @@ def forward(self, batch: dict[str:any], flops_meter= None) -> dict[str:any]:
batch["vision_encoder_target"] = self.vision_encoder.extract_image_representations(batch, flops_meter)

if "text_embedding" not in batch:
batch["caption"] = [
"" if random.random() <= self.text_cfg_ratio
else cap
for cap in batch["caption"]
]
batch["text_embedding"], batch["attention_mask"] = self.text_encoder(batch, flops_meter)

conditional_signal, conditional_mask = batch["text_embedding"], batch["attention_mask"]

if random.random() <= self.text_cfg_ratio:
conditional_signal = self.diffusion_transformer.negative_token.repeat(
conditional_signal.size(0), conditional_signal.size(1), 1
)
conditional_mask = torch.ones_like(
conditional_mask, dtype=conditional_signal.dtype
)

latent_code = batch["latent_code"]
noised_x, t, target = self.scheduler.sample_noised_input(latent_code)
output = self.diffusion_transformer(
Expand Down
95 changes: 80 additions & 15 deletions apps/Castor/modules/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,27 +132,72 @@ def create_schedulers(self, args: SchedulerArgs):
)
return scheduler

def compute_density_for_timestep_sampling(self, batch_size: int) -> torch.Tensor:
def compute_density_for_timestep_sampling(
self, batch_size: int, image_seq_len: Optional[Union[int, List[int]]] = None
) -> torch.Tensor:
"""
Compute the density for sampling the timesteps when doing SD3 training.

Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.

SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if self.weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(
mean=self.logit_mean,
std=self.logit_std,
size=(batch_size,),
device="cpu",
# Determine the logit_mean parameter for torch.normal
# It can be a scalar or a tensor of shape (batch_size,)
if image_seq_len is not None and self.scheduler.config.use_dynamic_shifting:
if isinstance(image_seq_len, int):
# Single sequence length for the whole batch
mu = calculate_shift(
image_seq_len=image_seq_len,
base_seq_len=self.scheduler.config.base_image_seq_len,
max_seq_len=self.scheduler.config.max_image_seq_len,
base_shift=self.scheduler.config.base_shift,
max_shift=self.scheduler.config.max_shift,
)
adjusted_logit_mean = mu * self.scheduler.config.shift
logit_mean_param = torch.full(
(batch_size,), adjusted_logit_mean, device="cpu"
)
elif isinstance(image_seq_len, list):
# List of sequence lengths, one for each item in the batch
if len(image_seq_len) != batch_size:
raise ValueError(
"If image_seq_len is a list, its length must match batch_size."
)

mus = [
calculate_shift(
image_seq_len=sl,
base_seq_len=self.scheduler.config.base_image_seq_len,
max_seq_len=self.scheduler.config.max_image_seq_len,
base_shift=self.scheduler.config.base_shift,
max_shift=self.scheduler.config.max_shift,
)
for sl in image_seq_len
]
# Create a tensor of means, one for each item in the batch
logit_mean_param = torch.tensor(
mus, device="cpu"
) * self.scheduler.config.shift
else:
raise TypeError(
"image_seq_len must be an int or a list of ints, but got "
f"{type(image_seq_len)}"
)
else:
# No dynamic shifting or no seq_len provided, use default logit_mean for all
logit_mean_param = torch.full(
(batch_size,), self.logit_mean, device="cpu"
)

if self.weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(mu * shift, std_dev)$).
u = torch.normal(logit_mean_param, self.logit_std)
u = torch.nn.functional.sigmoid(u)
elif self.weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - self.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
else: # 'uniform'
u = torch.rand(size=(batch_size,), device="cpu")
return u

Expand All @@ -167,40 +212,60 @@ def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
return sigma

def sample_noised_input(
self, x: torch.Tensor
) -> Tuple[torch.tensor, torch.tensor, torch.tensor]:
self, x: Union[torch.Tensor, List[torch.Tensor]]
) -> Tuple[
Union[torch.Tensor, List[torch.Tensor]],
torch.Tensor,
Union[torch.Tensor, List[torch.Tensor]],
]:
"""
Samples a noisy input given a clean latent x and returns noisy input, timesteps and target.
"""

use_dynamic_res = isinstance(x, list)
if use_dynamic_res:
bsz = len(x)
# Assuming x[i] is shaped [seq_len_i, ...], so seq_len is shape[0]
# Adjust if seq_len is at a different dimension.
image_seq_lens = [
s.shape[0] for s in x
] # Get seq_len for each tensor in the list

u = self.compute_density_for_timestep_sampling(
batch_size=bsz,
image_seq_len=image_seq_lens, # Pass list of sequence lengths
)
indices = (u * self.scheduler.config.num_train_timesteps).long()
timesteps = self.scheduler.timesteps[indices].to(device=x[0].device)
sigmas = self.get_sigmas(timesteps, n_dim=x[0].ndim + 1, dtype=x[0].dtype)
# n_dim ensures sigmas[i] can broadcast with x[i]
# If x[i] is [S_i, D], x[0].ndim = 2. sigmas[i] will be (1,) or (1,1) after unsqueeze.
sigmas = self.get_sigmas(timesteps, n_dim=x[0].ndim, dtype=x[0].dtype)
noise_model_input_list = []
target_list = []
for i in range(bsz):
_noise = torch.randn_like(x[i])
# sigmas is (bsz, 1, ...), so sigmas[i] is (1, ...)
_noisy_model_input = (1.0 - sigmas[i]) * x[i] + sigmas[i] * _noise
_target = _noise - x[i]
_target = _noise - x[i] # Velocity prediction target
noise_model_input_list.append(_noisy_model_input)
target_list.append(_target)

return noise_model_input_list, timesteps, target_list
else:
else: # x is a single torch.Tensor
bsz = x.size(0)
# Assuming x is of shape [bsz, seq_len, dim]
# The sequence length for shifting is x.shape[1]
image_seq_len_for_shift = x.shape[1]

noise = torch.randn_like(x)
u = self.compute_density_for_timestep_sampling(
batch_size=bsz,
image_seq_len=image_seq_len_for_shift, # Pass single sequence length
)
indices = (u * self.scheduler.config.num_train_timesteps).long()
timesteps = self.scheduler.timesteps[indices].to(device=x.device)
# If x is [bsz, S, D], x.ndim = 3. sigmas will be (bsz, 1, 1) after unsqueeze.
sigmas = self.get_sigmas(timesteps, n_dim=x.ndim, dtype=x.dtype)
noisy_model_input = (1.0 - sigmas) * x + sigmas * noise
target = noise - x
target = noise - x # Velocity prediction target
return noisy_model_input, timesteps, target
27 changes: 24 additions & 3 deletions apps/Castor/modules/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl

logger = logging.getLogger()
from typing import Optional


@dataclass
Expand Down Expand Up @@ -95,6 +94,17 @@ def __init__(self, args: TextEncoderArgs):
self.init_model(args, model_path)
self.init_tokenizer(model_path)
self.init_processor(model_path)

self.system_prompts = [
"""You are NucleusT2I, a text-to-image assistant, trained to generate high-quality images from user prompts.""",
"""You are NucleusT2I, a text-to-image assistant. Your task is to generate the highest-quality image possible based solely on the user's description, even if it is vague or incomplete.

- Do not ask the user for clarification or provide suggestions.
- If the user's input is ambiguous, make reasonable creative assumptions and proceed.
- Only modify or refine the image when the user provides additional instructions in follow-up messages.
- Always maintain context from previous turns to support iterative improvements, but never initiate clarification.
"""
]

def init_model(self, args: TextEncoderArgs, model_path: str):
config = AutoConfig.from_pretrained(model_path)
Expand Down Expand Up @@ -125,10 +135,12 @@ def dim(self) -> int:
return self.model.config.hidden_size

def _convert_caption_to_messages(self, caption: str) -> str:
if not caption:
return ""
messages = [
{
"role": "system",
"content": "You are an assistant designed to generate high-quality images based on user prompts.",
"content": self.system_prompts[0],
},
{
"role": "user",
Expand All @@ -152,13 +164,22 @@ def __call__(self, batch: dict[str:any], flops_meter= None) -> Tuple[torch.Tenso
self._convert_caption_to_messages(caption)
for caption in batch["caption"]
]

if all(msg == "" for msg in messages):
B = len(batch["caption"])
context = torch.zeros(
B, 0, self.dim(), dtype=self.dtype, device=self.model.device)
context_mask = torch.zeros(
B, 0, dtype=torch.bool, device=self.model.device)
return context, context_mask

inputs = self.processor(
text=messages,
padding=True,
return_tensors="pt",
max_length=self.text_seqlen,
truncation=True,
).to(device=self.model.device, dtype=self.dtype)
).to(device=self.model.device)

if flops_meter is not None:
flops_meter.log_text_encoder_flops(inputs['input_ids'].shape)
Expand Down
2 changes: 0 additions & 2 deletions apps/Castor/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,6 @@ def __init__(self, args: TransformerArgs):
self.cond_norm = RMSNorm(
args.dim, eps=args.norm_eps, liger_rms_norm=args.liger_rms_norm
)
self.negative_token = nn.Parameter(torch.zeros(1, 1, args.condition_dim))
self.cond_proj = nn.Linear(
in_features=args.condition_dim,
out_features=args.dim,
Expand Down Expand Up @@ -669,6 +668,5 @@ def reset_parameters(self, init_std=None):
a=-3 * init_std,
b=3 * init_std,
)
nn.init.normal_(self.negative_token, std=0.02)
if self.shared_adaLN:
self.adaLN_modulation.reset_parameters()
2 changes: 1 addition & 1 deletion lingua/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def load(
# If none of those are available don't do anything
if path is None:
# If no checkpoints exist do nothing
logger.info("No checkpoints found ! Init train state from sratch...")
logger.info("No checkpoints found ! Init train state from scratch...")
return

# Only load train state if it's provided, the files exist and we're not loading from init path
Expand Down