Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,10 @@ def decorator(func):
def get_active_backend(cls):
return cls._active_backend, cls._backends[cls._active_backend]

@classmethod
def set_active_backend(cls, backend: str):
cls._active_backend = backend

@classmethod
def list_backends(cls):
return list(cls._backends.keys())
Expand Down Expand Up @@ -294,12 +298,12 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
_maybe_download_kernel_for_backend(backend)

old_backend = _AttentionBackendRegistry._active_backend
_AttentionBackendRegistry._active_backend = backend
_AttentionBackendRegistry.set_active_backend(backend)

try:
yield
finally:
_AttentionBackendRegistry._active_backend = old_backend
_AttentionBackendRegistry.set_active_backend(old_backend)


def dispatch_attention_fn(
Expand Down Expand Up @@ -348,6 +352,18 @@ def dispatch_attention_fn(
check(**kwargs)

kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}

if "_parallel_config" in kwargs and kwargs["_parallel_config"] is not None:
attention_backend = AttentionBackendName(backend_name)
if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend):
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
raise ValueError(
f"Context parallelism is enabled but backend '{attention_backend.value}' "
f"which does not support context parallelism. "
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
f"calling `model.enable_parallelism()`."
)

return backend_fn(**kwargs)


Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ def set_attention_backend(self, backend: str) -> None:
from .attention import AttentionModuleMixin
from .attention_dispatch import (
AttentionBackendName,
_AttentionBackendRegistry,
_check_attention_backend_requirements,
_maybe_download_kernel_for_backend,
)
Expand Down Expand Up @@ -629,6 +630,9 @@ def set_attention_backend(self, backend: str) -> None:
continue
processor._attention_backend = backend

# Important to set the active backend so that it propagates gracefully throughout.
_AttentionBackendRegistry.set_active_backend(backend)

def reset_attention_backend(self) -> None:
"""
Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
Expand Down Expand Up @@ -1541,7 +1545,7 @@ def enable_parallelism(
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
f"calling `enable_parallelism()`."
f"calling `model.enable_parallelism()`."
)

# All modules use the same attention processor and backend. We don't need to
Expand Down
Loading