Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ model:
gen_seqlen: 32
condition_seqlen: 256
norm_eps: 1e-5
condition_dim: 2048
condition_dim: 3584 # 2048 - 3b qwen, 3584 - 7b qwen
qk_norm: false
liger_rms_norm: true
liger_ffn: true
Expand All @@ -70,12 +70,12 @@ model:
enable_tiling: false
enable_slicing: false
text_encoder:
config_name: "Qwen/Qwen2.5-VL-3B-Instruct"
config_name: "Qwen/Qwen2.5-VL-7B-Instruct"
dtype: "bf16"
text_seqlen: 256
model_path: "/mnt/pollux/checkpoints/Qwen2.5-VL-3B-Instruct"
# qwen has 36 layers, we can use the first 3/4 layers
layers_to_use: 30
model_path: "/mnt/pollux/checkpoints/Qwen2.5-VL-7B-Instruct"
# qwen3.5B has 36 and 7B has 28 layers, we can use the first 3/4 layers
layers_to_use: 21
data:
- stage: stage-1
id: 1
Expand All @@ -84,7 +84,7 @@ data:
source: mongodb
image_size: 256
condition_image_size: 256
max_ratio: 2.0
max_ratio: 1.0
partition_key: 'partition_key'
retries: 3
extract_field:
Expand Down
8 changes: 7 additions & 1 deletion apps/Castor/modules/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
GemmaTokenizerFast,
UMT5EncoderModel,
)
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl

logger = logging.getLogger()
from typing import Optional
Expand Down Expand Up @@ -93,8 +94,10 @@ def __init__(self, args: TextEncoderArgs):
"Qwen/Qwen2.5-VL-3B-Instruct" if args.model_path == "" else args.model_path,
torch_dtype=self.dtype,
).cuda()
apply_liger_kernel_to_qwen2_5_vl(self.model)
if args.layers_to_use is not None:
self.model.layers = self.model.layers[: args.layers_to_use]
self.model = torch.compile(self.model)
self.model.eval()
self.model.requires_grad_(False)
self.tokenizer = AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -127,6 +130,9 @@ def _convert_caption_to_messages(self, caption: str) -> str:
)

def __call__(self, batch: dict[str:any]) -> Tuple[torch.Tensor, torch.Tensor]:
"""
returns last_hidden_state and attention_mask, right padded
"""
assert "caption" in batch
if isinstance(batch["caption"][0], tuple):
batch["caption"] = [x[0] for x in batch["caption"]]
Expand Down Expand Up @@ -235,7 +241,7 @@ def __call__(self, batch: dict[str:any]) -> Tuple[torch.Tensor, torch.Tensor]:
def create_text_encoder(args: TextEncoderArgs) -> BaseTextEncoder:
if args.config_name == "ViT-B/32":
return CLIP(args)
elif args.config_name == "Qwen/Qwen2.5-VL-3B-Instruct":
elif args.config_name == "Qwen/Qwen2.5-VL-3B-Instruct" or args.config_name == "Qwen/Qwen2.5-VL-7B-Instruct":
return Qwen2_5_VL(args)
elif args.config_name == "Gemma2_2B_it":
return Gemma2_2B_it(args)
Expand Down
119 changes: 65 additions & 54 deletions apps/Castor/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

logger = logging.getLogger()

def nearest_multiple_of_8(x):
return ((x + 7) // 8) * 8

@dataclass
class TransformerArgs(BaseTransformerArgs):
Expand Down Expand Up @@ -289,6 +291,7 @@ def patchify_and_embed_image(
max_seq_len = max(
[cond_l[i] + (H_list[i] // pH) * (W_list[i] // pW) for i in range(bsz)]
)
max_seq_len = nearest_multiple_of_8(max_seq_len)
x_new = torch.zeros(bsz, max_seq_len, self.dim, dtype=x[0].dtype).to(
x[0].device
)
Expand Down Expand Up @@ -331,47 +334,61 @@ def patchify_and_embed_image(
return x_new, x_mask, cond_l, (H_list, W_list), freqs_cis
else:
B, C, H, W = x.size()
target_dtype = x.dtype
cond_l = condition.size(1)
condition = condition.to(x.dtype)
assert H % pH == 0, f"H ({H}) should be divisible by pH ({pH})"
assert W % pW == 0, f"W ({W}) should be divisible by pW ({pW})"

# 1. Patchify and embed image tensor

# 1. Get actual condition lengths from mask
cond_l = condition_mask.sum(dim=1, dtype=torch.int32) # Shape: [B]
max_cond_l = cond_l.max(dim=0)[0].item()

# 2. Patchify and embed image tensor
x_img_patch = (
x.view(B, C, H // pH, pH, W // pW, pW)
.permute(0, 2, 4, 3, 5, 1)
.flatten(3)
)
x_img_embed = self.img_embed(x_img_patch) # Assume img_embed preserves or outputs target_dtype
x_img_embed = x_img_embed.flatten(1, 2) # Shape: [B, L_img, D]
L_img = x_img_embed.size(1)

# Ensure image embedding has the target dtype
x_img_embed = x_img_embed.to(target_dtype)

# 2. Prepare condition tensor
# Cast condition to the target dtype *before* concatenation
condition = condition.to(target_dtype) # Shape: [B, L_cond, D]
# 4. Initialize combined tensors
max_seq_len = nearest_multiple_of_8(max_cond_l + L_img)
x_combined = torch.zeros(B, max_seq_len, self.dim, dtype=x.dtype, device=x.device)
x_mask = torch.zeros(B, max_seq_len, dtype=torch.bool, device=x.device)
freqs_cis = torch.zeros(
(
B,
max_seq_len,
)
+ (self.rope_embeddings_conditions.freqs_cis.shape[-3:]),
dtype=x.dtype,
).to(x.device)

# 3. Concatenate condition and image embeddings
x_combined = torch.cat([condition, x_img_embed], dim=1) # Shape: [B, L_cond + L_img, D]
# 5. Precompute RoPE embeddings
freqs_cis_cond_all = self.rope_embeddings_conditions.freqs_cis[:max_cond_l].to(device=x.device, dtype=x.dtype)
freqs_cis_img_all = self.rope_embeddings_image.freqs_cis[: H // pH, : W // pW].flatten(0, 1).to(device=x.device, dtype=x.dtype) # Shape [L_img, ..., D/2, 2]

# 4. Create attention mask
x_mask_img = torch.ones(B, (H // pH) * (W // pW), dtype=torch.bool, device=x.device)
x_mask = torch.cat([condition_mask, x_mask_img], dim=1) # Shape: [B, L_cond + L_img]
# 6. Populate tensors respecting actual condition lengths
for i in range(B):
# Place condition and image embeddings
x_combined[i, :cond_l[i]] = condition[i, :cond_l[i]]
x_combined[i, cond_l[i] : cond_l[i] + L_img] = x_img_embed[i]

# 5. Prepare RoPE embeddings
# Ensure RoPE embeddings are on the correct device and have the target dtype
freqs_cis_cond = self.rope_embeddings_conditions.freqs_cis[:cond_l].to(device=x.device, dtype=target_dtype)
freqs_cis_img = self.rope_embeddings_image.freqs_cis[: H // pH, : W // pW].flatten(0, 1).to(device=x.device, dtype=target_dtype)
# Create mask
x_mask[i, : cond_l[i] + L_img] = True

# Concatenate RoPE embeddings
freqs_cis = torch.cat([freqs_cis_cond, freqs_cis_img], dim=0) # Shape: [L_cond + L_img, ..., D/2, 2]
# Expand for batch dimension
freqs_cis = freqs_cis.unsqueeze(0).expand(B, -1, *([-1] * (freqs_cis.dim() - 1))) # Shape: [B, L_cond + L_img, ..., D/2, 2]
# Create RoPE embeddings
freqs_cis[i, :cond_l[i]] = freqs_cis_cond_all[:cond_l[i]]
freqs_cis[i, cond_l[i] : cond_l[i] + L_img] = freqs_cis_img_all # Image RoPE is the same for all in batch here

# Return list of actual lengths for consistency
return (
x_combined,
x_mask,
cond_l,
(H, W),
cond_l.tolist(),
(H, W), # Return single H, W tuple
freqs_cis,
)

Expand Down Expand Up @@ -417,7 +434,7 @@ def forward(
return output

def unpatchify_image(
self, x: torch.Tensor, cond_l: Union[List[int], int], img_size: Tuple[int, int]
self, x: torch.Tensor, cond_l: List[int], img_size: Union[Tuple[int, int], Tuple[List[int], List[int]]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Convert patched image features back to the original latent format.
Expand All @@ -444,36 +461,30 @@ def unpatchify_image(
out_x_list.append(_x)
return out_x_list
else:
# Handle the case where cond_l is an integer, not a list
if isinstance(cond_l, list):
# If cond_l is unexpectedly a list here, take the first element or max?
# Assuming it should be consistent with the input type logic.
# If input was Tensor, cond_l should be int. If list, list.
# This path assumes input was Tensor, so cond_l should be int.
# If it's a list, maybe log a warning or error.
# For now, assume it's an int if we reach here.
if len(cond_l) > 0:
max_cond_l = cond_l[0] # Or max(cond_l) if that makes sense
# logger.warning("cond_l was a list in unpatchify_image tensor path.")
else:
max_cond_l = 0 # Handle empty list case
else:
max_cond_l = cond_l # It's already an int

# Handle non-dynamic (tensor) case
H, W = img_size
L_img = (H // pH) * (W // pW)
B = x.size(0)
L = (H // pH) * (W // pW)
# Ensure slicing indices are correct
img_features = x[:, max_cond_l : max_cond_l + L]

# Ensure the view dimensions match the extracted features
# B, L_img, D_out_patch = img_features.shape
# D_out_patch should be pH * pW * self.out_channels
# L_img should be (H // pH) * (W // pW)
img_features = img_features.view(
B, H // pH, W // pW, pH, pW, self.out_channels
)
img_features = img_features.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
return img_features

# Initialize output tensor
# Note: Need to ensure self.out_channels is correctly defined/accessed
out_features = torch.zeros(B, self.out_channels, H, W, dtype=x.dtype, device=x.device)

for i in range(B):
l_cond = cond_l[i]

# Extract the image part for this batch item
img_features_i = x[i, l_cond : l_cond + L_img] # Shape [L_img, D_out_patch]

# Reshape back to image format
# D_out_patch = pH * pW * self.out_channels must hold
img_features_i = img_features_i.view(
H // pH, W // pW, pH, pW, self.out_channels
)
img_features_i = img_features_i.permute( 5, 1, 3, 2, 4).flatten(3, 4).flatten(2, 3) # [C_out, H, W]
out_features[i] = img_features_i

return out_features

def reset_parameters(self, init_std=None):
# Either use fixed base std or sqrt model dim
Expand Down
5 changes: 4 additions & 1 deletion apps/Castor/modules/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def __init__(self, args: VideoVAEArgs):
self.scale = 0.3611
self.shift = 0.1159

@torch.compile
@torch.no_grad()
def encode(self, x: torch.Tensor) -> torch.Tensor:
x = x.to(device=self.vae.device, dtype=self.vae.dtype)
Expand All @@ -199,6 +200,7 @@ def encode(self, x: torch.Tensor) -> torch.Tensor:
) * self.scale
return x

@torch.compile
@torch.no_grad()
def decode(self, x: torch.Tensor) -> torch.Tensor:
x = x.to(device=self.vae.device, dtype=self.vae.dtype)
Expand All @@ -207,7 +209,8 @@ def decode(self, x: torch.Tensor) -> torch.Tensor:
# Use the VAE's decode method and get the sample
decoded = self.vae.decode(x).sample
return decoded


@torch.compile
@torch.no_grad()
def forward(self, x=torch.Tensor):
x = self.encode(x)
Expand Down
1 change: 1 addition & 0 deletions apps/main/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,4 +583,5 @@ def main():


if __name__ == "__main__":
torch.set_float32_matmul_precision('high')
main()
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ boto3
python-dotenv
ijson
flash-attn
transformers
transformers
liger-kernel
timm