-
Notifications
You must be signed in to change notification settings - Fork 127
[feat]: support qwen-image-layered #694
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @Watebear, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces comprehensive support for a "layered" Qwen-Image model, significantly extending its image-to-image generation capabilities. The changes span across model configuration, input processing, latent space manipulation, and output handling, allowing the system to process and generate images with multiple conceptual layers. This feature is designed to enable more complex and structured image generation tasks, supported by new configuration options, specialized positional embeddings, and an updated image processing pipeline. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for qwen-image-layered, which involves changes across configurations, model runners, encoders, and schedulers. The implementation introduces new logic for handling layered images, including RGBA support, shape calculations, and multi-image output.
I've identified several critical issues that need to be addressed. These include leftover debugging code that breaks functionality, a potential AttributeError due to an incorrect conditional check, and a TypeError that will occur for non-layered outputs. Additionally, there are opportunities for refactoring to improve code quality by reducing code duplication and improving naming conventions. Please see the detailed comments for specific suggestions.
| text = [ | ||
| "A charming anime character with short, light blue hair adorned with white flowers and a purple ribbon stands gracefully. She wears a detailed maid outfit featuring a white blouse with ruffled cuffs and a black apron, accessorized with a bow at the neckline. Her hands are clasped together in front of her, and she gazes slightly downward with a gentle expression. The background is a soft, light blue gradient, giving the scene a serene and ethereal atmosphere." | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if len(images[0]) > 1: | ||
| image_prefix = f"{input_info.save_result_path}".split(".")[0] | ||
| for idx, image in enumerate(images[0]): | ||
| image.save(f"{image_prefix}_{idx}.png") | ||
| logger.info(f"Image saved: {image_prefix}_{idx}.png") | ||
| else: | ||
| image = images[0] | ||
| image.save(f"{input_info.save_result_path}") | ||
| logger.info(f"Image saved: {input_info.save_result_path}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check len(images[0]) > 1 will raise a TypeError for non-layered outputs because images[0] will be a PIL Image object, which does not have a len(). You should use self.is_layered to distinguish between layered and non-layered outputs to avoid this error.
| if len(images[0]) > 1: | |
| image_prefix = f"{input_info.save_result_path}".split(".")[0] | |
| for idx, image in enumerate(images[0]): | |
| image.save(f"{image_prefix}_{idx}.png") | |
| logger.info(f"Image saved: {image_prefix}_{idx}.png") | |
| else: | |
| image = images[0] | |
| image.save(f"{input_info.save_result_path}") | |
| logger.info(f"Image saved: {input_info.save_result_path}") | |
| if self.is_layered: | |
| image_prefix = f"{input_info.save_result_path}".split(".")[0] | |
| for idx, image in enumerate(images[0]): | |
| image.save(f"{image_prefix}_{idx}.png") | |
| logger.info(f"Image saved: {image_prefix}_{idx}.png") | |
| else: | |
| image = images[0] | |
| image.save(f"{input_info.save_result_path}") | |
| logger.info(f"Image saved: {input_info.save_result_path}") |
| self.scheduler_config.get("base_shift", 0.5), | ||
| self.scheduler_config.get("max_shift", 1.15), | ||
| ) | ||
| if self.layers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| class QwenEmbedLayer3DRope(nn.Module): | ||
| def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): | ||
| super().__init__() | ||
| self.theta = theta | ||
| self.axes_dim = axes_dim | ||
| pos_index = torch.arange(4096) | ||
| neg_index = torch.arange(4096).flip(0) * -1 - 1 | ||
| self.pos_freqs = torch.cat( | ||
| [ | ||
| self.rope_params(pos_index, self.axes_dim[0], self.theta), | ||
| self.rope_params(pos_index, self.axes_dim[1], self.theta), | ||
| self.rope_params(pos_index, self.axes_dim[2], self.theta), | ||
| ], | ||
| dim=1, | ||
| ) | ||
| self.neg_freqs = torch.cat( | ||
| [ | ||
| self.rope_params(neg_index, self.axes_dim[0], self.theta), | ||
| self.rope_params(neg_index, self.axes_dim[1], self.theta), | ||
| self.rope_params(neg_index, self.axes_dim[2], self.theta), | ||
| ], | ||
| dim=1, | ||
| ) | ||
|
|
||
| self.scale_rope = scale_rope | ||
|
|
||
| def rope_params(self, index, dim, theta=10000): | ||
| """ | ||
| Args: | ||
| index: [0, 1, 2, 3] 1D Tensor representing the position index of the token | ||
| """ | ||
| assert dim % 2 == 0 | ||
| freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) | ||
| freqs = torch.polar(torch.ones_like(freqs), freqs) | ||
| return freqs | ||
|
|
||
| def forward(self, video_fhw, txt_seq_lens, device): | ||
| """ | ||
| Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: | ||
| txt_length: [bs] a list of 1 integers representing the length of the text | ||
| """ | ||
| if self.pos_freqs.device != device: | ||
| self.pos_freqs = self.pos_freqs.to(device) | ||
| self.neg_freqs = self.neg_freqs.to(device) | ||
|
|
||
| if isinstance(video_fhw, list): | ||
| video_fhw = video_fhw[0] | ||
| if not isinstance(video_fhw, list): | ||
| video_fhw = [video_fhw] | ||
|
|
||
| vid_freqs = [] | ||
| max_vid_index = 0 | ||
| layer_num = len(video_fhw) - 1 | ||
| for idx, fhw in enumerate(video_fhw): | ||
| frame, height, width = fhw | ||
| if idx != layer_num: | ||
| video_freq = self._compute_video_freqs(frame, height, width, idx) | ||
| else: | ||
| ### For the condition image, we set the layer index to -1 | ||
| video_freq = self._compute_condition_freqs(frame, height, width) | ||
| video_freq = video_freq.to(device) | ||
| vid_freqs.append(video_freq) | ||
|
|
||
| if self.scale_rope: | ||
| max_vid_index = max(height // 2, width // 2, max_vid_index) | ||
| else: | ||
| max_vid_index = max(height, width, max_vid_index) | ||
|
|
||
| max_vid_index = max(max_vid_index, layer_num) | ||
|
|
||
| max_len = txt_seq_lens | ||
| txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] | ||
| vid_freqs = torch.cat(vid_freqs, dim=0) | ||
|
|
||
| return vid_freqs, txt_freqs | ||
|
|
||
| @functools.lru_cache(maxsize=None) | ||
| def _compute_video_freqs(self, frame, height, width, idx=0): | ||
| seq_lens = frame * height * width | ||
| freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
| freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
|
|
||
| freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) | ||
| if self.scale_rope: | ||
| freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) | ||
| freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) | ||
| freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) | ||
| freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) | ||
| else: | ||
| freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) | ||
| freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) | ||
|
|
||
| freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) | ||
| return freqs.clone().contiguous() | ||
|
|
||
| @functools.lru_cache(maxsize=None) | ||
| def _compute_condition_freqs(self, frame, height, width): | ||
| seq_lens = frame * height * width | ||
| freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
| freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) | ||
|
|
||
| freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1) | ||
| if self.scale_rope: | ||
| freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) | ||
| freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) | ||
| freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) | ||
| freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) | ||
| else: | ||
| freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) | ||
| freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) | ||
|
|
||
| freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) | ||
| return freqs.clone().contiguous() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new class QwenEmbedLayer3DRope contains a significant amount of code duplicated from the existing QwenEmbedRope class (e.g., __init__, rope_params). To improve maintainability and reduce redundancy, consider refactoring QwenEmbedLayer3DRope to inherit from QwenEmbedRope and only override the methods that have different implementations, such as forward.
| def _pack_latents(latents, batchsize, num_channels_latents, height, width, layers=None): | ||
| if not layers: | ||
| latents = latents.view(batchsize, num_channels_latents, height // 2, 2, width // 2, 2) | ||
| latents = latents.permute(0, 2, 4, 1, 3, 5) | ||
| latents = latents.reshape(batchsize, (height // 2) * (width // 2), num_channels_latents * 4) | ||
| else: | ||
| latents = latents.permute(0, 2, 1, 3, 4) | ||
| latents = latents.view(batchsize, layers, num_channels_latents, height // 2, 2, width // 2, 2) | ||
| latents = latents.permute(0, 1, 3, 5, 2, 4, 6) | ||
| latents = latents.reshape(batchsize, layers * (height // 2) * (width // 2), num_channels_latents * 4) | ||
| return latents |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _pack_latents static method is also defined in lightx2v/models/schedulers/qwen_image/scheduler.py. Although they have slightly different implementations to handle different input tensor shapes, this duplication can be confusing and hard to maintain. Consider moving them to a shared utility file and giving them more descriptive names to clarify their purpose, for example pack_latents_from_bcfhw and pack_latents_from_bfchw.
| embed0 = weights.time_text_embed_timestep_embedder_linear_2.apply(embed0) | ||
|
|
||
| if self.use_additional_t_cond: | ||
| is_rgb = torch.tensor([0] * 1).to(device=AI_DEVICE, dtype=torch.long) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The expression [0] * 1 is equivalent to [0]. Using torch.tensor([0], ...) or torch.zeros(1, ...) is more idiomatic and clearer for creating a tensor with a single zero.
| is_rgb = torch.tensor([0] * 1).to(device=AI_DEVICE, dtype=torch.long) | |
| is_rgb = torch.tensor([0], device=AI_DEVICE, dtype=torch.long) |
| if self.config.get("layered", False): | ||
| img_ori = Image.open(img_path).convert("RGBA") | ||
| else: | ||
| img_ori = Image.open(img_path).convert("RGB") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency, it's better to use the self.is_layered attribute, which is already initialized in __init__, instead of calling self.config.get("layered", False) again.
| if self.config.get("layered", False): | |
| img_ori = Image.open(img_path).convert("RGBA") | |
| else: | |
| img_ori = Image.open(img_path).convert("RGB") | |
| if self.is_layered: | |
| img_ori = Image.open(img_path).convert("RGBA") | |
| else: | |
| img_ori = Image.open(img_path).convert("RGB") |
| latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) | ||
| latents = latents.permute(0, 2, 4, 1, 3, 5) | ||
| latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) | ||
| def _pack_latents(latents, batchsize, num_channels_latents, height, width, layers=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter name batchsize does not follow the PEP 8 snake_case convention for function and variable names. Please rename it to batch_size for consistency.
| def _pack_latents(latents, batchsize, num_channels_latents, height, width, layers=None): | |
| def _pack_latents(latents, batch_size, num_channels_latents, height, width, layers=None): |
| images = self.image_processor.postprocess(images, output_type="pil") | ||
| if self.is_layered: | ||
| b, c, f, h, w = latents.shape | ||
| latents = latents[:, :, 1:] # remove the first frame as it is the orgin input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| latents = latents.view(batchsize, num_channels_latents, height // 2, 2, width // 2, 2) | ||
| latents = latents.permute(0, 2, 4, 1, 3, 5) | ||
| latents = latents.reshape(batchsize, (height // 2) * (width // 2), num_channels_latents * 4) | ||
| def _pack_latents(latents, batchsize, num_channels_latents, height, width, layers=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter name batchsize does not follow the PEP 8 snake_case convention for function and variable names. Please rename it to batch_size for consistency.
| def _pack_latents(latents, batchsize, num_channels_latents, height, width, layers=None): | |
| def _pack_latents(latents, batch_size, num_channels_latents, height, width, layers=None): |
No description provided.