Skip to content

Conversation

@Watebear
Copy link
Collaborator

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Watebear, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces comprehensive support for a "layered" Qwen-Image model, significantly extending its image-to-image generation capabilities. The changes span across model configuration, input processing, latent space manipulation, and output handling, allowing the system to process and generate images with multiple conceptual layers. This feature is designed to enable more complex and structured image generation tasks, supported by new configuration options, specialized positional embeddings, and an updated image processing pipeline.

Highlights

  • New Layered Qwen-Image Configuration: A new configuration file qwen_image_i2i_layered.json has been added, enabling specific parameters for layered image processing, including layered: true, layers: 4, and resolution: 640.
  • Dynamic VAE Image Size and Captioning: The qwen25_vlforconditionalgeneration.py now dynamically adjusts the VAE image size based on the layered configuration and introduces image captioning capabilities (English/Chinese) to generate text prompts from input images for layered models.
  • Enhanced Positional Embeddings: A new QwenEmbedLayer3DRope class has been introduced in scheduler.py to handle 3D Rotary Positional Embeddings, specifically designed for layered image data, and the scheduler now conditionally uses this or the standard QwenEmbedRope.
  • Layer-Aware Latent and Image Processing: Modifications across scheduler.py, vae.py, and qwen_image_runner.py ensure that latent representations and image processing steps (packing, unpacking, decoding, saving) correctly account for the new layers dimension, allowing for multi-layered image generation and output.
  • Additional Time Conditioning: The pre_infer.py and pre_weights.py files have been updated to support an additional_t_cond for time embeddings, which is conditionally applied based on the configuration, potentially enhancing temporal coherence or control in layered generation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for qwen-image-layered, which involves changes across configurations, model runners, encoders, and schedulers. The implementation introduces new logic for handling layered images, including RGBA support, shape calculations, and multi-image output.

I've identified several critical issues that need to be addressed. These include leftover debugging code that breaks functionality, a potential AttributeError due to an incorrect conditional check, and a TypeError that will occur for non-layered outputs. Additionally, there are opportunities for refactoring to improve code quality by reducing code duplication and improving naming conventions. Please see the detailed comments for specific suggestions.

Comment on lines +161 to +163
text = [
"A charming anime character with short, light blue hair adorned with white flowers and a purple ribbon stands gracefully. She wears a detailed maid outfit featuring a white blouse with ruffled cuffs and a black apron, accessorized with a bow at the neckline. Her hands are clasped together in front of her, and she gazes slightly downward with a gentle expression. The background is a soft, light blue gradient, giving the scene a serene and ethereal atmosphere."
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This hardcoded text overwrites the generated image caption from get_image_caption on the preceding line. This appears to be leftover debugging code and will prevent the model from using the correct, dynamically generated prompt. Please remove these lines.

Comment on lines 334 to 342
if len(images[0]) > 1:
image_prefix = f"{input_info.save_result_path}".split(".")[0]
for idx, image in enumerate(images[0]):
image.save(f"{image_prefix}_{idx}.png")
logger.info(f"Image saved: {image_prefix}_{idx}.png")
else:
image = images[0]
image.save(f"{input_info.save_result_path}")
logger.info(f"Image saved: {input_info.save_result_path}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The check len(images[0]) > 1 will raise a TypeError for non-layered outputs because images[0] will be a PIL Image object, which does not have a len(). You should use self.is_layered to distinguish between layered and non-layered outputs to avoid this error.

Suggested change
if len(images[0]) > 1:
image_prefix = f"{input_info.save_result_path}".split(".")[0]
for idx, image in enumerate(images[0]):
image.save(f"{image_prefix}_{idx}.png")
logger.info(f"Image saved: {image_prefix}_{idx}.png")
else:
image = images[0]
image.save(f"{input_info.save_result_path}")
logger.info(f"Image saved: {input_info.save_result_path}")
if self.is_layered:
image_prefix = f"{input_info.save_result_path}".split(".")[0]
for idx, image in enumerate(images[0]):
image.save(f"{image_prefix}_{idx}.png")
logger.info(f"Image saved: {image_prefix}_{idx}.png")
else:
image = images[0]
image.save(f"{input_info.save_result_path}")
logger.info(f"Image saved: {input_info.save_result_path}")

self.scheduler_config.get("base_shift", 0.5),
self.scheduler_config.get("max_shift", 1.15),
)
if self.layers:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This will raise an AttributeError if self.is_layered is False, because self.layers is only defined within the if self.is_layered: block in the __init__ method. You should check self.is_layered instead of self.layers to avoid this bug.

Suggested change
if self.layers:
if self.is_layered:

Comment on lines +286 to +399
class QwenEmbedLayer3DRope(nn.Module):
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
)

self.scale_rope = scale_rope

def rope_params(self, index, dim, theta=10000):
"""
Args:
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
"""
assert dim % 2 == 0
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs

def forward(self, video_fhw, txt_seq_lens, device):
"""
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
txt_length: [bs] a list of 1 integers representing the length of the text
"""
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)

if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
if not isinstance(video_fhw, list):
video_fhw = [video_fhw]

vid_freqs = []
max_vid_index = 0
layer_num = len(video_fhw) - 1
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
if idx != layer_num:
video_freq = self._compute_video_freqs(frame, height, width, idx)
else:
### For the condition image, we set the layer index to -1
video_freq = self._compute_condition_freqs(frame, height, width)
video_freq = video_freq.to(device)
vid_freqs.append(video_freq)

if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)

max_vid_index = max(max_vid_index, layer_num)

max_len = txt_seq_lens
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)

return vid_freqs, txt_freqs

@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)

freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)

freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()

@functools.lru_cache(maxsize=None)
def _compute_condition_freqs(self, frame, height, width):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)

freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)

freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new class QwenEmbedLayer3DRope contains a significant amount of code duplicated from the existing QwenEmbedRope class (e.g., __init__, rope_params). To improve maintainability and reduce redundancy, consider refactoring QwenEmbedLayer3DRope to inherit from QwenEmbedRope and only override the methods that have different implementations, such as forward.

Comment on lines +101 to 111
def _pack_latents(latents, batchsize, num_channels_latents, height, width, layers=None):
if not layers:
latents = latents.view(batchsize, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batchsize, (height // 2) * (width // 2), num_channels_latents * 4)
else:
latents = latents.permute(0, 2, 1, 3, 4)
latents = latents.view(batchsize, layers, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 1, 3, 5, 2, 4, 6)
latents = latents.reshape(batchsize, layers * (height // 2) * (width // 2), num_channels_latents * 4)
return latents
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _pack_latents static method is also defined in lightx2v/models/schedulers/qwen_image/scheduler.py. Although they have slightly different implementations to handle different input tensor shapes, this duplication can be confusing and hard to maintain. Consider moving them to a shared utility file and giving them more descriptive names to clarify their purpose, for example pack_latents_from_bcfhw and pack_latents_from_bfchw.

embed0 = weights.time_text_embed_timestep_embedder_linear_2.apply(embed0)

if self.use_additional_t_cond:
is_rgb = torch.tensor([0] * 1).to(device=AI_DEVICE, dtype=torch.long)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The expression [0] * 1 is equivalent to [0]. Using torch.tensor([0], ...) or torch.zeros(1, ...) is more idiomatic and clearer for creating a tensor with a single zero.

Suggested change
is_rgb = torch.tensor([0] * 1).to(device=AI_DEVICE, dtype=torch.long)
is_rgb = torch.tensor([0], device=AI_DEVICE, dtype=torch.long)

Comment on lines +120 to +123
if self.config.get("layered", False):
img_ori = Image.open(img_path).convert("RGBA")
else:
img_ori = Image.open(img_path).convert("RGB")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency, it's better to use the self.is_layered attribute, which is already initialized in __init__, instead of calling self.config.get("layered", False) again.

Suggested change
if self.config.get("layered", False):
img_ori = Image.open(img_path).convert("RGBA")
else:
img_ori = Image.open(img_path).convert("RGB")
if self.is_layered:
img_ori = Image.open(img_path).convert("RGBA")
else:
img_ori = Image.open(img_path).convert("RGB")

latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
def _pack_latents(latents, batchsize, num_channels_latents, height, width, layers=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter name batchsize does not follow the PEP 8 snake_case convention for function and variable names. Please rename it to batch_size for consistency.

Suggested change
def _pack_latents(latents, batchsize, num_channels_latents, height, width, layers=None):
def _pack_latents(latents, batch_size, num_channels_latents, height, width, layers=None):

images = self.image_processor.postprocess(images, output_type="pil")
if self.is_layered:
b, c, f, h, w = latents.shape
latents = latents[:, :, 1:] # remove the first frame as it is the orgin input
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the comment. 'orgin' should be 'origin'.

Suggested change
latents = latents[:, :, 1:] # remove the first frame as it is the orgin input
latents = latents[:, :, 1:] # remove the first frame as it is the origin input

latents = latents.view(batchsize, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batchsize, (height // 2) * (width // 2), num_channels_latents * 4)
def _pack_latents(latents, batchsize, num_channels_latents, height, width, layers=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter name batchsize does not follow the PEP 8 snake_case convention for function and variable names. Please rename it to batch_size for consistency.

Suggested change
def _pack_latents(latents, batchsize, num_channels_latents, height, width, layers=None):
def _pack_latents(latents, batch_size, num_channels_latents, height, width, layers=None):

@helloyongyang helloyongyang merged commit b4c7f5f into main Jan 5, 2026
2 checks passed
@helloyongyang helloyongyang deleted the layered branch January 5, 2026 08:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants