Skip to content
8 changes: 5 additions & 3 deletions diffsynth_engine/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ def _parse_tuple(value: str) -> Tuple[int, int] | int:
raise ValueError(f"Cannot parse tuple: {value}, format should be '256,256' or '256'")


def _parse_attention_type(attn_type_str: str) -> AttentionType:
def _parse_attention_type(attn_type_str: str | None) -> AttentionType | None:
"""Convert string to AttentionType enum"""
if attn_type_str is None:
return None
return AttentionType[attn_type_str.upper()]


Expand Down Expand Up @@ -106,9 +108,9 @@ def parse_cli_args() -> Dict[str, Any]:
attn_group.add_argument(
"--attn-type",
type=str,
default="sdpa",
default=None,
choices=attn_type_choices,
help="Attention type (default: sdpa)",
help="Attention type (default: auto, SDPA on GPU, MINDIE on NPU)",
)
attn_group.add_argument(
"--sparge-topk",
Expand Down
2 changes: 1 addition & 1 deletion diffsynth_engine/configs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class PipelineConfig:
vae_tile_stride: int | Tuple[int, int] = (192, 192)

# attention
attn_type: AttentionType = AttentionType.SDPA
attn_type: AttentionType | None = None # None = auto-detect
attn_params: Optional[AttentionParams] = None

# parallelism
Expand Down
1 change: 1 addition & 0 deletions diffsynth_engine/layers/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class AttentionType(enum.Enum):
SAGE2 = enum.auto()
SAGE3 = enum.auto()
SPARGE = enum.auto()
MINDIE = enum.auto()

def __str__(self) -> str:
return self.name.lower()
Expand Down
79 changes: 79 additions & 0 deletions diffsynth_engine/layers/attention/backends/mindie_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
from diffsynth_engine.layers.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionType,
)
from diffsynth_engine.utils.import_utils import is_npu_available


class MindieAttentionBackend(AttentionBackend):
@staticmethod
def check_availability() -> None:
if not is_npu_available():
raise RuntimeError("NPU is not available, cannot use MINDIE attention backend")

@staticmethod
def get_type() -> AttentionType:
return AttentionType.MINDIE

@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
return MindieAttentionImpl

@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return AttentionMetadata

@staticmethod
def get_builder_cls() -> type:
return None

@staticmethod
def get_supported_head_sizes() -> list[int]:
return []


class MindieAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
softmax_scale: float | None = None,
causal: bool = False,
num_kv_heads: int | None = None,
**extra_impl_args,
) -> None:
if num_kv_heads is None:
num_kv_heads = num_heads
self.num_kv_groups = num_heads // num_kv_heads
self.causal = causal
self.softmax_scale = softmax_scale
self.num_heads = num_heads
self.head_size = head_size

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
attn_metadata=None,
) -> torch.Tensor:
from mindiesd.layers.flash_attn.attention_forward import attention_forward

scale = self.softmax_scale
if scale is None:
scale = self.head_size ** -0.5

out = attention_forward(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
scale=scale,
fused=True,
head_first=False,
)
return out
9 changes: 6 additions & 3 deletions diffsynth_engine/layers/attention/selector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import cache

from diffsynth_engine.layers.attention.backends.abstract import AttentionBackend, AttentionType
from diffsynth_engine.utils.import_utils import LazyImport
from diffsynth_engine.utils.import_utils import LazyImport, is_npu_available

AiterBackend = LazyImport("diffsynth_engine.layers.attention.backends.aiter", "AiterBackend")
AiterFP8Backend = LazyImport("diffsynth_engine.layers.attention.backends.aiter", "AiterFP8Backend")
Expand All @@ -15,6 +15,7 @@
SageAttention3Backend = LazyImport("diffsynth_engine.layers.attention.backends.sage_attn_3", "SageAttention3Backend")
SDPABackend = LazyImport("diffsynth_engine.layers.attention.backends.sdpa", "SDPABackend")
SpargeAttentionBackend = LazyImport("diffsynth_engine.layers.attention.backends.sparge_attn", "SpargeAttentionBackend")
MindieAttentionBackend = LazyImport("diffsynth_engine.layers.attention.backends.mindie_attn", "MindieAttentionBackend")

_attention_backends = {
AttentionType.AITER: AiterBackend,
Expand All @@ -27,14 +28,16 @@
AttentionType.SAGE3: SageAttention3Backend,
AttentionType.SDPA: SDPABackend,
AttentionType.SPARGE: SpargeAttentionBackend,
AttentionType.MINDIE: MindieAttentionBackend,
}


@cache
def get_attn_backend(head_size: int, attn_type: AttentionType | None = None) -> type["AttentionBackend"]:
# use SDPA as default
if attn_type is None:
attn_type = AttentionType.SDPA
# Auto-detect: NPU → MINDIE, otherwise → SDPA
attn_type = AttentionType.MINDIE if is_npu_available() else AttentionType.SDPA

selected_backend = _attention_backends[attn_type]
selected_backend.check_availability()
if not selected_backend.supports_head_size(head_size):
Expand Down
72 changes: 72 additions & 0 deletions diffsynth_engine/layers/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch.nn as nn
import torch.nn.functional as F
from diffsynth_engine.utils.import_utils import is_npu_available

try:
import torch_npu
except ImportError:
torch_npu = None


class _GELUProj(nn.Module):
"""Wrapper to match diffusers FeedForward GELU structure with internal proj.

This wrapper holds the first Linear layer as .proj to match checkpoint keys.
"""

def __init__(self, dim, inner_dim):
super().__init__()
self.proj = nn.Linear(dim, inner_dim)

def forward(self, x):
return F.gelu(x, approximate="tanh")


class FastGELUMLP(nn.Module):
"""MLP with npu_fast_gelu on NPU, fallback to F.gelu on other devices.

Functionally equivalent to diffusers.models.attention.FeedForward(
dim=dim, dim_out=dim, activation_fn="gelu-approximate"
)
"""

def __init__(self, dim, dim_out=None, mult=4):
"""Initialize MLP.

Args:
dim: Input and output dimension
dim_out: Output dimension, defaults to dim
mult: inner_dim = dim * mult, defaults to 4
"""
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out or dim

# Match diffusers FeedForward structure: net[0]=GELU(proj), net[2]=output
# net[1] is Dropout which is skipped in inference
self.net = nn.ModuleList([
_GELUProj(dim, inner_dim),
nn.Dropout(0.0),
nn.Linear(inner_dim, dim_out),
])

def forward(self, hidden_states):
"""Forward pass.

Args:
hidden_states: Input tensor, shape [B, S, dim]

Returns:
Output tensor, shape [B, S, dim_out]
"""
# net[0] = _GELUProj with internal proj (dim → inner_dim)
hidden_states = self.net[0].proj(hidden_states)

if is_npu_available() and torch_npu is not None:
hidden_states = torch_npu.npu_fast_gelu(hidden_states)
else:
hidden_states = F.gelu(hidden_states, approximate="tanh")

# net[2] = output Linear (inner_dim → dim_out)
hidden_states = self.net[2](hidden_states)
return hidden_states
78 changes: 78 additions & 0 deletions diffsynth_engine/layers/norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch
import torch.nn as nn
from diffusers.models.normalization import RMSNorm as DiffusersRMSNorm
from diffsynth_engine.utils.import_utils import is_npu_available

try:
import torch_npu
except ImportError:
torch_npu = None

try:
from mindiesd.layers import layernorm_scale_shift
except ImportError:
layernorm_scale_shift = None


class RMSNorm(nn.Module):
"""NPU-optimized RMSNorm wrapper with fallback to diffusers implementation."""

def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
# Cache the fallback instance so forward() reuses the same weight
# tensor. register_parameter is reference assignment (no copy), so
# self.weight and self._fallback.weight share the same storage.
# When a checkpoint writes to "weight", both paths see the update.
self._fallback = DiffusersRMSNorm(hidden_size, eps)
self.register_parameter("weight", self._fallback.weight)

def forward(self, hidden_states):
if is_npu_available() and torch_npu is not None:
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
else:
return self._fallback(hidden_states)


class AdaLayerNorm(nn.Module):
"""NPU-optimized AdaLayerNorm with fallback to original implementation.

Performs: output = layernorm(x) * (1 + scale) + shift

Args:
layernorm: The underlying nn.LayerNorm module (elementwise_affine=False)
"""

def __init__(self, layernorm: nn.LayerNorm):
super().__init__()
self.layernorm = layernorm

def forward(self, hidden_states: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states: Input tensor, shape [B, S, H]
scale: Scale parameter, shape [B, H] or [B, 1, H]
shift: Shift parameter, shape [B, H] or [B, 1, H]

Returns:
layernorm(x) * (1 + scale) + shift
"""
if is_npu_available() and layernorm_scale_shift is not None:
# NPU path: use MindIE-SD fused operator
return layernorm_scale_shift(
layernorm=self.layernorm,
x=hidden_states,
scale=scale,
shift=shift,
fused=True
)
else:
# Fallback: original Python implementation
normed = self.layernorm(hidden_states)
# Handle [B, 1, H] -> [B, H] dimension
if scale.dim() == 2:
scale = scale.unsqueeze(1)
if shift.dim() == 2:
shift = shift.unsqueeze(1)
return normed * (1 + scale) + shift
Loading