From 6355f6208e2d35027d1c7e537e7c337131741753 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Wed, 6 May 2026 15:21:01 -0700 Subject: [PATCH 1/6] [PyTorch] Batch CP attention tests in single torchrun to amortize NCCL init Each parametrized CP test currently spawns its own torchrun process and pays 5-15s of NCCL init/destroy. With ~650-800 collected tests this adds up to 1.5-3 hours of pure setup overhead. This change introduces a session-scoped fixture that: 1. Calls per-test ``_prepare_*`` helpers to get either a skip reason or a kwargs dict for the worker. 2. Groups runnable configs by ``num_gpus`` and chunks them into batches of CP_TEST_BATCH_SIZE (default 16). 3. Launches one torchrun per chunk; the worker initialises NCCL once and runs all configs in the chunk inside the same world. Per-config results are flushed to JSON after every config so a crash mid-batch still leaves earlier results intact. Set CP_TEST_BATCH_SIZE=1 to bisect a failing batch. Also includes a small bugfix in dot_product_attention/utils.py: the deterministic-FA3 disable condition was firing for any head_dim_qk > 128 (including inference); restrict it to is_training and large head dims. Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp.py | 137 +++++++- .../attention/test_attention_with_cp.py | 326 ++++++++++++++++-- 2 files changed, 415 insertions(+), 48 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 8dfea644a5..2b4dbbd166 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -4,6 +4,9 @@ import os import sys +import copy +import json +import traceback import logging from contextlib import nullcontext import torch @@ -209,10 +212,10 @@ def run_dpa_with_cp( os.environ["NVTE_FUSED_ATTN"] = "0" if kernel_backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" - config = model_configs_flash_attn[model] + config = copy.deepcopy(model_configs_flash_attn[model]) if kernel_backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - config = model_configs_fused_attn[model] + config = copy.deepcopy(model_configs_fused_attn[model]) assert config.attn_mask_type in [ "causal", "no_mask", @@ -223,18 +226,13 @@ def run_dpa_with_cp( else: config.attn_mask_type = "padding" - # set up distributed group - rank = int(os.getenv("RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - if dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() - else: - device_count = torch.cuda.device_count() - device = rank % device_count - torch.cuda.set_device(device) + # Process group is managed by main(); one init/destroy per torchrun, not per config. + assert dist.is_initialized(), ( + "dist.init_process_group must be called before run_dpa_with_cp" + ) + world_size = dist.get_world_size() + rank = dist.get_rank() logging.info(f"[Rank {rank}] Setup: world_size {world_size}") - dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) # set up communication group for CP cp_comm_ranks = range(world_size) @@ -630,7 +628,6 @@ def run_dpa_with_cp( == 0 ) else: - # Forward-only: reshape only out/out_ for comparison out = out.index_select(0, seq_idx_q).contiguous() out_ = out_ @@ -762,14 +759,118 @@ def run_dpa_with_cp( ) logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") - # destroy distribution group - dist.destroy_process_group() + # Destroy per-config communication groups so they don't leak into the next + # config in batch mode. The global process group is torn down by main(). + dist.destroy_process_group(cp_comm_group) + if cp_comm_type == "a2a+p2p": + for sg in cp_comm_sub_groups: + dist.destroy_process_group(sg) + + +# Env vars set by run_dpa_with_cp; cleared between batch configs to prevent leakage. +_TRANSIENT_ENV_KEYS = ( + "NVTE_FP8_DPA_BWD", + "NVTE_DPA_FP8CS_O_in_F16", + "NVTE_FLASH_ATTN", + "NVTE_FUSED_ATTN", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO", +) + + +def _init_distributed(): + """Init NCCL process group + CUDA device once per torchrun invocation.""" + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + device_count = torch.cuda.device_count() + # Prefer LOCAL_RANK when available (set by torchrun / torch.distributed.launch); + # fall back to RANK % device_count for single-node runs. + local_rank = int(os.getenv("LOCAL_RANK", str(rank % device_count))) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + return rank, world_size + + +def _run_single_config(kwargs): + """Run one config, return ``(ok, error_message)``. + + Re-seeds RNG before each config so results are deterministic and + order-independent within a batch. + """ + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + try: + run_dpa_with_cp(**kwargs) + return True, None + except BaseException: # noqa: BLE001 - capture any failure for per-config reporting + return False, traceback.format_exc() def main(**kwargs): - run_dpa_with_cp(**kwargs) + """Entry point. + + Two modes: + + * Single-config (legacy): ``run_attention_with_cp.py key=val ...`` runs + one config, propagates exceptions for normal exit-code signalling. + * Batch: ``run_attention_with_cp.py batch_config_json=`` reads a + JSON list of kwargs dicts, runs each via ``_run_single_config``, + aggregates ``ok`` across ranks (any rank failure → False), and flushes + ``[{ok,error}, ...]`` atomically to ``.results.json`` after each + config so a worker crash mid-batch leaves earlier results intact. + Transient env vars are reset between configs; per-config NCCL groups + are torn down inside ``run_dpa_with_cp``. + """ + batch_path = kwargs.pop("batch_config_json", None) + rank, _ = _init_distributed() + try: + if batch_path is None: + run_dpa_with_cp(**kwargs) + else: + with open(batch_path, "r") as f: + configs = json.load(f) + assert isinstance(configs, list), ( + f"batch_config_json must be a JSON list, got {type(configs)}" + ) + results_path = batch_path + ".results.json" + results = [] + + def _flush_results(): + if rank != 0: + return + # Atomic write: tmp + rename so the reader never sees partial JSON. + tmp_path = results_path + ".tmp" + with open(tmp_path, "w") as f: + json.dump(results, f) + os.replace(tmp_path, results_path) + + for cfg in configs: + for env_key in _TRANSIENT_ENV_KEYS: + os.environ.pop(env_key, None) + ok, err = _run_single_config(cfg) + # Aggregate ok across ranks so a non-rank-0 failure (e.g. a + # per-partition compare assertion that fires only on rank > 0) + # is not silently swallowed when only rank 0 writes the result. + ok_tensor = torch.tensor(1 if ok else 0, dtype=torch.int32, device="cuda") + dist.all_reduce(ok_tensor, op=dist.ReduceOp.MIN) + ok_aggregate = bool(ok_tensor.item()) + if not ok_aggregate and ok and err is None: + err = "Failed on a non-zero rank (see subprocess stderr for traceback)" + results.append({"ok": ok_aggregate, "error": err}) + _flush_results() + try: + dist.barrier() + except BaseException: # noqa: BLE001 + results[-1]["ok"] = False + if results[-1]["error"] is None: + results[-1]["error"] = traceback.format_exc() + _flush_results() + break + torch.cuda.empty_cache() + finally: + if dist.is_initialized(): + dist.destroy_process_group() if __name__ == "__main__": - kwargs = dict(arg.split("=") for arg in sys.argv[2:]) + kwargs = dict(arg.split("=", 1) for arg in sys.argv[2:]) main(**kwargs) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 23d1bfdd85..f07900be83 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -3,8 +3,9 @@ # See LICENSE for license information. import os -import subprocess +import json import sys +import tempfile import pathlib import logging import copy @@ -41,6 +42,8 @@ test_essential = True +_BATCH_SIZE = int(os.getenv("CP_TEST_BATCH_SIZE", "16")) + model_configs_flash_attn = { # test: ModelConfig(b, sq, hq, dqk) "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA @@ -75,6 +78,251 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): return args +# --------------------------------------------------------------------------- +# Batched dispatch — keeps test bodies identical to the non-batched flow +# (parametrize stack + inline ``pytest.skip(...)``) and replaces only the +# final ``run_distributed(...)`` call with ``_run_or_fetch(...)``. +# +# Flow: +# 1. Collect (dry-run, in-process). Session fixture ``_cp_batch_results`` +# walks each parametrized item that requests this fixture, calls it +# with a stub ``request`` (only ``request.node.nodeid``). The body runs +# its ``pytest.skip(...)`` checks; if none fire, ``_run_or_fetch`` +# records the kwargs in ``_COLLECTED_KWARGS`` instead of launching +# torchrun. ``@pytest.mark.skip(if)`` markers are evaluated up front +# via ``_item_static_skip`` so marker-skipped items aren't queued. +# 2. Batch + execute. Recorded kwargs are grouped by num_gpus_per_node, +# chunked into batches of CP_TEST_BATCH_SIZE (default 16), and each +# batch runs in one torchrun (``_run_one_batch``). Worker +# (run_attention_with_cp.py) inits NCCL once, loops over configs, +# flushes per-config results to ``.results.json`` atomically. +# 3. Execute mode (normal pytest run). The test body re-evaluates its +# skip checks; if none fire, ``_run_or_fetch`` looks up the recorded +# result by nodeid and asserts pass/fail. +# +# Failure handling: +# - Inline ``pytest.skip``: same code path as non-batched. +# - Worker assertion: surfaced as ``AssertionError`` from the JSON entry. +# - Per-rank failure: cross-rank ``dist.all_reduce(ok, op=MIN)`` in the +# worker so a rank > 0 assertion isn't swallowed by the rank-0-only flush. +# - Worker subprocess crash mid-batch: configs without flushed results are +# marked unattributed and ``_run_one_batch`` retries each as a singleton +# to identify the actual culprit. Disable via ``CP_TEST_BATCH_RETRY=0``. +# - Dry-run exception: caught; the same error fires in execute mode and +# pytest reports it as a normal test ERROR (no fixture-level cascade). +# +# To add a new batched test: write it like a non-batched CP test +# (parametrize + inline ``pytest.skip(...)``), accept +# ``request, _cp_batch_results`` as fixtures, and replace +# ``run_distributed(get_bash_arguments(...))`` with +# ``_run_or_fetch(request, _cp_batch_results, num_gpus_per_node=N, ...)``. +# Nothing else needs wiring up. +# +# Knobs: +# CP_TEST_BATCH_SIZE=N configs per torchrun; default 16; set 1 to bisect. +# CP_TEST_BATCH_RETRY=0 skip the singleton retry on unattributed crashes. +# +# Caveats: +# - ``pytest -k`` reduces what's collected and therefore what's batched. +# - The body executes once per item in collect mode (cheap; only Python +# skip logic + ``get_available_attention_backends``). +# - Mutations to module-level state during the body persist between collect +# and execute. Worker uses ``copy.deepcopy(model_configs_*[model])`` so +# ``run_dpa_with_cp`` mutations don't leak across configs in a batch. +# --------------------------------------------------------------------------- + +# Module-level state used by the session fixture's collect phase. +_COLLECT_MODE = False +_COLLECTED_KWARGS = {} # nodeid -> kwargs dict (populated in collect mode) + + +def _run_or_fetch(request, batch_results, *, num_gpus_per_node, **worker_kwargs): + """Drop-in replacement for ``run_distributed(get_bash_arguments(...))``. + + In *collect mode* (during the session fixture's first pass), records this + test's kwargs so the fixture can batch them. In *execute mode* (the normal + test run), looks up the pre-computed result and either passes, fails, or + skips. + """ + if _COLLECT_MODE: + _COLLECTED_KWARGS[request.node.nodeid] = dict( + num_gpus=num_gpus_per_node, **worker_kwargs + ) + return + entry = batch_results.get(request.node.nodeid) + if entry is None: + pytest.skip("No batched result recorded (collection mismatch).") + if not entry.get("ok", False): + raise AssertionError(entry.get("error") or "Batched config failed (no error captured)") + + +def _run_batch_once(num_gpus, configs): + """Launch one torchrun that runs *configs* sequentially inside one NCCL world. + + Returns a list of ``{"ok": bool, "error": str|None}`` dicts, one per config. + Missing entries (subprocess crashed mid-batch) are synthesized as failures. + """ + # Stringify values: run_dpa_with_cp uses ``== "True"`` string comparisons. + # Strip ``num_gpus`` (launcher-only, not a worker kwarg). + worker_kwargs = [ + {k: str(v) for k, v in cfg.items() if k != "num_gpus"} for cfg in configs + ] + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".cp_batch.json", delete=False + ) as fh: + batch_path = fh.name + json.dump(worker_kwargs, fh) + results_path = batch_path + ".results.json" + + try: + argv = get_bash_arguments(num_gpus_per_node=num_gpus, batch_config_json=batch_path) + launch_err = None + try: + run_distributed(argv) + except AssertionError as exc: + launch_err = str(exc) + + try: + with open(results_path, "r") as f: + per_cfg = json.load(f) + except (OSError, json.JSONDecodeError): + per_cfg = [] + + results = [] + for i in range(len(configs)): + if i < len(per_cfg): + results.append(per_cfg[i]) + else: + results.append( + { + "ok": False, + "error": launch_err or "Subprocess exited before this config ran.", + "_unattributed": True, + } + ) + return results + finally: + for p in (batch_path, results_path): + try: + os.unlink(p) + except OSError: + pass + + +def _run_one_batch(num_gpus, configs): + """Run a batch, then retry any unattributed-crash entries as singletons. + + When the worker subprocess crashes (segfault / NCCL hang / OOM) before it + can flush a per-config result, every config past the crash gets the + generic "Subprocess exited before this config ran" marker and we don't + know which config is the actual culprit. Re-running each marked config in + its own torchrun pinpoints which one crashes on its own (and salvages a + real result for the ones that ran fine but were caught downstream of the + crash). + + Disable via ``CP_TEST_BATCH_RETRY=0`` — useful if the singleton retries + themselves are taking too long on a flaky cluster. + """ + results = _run_batch_once(num_gpus, configs) + if len(configs) <= 1 or not int(os.getenv("CP_TEST_BATCH_RETRY", "1")): + for r in results: + r.pop("_unattributed", None) + return results + for i, r in enumerate(results): + if r.pop("_unattributed", False): + results[i] = _run_batch_once(num_gpus, [configs[i]])[0] + results[i].pop("_unattributed", None) + return results + + +class _DummyRequest: + """Stand-in for the ``request`` fixture during the dry-run phase. + + The test body only touches ``request.node.nodeid``, so this is enough. + """ + + def __init__(self, nodeid): + self.node = type("_DummyNode", (), {"nodeid": nodeid})() + + +def _item_static_skip(item): + """Return True if pytest's static skip/skipif markers would skip *item*. + + These markers are evaluated by pytest at runtime, before the test body + runs. The dry-run calls ``item.function(...)`` directly and would bypass + them — we replicate the check here so a marker-skipped test isn't queued + for torchrun unnecessarily. + """ + for marker in item.iter_markers("skip"): + return True + for marker in item.iter_markers("skipif"): + cond = marker.args[0] if marker.args else marker.kwargs.get("condition") + if cond: + return True + return False + + +def _dry_run_item(item): + """Invoke a parametrized test body in collect mode. + + Raises ``pytest.skip.Exception`` if the body skips, otherwise returns + after ``_run_or_fetch`` has stashed the kwargs in ``_COLLECTED_KWARGS``. + """ + func = item.function + params = dict(item.callspec.params) + func(_DummyRequest(item.nodeid), {}, **params) + + +@pytest.fixture(scope="session") +def _cp_batch_results(request): + """Run all batched test bodies once in collect mode, then run torchrun batches. + + Skips are NOT tracked here — the test body raises ``pytest.skip(...)`` in + both collect and execute mode, so skipped tests never reach ``_run_or_fetch`` + and don't need an entry in the result map. + """ + global _COLLECT_MODE + + items = [ + it + for it in request.session.items + if "_cp_batch_results" in getattr(it, "fixturenames", ()) + ] + + _COLLECTED_KWARGS.clear() + _COLLECT_MODE = True + try: + for item in items: + if _item_static_skip(item): + continue # pytest will skip this at runtime; don't queue for torchrun + try: + _dry_run_item(item) + except pytest.skip.Exception: + pass # the same pytest.skip will fire again in execute mode + except BaseException: # noqa: BLE001 + # Don't let a single bad item kill the whole session fixture — + # pytest will re-raise the same error in execute mode and the + # failure will surface there as a normal test ERROR. + pass + finally: + _COLLECT_MODE = False + + by_num_gpus = {} + for nodeid, kwargs in _COLLECTED_KWARGS.items(): + num_gpus = kwargs.pop("num_gpus") + by_num_gpus.setdefault(num_gpus, []).append((nodeid, kwargs)) + + results = {} + for num_gpus, entries in by_num_gpus.items(): + for start in range(0, len(entries), _BATCH_SIZE): + chunk = entries[start : start + _BATCH_SIZE] + chunk_results = _run_one_batch(num_gpus, [kw for _, kw in chunk]) + for (nodeid, _), res in zip(chunk, chunk_results): + results[nodeid] = res + return results + + dtypes = ["bf16", "fp16"] qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] @@ -91,7 +339,9 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("cp_comm_type", cp_comm_types) -def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): +def test_cp_with_flash_attention( + request, _cp_batch_results, dtype, model, qkv_format, cp_comm_type +): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") @@ -140,16 +390,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): if not flash_attn_supported: pytest.skip("No attention backend available.") - run_distributed( - get_bash_arguments( - num_gpus_per_node=num_gpus, - dtype=dtype, - model=model, - qkv_format=qkv_format, - kernel_backend="FlashAttention", - cp_comm_type=cp_comm_type, - log_level=pytest_logging_level, - ), + _run_or_fetch( + request, + _cp_batch_results, + num_gpus_per_node=num_gpus, + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FlashAttention", + cp_comm_type=cp_comm_type, + log_level=pytest_logging_level, ) @@ -274,7 +524,17 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("scaling_mode", [None, "delayed", "current", "mxfp8"]) @pytest.mark.parametrize("f16_O", [True, False]) def test_cp_with_fused_attention( - dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O + request, + _cp_batch_results, + dtype, + model, + qkv_format, + cp_comm_type, + fp8_bwd, + fp8_mha, + fp8_dpa, + scaling_mode, + f16_O, ): config = model_configs_fused_attn[model] config.context_parallel = True @@ -386,6 +646,7 @@ def test_cp_with_fused_attention( is_training=is_training, deterministic=_deterministic, ) + _, fused_attn_supported, _ = available_backends if fused_attn_supported and config.attn_mask_type in ["causal", "padding_causal"]: config_copy = copy.deepcopy(config) @@ -404,21 +665,26 @@ def test_cp_with_fused_attention( if not fused_attn_supported: pytest.skip("No attention backend available.") - run_distributed( - get_bash_arguments( - num_gpus_per_node=num_gpus, - dtype=dtype, - model=model, - qkv_format=qkv_format, - kernel_backend="FusedAttention", - cp_comm_type=cp_comm_type, - fp8_bwd=fp8_bwd, - fp8_dpa=fp8_dpa, - fp8_mha=fp8_mha, - scaling_mode=scaling_mode, - f16_O=f16_O, - is_training=is_training, - deterministic=_deterministic, - log_level=pytest_logging_level, - ), + if _deterministic and config.softmax_type != "vanilla": + pytest.skip("Deterministic mode does not support non-vanilla softmax with FusedAttention") + if _deterministic and config.attn_bias_type == "post_scale_bias" and is_training: + pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad") + + _run_or_fetch( + request, + _cp_batch_results, + num_gpus_per_node=num_gpus, + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FusedAttention", + cp_comm_type=cp_comm_type, + fp8_bwd=fp8_bwd, + fp8_dpa=fp8_dpa, + fp8_mha=fp8_mha, + scaling_mode=scaling_mode, + f16_O=f16_O, + is_training=is_training, + deterministic=_deterministic, + log_level=pytest_logging_level, ) From 686e76b143dc07557b50d03b743c5de98c2b8651 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 May 2026 14:14:54 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/run_attention_with_cp.py | 10 ++++------ .../pytorch/attention/test_attention_with_cp.py | 16 ++++------------ 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 2b4dbbd166..a78df40dad 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -227,9 +227,7 @@ def run_dpa_with_cp( config.attn_mask_type = "padding" # Process group is managed by main(); one init/destroy per torchrun, not per config. - assert dist.is_initialized(), ( - "dist.init_process_group must be called before run_dpa_with_cp" - ) + assert dist.is_initialized(), "dist.init_process_group must be called before run_dpa_with_cp" world_size = dist.get_world_size() rank = dist.get_rank() logging.info(f"[Rank {rank}] Setup: world_size {world_size}") @@ -828,9 +826,9 @@ def main(**kwargs): else: with open(batch_path, "r") as f: configs = json.load(f) - assert isinstance(configs, list), ( - f"batch_config_json must be a JSON list, got {type(configs)}" - ) + assert isinstance( + configs, list + ), f"batch_config_json must be a JSON list, got {type(configs)}" results_path = batch_path + ".results.json" results = [] diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index f07900be83..fb2f77a689 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -145,9 +145,7 @@ def _run_or_fetch(request, batch_results, *, num_gpus_per_node, **worker_kwargs) skips. """ if _COLLECT_MODE: - _COLLECTED_KWARGS[request.node.nodeid] = dict( - num_gpus=num_gpus_per_node, **worker_kwargs - ) + _COLLECTED_KWARGS[request.node.nodeid] = dict(num_gpus=num_gpus_per_node, **worker_kwargs) return entry = batch_results.get(request.node.nodeid) if entry is None: @@ -164,13 +162,9 @@ def _run_batch_once(num_gpus, configs): """ # Stringify values: run_dpa_with_cp uses ``== "True"`` string comparisons. # Strip ``num_gpus`` (launcher-only, not a worker kwarg). - worker_kwargs = [ - {k: str(v) for k, v in cfg.items() if k != "num_gpus"} for cfg in configs - ] + worker_kwargs = [{k: str(v) for k, v in cfg.items() if k != "num_gpus"} for cfg in configs] - with tempfile.NamedTemporaryFile( - mode="w", suffix=".cp_batch.json", delete=False - ) as fh: + with tempfile.NamedTemporaryFile(mode="w", suffix=".cp_batch.json", delete=False) as fh: batch_path = fh.name json.dump(worker_kwargs, fh) results_path = batch_path + ".results.json" @@ -285,9 +279,7 @@ def _cp_batch_results(request): global _COLLECT_MODE items = [ - it - for it in request.session.items - if "_cp_batch_results" in getattr(it, "fixturenames", ()) + it for it in request.session.items if "_cp_batch_results" in getattr(it, "fixturenames", ()) ] _COLLECTED_KWARGS.clear() From f8516a95b166b1c0349a2b355cf9ff856940e257 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 7 May 2026 15:21:37 -0700 Subject: [PATCH 3/6] [PyTorch] Broaden launch-error catch in CP batch dispatch Widen the except in _run_batch_once from AssertionError to Exception so OS-level failures from subprocess.run (FileNotFoundError when the worker script is missing, PermissionError, OSError when fds are exhausted, etc.) are attributed to the batch they came from instead of escaping the session-scoped _cp_batch_results fixture and ERROR-ing every CP test in the run. Addresses Greptile P1 review comment on PR 2965. Signed-off-by: Sudhakar Singh --- tests/pytorch/attention/test_attention_with_cp.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index fb2f77a689..d0d38689af 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -174,8 +174,13 @@ def _run_batch_once(num_gpus, configs): launch_err = None try: run_distributed(argv) - except AssertionError as exc: - launch_err = str(exc) + except Exception as exc: + # Catch broadly: subprocess.run can raise OSError/FileNotFoundError + # in addition to the AssertionError that run_distributed wraps a + # non-zero exit in. Letting any of these propagate would tear down + # the session fixture and ERROR every batched test instead of + # marking just this batch's configs as failed. + launch_err = str(exc) or repr(exc) try: with open(results_path, "r") as f: From a0dbf1a892c655a847327cdecce92e35f861d3b8 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 8 May 2026 09:37:03 -0700 Subject: [PATCH 4/6] [PyTorch] Fix FP8 cascade failures and skip divergence in CP batch tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit FP8GlobalStateManager retains quantizer registrations that reference destroyed NCCL process groups, causing cascade failures when multiple FP8 configs run in a single torchrun batch. Reset the singleton between configs to prevent this. get_available_attention_backends is stateful — calling it during the dry-run collect phase can produce different results than during the execute phase, causing "skip divergence" where the batch collects configs that should have been skipped. Cache backend availability per test node ID so the decision is consistent across phases. Also: pass MASTER_PORT through to torchrun so parallel pytest invocations on different GPU sets don't collide, and add [CP-BATCH] progress logging to the batch infrastructure. Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp.py | 2 + .../attention/test_attention_with_cp.py | 85 ++++++++++++++----- 2 files changed, 64 insertions(+), 23 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index a78df40dad..58bd9e3aca 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -24,6 +24,7 @@ Float8CurrentScalingQuantizer, MXFP8Quantizer, ) +from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.common.recipe import ( DelayedScaling, Float8CurrentScaling, @@ -842,6 +843,7 @@ def _flush_results(): os.replace(tmp_path, results_path) for cfg in configs: + FP8GlobalStateManager.reset() for env_key in _TRANSIENT_ENV_KEYS: os.environ.pop(env_key, None) ok, err = _run_single_config(cfg) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index d0d38689af..efed582262 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -70,6 +70,8 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): "torch.distributed.launch", "--nproc-per-node=" + str(num_gpus_per_node), ] + if "MASTER_PORT" in os.environ: + args.append("--master-port=" + os.environ["MASTER_PORT"]) te_path = os.getenv("TE_PATH", "/opt/transformerengine") script_path = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp.py") args.append(script_path) @@ -134,6 +136,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): # Module-level state used by the session fixture's collect phase. _COLLECT_MODE = False _COLLECTED_KWARGS = {} # nodeid -> kwargs dict (populated in collect mode) +_BACKEND_CACHE = {} # nodeid -> (fused_attn_supported, fused_attn_backends) or (flash_attn_supported,) def _run_or_fetch(request, batch_results, *, num_gpus_per_node, **worker_kwargs): @@ -287,8 +290,11 @@ def _cp_batch_results(request): it for it in request.session.items if "_cp_batch_results" in getattr(it, "fixturenames", ()) ] + import time as _time + _COLLECTED_KWARGS.clear() _COLLECT_MODE = True + _t0 = _time.monotonic() try: for item in items: if _item_static_skip(item): @@ -304,6 +310,11 @@ def _cp_batch_results(request): pass finally: _COLLECT_MODE = False + print( + f"\n[CP-BATCH] Collect done: {len(_COLLECTED_KWARGS)} configs from" + f" {len(items)} items in {_time.monotonic() - _t0:.1f}s", + flush=True, + ) by_num_gpus = {} for nodeid, kwargs in _COLLECTED_KWARGS.items(): @@ -312,11 +323,29 @@ def _cp_batch_results(request): results = {} for num_gpus, entries in by_num_gpus.items(): - for start in range(0, len(entries), _BATCH_SIZE): + n_batches = (len(entries) + _BATCH_SIZE - 1) // _BATCH_SIZE + for batch_idx, start in enumerate(range(0, len(entries), _BATCH_SIZE)): chunk = entries[start : start + _BATCH_SIZE] + print( + f"[CP-BATCH] Running batch {batch_idx + 1}/{n_batches}" + f" ({len(chunk)} cfgs, {num_gpus} GPUs)...", + flush=True, + ) + _bt = _time.monotonic() chunk_results = _run_one_batch(num_gpus, [kw for _, kw in chunk]) + ok = sum(1 for r in chunk_results if r.get("ok")) + print( + f"[CP-BATCH] => {ok}/{len(chunk)} passed" + f" in {_time.monotonic() - _bt:.1f}s", + flush=True, + ) for (nodeid, _), res in zip(chunk, chunk_results): results[nodeid] = res + print( + f"[CP-BATCH] All batches done: {len(results)} results" + f" in {_time.monotonic() - _t0:.1f}s total", + flush=True, + ) return results @@ -378,12 +407,17 @@ def test_cp_with_flash_attention( if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} - available_backends, *_ = get_available_attention_backends( - config, - qkv_dtype=dtypes[dtype], - qkv_layout="_".join([qkv_format] * 3), - ) - flash_attn_supported, *_ = available_backends + nodeid = request.node.nodeid + if nodeid in _BACKEND_CACHE: + flash_attn_supported = _BACKEND_CACHE[nodeid] + else: + available_backends, *_ = get_available_attention_backends( + config, + qkv_dtype=dtypes[dtype], + qkv_layout="_".join([qkv_format] * 3), + ) + flash_attn_supported, *_ = available_backends + _BACKEND_CACHE[nodeid] = flash_attn_supported if not flash_attn_supported: pytest.skip("No attention backend available.") @@ -634,23 +668,12 @@ def test_cp_with_fused_attention( # For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s. is_training = False if config.bias_shape == "111s" else True - available_backends, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, - qkv_layout="_".join([qkv_format] * 3), - fp8=fp8, - fp8_meta=fp8_meta, - is_training=is_training, - deterministic=_deterministic, - ) - - _, fused_attn_supported, _ = available_backends - if fused_attn_supported and config.attn_mask_type in ["causal", "padding_causal"]: - config_copy = copy.deepcopy(config) - config_copy.context_parallel = False - config_copy.attn_mask_type = config.attn_mask_type + "_bottom_right" + nodeid = request.node.nodeid + if nodeid in _BACKEND_CACHE: + fused_attn_supported = _BACKEND_CACHE[nodeid] + else: available_backends, _, fused_attn_backends = get_available_attention_backends( - config_copy, + config, qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, qkv_layout="_".join([qkv_format] * 3), fp8=fp8, @@ -658,7 +681,23 @@ def test_cp_with_fused_attention( is_training=is_training, deterministic=_deterministic, ) + _, fused_attn_supported, _ = available_backends + if fused_attn_supported and config.attn_mask_type in ["causal", "padding_causal"]: + config_copy = copy.deepcopy(config) + config_copy.context_parallel = False + config_copy.attn_mask_type = config.attn_mask_type + "_bottom_right" + available_backends, _, fused_attn_backends = get_available_attention_backends( + config_copy, + qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, + qkv_layout="_".join([qkv_format] * 3), + fp8=fp8, + fp8_meta=fp8_meta, + is_training=is_training, + deterministic=_deterministic, + ) + _, fused_attn_supported, _ = available_backends + _BACKEND_CACHE[nodeid] = fused_attn_supported if not fused_attn_supported: pytest.skip("No attention backend available.") From fe10e676a023b189491da3532c9276016ffb1670 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 8 May 2026 13:56:21 -0700 Subject: [PATCH 5/6] [PyTorch] Improve encapsulation of CP batch test infra Restore run_dpa_with_cp as self-contained: detect whether dist is already initialized and only init/destroy the global process group when called standalone (legacy single-config mode). In batch mode the function reuses the caller's process group and only tears down per-config CP comm groups. Extract _cached_backend_check helper so the backend-availability cache is not scattered into both test bodies. Trim verbose docstrings and inline comments down to single-line summaries. Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp.py | 44 ++--- .../attention/test_attention_with_cp.py | 179 ++++-------------- 2 files changed, 53 insertions(+), 170 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 58bd9e3aca..92c095d1dd 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -227,8 +227,15 @@ def run_dpa_with_cp( else: config.attn_mask_type = "padding" - # Process group is managed by main(); one init/destroy per torchrun, not per config. - assert dist.is_initialized(), "dist.init_process_group must be called before run_dpa_with_cp" + # When called from batch main(), dist is already initialized — reuse it. + # When called standalone (legacy single-config), init here. + _owns_dist = not dist.is_initialized() + if _owns_dist: + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + device_count = torch.cuda.device_count() + torch.cuda.set_device(rank % device_count) + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) world_size = dist.get_world_size() rank = dist.get_rank() logging.info(f"[Rank {rank}] Setup: world_size {world_size}") @@ -758,15 +765,14 @@ def run_dpa_with_cp( ) logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") - # Destroy per-config communication groups so they don't leak into the next - # config in batch mode. The global process group is torn down by main(). dist.destroy_process_group(cp_comm_group) if cp_comm_type == "a2a+p2p": for sg in cp_comm_sub_groups: dist.destroy_process_group(sg) + if _owns_dist: + dist.destroy_process_group() -# Env vars set by run_dpa_with_cp; cleared between batch configs to prevent leakage. _TRANSIENT_ENV_KEYS = ( "NVTE_FP8_DPA_BWD", "NVTE_DPA_FP8CS_O_in_F16", @@ -777,12 +783,9 @@ def run_dpa_with_cp( def _init_distributed(): - """Init NCCL process group + CUDA device once per torchrun invocation.""" rank = int(os.getenv("RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) device_count = torch.cuda.device_count() - # Prefer LOCAL_RANK when available (set by torchrun / torch.distributed.launch); - # fall back to RANK % device_count for single-node runs. local_rank = int(os.getenv("LOCAL_RANK", str(rank % device_count))) torch.cuda.set_device(local_rank) dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) @@ -790,11 +793,7 @@ def _init_distributed(): def _run_single_config(kwargs): - """Run one config, return ``(ok, error_message)``. - - Re-seeds RNG before each config so results are deterministic and - order-independent within a batch. - """ + """Run one config, return ``(ok, error_message)``.""" torch.manual_seed(1234) torch.cuda.manual_seed(1234) try: @@ -805,20 +804,7 @@ def _run_single_config(kwargs): def main(**kwargs): - """Entry point. - - Two modes: - - * Single-config (legacy): ``run_attention_with_cp.py key=val ...`` runs - one config, propagates exceptions for normal exit-code signalling. - * Batch: ``run_attention_with_cp.py batch_config_json=`` reads a - JSON list of kwargs dicts, runs each via ``_run_single_config``, - aggregates ``ok`` across ranks (any rank failure → False), and flushes - ``[{ok,error}, ...]`` atomically to ``.results.json`` after each - config so a worker crash mid-batch leaves earlier results intact. - Transient env vars are reset between configs; per-config NCCL groups - are torn down inside ``run_dpa_with_cp``. - """ + """Single-config (key=val args) or batch (batch_config_json=) entry point.""" batch_path = kwargs.pop("batch_config_json", None) rank, _ = _init_distributed() try: @@ -836,7 +822,6 @@ def main(**kwargs): def _flush_results(): if rank != 0: return - # Atomic write: tmp + rename so the reader never sees partial JSON. tmp_path = results_path + ".tmp" with open(tmp_path, "w") as f: json.dump(results, f) @@ -847,9 +832,6 @@ def _flush_results(): for env_key in _TRANSIENT_ENV_KEYS: os.environ.pop(env_key, None) ok, err = _run_single_config(cfg) - # Aggregate ok across ranks so a non-rank-0 failure (e.g. a - # per-partition compare assertion that fires only on rank > 0) - # is not silently swallowed when only rank 0 writes the result. ok_tensor = torch.tensor(1 if ok else 0, dtype=torch.int32, device="cuda") dist.all_reduce(ok_tensor, op=dist.ReduceOp.MIN) ok_aggregate = bool(ok_tensor.item()) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index efed582262..b09d36d6fa 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -81,72 +81,30 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): # --------------------------------------------------------------------------- -# Batched dispatch — keeps test bodies identical to the non-batched flow -# (parametrize stack + inline ``pytest.skip(...)``) and replaces only the -# final ``run_distributed(...)`` call with ``_run_or_fetch(...)``. +# Batched dispatch: session fixture dry-runs each test to collect kwargs, +# groups by num_gpus, chunks into batches of CP_TEST_BATCH_SIZE, and runs +# each batch in one torchrun. Test bodies are unchanged except they call +# ``_run_or_fetch(request, _cp_batch_results, ...)`` instead of +# ``run_distributed(get_bash_arguments(...))``. # -# Flow: -# 1. Collect (dry-run, in-process). Session fixture ``_cp_batch_results`` -# walks each parametrized item that requests this fixture, calls it -# with a stub ``request`` (only ``request.node.nodeid``). The body runs -# its ``pytest.skip(...)`` checks; if none fire, ``_run_or_fetch`` -# records the kwargs in ``_COLLECTED_KWARGS`` instead of launching -# torchrun. ``@pytest.mark.skip(if)`` markers are evaluated up front -# via ``_item_static_skip`` so marker-skipped items aren't queued. -# 2. Batch + execute. Recorded kwargs are grouped by num_gpus_per_node, -# chunked into batches of CP_TEST_BATCH_SIZE (default 16), and each -# batch runs in one torchrun (``_run_one_batch``). Worker -# (run_attention_with_cp.py) inits NCCL once, loops over configs, -# flushes per-config results to ``.results.json`` atomically. -# 3. Execute mode (normal pytest run). The test body re-evaluates its -# skip checks; if none fire, ``_run_or_fetch`` looks up the recorded -# result by nodeid and asserts pass/fail. -# -# Failure handling: -# - Inline ``pytest.skip``: same code path as non-batched. -# - Worker assertion: surfaced as ``AssertionError`` from the JSON entry. -# - Per-rank failure: cross-rank ``dist.all_reduce(ok, op=MIN)`` in the -# worker so a rank > 0 assertion isn't swallowed by the rank-0-only flush. -# - Worker subprocess crash mid-batch: configs without flushed results are -# marked unattributed and ``_run_one_batch`` retries each as a singleton -# to identify the actual culprit. Disable via ``CP_TEST_BATCH_RETRY=0``. -# - Dry-run exception: caught; the same error fires in execute mode and -# pytest reports it as a normal test ERROR (no fixture-level cascade). -# -# To add a new batched test: write it like a non-batched CP test -# (parametrize + inline ``pytest.skip(...)``), accept -# ``request, _cp_batch_results`` as fixtures, and replace -# ``run_distributed(get_bash_arguments(...))`` with -# ``_run_or_fetch(request, _cp_batch_results, num_gpus_per_node=N, ...)``. -# Nothing else needs wiring up. -# -# Knobs: -# CP_TEST_BATCH_SIZE=N configs per torchrun; default 16; set 1 to bisect. -# CP_TEST_BATCH_RETRY=0 skip the singleton retry on unattributed crashes. -# -# Caveats: -# - ``pytest -k`` reduces what's collected and therefore what's batched. -# - The body executes once per item in collect mode (cheap; only Python -# skip logic + ``get_available_attention_backends``). -# - Mutations to module-level state during the body persist between collect -# and execute. Worker uses ``copy.deepcopy(model_configs_*[model])`` so -# ``run_dpa_with_cp`` mutations don't leak across configs in a batch. +# Env knobs: CP_TEST_BATCH_SIZE (default 16), CP_TEST_BATCH_RETRY (default 1). # --------------------------------------------------------------------------- # Module-level state used by the session fixture's collect phase. _COLLECT_MODE = False _COLLECTED_KWARGS = {} # nodeid -> kwargs dict (populated in collect mode) -_BACKEND_CACHE = {} # nodeid -> (fused_attn_supported, fused_attn_backends) or (flash_attn_supported,) +_BACKEND_CACHE = {} -def _run_or_fetch(request, batch_results, *, num_gpus_per_node, **worker_kwargs): - """Drop-in replacement for ``run_distributed(get_bash_arguments(...))``. +def _cached_backend_check(nodeid, check_fn): + """Cache backend availability per test node so dry-run and execute agree.""" + if nodeid not in _BACKEND_CACHE: + _BACKEND_CACHE[nodeid] = check_fn() + return _BACKEND_CACHE[nodeid] - In *collect mode* (during the session fixture's first pass), records this - test's kwargs so the fixture can batch them. In *execute mode* (the normal - test run), looks up the pre-computed result and either passes, fails, or - skips. - """ + +def _run_or_fetch(request, batch_results, *, num_gpus_per_node, **worker_kwargs): + """Collect mode: stash kwargs. Execute mode: look up pre-computed result.""" if _COLLECT_MODE: _COLLECTED_KWARGS[request.node.nodeid] = dict(num_gpus=num_gpus_per_node, **worker_kwargs) return @@ -158,13 +116,7 @@ def _run_or_fetch(request, batch_results, *, num_gpus_per_node, **worker_kwargs) def _run_batch_once(num_gpus, configs): - """Launch one torchrun that runs *configs* sequentially inside one NCCL world. - - Returns a list of ``{"ok": bool, "error": str|None}`` dicts, one per config. - Missing entries (subprocess crashed mid-batch) are synthesized as failures. - """ - # Stringify values: run_dpa_with_cp uses ``== "True"`` string comparisons. - # Strip ``num_gpus`` (launcher-only, not a worker kwarg). + """Launch one torchrun for *configs*; return list of ``{ok, error}`` dicts.""" worker_kwargs = [{k: str(v) for k, v in cfg.items() if k != "num_gpus"} for cfg in configs] with tempfile.NamedTemporaryFile(mode="w", suffix=".cp_batch.json", delete=False) as fh: @@ -178,11 +130,6 @@ def _run_batch_once(num_gpus, configs): try: run_distributed(argv) except Exception as exc: - # Catch broadly: subprocess.run can raise OSError/FileNotFoundError - # in addition to the AssertionError that run_distributed wraps a - # non-zero exit in. Letting any of these propagate would tear down - # the session fixture and ERROR every batched test instead of - # marking just this batch's configs as failed. launch_err = str(exc) or repr(exc) try: @@ -213,19 +160,7 @@ def _run_batch_once(num_gpus, configs): def _run_one_batch(num_gpus, configs): - """Run a batch, then retry any unattributed-crash entries as singletons. - - When the worker subprocess crashes (segfault / NCCL hang / OOM) before it - can flush a per-config result, every config past the crash gets the - generic "Subprocess exited before this config ran" marker and we don't - know which config is the actual culprit. Re-running each marked config in - its own torchrun pinpoints which one crashes on its own (and salvages a - real result for the ones that ran fine but were caught downstream of the - crash). - - Disable via ``CP_TEST_BATCH_RETRY=0`` — useful if the singleton retries - themselves are taking too long on a flaky cluster. - """ + """Run a batch, retrying unattributed crashes as singletons to isolate the culprit.""" results = _run_batch_once(num_gpus, configs) if len(configs) <= 1 or not int(os.getenv("CP_TEST_BATCH_RETRY", "1")): for r in results: @@ -239,23 +174,14 @@ def _run_one_batch(num_gpus, configs): class _DummyRequest: - """Stand-in for the ``request`` fixture during the dry-run phase. - - The test body only touches ``request.node.nodeid``, so this is enough. - """ + """Minimal stand-in for the ``request`` fixture during dry-run.""" def __init__(self, nodeid): self.node = type("_DummyNode", (), {"nodeid": nodeid})() def _item_static_skip(item): - """Return True if pytest's static skip/skipif markers would skip *item*. - - These markers are evaluated by pytest at runtime, before the test body - runs. The dry-run calls ``item.function(...)`` directly and would bypass - them — we replicate the check here so a marker-skipped test isn't queued - for torchrun unnecessarily. - """ + """Return True if pytest skip/skipif markers would skip *item*.""" for marker in item.iter_markers("skip"): return True for marker in item.iter_markers("skipif"): @@ -266,11 +192,7 @@ def _item_static_skip(item): def _dry_run_item(item): - """Invoke a parametrized test body in collect mode. - - Raises ``pytest.skip.Exception`` if the body skips, otherwise returns - after ``_run_or_fetch`` has stashed the kwargs in ``_COLLECTED_KWARGS``. - """ + """Invoke a parametrized test body in collect mode to gather kwargs.""" func = item.function params = dict(item.callspec.params) func(_DummyRequest(item.nodeid), {}, **params) @@ -278,12 +200,7 @@ def _dry_run_item(item): @pytest.fixture(scope="session") def _cp_batch_results(request): - """Run all batched test bodies once in collect mode, then run torchrun batches. - - Skips are NOT tracked here — the test body raises ``pytest.skip(...)`` in - both collect and execute mode, so skipped tests never reach ``_run_or_fetch`` - and don't need an entry in the result map. - """ + """Dry-run all batched tests to collect kwargs, then dispatch torchrun batches.""" global _COLLECT_MODE items = [ @@ -298,15 +215,12 @@ def _cp_batch_results(request): try: for item in items: if _item_static_skip(item): - continue # pytest will skip this at runtime; don't queue for torchrun + continue try: _dry_run_item(item) except pytest.skip.Exception: - pass # the same pytest.skip will fire again in execute mode + pass except BaseException: # noqa: BLE001 - # Don't let a single bad item kill the whole session fixture — - # pytest will re-raise the same error in execute mode and the - # failure will surface there as a normal test ERROR. pass finally: _COLLECT_MODE = False @@ -407,17 +321,12 @@ def test_cp_with_flash_attention( if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} - nodeid = request.node.nodeid - if nodeid in _BACKEND_CACHE: - flash_attn_supported = _BACKEND_CACHE[nodeid] - else: - available_backends, *_ = get_available_attention_backends( - config, - qkv_dtype=dtypes[dtype], - qkv_layout="_".join([qkv_format] * 3), - ) - flash_attn_supported, *_ = available_backends - _BACKEND_CACHE[nodeid] = flash_attn_supported + flash_attn_supported = _cached_backend_check( + request.node.nodeid, + lambda: get_available_attention_backends( + config, qkv_dtype=dtypes[dtype], qkv_layout="_".join([qkv_format] * 3) + )[0][0], + ) if not flash_attn_supported: pytest.skip("No attention backend available.") @@ -668,12 +577,8 @@ def test_cp_with_fused_attention( # For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s. is_training = False if config.bias_shape == "111s" else True - nodeid = request.node.nodeid - if nodeid in _BACKEND_CACHE: - fused_attn_supported = _BACKEND_CACHE[nodeid] - else: - available_backends, _, fused_attn_backends = get_available_attention_backends( - config, + def _check_fused_backend(): + backend_kwargs = dict( qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, qkv_layout="_".join([qkv_format] * 3), fp8=fp8, @@ -681,23 +586,19 @@ def test_cp_with_fused_attention( is_training=is_training, deterministic=_deterministic, ) - - _, fused_attn_supported, _ = available_backends - if fused_attn_supported and config.attn_mask_type in ["causal", "padding_causal"]: + available_backends, _, _ = get_available_attention_backends(config, **backend_kwargs) + _, supported, _ = available_backends + if supported and config.attn_mask_type in ["causal", "padding_causal"]: config_copy = copy.deepcopy(config) config_copy.context_parallel = False config_copy.attn_mask_type = config.attn_mask_type + "_bottom_right" - available_backends, _, fused_attn_backends = get_available_attention_backends( - config_copy, - qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, - qkv_layout="_".join([qkv_format] * 3), - fp8=fp8, - fp8_meta=fp8_meta, - is_training=is_training, - deterministic=_deterministic, + available_backends, _, _ = get_available_attention_backends( + config_copy, **backend_kwargs ) - _, fused_attn_supported, _ = available_backends - _BACKEND_CACHE[nodeid] = fused_attn_supported + _, supported, _ = available_backends + return supported + + fused_attn_supported = _cached_backend_check(request.node.nodeid, _check_fused_backend) if not fused_attn_supported: pytest.skip("No attention backend available.") From cd5d25b19e85ef724d0a57c7e8aa933d7a623a38 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 May 2026 21:02:56 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention_with_cp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index b09d36d6fa..75208f6436 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -249,8 +249,7 @@ def _cp_batch_results(request): chunk_results = _run_one_batch(num_gpus, [kw for _, kw in chunk]) ok = sum(1 for r in chunk_results if r.get("ok")) print( - f"[CP-BATCH] => {ok}/{len(chunk)} passed" - f" in {_time.monotonic() - _bt:.1f}s", + f"[CP-BATCH] => {ok}/{len(chunk)} passed in {_time.monotonic() - _bt:.1f}s", flush=True, ) for (nodeid, _), res in zip(chunk, chunk_results): @@ -577,6 +576,7 @@ def test_cp_with_fused_attention( # For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s. is_training = False if config.bias_shape == "111s" else True + def _check_fused_backend(): backend_kwargs = dict( qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn,