Skip to content

feat: NPU operator adaptations for MindIE-SD#243

Open
Chitandaaaaa wants to merge 9 commits intomodelscope:v1from
Chitandaaaaa:feat/npu-adaptations
Open

feat: NPU operator adaptations for MindIE-SD#243
Chitandaaaaa wants to merge 9 commits intomodelscope:v1from
Chitandaaaaa:feat/npu-adaptations

Conversation

@Chitandaaaaa
Copy link
Copy Markdown

Summary

Add 5 NPU operator adaptations using MindIE-SD fused kernels for Huawei Ascend NPU:

Module NPU Operator Fallback
FastGELU torch_npu.npu_fast_gelu F.gelu(approximate="tanh")
RMSNorm torch_npu.npu_rms_norm DiffusersRMSNorm
Attention mindiesd.attention_forward (FIA) existing SDPA
RoPE mindiesd.rotary_position_embedding original complex/real impl
AdaLayerNorm mindiesd.layernorm_scale_shift manual norm*(1+scale)+shift

Each operator uses NPU fused path when is_npu_available() is True, falling back to the original implementation otherwise. Non-NPU environments are completely unaffected.

Changes

  • layers/norm.py: RMSNorm + AdaLayerNorm NPU wrappers
  • layers/mlp.py: FastGELUMLP replacing FeedForward (checkpoint-key compatible)
  • backends/mindie_attn.py: MINDIE attention backend via existing Backend registry
  • backends/abstract.py: Add AttentionType.MINDIE enum
  • selector.py: Auto-switch to MINDIE when NPU detected
  • transformer_qwenimage.py: RoPE NPU paths + AdaLayerNorm integration
  • import_utils.py: is_npu_available() with mindiesd + manual fallback detection
  • test_adalayernorm.py: Unit tests for 2D/3D/mixed dims, equivalence, batch sizes

Additional fixes

  • Fix cos/sin broadcast dims in RoPE use_real=True path ([None,None][None,:,None,:])
  • Cache RMSNorm fallback instance to share weight tensor with NPU path

hammer added 6 commits April 22, 2026 16:50
- Add FastGELUMLP class in layers/mlp.py with npu_fast_gelu on NPU,
  fallback to F.gelu on non-NPU devices
- Add is_npu_available() to utils/import_utils.py for NPU detection
- Replace FeedForward with FastGELUMLP in transformer_qwenimage.py
- Add RMSNorm class in layers/norm.py with npu_rms_norm on NPU,
  fallback to DiffusersRMSNorm on non-NPU devices
- Replace diffusers RMSNorm import with diffsynth_engine.layers.norm
- Add AttentionType.MINDIE to abstract.py
- Add MindieAttentionBackend and MindieAttentionImpl in mindie_attn.py
  using mindiesd.layers.flash_attn.attention_forward
- Register MINDIE backend in selector.py, auto-switch when NPU available
…4 of 4)

- Add is_npu_available import
- use_real=True: NPU path uses mindiesd rotary_position_embedding
  with rotated_mode mapping (rotated_half/rotated_interleaved)
  Also fixes cos/sin broadcast bug: [None, None] -> [None, :, None, :]
- use_real=False: NPU path uses rotated_complex mode to handle
  dimension mismatch between freqs_cis [S, D//2] and x [B, S, H, D]
- Add AdaLayerNorm class in layers/norm.py with NPU operator
  support via MindIE-SD layernorm_scale_shift
- Replace 4 nn.LayerNorm instances in QwenImageTransformerBlock
  with AdaLayerNorm
- Adjust forward method to use one-step AdaLayerNorm instead
  of two-step norm + _modulate
- Add unit tests for AdaLayerNorm
- norm.py: store DiffusersRMSNorm as self._fallback so fallback path
  reuses the same weight tensor (shared via register_parameter),
  fixing checkpoint weight loss on non-NPU devices.
- transformer_qwenimage.py: remove DEBUG_ATTN print block left from
  attention output debugging.
- transformer_qwenimage.py: add NOTE on _modulate explaining it is
  preserved for future zero_cond_t=True conditional modulation path.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces NPU-optimized components to the DiffSynth engine, including a new MINDIE attention backend and optimized implementations for MLP, RMSNorm, and AdaLayerNorm. The QwenImage transformer has been updated to leverage these optimizations along with a fused RoPE operator. However, several issues were identified: the updated forward pass in the QwenImage transformer breaks the zero_cond_t functionality by ignoring modulate_index, and a dimension mismatch occurs when applying modulation gates to 3D tensors. Additionally, the attention backend selection logic should be adjusted to respect user-specified backends even when an NPU is present.

Comment on lines +664 to +667
img_modulated = self.img_norm1(hidden_states, img_scale1, img_shift1)

# Process text stream - norm1 + modulation
txt_normed = self.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
# Process text stream - norm1 + modulation (AdaLayerNorm)
txt_modulated = self.txt_norm1(encoder_hidden_states, txt_scale1, txt_shift1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The updated forward pass ignores modulate_index, which breaks the zero_cond_t=True functionality. When zero_cond_t is enabled, temb is doubled and modulate_index is used to select the correct modulation parameters per token.

By using AdaLayerNorm directly with chunked parameters, the code will crash due to a batch size mismatch (img_mod1 has 2*B while hidden_states has B) and bypass the conditional logic. As noted in the docstring for _modulate (line 599), the code should switch back to using _modulate when modulate_index is provided.

Comment on lines 686 to 687
hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The gates (e.g., img_gate1, txt_gate1) are 2D tensors of shape [B, D]. Multiplying them directly with the 3D attention outputs [B, S, D] will raise a RuntimeError because the dimensions are not broadcastable. They must be unsqueezed to [B, 1, D]. This also applies to img_gate2 and txt_gate2 in lines 692 and 697.

Suggested change
hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
hidden_states = hidden_states + img_gate1.unsqueeze(1) * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1.unsqueeze(1) * txt_attn_output

Comment on lines +38 to +43
if attn_type is None:
attn_type = AttentionType.SDPA

# NPU auto-switch: use MINDIE when NPU is available
if is_npu_available():
attn_type = AttentionType.MINDIE
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation unconditionally overrides the attn_type to MINDIE whenever an NPU is detected, even if the user explicitly requested a different backend (e.g., SDPA). This limits flexibility for debugging or benchmarking. The auto-switch should only apply when attn_type is not specified.

Suggested change
if attn_type is None:
attn_type = AttentionType.SDPA
# NPU auto-switch: use MINDIE when NPU is available
if is_npu_available():
attn_type = AttentionType.MINDIE
if attn_type is None:
attn_type = AttentionType.MINDIE if is_npu_available() else AttentionType.SDPA

hammer added 3 commits May 6, 2026 14:52
Chunk produces gates as [B, dim] (2D), but they multiply with
[B, S, dim] attention/MLP outputs. PyTorch broadcast rules require
matching trailing dimensions — [B, dim] * [B, S, dim] fails when
B > 1 because B != S and neither is 1. Add .unsqueeze(1) at all
4 gate-multiply sites to restore the [B, 1, dim] shape that
_modulate previously guaranteed.
…crash

When zero_cond_t=True, temb has batch 2*B (cond+uncond CFG). The old code
computed img_mod_params from the pre-chunk temb, producing [2*B, 6*dim]
scale/shift/gate that crash against [B, S, dim] hidden_states in AdaLayerNorm.

Move img_mod_params after the zero_cond_t chunk so img and txt both use the
cond half (B). Per-token CFG via modulate_index is unsupported with
AdaLayerNorm; _modulate is preserved for when full support is needed.
…_type

Previously, NPU detection unconditionally overrode any attn_type to MINDIE,
even when the user explicitly chose SDPA or FA2. Now auto-detect only fires
when attn_type is None (user didn't choose). Three changes must work together:

- selector.py: auto-detect only on attn_type is None
- configs/base.py: default from SDPA to None (None = "not chosen")
- args.py: CLI default from "sdpa" to None, parse handles None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant