From c8abb5d7c01ca6a7c0bf82c27c91a326155a5e43 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 12 Dec 2025 15:20:18 +0530 Subject: [PATCH 1/4] gracefully error out when attn-backend x cp combo isn't supported. --- tests/others/test_attention_backends.py | 39 ++++++++++++++++--------- tests/testing_utils.py | 4 +++ 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/tests/others/test_attention_backends.py b/tests/others/test_attention_backends.py index 01f4521c5adc..2cfdde1e32fb 100644 --- a/tests/others/test_attention_backends.py +++ b/tests/others/test_attention_backends.py @@ -11,8 +11,7 @@ pytest tests/others/test_attention_backends.py ``` -Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in -"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128). +Tests were conducted on an H100 with PyTorch 2.9.1 (CUDA 12.9). Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and @@ -24,6 +23,8 @@ import pytest import torch +from ..testing_utils import numpy_cosine_similarity_distance + pytestmark = pytest.mark.skipif( os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough." @@ -36,23 +37,28 @@ FORWARD_CASES = [ ( "flash_hub", - torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16) + torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16), + 1e-4 ), ( "_flash_3_hub", torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16), + 1e-4 ), ( "native", - torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16) - ), + torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16), + 1e-4 + ), ( "_native_cudnn", torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16), + 5e-4 ), ( "aiter", torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16), + 1e-4 ) ] @@ -60,27 +66,32 @@ ( "flash_hub", torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16), - True + True, + 1e-4 ), ( "_flash_3_hub", torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16), True, + 1e-4 ), ( "native", torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16), True, + 1e-4 ), ( "_native_cudnn", torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16), True, + 5e-4, ), ( "aiter", torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16), True, + 1e-4 ) ] # fmt: on @@ -104,11 +115,11 @@ def _backend_is_probably_supported(pipe, name: str): return False -def _check_if_slices_match(output, expected_slice): +def _check_if_slices_match(output, expected_slice, expected_diff=1e-4): img = output.images.detach().cpu() generated_slice = img.flatten() generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) - assert torch.allclose(generated_slice, expected_slice, atol=1e-4) + assert numpy_cosine_similarity_distance(generated_slice, expected_slice) < expected_diff @pytest.fixture(scope="session") @@ -126,23 +137,23 @@ def pipe(device): return pipe -@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES]) -def test_forward(pipe, backend_name, expected_slice): +@pytest.mark.parametrize("backend_name,expected_slice,expected_diff", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES]) +def test_forward(pipe, backend_name, expected_slice, expected_diff): out = _backend_is_probably_supported(pipe, backend_name) if isinstance(out, bool): pytest.xfail(f"Backend '{backend_name}' not supported in this environment.") modified_pipe = out[0] out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0)) - _check_if_slices_match(out, expected_slice) + _check_if_slices_match(out, expected_slice, expected_diff) @pytest.mark.parametrize( - "backend_name,expected_slice,error_on_recompile", + "backend_name,expected_slice,error_on_recompile,expected_diff", COMPILE_CASES, ids=[c[0] for c in COMPILE_CASES], ) -def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile): +def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile, expected_diff): if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"): pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.") @@ -160,4 +171,4 @@ def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recom ): out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0)) - _check_if_slices_match(out, expected_slice) + _check_if_slices_match(out, expected_slice, expected_diff) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 4550813259af..996853bbbebd 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -131,6 +131,10 @@ def torch_all_close(a, b, *args, **kwargs): def numpy_cosine_similarity_distance(a, b): + if isinstance(a, torch.Tensor): + a = a.detach().cpu().float().numpy() + if isinstance(b, torch.Tensor): + b = b.detach().cpu().float().numpy() similarity = np.dot(a, b) / (norm(a) * norm(b)) distance = 1.0 - similarity.mean() From 23251d6cf647fb03c40d28d8e09a57650d32e6f6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 12 Dec 2025 15:24:09 +0530 Subject: [PATCH 2/4] Revert "gracefully error out when attn-backend x cp combo isn't supported." This reverts commit c8abb5d7c01ca6a7c0bf82c27c91a326155a5e43. --- tests/others/test_attention_backends.py | 39 +++++++++---------------- tests/testing_utils.py | 4 --- 2 files changed, 14 insertions(+), 29 deletions(-) diff --git a/tests/others/test_attention_backends.py b/tests/others/test_attention_backends.py index 2cfdde1e32fb..01f4521c5adc 100644 --- a/tests/others/test_attention_backends.py +++ b/tests/others/test_attention_backends.py @@ -11,7 +11,8 @@ pytest tests/others/test_attention_backends.py ``` -Tests were conducted on an H100 with PyTorch 2.9.1 (CUDA 12.9). +Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in +"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128). Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and @@ -23,8 +24,6 @@ import pytest import torch -from ..testing_utils import numpy_cosine_similarity_distance - pytestmark = pytest.mark.skipif( os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough." @@ -37,28 +36,23 @@ FORWARD_CASES = [ ( "flash_hub", - torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16), - 1e-4 + torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16) ), ( "_flash_3_hub", torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16), - 1e-4 ), ( "native", - torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16), - 1e-4 - ), + torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16) + ), ( "_native_cudnn", torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16), - 5e-4 ), ( "aiter", torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16), - 1e-4 ) ] @@ -66,32 +60,27 @@ ( "flash_hub", torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16), - True, - 1e-4 + True ), ( "_flash_3_hub", torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16), True, - 1e-4 ), ( "native", torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16), True, - 1e-4 ), ( "_native_cudnn", torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16), True, - 5e-4, ), ( "aiter", torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16), True, - 1e-4 ) ] # fmt: on @@ -115,11 +104,11 @@ def _backend_is_probably_supported(pipe, name: str): return False -def _check_if_slices_match(output, expected_slice, expected_diff=1e-4): +def _check_if_slices_match(output, expected_slice): img = output.images.detach().cpu() generated_slice = img.flatten() generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) - assert numpy_cosine_similarity_distance(generated_slice, expected_slice) < expected_diff + assert torch.allclose(generated_slice, expected_slice, atol=1e-4) @pytest.fixture(scope="session") @@ -137,23 +126,23 @@ def pipe(device): return pipe -@pytest.mark.parametrize("backend_name,expected_slice,expected_diff", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES]) -def test_forward(pipe, backend_name, expected_slice, expected_diff): +@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES]) +def test_forward(pipe, backend_name, expected_slice): out = _backend_is_probably_supported(pipe, backend_name) if isinstance(out, bool): pytest.xfail(f"Backend '{backend_name}' not supported in this environment.") modified_pipe = out[0] out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0)) - _check_if_slices_match(out, expected_slice, expected_diff) + _check_if_slices_match(out, expected_slice) @pytest.mark.parametrize( - "backend_name,expected_slice,error_on_recompile,expected_diff", + "backend_name,expected_slice,error_on_recompile", COMPILE_CASES, ids=[c[0] for c in COMPILE_CASES], ) -def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile, expected_diff): +def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile): if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"): pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.") @@ -171,4 +160,4 @@ def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recom ): out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0)) - _check_if_slices_match(out, expected_slice, expected_diff) + _check_if_slices_match(out, expected_slice) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 996853bbbebd..4550813259af 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -131,10 +131,6 @@ def torch_all_close(a, b, *args, **kwargs): def numpy_cosine_similarity_distance(a, b): - if isinstance(a, torch.Tensor): - a = a.detach().cpu().float().numpy() - if isinstance(b, torch.Tensor): - b = b.detach().cpu().float().numpy() similarity = np.dot(a, b) / (norm(a) * norm(b)) distance = 1.0 - similarity.mean() From 738f278d93bb3f72527ca2d5607c6cc95dacb17e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 12 Dec 2025 15:25:59 +0530 Subject: [PATCH 3/4] gracefully error out when attn-backend x cp combo isn't supported. --- src/diffusers/models/attention_dispatch.py | 22 +++++++++++++++++++--- src/diffusers/models/modeling_utils.py | 6 +++++- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 310c44457c27..7ae8ea0f00d1 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -235,6 +235,10 @@ def decorator(func): def get_active_backend(cls): return cls._active_backend, cls._backends[cls._active_backend] + @classmethod + def set_active_backend(cls, backend: str): + cls._active_backend = backend + @classmethod def list_backends(cls): return list(cls._backends.keys()) @@ -294,12 +298,12 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke _maybe_download_kernel_for_backend(backend) old_backend = _AttentionBackendRegistry._active_backend - _AttentionBackendRegistry._active_backend = backend + _AttentionBackendRegistry.set_active_backend(backend) try: yield finally: - _AttentionBackendRegistry._active_backend = old_backend + _AttentionBackendRegistry.set_active_backend(old_backend) def dispatch_attention_fn( @@ -325,7 +329,7 @@ def dispatch_attention_fn( else: backend_name = AttentionBackendName(backend) backend_fn = _AttentionBackendRegistry._backends.get(backend_name) - + kwargs = { "query": query, "key": key, @@ -348,6 +352,18 @@ def dispatch_attention_fn( check(**kwargs) kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]} + + if "_parallel_config" in kwargs and kwargs["_parallel_config"] is not None: + attention_backend = AttentionBackendName(backend_name) + if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend): + compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel) + raise ValueError( + f"Context parallelism is enabled but backend '{attention_backend.value}' " + f"which does not support context parallelism. " + f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before " + f"calling `model.enable_parallelism()`." + ) + return backend_fn(**kwargs) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 41da95d3a2a2..1d3c62338329 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -601,6 +601,7 @@ def set_attention_backend(self, backend: str) -> None: """ from .attention import AttentionModuleMixin from .attention_dispatch import ( + _AttentionBackendRegistry, AttentionBackendName, _check_attention_backend_requirements, _maybe_download_kernel_for_backend, @@ -628,6 +629,9 @@ def set_attention_backend(self, backend: str) -> None: if processor is None or not hasattr(processor, "_attention_backend"): continue processor._attention_backend = backend + + # Important to set the active backend so that it propagates gracefully throughout. + _AttentionBackendRegistry.set_active_backend(backend) def reset_attention_backend(self) -> None: """ @@ -1541,7 +1545,7 @@ def enable_parallelism( f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' " f"is using backend '{attention_backend.value}' which does not support context parallelism. " f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before " - f"calling `enable_parallelism()`." + f"calling `model.enable_parallelism()`." ) # All modules use the same attention processor and backend. We don't need to From 0c35ed4708c03d50da0ba4e15a7ad77086200255 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 12 Dec 2025 15:26:43 +0530 Subject: [PATCH 4/4] up --- src/diffusers/models/attention_dispatch.py | 6 +++--- src/diffusers/models/modeling_utils.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 7ae8ea0f00d1..3aed728b50d1 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -329,7 +329,7 @@ def dispatch_attention_fn( else: backend_name = AttentionBackendName(backend) backend_fn = _AttentionBackendRegistry._backends.get(backend_name) - + kwargs = { "query": query, "key": key, @@ -352,7 +352,7 @@ def dispatch_attention_fn( check(**kwargs) kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]} - + if "_parallel_config" in kwargs and kwargs["_parallel_config"] is not None: attention_backend = AttentionBackendName(backend_name) if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend): @@ -363,7 +363,7 @@ def dispatch_attention_fn( f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before " f"calling `model.enable_parallelism()`." ) - + return backend_fn(**kwargs) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 1d3c62338329..197a0d78bc36 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -601,8 +601,8 @@ def set_attention_backend(self, backend: str) -> None: """ from .attention import AttentionModuleMixin from .attention_dispatch import ( - _AttentionBackendRegistry, AttentionBackendName, + _AttentionBackendRegistry, _check_attention_backend_requirements, _maybe_download_kernel_for_backend, ) @@ -629,7 +629,7 @@ def set_attention_backend(self, backend: str) -> None: if processor is None or not hasattr(processor, "_attention_backend"): continue processor._attention_backend = backend - + # Important to set the active backend so that it propagates gracefully throughout. _AttentionBackendRegistry.set_active_backend(backend)