diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 8dfea644a5..92c095d1dd 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -4,6 +4,9 @@ import os import sys +import copy +import json +import traceback import logging from contextlib import nullcontext import torch @@ -21,6 +24,7 @@ Float8CurrentScalingQuantizer, MXFP8Quantizer, ) +from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.common.recipe import ( DelayedScaling, Float8CurrentScaling, @@ -209,10 +213,10 @@ def run_dpa_with_cp( os.environ["NVTE_FUSED_ATTN"] = "0" if kernel_backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" - config = model_configs_flash_attn[model] + config = copy.deepcopy(model_configs_flash_attn[model]) if kernel_backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - config = model_configs_fused_attn[model] + config = copy.deepcopy(model_configs_fused_attn[model]) assert config.attn_mask_type in [ "causal", "no_mask", @@ -223,18 +227,18 @@ def run_dpa_with_cp( else: config.attn_mask_type = "padding" - # set up distributed group - rank = int(os.getenv("RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - if dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() - else: + # When called from batch main(), dist is already initialized — reuse it. + # When called standalone (legacy single-config), init here. + _owns_dist = not dist.is_initialized() + if _owns_dist: + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) device_count = torch.cuda.device_count() - device = rank % device_count - torch.cuda.set_device(device) + torch.cuda.set_device(rank % device_count) + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + world_size = dist.get_world_size() + rank = dist.get_rank() logging.info(f"[Rank {rank}] Setup: world_size {world_size}") - dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) # set up communication group for CP cp_comm_ranks = range(world_size) @@ -630,7 +634,6 @@ def run_dpa_with_cp( == 0 ) else: - # Forward-only: reshape only out/out_ for comparison out = out.index_select(0, seq_idx_q).contiguous() out_ = out_ @@ -762,14 +765,94 @@ def run_dpa_with_cp( ) logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") - # destroy distribution group - dist.destroy_process_group() + dist.destroy_process_group(cp_comm_group) + if cp_comm_type == "a2a+p2p": + for sg in cp_comm_sub_groups: + dist.destroy_process_group(sg) + if _owns_dist: + dist.destroy_process_group() + + +_TRANSIENT_ENV_KEYS = ( + "NVTE_FP8_DPA_BWD", + "NVTE_DPA_FP8CS_O_in_F16", + "NVTE_FLASH_ATTN", + "NVTE_FUSED_ATTN", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO", +) + + +def _init_distributed(): + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + device_count = torch.cuda.device_count() + local_rank = int(os.getenv("LOCAL_RANK", str(rank % device_count))) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + return rank, world_size + + +def _run_single_config(kwargs): + """Run one config, return ``(ok, error_message)``.""" + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + try: + run_dpa_with_cp(**kwargs) + return True, None + except BaseException: # noqa: BLE001 - capture any failure for per-config reporting + return False, traceback.format_exc() def main(**kwargs): - run_dpa_with_cp(**kwargs) + """Single-config (key=val args) or batch (batch_config_json=) entry point.""" + batch_path = kwargs.pop("batch_config_json", None) + rank, _ = _init_distributed() + try: + if batch_path is None: + run_dpa_with_cp(**kwargs) + else: + with open(batch_path, "r") as f: + configs = json.load(f) + assert isinstance( + configs, list + ), f"batch_config_json must be a JSON list, got {type(configs)}" + results_path = batch_path + ".results.json" + results = [] + + def _flush_results(): + if rank != 0: + return + tmp_path = results_path + ".tmp" + with open(tmp_path, "w") as f: + json.dump(results, f) + os.replace(tmp_path, results_path) + + for cfg in configs: + FP8GlobalStateManager.reset() + for env_key in _TRANSIENT_ENV_KEYS: + os.environ.pop(env_key, None) + ok, err = _run_single_config(cfg) + ok_tensor = torch.tensor(1 if ok else 0, dtype=torch.int32, device="cuda") + dist.all_reduce(ok_tensor, op=dist.ReduceOp.MIN) + ok_aggregate = bool(ok_tensor.item()) + if not ok_aggregate and ok and err is None: + err = "Failed on a non-zero rank (see subprocess stderr for traceback)" + results.append({"ok": ok_aggregate, "error": err}) + _flush_results() + try: + dist.barrier() + except BaseException: # noqa: BLE001 + results[-1]["ok"] = False + if results[-1]["error"] is None: + results[-1]["error"] = traceback.format_exc() + _flush_results() + break + torch.cuda.empty_cache() + finally: + if dist.is_initialized(): + dist.destroy_process_group() if __name__ == "__main__": - kwargs = dict(arg.split("=") for arg in sys.argv[2:]) + kwargs = dict(arg.split("=", 1) for arg in sys.argv[2:]) main(**kwargs) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 23d1bfdd85..75208f6436 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -3,8 +3,9 @@ # See LICENSE for license information. import os -import subprocess +import json import sys +import tempfile import pathlib import logging import copy @@ -41,6 +42,8 @@ test_essential = True +_BATCH_SIZE = int(os.getenv("CP_TEST_BATCH_SIZE", "16")) + model_configs_flash_attn = { # test: ModelConfig(b, sq, hq, dqk) "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA @@ -67,6 +70,8 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): "torch.distributed.launch", "--nproc-per-node=" + str(num_gpus_per_node), ] + if "MASTER_PORT" in os.environ: + args.append("--master-port=" + os.environ["MASTER_PORT"]) te_path = os.getenv("TE_PATH", "/opt/transformerengine") script_path = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp.py") args.append(script_path) @@ -75,6 +80,188 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): return args +# --------------------------------------------------------------------------- +# Batched dispatch: session fixture dry-runs each test to collect kwargs, +# groups by num_gpus, chunks into batches of CP_TEST_BATCH_SIZE, and runs +# each batch in one torchrun. Test bodies are unchanged except they call +# ``_run_or_fetch(request, _cp_batch_results, ...)`` instead of +# ``run_distributed(get_bash_arguments(...))``. +# +# Env knobs: CP_TEST_BATCH_SIZE (default 16), CP_TEST_BATCH_RETRY (default 1). +# --------------------------------------------------------------------------- + +# Module-level state used by the session fixture's collect phase. +_COLLECT_MODE = False +_COLLECTED_KWARGS = {} # nodeid -> kwargs dict (populated in collect mode) +_BACKEND_CACHE = {} + + +def _cached_backend_check(nodeid, check_fn): + """Cache backend availability per test node so dry-run and execute agree.""" + if nodeid not in _BACKEND_CACHE: + _BACKEND_CACHE[nodeid] = check_fn() + return _BACKEND_CACHE[nodeid] + + +def _run_or_fetch(request, batch_results, *, num_gpus_per_node, **worker_kwargs): + """Collect mode: stash kwargs. Execute mode: look up pre-computed result.""" + if _COLLECT_MODE: + _COLLECTED_KWARGS[request.node.nodeid] = dict(num_gpus=num_gpus_per_node, **worker_kwargs) + return + entry = batch_results.get(request.node.nodeid) + if entry is None: + pytest.skip("No batched result recorded (collection mismatch).") + if not entry.get("ok", False): + raise AssertionError(entry.get("error") or "Batched config failed (no error captured)") + + +def _run_batch_once(num_gpus, configs): + """Launch one torchrun for *configs*; return list of ``{ok, error}`` dicts.""" + worker_kwargs = [{k: str(v) for k, v in cfg.items() if k != "num_gpus"} for cfg in configs] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".cp_batch.json", delete=False) as fh: + batch_path = fh.name + json.dump(worker_kwargs, fh) + results_path = batch_path + ".results.json" + + try: + argv = get_bash_arguments(num_gpus_per_node=num_gpus, batch_config_json=batch_path) + launch_err = None + try: + run_distributed(argv) + except Exception as exc: + launch_err = str(exc) or repr(exc) + + try: + with open(results_path, "r") as f: + per_cfg = json.load(f) + except (OSError, json.JSONDecodeError): + per_cfg = [] + + results = [] + for i in range(len(configs)): + if i < len(per_cfg): + results.append(per_cfg[i]) + else: + results.append( + { + "ok": False, + "error": launch_err or "Subprocess exited before this config ran.", + "_unattributed": True, + } + ) + return results + finally: + for p in (batch_path, results_path): + try: + os.unlink(p) + except OSError: + pass + + +def _run_one_batch(num_gpus, configs): + """Run a batch, retrying unattributed crashes as singletons to isolate the culprit.""" + results = _run_batch_once(num_gpus, configs) + if len(configs) <= 1 or not int(os.getenv("CP_TEST_BATCH_RETRY", "1")): + for r in results: + r.pop("_unattributed", None) + return results + for i, r in enumerate(results): + if r.pop("_unattributed", False): + results[i] = _run_batch_once(num_gpus, [configs[i]])[0] + results[i].pop("_unattributed", None) + return results + + +class _DummyRequest: + """Minimal stand-in for the ``request`` fixture during dry-run.""" + + def __init__(self, nodeid): + self.node = type("_DummyNode", (), {"nodeid": nodeid})() + + +def _item_static_skip(item): + """Return True if pytest skip/skipif markers would skip *item*.""" + for marker in item.iter_markers("skip"): + return True + for marker in item.iter_markers("skipif"): + cond = marker.args[0] if marker.args else marker.kwargs.get("condition") + if cond: + return True + return False + + +def _dry_run_item(item): + """Invoke a parametrized test body in collect mode to gather kwargs.""" + func = item.function + params = dict(item.callspec.params) + func(_DummyRequest(item.nodeid), {}, **params) + + +@pytest.fixture(scope="session") +def _cp_batch_results(request): + """Dry-run all batched tests to collect kwargs, then dispatch torchrun batches.""" + global _COLLECT_MODE + + items = [ + it for it in request.session.items if "_cp_batch_results" in getattr(it, "fixturenames", ()) + ] + + import time as _time + + _COLLECTED_KWARGS.clear() + _COLLECT_MODE = True + _t0 = _time.monotonic() + try: + for item in items: + if _item_static_skip(item): + continue + try: + _dry_run_item(item) + except pytest.skip.Exception: + pass + except BaseException: # noqa: BLE001 + pass + finally: + _COLLECT_MODE = False + print( + f"\n[CP-BATCH] Collect done: {len(_COLLECTED_KWARGS)} configs from" + f" {len(items)} items in {_time.monotonic() - _t0:.1f}s", + flush=True, + ) + + by_num_gpus = {} + for nodeid, kwargs in _COLLECTED_KWARGS.items(): + num_gpus = kwargs.pop("num_gpus") + by_num_gpus.setdefault(num_gpus, []).append((nodeid, kwargs)) + + results = {} + for num_gpus, entries in by_num_gpus.items(): + n_batches = (len(entries) + _BATCH_SIZE - 1) // _BATCH_SIZE + for batch_idx, start in enumerate(range(0, len(entries), _BATCH_SIZE)): + chunk = entries[start : start + _BATCH_SIZE] + print( + f"[CP-BATCH] Running batch {batch_idx + 1}/{n_batches}" + f" ({len(chunk)} cfgs, {num_gpus} GPUs)...", + flush=True, + ) + _bt = _time.monotonic() + chunk_results = _run_one_batch(num_gpus, [kw for _, kw in chunk]) + ok = sum(1 for r in chunk_results if r.get("ok")) + print( + f"[CP-BATCH] => {ok}/{len(chunk)} passed in {_time.monotonic() - _bt:.1f}s", + flush=True, + ) + for (nodeid, _), res in zip(chunk, chunk_results): + results[nodeid] = res + print( + f"[CP-BATCH] All batches done: {len(results)} results" + f" in {_time.monotonic() - _t0:.1f}s total", + flush=True, + ) + return results + + dtypes = ["bf16", "fp16"] qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] @@ -91,7 +278,9 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("cp_comm_type", cp_comm_types) -def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): +def test_cp_with_flash_attention( + request, _cp_batch_results, dtype, model, qkv_format, cp_comm_type +): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") @@ -131,25 +320,25 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} - available_backends, *_ = get_available_attention_backends( - config, - qkv_dtype=dtypes[dtype], - qkv_layout="_".join([qkv_format] * 3), + flash_attn_supported = _cached_backend_check( + request.node.nodeid, + lambda: get_available_attention_backends( + config, qkv_dtype=dtypes[dtype], qkv_layout="_".join([qkv_format] * 3) + )[0][0], ) - flash_attn_supported, *_ = available_backends if not flash_attn_supported: pytest.skip("No attention backend available.") - run_distributed( - get_bash_arguments( - num_gpus_per_node=num_gpus, - dtype=dtype, - model=model, - qkv_format=qkv_format, - kernel_backend="FlashAttention", - cp_comm_type=cp_comm_type, - log_level=pytest_logging_level, - ), + _run_or_fetch( + request, + _cp_batch_results, + num_gpus_per_node=num_gpus, + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FlashAttention", + cp_comm_type=cp_comm_type, + log_level=pytest_logging_level, ) @@ -274,7 +463,17 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("scaling_mode", [None, "delayed", "current", "mxfp8"]) @pytest.mark.parametrize("f16_O", [True, False]) def test_cp_with_fused_attention( - dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O + request, + _cp_batch_results, + dtype, + model, + qkv_format, + cp_comm_type, + fp8_bwd, + fp8_mha, + fp8_dpa, + scaling_mode, + f16_O, ): config = model_configs_fused_attn[model] config.context_parallel = True @@ -377,22 +576,9 @@ def test_cp_with_fused_attention( # For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s. is_training = False if config.bias_shape == "111s" else True - available_backends, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, - qkv_layout="_".join([qkv_format] * 3), - fp8=fp8, - fp8_meta=fp8_meta, - is_training=is_training, - deterministic=_deterministic, - ) - _, fused_attn_supported, _ = available_backends - if fused_attn_supported and config.attn_mask_type in ["causal", "padding_causal"]: - config_copy = copy.deepcopy(config) - config_copy.context_parallel = False - config_copy.attn_mask_type = config.attn_mask_type + "_bottom_right" - available_backends, _, fused_attn_backends = get_available_attention_backends( - config_copy, + + def _check_fused_backend(): + backend_kwargs = dict( qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, qkv_layout="_".join([qkv_format] * 3), fp8=fp8, @@ -400,25 +586,42 @@ def test_cp_with_fused_attention( is_training=is_training, deterministic=_deterministic, ) - _, fused_attn_supported, _ = available_backends + available_backends, _, _ = get_available_attention_backends(config, **backend_kwargs) + _, supported, _ = available_backends + if supported and config.attn_mask_type in ["causal", "padding_causal"]: + config_copy = copy.deepcopy(config) + config_copy.context_parallel = False + config_copy.attn_mask_type = config.attn_mask_type + "_bottom_right" + available_backends, _, _ = get_available_attention_backends( + config_copy, **backend_kwargs + ) + _, supported, _ = available_backends + return supported + + fused_attn_supported = _cached_backend_check(request.node.nodeid, _check_fused_backend) if not fused_attn_supported: pytest.skip("No attention backend available.") - run_distributed( - get_bash_arguments( - num_gpus_per_node=num_gpus, - dtype=dtype, - model=model, - qkv_format=qkv_format, - kernel_backend="FusedAttention", - cp_comm_type=cp_comm_type, - fp8_bwd=fp8_bwd, - fp8_dpa=fp8_dpa, - fp8_mha=fp8_mha, - scaling_mode=scaling_mode, - f16_O=f16_O, - is_training=is_training, - deterministic=_deterministic, - log_level=pytest_logging_level, - ), + if _deterministic and config.softmax_type != "vanilla": + pytest.skip("Deterministic mode does not support non-vanilla softmax with FusedAttention") + if _deterministic and config.attn_bias_type == "post_scale_bias" and is_training: + pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad") + + _run_or_fetch( + request, + _cp_batch_results, + num_gpus_per_node=num_gpus, + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FusedAttention", + cp_comm_type=cp_comm_type, + fp8_bwd=fp8_bwd, + fp8_dpa=fp8_dpa, + fp8_mha=fp8_mha, + scaling_mode=scaling_mode, + f16_O=f16_O, + is_training=is_training, + deterministic=_deterministic, + log_level=pytest_logging_level, )