From 11c3ed21fd82b0451481a7181fdcd0d740bcbf77 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Fri, 8 May 2026 21:41:26 +0000 Subject: [PATCH 1/8] Add cuDNN score_mod attention path Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 125 ++++++++ .../dot_product_attention/backends.py | 291 +++++++++++++++++- .../dot_product_attention.py | 141 ++++++++- 3 files changed, 555 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 32ea1694ee..879f48dc0c 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1390,6 +1390,131 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: return out, max_logit, (None, None, None, d_softmax_offset) +def _score_mod_causal(score_mod_graph, score_tensor, tensors): + """cuDNN frontend score_mod implementing top-left causal masking.""" + import cudnn # pylint: disable=import-outside-toplevel + + row_index = score_mod_graph.gen_index(input=score_tensor, axis=2) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = score_mod_graph.gen_index(input=score_tensor, axis=3) + col_index.set_data_type(cudnn.data_type.INT32) + keep = score_mod_graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + keep.set_data_type(cudnn.data_type.BOOLEAN) + return score_mod_graph.binary_select( + input0=score_tensor, + input1=tensors["neg_inf"], + mask=keep, + ) + + +def _score_mod_causal_bprop(score_mod_graph, dP_tensor, tensors): + """cuDNN frontend score_mod_bprop implementing top-left causal masking.""" + import cudnn # pylint: disable=import-outside-toplevel + + row_index = score_mod_graph.gen_index(input=dP_tensor, axis=2) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = score_mod_graph.gen_index(input=dP_tensor, axis=3) + col_index.set_data_type(cudnn.data_type.INT32) + keep = score_mod_graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + keep.set_data_type(cudnn.data_type.BOOLEAN) + return score_mod_graph.binary_select( + input0=dP_tensor, + input1=tensors["zero"], + mask=keep, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") +@pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd"]) +def test_dot_product_attention_score_mod(dtype, qkv_format): + """Compare score_mod causal masking against standard cuDNN causal attention.""" + try: + import cudnn # pylint: disable=unused-import,import-outside-toplevel + except ImportError: + pytest.skip("cuDNN Python frontend is required for score_mod attention.") + + reset_rng_states() + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + config = ModelConfig(2, 64, 4, 64, attn_mask_type="no_mask") + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", + ) + if not available_backends[1] or not fused_attn_backends: + pytest.skip("FusedAttention is not available for this score_mod configuration.") + + if qkv_format == "sbhd": + q_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + else: + q_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + + q = (0.1 * torch.randn(q_shape, dtype=dtype, device="cuda")).requires_grad_() + k = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + v = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + q_ref, k_ref, v_ref = [x.detach().clone().requires_grad_() for x in (q, k, v)] + + flex_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="no_mask", + layer_number=1, + ).to(dtype=dtype, device="cuda") + ref_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="causal", + layer_number=1, + ).to(dtype=dtype, device="cuda") + + out = flex_attn( + q, + k, + v, + qkv_format=qkv_format, + attn_mask_type="no_mask", + score_mod=_score_mod_causal, + score_mod_bprop=_score_mod_causal_bprop, + score_mod_tensors={"neg_inf": torch.full((1, 1, 1, 1), -1e9)}, + score_mod_bprop_tensors={"zero": torch.full((1, 1, 1, 1), 0.0)}, + ) + out_ref = ref_attn( + q_ref, + k_ref, + v_ref, + qkv_format=qkv_format, + attn_mask_type="causal", + ) + + d_out = torch.randn_like(out) + out.backward(d_out) + out_ref.backward(d_out) + + tols = dict(atol=5e-2, rtol=5e-2) + torch.testing.assert_close(out, out_ref, **tols) + torch.testing.assert_close(q.grad, q_ref.grad, **tols) + torch.testing.assert_close(k.grad, k_ref.grad, **tols) + torch.testing.assert_close(v.grad, v_ref.grad, **tols) + + model_configs_te_layer = { # test: ModelConfig(b, sq, hq, dqk) "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"), diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 79ebbd4afa..e7e7a7d0f2 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1244,6 +1244,255 @@ def convert_to_torch_float8(tensor, dtype): return output.contiguous() +def _format_to_bhsd(tensor: torch.Tensor, tensor_format: str) -> torch.Tensor: + """Convert TE's SBHD/BSHD tensor formats to cuDNN frontend's BHSD format.""" + if tensor_format == "sbhd": + return tensor.permute(1, 2, 0, 3).contiguous() + if tensor_format == "bshd": + return tensor.permute(0, 2, 1, 3).contiguous() + raise ValueError(f"score_mod only supports SBHD/BSHD tensor formats, got {tensor_format}.") + + +def _bhsd_to_format(tensor: torch.Tensor, tensor_format: str) -> torch.Tensor: + """Convert cuDNN frontend's BHSD format back to TE's SBHD/BSHD tensor formats.""" + if tensor_format == "sbhd": + return tensor.permute(2, 0, 1, 3).contiguous() + if tensor_format == "bshd": + return tensor.permute(0, 2, 1, 3).contiguous() + raise ValueError(f"score_mod only supports SBHD/BSHD tensor formats, got {tensor_format}.") + + +def _make_cudnn_graph_tensor_dict(graph, tensors: Optional[Dict[str, torch.Tensor]]): + """Create cuDNN graph tensors matching runtime tensors.""" + if tensors is None: + return {} + return {name: graph.tensor_like(tensor) for name, tensor in tensors.items()} + + +def _wrap_score_mod(score_mod: Optional[Callable], graph_tensors: Dict[str, Any]): + """Adapt TE's score_mod signature to cuDNN frontend's two-argument callback.""" + if score_mod is None: + return None + + def _wrapped_score_mod(sdpa_graph, score_tensor): + return score_mod(sdpa_graph, score_tensor, graph_tensors) + + return _wrapped_score_mod + + +def _build_cudnn_pygraph(dtype: torch.dtype): + """Create a cuDNN frontend Python graph for F16/BF16 SDPA.""" + import cudnn # pylint: disable=import-outside-toplevel + + if dtype == torch.float16: + io_data_type = cudnn.data_type.HALF + elif dtype == torch.bfloat16: + io_data_type = cudnn.data_type.BFLOAT16 + else: + raise ValueError(f"score_mod only supports FP16/BF16 tensors, got {dtype}.") + + graph = cudnn.pygraph( + io_data_type=io_data_type, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + return cudnn, graph + + +def _build_and_run_cudnn_graph(graph, variant_pack: Dict[Any, torch.Tensor], device: torch.device): + """Build and execute a cuDNN frontend Python graph without caching.""" + import cudnn # pylint: disable=import-outside-toplevel + + graph.validate() + graph.build_operation_graph() + try: + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + except cudnn.cudnnGraphNotSupportedError as exc: + raise RuntimeError(f"cuDNN score_mod SDPA graph is not supported: {exc}") from exc + graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE) + + workspace = torch.empty( + max(graph.get_workspace_size(), 1), + device=device, + dtype=torch.uint8, + ) + graph.execute(variant_pack, workspace) + + +class FusedAttentionWithScoreModFunc(torch.autograd.Function): + """cuDNN frontend Python SDPA path with score_mod callback support.""" + + @staticmethod + def forward( + ctx, + is_training: bool, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], + deterministic: bool, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + query_bhsd = _format_to_bhsd(query_layer, q_format) + key_bhsd = _format_to_bhsd(key_layer, kv_format) + value_bhsd = _format_to_bhsd(value_layer, kv_format) + + cudnn, graph = _build_cudnn_pygraph(query_bhsd.dtype) + q = graph.tensor_like(query_bhsd) + k = graph.tensor_like(key_bhsd) + v = graph.tensor_like(value_bhsd) + + score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, score_mod_tensors) + wrapped_score_mod = _wrap_score_mod(score_mod, score_mod_graph_tensors) + + output_bhsd = torch.empty( + (*query_bhsd.shape[:-1], value_bhsd.shape[-1]), + device=query_bhsd.device, + dtype=query_bhsd.dtype, + ) + output, stats = graph.sdpa( + name="te_score_mod_sdpa", + q=q, + k=k, + v=v, + generate_stats=is_training, + attn_scale=attn_scale, + use_causal_mask=False, + score_mod=wrapped_score_mod, + ) + output.set_output(True).set_dim(output_bhsd.size()).set_stride(output_bhsd.stride()) + + variant_pack = { + q: query_bhsd, + k: key_bhsd, + v: value_bhsd, + output: output_bhsd, + } + if is_training: + stats_bhs1 = torch.empty( + (*query_bhsd.shape[:-1], 1), + device=query_bhsd.device, + dtype=torch.float32, + ) + stats.set_output(True).set_dim(stats_bhs1.size()).set_stride( + stats_bhs1.stride() + ).set_data_type(cudnn.data_type.FLOAT) + variant_pack[stats] = stats_bhs1 + else: + stats_bhs1 = None + for name, graph_tensor in score_mod_graph_tensors.items(): + variant_pack[graph_tensor] = score_mod_tensors[name] + + _build_and_run_cudnn_graph(graph, variant_pack, query_bhsd.device) + + ctx.is_training = is_training + ctx.q_format = q_format + ctx.kv_format = kv_format + ctx.attn_scale = attn_scale + ctx.score_mod = score_mod + ctx.score_mod_bprop = score_mod_bprop + ctx.score_mod_tensors = dict(score_mod_tensors or {}) + ctx.score_mod_bprop_tensors = dict(score_mod_bprop_tensors or {}) + ctx.deterministic = deterministic + if is_training: + ctx.save_for_backward(query_bhsd, key_bhsd, value_bhsd, output_bhsd, stats_bhs1) + else: + ctx.save_for_backward(query_bhsd, key_bhsd, value_bhsd, output_bhsd) + + return _bhsd_to_format(output_bhsd, q_format) + + @staticmethod + def backward(ctx, d_out: torch.Tensor): + # pylint: disable=missing-function-docstring + if not ctx.is_training: + raise RuntimeError( + "score_mod backward requires DotProductAttention to be in training mode." + ) + + query_bhsd, key_bhsd, value_bhsd, output_bhsd, stats_bhs1 = ctx.saved_tensors + d_out_bhsd = _format_to_bhsd(d_out, ctx.q_format) + + cudnn, graph = _build_cudnn_pygraph(query_bhsd.dtype) + q = graph.tensor_like(query_bhsd) + k = graph.tensor_like(key_bhsd) + v = graph.tensor_like(value_bhsd) + output = graph.tensor_like(output_bhsd) + d_output = graph.tensor_like(d_out_bhsd) + stats = graph.tensor_like(stats_bhs1) + + score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, ctx.score_mod_tensors) + score_mod_bprop_graph_tensors = ( + _make_cudnn_graph_tensor_dict(graph, ctx.score_mod_bprop_tensors) + if ctx.score_mod_bprop is not None + else {} + ) + wrapped_score_mod = _wrap_score_mod(ctx.score_mod, score_mod_graph_tensors) + wrapped_score_mod_bprop = _wrap_score_mod( + ctx.score_mod_bprop, score_mod_bprop_graph_tensors + ) + + dq_bhsd = torch.empty_like(query_bhsd) + dk_bhsd = torch.empty_like(key_bhsd) + dv_bhsd = torch.empty_like(value_bhsd) + dq, dk, dv = graph.sdpa_backward( + name="te_score_mod_sdpa_backward", + q=q, + k=k, + v=v, + o=output, + dO=d_output, + stats=stats, + attn_scale=ctx.attn_scale, + use_causal_mask=False, + score_mod=wrapped_score_mod, + score_mod_bprop=wrapped_score_mod_bprop, + use_deterministic_algorithm=ctx.deterministic, + ) + dq.set_output(True).set_dim(dq_bhsd.size()).set_stride(dq_bhsd.stride()) + dk.set_output(True).set_dim(dk_bhsd.size()).set_stride(dk_bhsd.stride()) + dv.set_output(True).set_dim(dv_bhsd.size()).set_stride(dv_bhsd.stride()) + + variant_pack = { + q: query_bhsd, + k: key_bhsd, + v: value_bhsd, + output: output_bhsd, + d_output: d_out_bhsd, + stats: stats_bhs1, + dq: dq_bhsd, + dk: dk_bhsd, + dv: dv_bhsd, + } + for name, graph_tensor in score_mod_graph_tensors.items(): + variant_pack[graph_tensor] = ctx.score_mod_tensors[name] + for name, graph_tensor in score_mod_bprop_graph_tensors.items(): + variant_pack[graph_tensor] = ctx.score_mod_bprop_tensors[name] + + _build_and_run_cudnn_graph(graph, variant_pack, query_bhsd.device) + + return ( + None, + _bhsd_to_format(dq_bhsd, ctx.q_format), + _bhsd_to_format(dk_bhsd, ctx.kv_format), + _bhsd_to_format(dv_bhsd, ctx.kv_format), + None, + None, + None, + None, + None, + None, + None, + None, + ) + + class FusedAttnFunc(torch.autograd.Function): """FusedAttention forward and backward implementation""" @@ -1945,6 +2194,10 @@ def forward( inference_params: Optional[InferenceParams] = None, softmax_offset: torch.Tensor = None, fp8_output: bool = False, + score_mod: Optional[Callable] = None, + score_mod_bprop: Optional[Callable] = None, + score_mod_tensors: Optional[Dict[str, torch.Tensor]] = None, + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: """fused attention fprop""" assert ( @@ -2067,7 +2320,43 @@ def forward( cp_group[0] if cp_comm_type == "a2a+p2p" else cp_group ) - if context_parallel: + if score_mod is not None: + assert ( + not context_parallel + ), "score_mod is not supported with context parallelism!" + assert ( + not fp8 + ), "score_mod is not supported with FP8 FusedAttention!" + assert ( + type(query_layer) is torch.Tensor + and type(key_layer) is torch.Tensor + and type(value_layer) is torch.Tensor + ), "score_mod only supports unquantized torch.Tensor Q, K and V inputs!" + assert ( + fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + ), "score_mod requires the F16/BF16 cuDNN fused attention backend!" + assert ( + attn_mask_type == "no_mask" + and core_attention_bias_type == "no_bias" + and core_attention_bias is None + and self.softmax_type == "vanilla" + and self.attention_dropout == 0.0 + ), "score_mod is mutually exclusive with masks, bias, sink attention and dropout!" + output = FusedAttentionWithScoreModFunc.apply( + self.training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + self.softmax_scale, + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + self.deterministic, + ) + elif context_parallel: assert ( fp8 or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 17e9a337a4..7a7745d7c2 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -892,6 +892,10 @@ def forward( pad_between_seqs: Optional[bool] = None, fp8_output: Optional[bool] = False, num_splits: Optional[int] = 1, + score_mod: Optional[Callable] = None, + score_mod_bprop: Optional[Callable] = None, + score_mod_tensors: Optional[Dict[str, torch.Tensor]] = None, + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: r""" Dot Product Attention Layer. @@ -1080,6 +1084,18 @@ def forward( Optional split control for FlashAttention-3 only. When set, this value is forwarded to the FA3 backend to control internal kernel splitting behavior for non-context-parallel cases. It is ignored for other backends and when context parallelism is enabled. + score_mod: Optional[Callable], default = None + cuDNN frontend score modification callback. This is a cuDNN-only path and is mutually + exclusive with masks, bias, ALiBi, sink attention, dropout, FP8, context parallelism, + THD format, KV caching, and return_max_logit. The callback signature is + ``score_mod(graph, score, tensors) -> score``. + score_mod_bprop: Optional[Callable], default = None + Optional cuDNN frontend callback for the backward pass of score_mod. The callback + signature is ``score_mod_bprop(graph, dP, tensors) -> dP``. + score_mod_tensors: Optional[Dict[str, torch.Tensor]], default = None + Runtime tensors exposed to score_mod as cuDNN graph tensors. + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], default = None + Runtime tensors exposed to score_mod_bprop as cuDNN graph tensors. """ with self.prepare_forward_ctx( @@ -1088,6 +1104,13 @@ def forward( allow_non_contiguous=True, allow_different_data_and_param_types=self.softmax_type != "vanilla", ) as query_layer: + user_supplied_seqlens = ( + cu_seqlens_q is not None + or cu_seqlens_kv is not None + or cu_seqlens_q_padded is not None + or cu_seqlens_kv_padded is not None + ) + # checks for RNG if self.rng_states_tracker is not None and is_graph_capturing(): assert isinstance( @@ -1406,6 +1429,86 @@ def forward( else: pad_between_seqs = False + if score_mod is None: + assert score_mod_bprop is None, "score_mod_bprop requires score_mod!" + assert score_mod_tensors is None, "score_mod_tensors requires score_mod!" + assert score_mod_bprop_tensors is None, "score_mod_bprop_tensors requires score_mod!" + else: + assert callable(score_mod), "score_mod must be callable!" + assert ( + score_mod_bprop is None or callable(score_mod_bprop) + ), "score_mod_bprop must be callable when provided!" + assert ( + query_layer.dtype in [torch.float16, torch.bfloat16] + ), "score_mod only supports FP16 and BF16 tensors!" + assert ( + key_layer.dtype == query_layer.dtype and value_layer.dtype == query_layer.dtype + ), "score_mod requires Q, K and V tensors to have the same dtype!" + assert ( + type(query_layer) is torch.Tensor + and type(key_layer) is torch.Tensor + and type(value_layer) is torch.Tensor + ), "score_mod only supports unquantized torch.Tensor Q, K and V inputs!" + assert ( + not self.fp8 + ), "score_mod is not supported with FP8 DotProductAttention!" + assert not fp8_output, "score_mod is not supported with fp8_output!" + assert ( + not context_parallel + ), "score_mod is not supported with context parallelism!" + assert inference_params is None, "score_mod is not supported with KV caching!" + assert qkv_format != "thd", "score_mod is not supported with qkv_format='thd'!" + assert ( + not user_supplied_seqlens + ), "score_mod is mutually exclusive with explicit sequence length metadata!" + assert not pad_between_seqs, "score_mod is not supported with pad_between_seqs!" + assert attention_mask is None, "score_mod is mutually exclusive with attention_mask!" + assert ( + attn_mask_type == "no_mask" + ), "score_mod requires attn_mask_type='no_mask'!" + assert ( + window_size is None or window_size == (-1, -1) + ), "score_mod is mutually exclusive with sliding window attention!" + assert ( + core_attention_bias_type == "no_bias" and core_attention_bias is None + ), "score_mod is mutually exclusive with attention bias!" + assert alibi_slopes is None, "score_mod is mutually exclusive with ALiBi!" + assert ( + self.softmax_type == "vanilla" + ), "score_mod is mutually exclusive with sink attention!" + assert ( + self.attention_dropout == 0.0 + ), "score_mod is not supported with attention dropout!" + assert ( + not self.return_max_logit + ), "score_mod is not supported with return_max_logit!" + assert ( + not checkpoint_core_attention + ), "score_mod is not supported with checkpoint_core_attention!" + assert ( + not is_graph_capturing() + ), "score_mod is not supported with CUDA graph capture!" + assert num_splits == 1, "score_mod is not supported with num_splits != 1!" + assert ( + q_format in ["sbhd", "bshd"] and kv_format in ["sbhd", "bshd"] + ), "score_mod only supports SBHD/BSHD QKV formats!" + if score_mod_tensors is not None: + assert isinstance( + score_mod_tensors, dict + ), "score_mod_tensors must be a dict!" + assert all( + isinstance(k, str) and isinstance(v, torch.Tensor) + for k, v in score_mod_tensors.items() + ), "score_mod_tensors must map string names to torch.Tensor instances!" + if score_mod_bprop_tensors is not None: + assert isinstance( + score_mod_bprop_tensors, dict + ), "score_mod_bprop_tensors must be a dict!" + assert all( + isinstance(k, str) and isinstance(v, torch.Tensor) + for k, v in score_mod_bprop_tensors.items() + ), "score_mod_bprop_tensors must map string names to torch.Tensor instances!" + # gather attention params for get_attention_backend attention_params = dpa_utils.AttentionParams( qkv_type=type(query_layer), @@ -1443,7 +1546,39 @@ def forward( num_splits=num_splits, ) global _attention_backends - if is_in_onnx_export_mode(): + if score_mod is not None: + use_flash_attention = False + flash_attention_backend = None + use_fused_attention = True + use_unfused_attention = False + q_type = dpa_utils.TE_DType[query_layer.dtype] + fused_attention_backend = tex.get_fused_attn_backend( + self.training, + q_type, + q_type, + dpa_utils.QKVLayout["bshd_bshd_bshd"], + dpa_utils.AttnBiasType["no_bias"], + dpa_utils.AttnMaskType["no_mask"], + dpa_utils.SoftmaxType["vanilla"], + 0.0, + num_attention_heads, + num_gqa_groups, + max_seqlen_q, + max_seqlen_kv, + head_dim_qk, + head_dim_v, + -1, + -1, + False, + is_graph_capturing(), + self.deterministic, + ) + if fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend: + raise ValueError( + "score_mod requires a cuDNN FusedAttention backend, but no fused " + "attention backend supports the provided inputs." + ) + elif is_in_onnx_export_mode(): # We do not want to call get_attention_backend() in ONNX mode # and we want to avoid using any global variables like _attention_backends. use_flash_attention = False @@ -1619,6 +1754,10 @@ def forward( inference_params=inference_params, softmax_offset=softmax_offset, fp8_output=fp8_output, + score_mod=score_mod, + score_mod_bprop=score_mod_bprop, + score_mod_tensors=score_mod_tensors, + score_mod_bprop_tensors=score_mod_bprop_tensors, ) if use_unfused_attention: From eb35191a7071d868ff4163b79666004e79d87b54 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Mon, 11 May 2026 18:43:04 +0000 Subject: [PATCH 2/8] Avoid BHSD copies in score_mod attention Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 13 +- .../dot_product_attention/backends.py | 124 +++++++++--------- 2 files changed, 73 insertions(+), 64 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 879f48dc0c..195087efb0 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1436,7 +1436,8 @@ def _score_mod_causal_bprop(score_mod_graph, dP_tensor, tensors): @pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("qkv_format", ["sbhd", "bshd"]) -def test_dot_product_attention_score_mod(dtype, qkv_format): +@pytest.mark.parametrize("scalar_loss", [False, True]) +def test_dot_product_attention_score_mod(dtype, qkv_format, scalar_loss): """Compare score_mod causal masking against standard cuDNN causal attention.""" try: import cudnn # pylint: disable=unused-import,import-outside-toplevel @@ -1504,9 +1505,13 @@ def test_dot_product_attention_score_mod(dtype, qkv_format): attn_mask_type="causal", ) - d_out = torch.randn_like(out) - out.backward(d_out) - out_ref.backward(d_out) + if scalar_loss: + out.sum().backward() + out_ref.sum().backward() + else: + d_out = torch.randn_like(out) + out.backward(d_out) + out_ref.backward(d_out) tols = dict(atol=5e-2, rtol=5e-2) torch.testing.assert_close(out, out_ref, **tols) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index e7e7a7d0f2..3f765dd634 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1244,22 +1244,27 @@ def convert_to_torch_float8(tensor, dtype): return output.contiguous() -def _format_to_bhsd(tensor: torch.Tensor, tensor_format: str) -> torch.Tensor: - """Convert TE's SBHD/BSHD tensor formats to cuDNN frontend's BHSD format.""" +def _bhsd_dim_stride( + tensor: torch.Tensor, tensor_format: str +) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + """Describe an SBHD/BSHD tensor as cuDNN frontend's logical BHSD format.""" if tensor_format == "sbhd": - return tensor.permute(1, 2, 0, 3).contiguous() + return ( + (tensor.shape[1], tensor.shape[2], tensor.shape[0], tensor.shape[3]), + (tensor.stride(1), tensor.stride(2), tensor.stride(0), tensor.stride(3)), + ) if tensor_format == "bshd": - return tensor.permute(0, 2, 1, 3).contiguous() + return ( + (tensor.shape[0], tensor.shape[2], tensor.shape[1], tensor.shape[3]), + (tensor.stride(0), tensor.stride(2), tensor.stride(1), tensor.stride(3)), + ) raise ValueError(f"score_mod only supports SBHD/BSHD tensor formats, got {tensor_format}.") -def _bhsd_to_format(tensor: torch.Tensor, tensor_format: str) -> torch.Tensor: - """Convert cuDNN frontend's BHSD format back to TE's SBHD/BSHD tensor formats.""" - if tensor_format == "sbhd": - return tensor.permute(2, 0, 1, 3).contiguous() - if tensor_format == "bshd": - return tensor.permute(0, 2, 1, 3).contiguous() - raise ValueError(f"score_mod only supports SBHD/BSHD tensor formats, got {tensor_format}.") +def _bhsd_graph_tensor(graph, tensor: torch.Tensor, tensor_format: str): + """Create a cuDNN graph tensor with BHSD dims and TE-layout strides.""" + dim, stride = _bhsd_dim_stride(tensor, tensor_format) + return graph.tensor(dim=dim, stride=stride, data_type=tensor.dtype) def _make_cudnn_graph_tensor_dict(graph, tensors: Optional[Dict[str, torch.Tensor]]): @@ -1340,23 +1345,19 @@ def forward( deterministic: bool, ) -> torch.Tensor: # pylint: disable=missing-function-docstring - query_bhsd = _format_to_bhsd(query_layer, q_format) - key_bhsd = _format_to_bhsd(key_layer, kv_format) - value_bhsd = _format_to_bhsd(value_layer, kv_format) + q_bhsd_dim, _ = _bhsd_dim_stride(query_layer, q_format) - cudnn, graph = _build_cudnn_pygraph(query_bhsd.dtype) - q = graph.tensor_like(query_bhsd) - k = graph.tensor_like(key_bhsd) - v = graph.tensor_like(value_bhsd) + cudnn, graph = _build_cudnn_pygraph(query_layer.dtype) + q = _bhsd_graph_tensor(graph, query_layer, q_format) + k = _bhsd_graph_tensor(graph, key_layer, kv_format) + v = _bhsd_graph_tensor(graph, value_layer, kv_format) score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, score_mod_tensors) wrapped_score_mod = _wrap_score_mod(score_mod, score_mod_graph_tensors) - output_bhsd = torch.empty( - (*query_bhsd.shape[:-1], value_bhsd.shape[-1]), - device=query_bhsd.device, - dtype=query_bhsd.dtype, - ) + output_shape = (*query_layer.shape[:-1], value_layer.shape[-1]) + output_layer = torch.empty(output_shape, device=query_layer.device, dtype=query_layer.dtype) + output_dim, output_stride = _bhsd_dim_stride(output_layer, q_format) output, stats = graph.sdpa( name="te_score_mod_sdpa", q=q, @@ -1367,18 +1368,18 @@ def forward( use_causal_mask=False, score_mod=wrapped_score_mod, ) - output.set_output(True).set_dim(output_bhsd.size()).set_stride(output_bhsd.stride()) + output.set_output(True).set_dim(output_dim).set_stride(output_stride) variant_pack = { - q: query_bhsd, - k: key_bhsd, - v: value_bhsd, - output: output_bhsd, + q: query_layer, + k: key_layer, + v: value_layer, + output: output_layer, } if is_training: stats_bhs1 = torch.empty( - (*query_bhsd.shape[:-1], 1), - device=query_bhsd.device, + (*q_bhsd_dim[:-1], 1), + device=query_layer.device, dtype=torch.float32, ) stats.set_output(True).set_dim(stats_bhs1.size()).set_stride( @@ -1390,7 +1391,7 @@ def forward( for name, graph_tensor in score_mod_graph_tensors.items(): variant_pack[graph_tensor] = score_mod_tensors[name] - _build_and_run_cudnn_graph(graph, variant_pack, query_bhsd.device) + _build_and_run_cudnn_graph(graph, variant_pack, query_layer.device) ctx.is_training = is_training ctx.q_format = q_format @@ -1402,11 +1403,11 @@ def forward( ctx.score_mod_bprop_tensors = dict(score_mod_bprop_tensors or {}) ctx.deterministic = deterministic if is_training: - ctx.save_for_backward(query_bhsd, key_bhsd, value_bhsd, output_bhsd, stats_bhs1) + ctx.save_for_backward(query_layer, key_layer, value_layer, output_layer, stats_bhs1) else: - ctx.save_for_backward(query_bhsd, key_bhsd, value_bhsd, output_bhsd) + ctx.save_for_backward(query_layer, key_layer, value_layer, output_layer) - return _bhsd_to_format(output_bhsd, q_format) + return output_layer @staticmethod def backward(ctx, d_out: torch.Tensor): @@ -1416,15 +1417,15 @@ def backward(ctx, d_out: torch.Tensor): "score_mod backward requires DotProductAttention to be in training mode." ) - query_bhsd, key_bhsd, value_bhsd, output_bhsd, stats_bhs1 = ctx.saved_tensors - d_out_bhsd = _format_to_bhsd(d_out, ctx.q_format) + query_layer, key_layer, value_layer, output_layer, stats_bhs1 = ctx.saved_tensors + d_out = d_out.contiguous() - cudnn, graph = _build_cudnn_pygraph(query_bhsd.dtype) - q = graph.tensor_like(query_bhsd) - k = graph.tensor_like(key_bhsd) - v = graph.tensor_like(value_bhsd) - output = graph.tensor_like(output_bhsd) - d_output = graph.tensor_like(d_out_bhsd) + cudnn, graph = _build_cudnn_pygraph(query_layer.dtype) + q = _bhsd_graph_tensor(graph, query_layer, ctx.q_format) + k = _bhsd_graph_tensor(graph, key_layer, ctx.kv_format) + v = _bhsd_graph_tensor(graph, value_layer, ctx.kv_format) + output = _bhsd_graph_tensor(graph, output_layer, ctx.q_format) + d_output = _bhsd_graph_tensor(graph, d_out, ctx.q_format) stats = graph.tensor_like(stats_bhs1) score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, ctx.score_mod_tensors) @@ -1438,9 +1439,12 @@ def backward(ctx, d_out: torch.Tensor): ctx.score_mod_bprop, score_mod_bprop_graph_tensors ) - dq_bhsd = torch.empty_like(query_bhsd) - dk_bhsd = torch.empty_like(key_bhsd) - dv_bhsd = torch.empty_like(value_bhsd) + dq_layer = torch.empty_like(query_layer) + dk_layer = torch.empty_like(key_layer) + dv_layer = torch.empty_like(value_layer) + dq_dim, dq_stride = _bhsd_dim_stride(dq_layer, ctx.q_format) + dk_dim, dk_stride = _bhsd_dim_stride(dk_layer, ctx.kv_format) + dv_dim, dv_stride = _bhsd_dim_stride(dv_layer, ctx.kv_format) dq, dk, dv = graph.sdpa_backward( name="te_score_mod_sdpa_backward", q=q, @@ -1455,33 +1459,33 @@ def backward(ctx, d_out: torch.Tensor): score_mod_bprop=wrapped_score_mod_bprop, use_deterministic_algorithm=ctx.deterministic, ) - dq.set_output(True).set_dim(dq_bhsd.size()).set_stride(dq_bhsd.stride()) - dk.set_output(True).set_dim(dk_bhsd.size()).set_stride(dk_bhsd.stride()) - dv.set_output(True).set_dim(dv_bhsd.size()).set_stride(dv_bhsd.stride()) + dq.set_output(True).set_dim(dq_dim).set_stride(dq_stride) + dk.set_output(True).set_dim(dk_dim).set_stride(dk_stride) + dv.set_output(True).set_dim(dv_dim).set_stride(dv_stride) variant_pack = { - q: query_bhsd, - k: key_bhsd, - v: value_bhsd, - output: output_bhsd, - d_output: d_out_bhsd, + q: query_layer, + k: key_layer, + v: value_layer, + output: output_layer, + d_output: d_out, stats: stats_bhs1, - dq: dq_bhsd, - dk: dk_bhsd, - dv: dv_bhsd, + dq: dq_layer, + dk: dk_layer, + dv: dv_layer, } for name, graph_tensor in score_mod_graph_tensors.items(): variant_pack[graph_tensor] = ctx.score_mod_tensors[name] for name, graph_tensor in score_mod_bprop_graph_tensors.items(): variant_pack[graph_tensor] = ctx.score_mod_bprop_tensors[name] - _build_and_run_cudnn_graph(graph, variant_pack, query_bhsd.device) + _build_and_run_cudnn_graph(graph, variant_pack, query_layer.device) return ( None, - _bhsd_to_format(dq_bhsd, ctx.q_format), - _bhsd_to_format(dk_bhsd, ctx.kv_format), - _bhsd_to_format(dv_bhsd, ctx.kv_format), + dq_layer, + dk_layer, + dv_layer, None, None, None, From 57ce106435dba1e95ef8e58d15b64bfdcbfca59c Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Mon, 11 May 2026 21:21:44 +0000 Subject: [PATCH 3/8] Test relative position score_mod attention Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 138 ++++++++++++++++++++++ 1 file changed, 138 insertions(+) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 195087efb0..f715836a30 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1432,6 +1432,41 @@ def _score_mod_causal_bprop(score_mod_graph, dP_tensor, tensors): ) +def _score_mod_relative_position(score_mod_graph, score_tensor, _tensors): + """cuDNN frontend score_mod adding relative position bias.""" + import cudnn # pylint: disable=import-outside-toplevel + + row_index = score_mod_graph.gen_index(input=score_tensor, axis=2) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = score_mod_graph.gen_index(input=score_tensor, axis=3) + col_index.set_data_type(cudnn.data_type.INT32) + relative_position = score_mod_graph.sub( + a=row_index, + b=col_index, + compute_data_type=cudnn.data_type.FLOAT, + ) + relative_position.set_data_type(cudnn.data_type.FLOAT) + return score_mod_graph.add( + a=score_tensor, + b=relative_position, + compute_data_type=cudnn.data_type.FLOAT, + ) + + +def _score_mod_identity_bprop(_score_mod_graph, dP_tensor, _tensors): + """cuDNN frontend score_mod_bprop for score_mods with unit score derivative.""" + return dP_tensor + + +def _relative_position_bias(config, dtype): + """Materialize score + (q_idx - kv_idx) as post-scale attention bias.""" + q_idx = torch.arange(config.max_seqlen_q, dtype=torch.float32, device="cuda").view(1, 1, -1, 1) + kv_idx = torch.arange(config.max_seqlen_kv, dtype=torch.float32, device="cuda").view( + 1, 1, 1, -1 + ) + return (q_idx - kv_idx).to(dtype).expand(1, config.num_heads, -1, -1).contiguous() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") @pytest.mark.parametrize("dtype", param_types) @@ -1520,6 +1555,109 @@ def test_dot_product_attention_score_mod(dtype, qkv_format, scalar_loss): torch.testing.assert_close(v.grad, v_ref.grad, **tols) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") +@pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd"]) +def test_dot_product_attention_score_mod_relative_position(dtype, qkv_format): + """Compare relative-position score_mod against materialized post-scale bias.""" + try: + import cudnn # pylint: disable=unused-import,import-outside-toplevel + except ImportError: + pytest.skip("cuDNN Python frontend is required for score_mod attention.") + + reset_rng_states() + + config = ModelConfig(2, 16, 4, 64, attn_mask_type="no_mask") + bias_config = ModelConfig( + config.batch_size, + config.max_seqlen_q, + config.num_heads, + config.head_dim_qk, + attn_mask_type="no_mask", + attn_bias_type="post_scale_bias", + bias_shape="1hss", + ) + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", + ) + bias_available_backends, _, bias_fused_attn_backends = get_available_attention_backends( + bias_config, + qkv_dtype=dtype, + qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", + ) + if ( + not available_backends[1] + or not fused_attn_backends + or not bias_available_backends[1] + or not bias_fused_attn_backends + ): + pytest.skip("FusedAttention is not available for this relative-position configuration.") + + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + if qkv_format == "sbhd": + q_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + else: + q_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + + q = (0.1 * torch.randn(q_shape, dtype=dtype, device="cuda")).requires_grad_() + k = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + v = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + q_ref, k_ref, v_ref = [x.detach().clone().requires_grad_() for x in (q, k, v)] + + flex_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="no_mask", + layer_number=1, + ).to(dtype=dtype, device="cuda") + ref_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="no_mask", + layer_number=1, + ).to(dtype=dtype, device="cuda") + + out = flex_attn( + q, + k, + v, + qkv_format=qkv_format, + attn_mask_type="no_mask", + score_mod=_score_mod_relative_position, + score_mod_bprop=_score_mod_identity_bprop, + ) + out_ref = ref_attn( + q_ref, + k_ref, + v_ref, + qkv_format=qkv_format, + attn_mask_type="no_mask", + core_attention_bias_type="post_scale_bias", + core_attention_bias=_relative_position_bias(config, dtype), + ) + + d_out = torch.randn_like(out) + out.backward(d_out) + out_ref.backward(d_out) + + tols = dict(atol=5e-2, rtol=5e-2) + torch.testing.assert_close(out, out_ref, **tols) + torch.testing.assert_close(q.grad, q_ref.grad, **tols) + torch.testing.assert_close(k.grad, k_ref.grad, **tols) + torch.testing.assert_close(v.grad, v_ref.grad, **tols) + + model_configs_te_layer = { # test: ModelConfig(b, sq, hq, dqk) "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"), From e6ba0ea8907c7bb26a4499bb2b61dc605fe71c8d Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Mon, 11 May 2026 21:44:30 +0000 Subject: [PATCH 4/8] Test softcap score_mod attention Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 153 ++++++++++++++++++++++ 1 file changed, 153 insertions(+) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index f715836a30..c42a42d4b4 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1458,6 +1458,53 @@ def _score_mod_identity_bprop(_score_mod_graph, dP_tensor, _tensors): return dP_tensor +class _ScoreModSoftcap: + """cuDNN frontend score_mod implementing softcapping.""" + + def __init__(self): + self.before_tanh_activation = None + + def forward(self, score_mod_graph, score_tensor, tensors): + """Apply softcap * tanh(score / softcap).""" + import cudnn # pylint: disable=import-outside-toplevel + + self.before_tanh_activation = score_mod_graph.div( + a=score_tensor, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + self.before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) + tanh_out = score_mod_graph.tanh(input=self.before_tanh_activation) + tanh_out.set_data_type(cudnn.data_type.FLOAT) + return score_mod_graph.mul( + a=tanh_out, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + + def backward(self, score_mod_graph, dP_tensor, tensors): + """Apply softcap derivative to dP.""" + import cudnn # pylint: disable=import-outside-toplevel + + d_tanh_out = score_mod_graph.mul( + a=dP_tensor, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + d_tanh_out.set_data_type(cudnn.data_type.FLOAT) + d_before_tanh_activation = score_mod_graph.tanh_backward( + loss=d_tanh_out, + input=self.before_tanh_activation, + compute_data_type=cudnn.data_type.FLOAT, + ) + d_before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) + return score_mod_graph.div( + a=d_before_tanh_activation, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + + def _relative_position_bias(config, dtype): """Materialize score + (q_idx - kv_idx) as post-scale attention bias.""" q_idx = torch.arange(config.max_seqlen_q, dtype=torch.float32, device="cuda").view(1, 1, -1, 1) @@ -1467,6 +1514,32 @@ def _relative_position_bias(config, dtype): return (q_idx - kv_idx).to(dtype).expand(1, config.num_heads, -1, -1).contiguous() +def _to_bhsd(tensor, qkv_format): + """Convert SBHD/BSHD test tensors to logical BHSD.""" + if qkv_format == "sbhd": + return tensor.permute(1, 2, 0, 3) + return tensor.permute(0, 2, 1, 3) + + +def _from_bhsd(tensor, qkv_format): + """Convert logical BHSD test tensors to SBHD/BSHD.""" + if qkv_format == "sbhd": + return tensor.permute(2, 0, 1, 3).contiguous() + return tensor.permute(0, 2, 1, 3).contiguous() + + +def _pytorch_softcap_attention(q, k, v, qkv_format, softmax_scale, softcap): + """PyTorch reference for softcapped scaled dot-product attention.""" + q_bhsd = _to_bhsd(q, qkv_format).float() + k_bhsd = _to_bhsd(k, qkv_format).float() + v_bhsd = _to_bhsd(v, qkv_format).float() + scores = torch.matmul(q_bhsd, k_bhsd.transpose(-2, -1)) * softmax_scale + scores = softcap * torch.tanh(scores / softcap) + probs = torch.softmax(scores, dim=-1) + out = _from_bhsd(torch.matmul(probs, v_bhsd), qkv_format).to(v.dtype) + return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1]) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") @pytest.mark.parametrize("dtype", param_types) @@ -1555,6 +1628,86 @@ def test_dot_product_attention_score_mod(dtype, qkv_format, scalar_loss): torch.testing.assert_close(v.grad, v_ref.grad, **tols) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") +@pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd"]) +def test_dot_product_attention_score_mod_softcap(dtype, qkv_format): + """Compare softcap score_mod against PyTorch math attention.""" + try: + import cudnn # pylint: disable=unused-import,import-outside-toplevel + except ImportError: + pytest.skip("cuDNN Python frontend is required for score_mod attention.") + + reset_rng_states() + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + config = ModelConfig(2, 16, 4, 64, attn_mask_type="no_mask") + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", + ) + if not available_backends[1] or not fused_attn_backends: + pytest.skip("FusedAttention is not available for this softcap configuration.") + + if qkv_format == "sbhd": + q_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + else: + q_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + + q = torch.randn(q_shape, dtype=dtype, device="cuda").requires_grad_() + k = torch.randn(kv_shape, dtype=dtype, device="cuda").requires_grad_() + v = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + q_ref, k_ref, v_ref = [x.detach().clone().requires_grad_() for x in (q, k, v)] + + softcap = 0.8 + softcap_tensor = torch.full((1, 1, 1, 1), softcap) + softcap_score_mod = _ScoreModSoftcap() + flex_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="no_mask", + layer_number=1, + ).to(dtype=dtype, device="cuda") + + out = flex_attn( + q, + k, + v, + qkv_format=qkv_format, + attn_mask_type="no_mask", + score_mod=softcap_score_mod.forward, + score_mod_bprop=softcap_score_mod.backward, + score_mod_tensors={"softcap": softcap_tensor}, + score_mod_bprop_tensors={"softcap": softcap_tensor}, + ) + out_ref = _pytorch_softcap_attention( + q_ref, + k_ref, + v_ref, + qkv_format, + 1.0 / config.head_dim_qk**0.5, + softcap, + ) + + d_out = torch.randn_like(out) + out.backward(d_out) + out_ref.backward(d_out) + + tols = dict(atol=7e-2, rtol=7e-2) + torch.testing.assert_close(out, out_ref, **tols) + torch.testing.assert_close(q.grad, q_ref.grad, **tols) + torch.testing.assert_close(k.grad, k_ref.grad, **tols) + torch.testing.assert_close(v.grad, v_ref.grad, **tols) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") @pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") @pytest.mark.parametrize("dtype", param_types) From dcb6b492cc3396a95a57132ba8fcd90711020233 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Mon, 11 May 2026 23:17:04 +0000 Subject: [PATCH 5/8] Run score_mod graphs on current CUDA stream Signed-off-by: Vladimir Cherepanov --- .../dot_product_attention/backends.py | 32 ++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 3f765dd634..625c030e0c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -89,6 +89,7 @@ _flash_attn_bwd = None _flash_attn_varlen_fwd = None _flash_attn_varlen_bwd = None +_cudnn_score_mod_handles: Dict[torch.device, Any] = {} # Try to import Flash Attention v2 try: @@ -1285,7 +1286,25 @@ def _wrapped_score_mod(sdpa_graph, score_tensor): return _wrapped_score_mod -def _build_cudnn_pygraph(dtype: torch.dtype): +def _get_cudnn_current_stream_handle(cudnn, device: torch.device): + """Return a cuDNN handle for device, bound to PyTorch's current stream.""" + if device.type != "cuda": + raise ValueError(f"score_mod only supports CUDA tensors, got device {device}.") + if device.index is None: + device = torch.device("cuda", torch.cuda.current_device()) + + handle = _cudnn_score_mod_handles.get(device) + with torch.cuda.device(device): + if handle is None: + handle = cudnn.create_handle() + _cudnn_score_mod_handles[device] = handle + + stream = torch.cuda.current_stream(device).cuda_stream + cudnn.set_stream(handle=handle, stream=stream) + return handle + + +def _build_cudnn_pygraph(dtype: torch.dtype, device: torch.device): """Create a cuDNN frontend Python graph for F16/BF16 SDPA.""" import cudnn # pylint: disable=import-outside-toplevel @@ -1300,6 +1319,7 @@ def _build_cudnn_pygraph(dtype: torch.dtype): io_data_type=io_data_type, intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT, + handle=_get_cudnn_current_stream_handle(cudnn, device), ) return cudnn, graph @@ -1322,7 +1342,11 @@ def _build_and_run_cudnn_graph(graph, variant_pack: Dict[Any, torch.Tensor], dev device=device, dtype=torch.uint8, ) - graph.execute(variant_pack, workspace) + graph.execute( + variant_pack, + workspace, + handle=_get_cudnn_current_stream_handle(cudnn, device), + ) class FusedAttentionWithScoreModFunc(torch.autograd.Function): @@ -1347,7 +1371,7 @@ def forward( # pylint: disable=missing-function-docstring q_bhsd_dim, _ = _bhsd_dim_stride(query_layer, q_format) - cudnn, graph = _build_cudnn_pygraph(query_layer.dtype) + cudnn, graph = _build_cudnn_pygraph(query_layer.dtype, query_layer.device) q = _bhsd_graph_tensor(graph, query_layer, q_format) k = _bhsd_graph_tensor(graph, key_layer, kv_format) v = _bhsd_graph_tensor(graph, value_layer, kv_format) @@ -1420,7 +1444,7 @@ def backward(ctx, d_out: torch.Tensor): query_layer, key_layer, value_layer, output_layer, stats_bhs1 = ctx.saved_tensors d_out = d_out.contiguous() - cudnn, graph = _build_cudnn_pygraph(query_layer.dtype) + cudnn, graph = _build_cudnn_pygraph(query_layer.dtype, query_layer.device) q = _bhsd_graph_tensor(graph, query_layer, ctx.q_format) k = _bhsd_graph_tensor(graph, key_layer, ctx.kv_format) v = _bhsd_graph_tensor(graph, value_layer, ctx.kv_format) From fefcbe7f37f995a09ff4df7c5c88c227d6fd0e71 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Tue, 12 May 2026 21:47:04 +0000 Subject: [PATCH 6/8] Add PyTorch score_mod execution plan cache Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 157 +++++ .../dot_product_attention/backends.py | 547 +++++++++++++++--- 2 files changed, 616 insertions(+), 88 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index c42a42d4b4..a8f1811dc0 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -26,6 +26,7 @@ from transformer_engine.pytorch.attention.dot_product_attention import ( _attention_backends, ) +import transformer_engine.pytorch.attention.dot_product_attention.backends as dpa_backends from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils, check_set_window_size, @@ -1505,6 +1506,162 @@ def backward(self, score_mod_graph, dP_tensor, tensors): ) +def _score_mod_cache_cpu_inputs(): + """Small CPU tensors for score_mod cache-key tests.""" + q = torch.empty((2, 4, 3, 8), dtype=torch.float16) + k = torch.empty((2, 4, 3, 8), dtype=torch.float16) + v = torch.empty((2, 4, 3, 8), dtype=torch.float16) + o = torch.empty((2, 4, 3, 8), dtype=torch.float16) + stats = torch.empty((2, 3, 4, 1), dtype=torch.float32) + return q, k, v, o, stats + + +def test_score_mod_cache_bound_method_key_stable(): + """Bound method keys should be stable across repeated attribute access.""" + softcap = _ScoreModSoftcap() + key_0 = dpa_backends._score_mod_callback_cache_key(softcap.forward) + key_1 = dpa_backends._score_mod_callback_cache_key(softcap.forward) + other_key = dpa_backends._score_mod_callback_cache_key(_ScoreModSoftcap().forward) + + assert key_0 == key_1 + assert key_0 != other_key + + +def test_score_mod_cache_key_ignores_pass_by_value_values(): + """Scalar CPU tensor values are runtime inputs, not execution-plan metadata.""" + q, k, v, o, stats = _score_mod_cache_cpu_inputs() + key_0 = dpa_backends._cudnn_score_mod_fwd_cache_key( + True, + q, + k, + v, + o, + stats, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor(0.8, dtype=torch.float32)}, + ) + key_1 = dpa_backends._cudnn_score_mod_fwd_cache_key( + True, + q, + k, + v, + o, + stats, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor(1.2, dtype=torch.float32)}, + ) + key_2 = dpa_backends._cudnn_score_mod_fwd_cache_key( + True, + q, + k, + v, + o, + stats, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor([0.8], dtype=torch.float32)}, + ) + + assert key_0 == key_1 + assert key_0 != key_2 + + +def test_score_mod_cache_fwd_reuses_graph_for_pass_by_value_changes(monkeypatch): + """Fprop graph cache should reuse entries when only scalar CPU tensor values change.""" + q, k, v, o, stats = _score_mod_cache_cpu_inputs() + cache = dpa_backends._cudnn_score_mod_graph_cache + saved_cache = dict(cache) + build_entries = [] + + def fake_build( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats_bhs1, + ): + del ( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats_bhs1, + ) + entry = object() + build_entries.append(entry) + return entry + + monkeypatch.setattr(dpa_backends, "_build_cudnn_score_mod_fwd_graph", fake_build) + try: + cache.clear() + entry_0 = dpa_backends._get_cudnn_score_mod_fwd_graph( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor(0.8, dtype=torch.float32)}, + o, + stats, + ) + entry_1 = dpa_backends._get_cudnn_score_mod_fwd_graph( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor(1.2, dtype=torch.float32)}, + o, + stats, + ) + entry_2 = dpa_backends._get_cudnn_score_mod_fwd_graph( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor([0.8], dtype=torch.float32)}, + o, + stats, + ) + finally: + cache.clear() + cache.update(saved_cache) + + assert entry_0 is entry_1 + assert entry_2 is not entry_0 + assert len(build_entries) == 2 + + def _relative_position_bias(config, dtype): """Materialize score + (q_idx - kv_idx) as post-scale attention bias.""" q_idx = torch.arange(config.max_seqlen_q, dtype=torch.float32, device="cuda").view(1, 1, -1, 1) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 625c030e0c..66a49bbc64 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -4,6 +4,7 @@ """Attention Backends.""" from contextlib import nullcontext +from dataclasses import dataclass from importlib.metadata import version as get_pkg_version from importlib.metadata import PackageNotFoundError import os @@ -90,6 +91,7 @@ _flash_attn_varlen_fwd = None _flash_attn_varlen_bwd = None _cudnn_score_mod_handles: Dict[torch.device, Any] = {} +_cudnn_score_mod_graph_cache: Dict[Tuple[Any, ...], Any] = {} # Try to import Flash Attention v2 try: @@ -1268,6 +1270,52 @@ def _bhsd_graph_tensor(graph, tensor: torch.Tensor, tensor_format: str): return graph.tensor(dim=dim, stride=stride, data_type=tensor.dtype) +def _score_mod_callback_cache_key(callback: Optional[Callable]) -> Optional[Tuple[Any, ...]]: + """Create a stable cache key for a score_mod callable.""" + if callback is None: + return None + self_obj = getattr(callback, "__self__", None) + func_obj = getattr(callback, "__func__", None) + if self_obj is not None and func_obj is not None: + return ("bound_method", id(self_obj), id(func_obj)) + return ("callable", id(callback)) + + +def _score_mod_device_key(device: torch.device) -> Tuple[Any, ...]: + """Normalize a tensor device for graph cache keys.""" + if device.type == "cuda": + index = device.index + if index is None: + index = torch.cuda.current_device() + return (device.type, index) + return (device.type, device.index) + + +def _score_mod_tensor_metadata(tensor: torch.Tensor) -> Tuple[Any, ...]: + """Describe tensor metadata that can affect cuDNN graph construction.""" + return ( + tuple(tensor.size()), + tuple(tensor.stride()), + tensor.dtype, + _score_mod_device_key(tensor.device), + ) + + +def _score_mod_tensor_dict_metadata( + tensors: Optional[Dict[str, torch.Tensor]], +) -> Tuple[Tuple[str, Tuple[Any, ...]], ...]: + """Describe score_mod tensor parameters without including their values.""" + if tensors is None: + return () + return tuple((name, _score_mod_tensor_metadata(tensor)) for name, tensor in tensors.items()) + + +def _score_mod_bhsd_tensor_metadata(tensor: torch.Tensor, tensor_format: str) -> Tuple[Any, ...]: + """Describe an SBHD/BSHD runtime tensor as a cuDNN BHSD graph tensor.""" + dim, stride = _bhsd_dim_stride(tensor, tensor_format) + return (dim, stride, tensor.dtype, _score_mod_device_key(tensor.device)) + + def _make_cudnn_graph_tensor_dict(graph, tensors: Optional[Dict[str, torch.Tensor]]): """Create cuDNN graph tensors matching runtime tensors.""" if tensors is None: @@ -1324,8 +1372,41 @@ def _build_cudnn_pygraph(dtype: torch.dtype, device: torch.device): return cudnn, graph -def _build_and_run_cudnn_graph(graph, variant_pack: Dict[Any, torch.Tensor], device: torch.device): - """Build and execute a cuDNN frontend Python graph without caching.""" +@dataclass +class _CudnnScoreModFwdGraphEntry: + """Cached cuDNN frontend graph and graph tensor handles for score_mod fprop.""" + + graph: Any + q: Any + k: Any + v: Any + output: Any + stats: Optional[Any] + score_mod_graph_tensors: Dict[str, Any] + workspace_size: int + + +@dataclass +class _CudnnScoreModBwdGraphEntry: + """Cached cuDNN frontend graph and graph tensor handles for score_mod bprop.""" + + graph: Any + q: Any + k: Any + v: Any + output: Any + d_output: Any + stats: Any + dq: Any + dk: Any + dv: Any + score_mod_graph_tensors: Dict[str, Any] + score_mod_bprop_graph_tensors: Dict[str, Any] + workspace_size: int + + +def _finalize_cudnn_graph(graph) -> int: + """Build a cuDNN frontend Python graph and return its workspace size.""" import cudnn # pylint: disable=import-outside-toplevel graph.validate() @@ -1336,9 +1417,22 @@ def _build_and_run_cudnn_graph(graph, variant_pack: Dict[Any, torch.Tensor], dev except cudnn.cudnnGraphNotSupportedError as exc: raise RuntimeError(f"cuDNN score_mod SDPA graph is not supported: {exc}") from exc graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE) + return max(graph.get_workspace_size(), 1) + + +def _execute_cudnn_graph( + graph, + variant_pack: Dict[Any, torch.Tensor], + workspace_size: int, + device: torch.device, +): + """Execute a built cuDNN frontend Python graph.""" + import cudnn # pylint: disable=import-outside-toplevel + if device.type == "cuda" and device.index is None: + device = torch.device("cuda", torch.cuda.current_device()) workspace = torch.empty( - max(graph.get_workspace_size(), 1), + workspace_size, device=device, dtype=torch.uint8, ) @@ -1349,6 +1443,307 @@ def _build_and_run_cudnn_graph(graph, variant_pack: Dict[Any, torch.Tensor], dev ) +def _cudnn_score_mod_fwd_cache_key( + is_training: bool, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + output_layer: torch.Tensor, + stats_bhs1: Optional[torch.Tensor], + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_tensors: Optional[Dict[str, torch.Tensor]], +) -> Tuple[Any, ...]: + """Cache key for score_mod fprop execution plans.""" + return ( + "fwd", + is_training, + q_format, + kv_format, + attn_scale, + _score_mod_callback_cache_key(score_mod), + _score_mod_bhsd_tensor_metadata(query_layer, q_format), + _score_mod_bhsd_tensor_metadata(key_layer, kv_format), + _score_mod_bhsd_tensor_metadata(value_layer, kv_format), + _score_mod_bhsd_tensor_metadata(output_layer, q_format), + _score_mod_tensor_metadata(stats_bhs1) if stats_bhs1 is not None else None, + _score_mod_tensor_dict_metadata(score_mod_tensors), + ) + + +def _cudnn_score_mod_bwd_cache_key( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + output_layer: torch.Tensor, + d_out: torch.Tensor, + stats_bhs1: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], + deterministic: bool, +) -> Tuple[Any, ...]: + """Cache key for score_mod bprop execution plans.""" + return ( + "bwd", + q_format, + kv_format, + attn_scale, + deterministic, + _score_mod_callback_cache_key(score_mod), + _score_mod_callback_cache_key(score_mod_bprop), + _score_mod_bhsd_tensor_metadata(query_layer, q_format), + _score_mod_bhsd_tensor_metadata(key_layer, kv_format), + _score_mod_bhsd_tensor_metadata(value_layer, kv_format), + _score_mod_bhsd_tensor_metadata(output_layer, q_format), + _score_mod_bhsd_tensor_metadata(d_out, q_format), + _score_mod_tensor_metadata(stats_bhs1), + _score_mod_tensor_dict_metadata(score_mod_tensors), + _score_mod_tensor_dict_metadata(score_mod_bprop_tensors), + ) + + +def _build_cudnn_score_mod_fwd_graph( + is_training: bool, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + output_layer: torch.Tensor, + stats_bhs1: Optional[torch.Tensor], +) -> _CudnnScoreModFwdGraphEntry: + """Build a cached cuDNN frontend graph for score_mod fprop.""" + import cudnn # pylint: disable=import-outside-toplevel + + _, graph = _build_cudnn_pygraph(query_layer.dtype, query_layer.device) + q = _bhsd_graph_tensor(graph, query_layer, q_format) + k = _bhsd_graph_tensor(graph, key_layer, kv_format) + v = _bhsd_graph_tensor(graph, value_layer, kv_format) + + score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, score_mod_tensors) + wrapped_score_mod = _wrap_score_mod(score_mod, score_mod_graph_tensors) + + output_dim, output_stride = _bhsd_dim_stride(output_layer, q_format) + output, stats = graph.sdpa( + name="te_score_mod_sdpa", + q=q, + k=k, + v=v, + generate_stats=is_training, + attn_scale=attn_scale, + use_causal_mask=False, + score_mod=wrapped_score_mod, + ) + output.set_output(True).set_dim(output_dim).set_stride(output_stride) + + if is_training: + assert stats_bhs1 is not None + stats.set_output(True).set_dim(stats_bhs1.size()).set_stride( + stats_bhs1.stride() + ).set_data_type(cudnn.data_type.FLOAT) + else: + stats = None + + workspace_size = _finalize_cudnn_graph(graph) + return _CudnnScoreModFwdGraphEntry( + graph=graph, + q=q, + k=k, + v=v, + output=output, + stats=stats, + score_mod_graph_tensors=score_mod_graph_tensors, + workspace_size=workspace_size, + ) + + +def _get_cudnn_score_mod_fwd_graph( + is_training: bool, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + output_layer: torch.Tensor, + stats_bhs1: Optional[torch.Tensor], +) -> _CudnnScoreModFwdGraphEntry: + """Return a cached cuDNN frontend graph for score_mod fprop.""" + key = _cudnn_score_mod_fwd_cache_key( + is_training, + query_layer, + key_layer, + value_layer, + output_layer, + stats_bhs1, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + ) + entry = _cudnn_score_mod_graph_cache.get(key) + if entry is None: + entry = _build_cudnn_score_mod_fwd_graph( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats_bhs1, + ) + _cudnn_score_mod_graph_cache[key] = entry + return entry + + +def _build_cudnn_score_mod_bwd_graph( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + output_layer: torch.Tensor, + d_out: torch.Tensor, + stats_bhs1: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], + deterministic: bool, +) -> _CudnnScoreModBwdGraphEntry: + """Build a cached cuDNN frontend graph for score_mod bprop.""" + _, graph = _build_cudnn_pygraph(query_layer.dtype, query_layer.device) + q = _bhsd_graph_tensor(graph, query_layer, q_format) + k = _bhsd_graph_tensor(graph, key_layer, kv_format) + v = _bhsd_graph_tensor(graph, value_layer, kv_format) + output = _bhsd_graph_tensor(graph, output_layer, q_format) + d_output = _bhsd_graph_tensor(graph, d_out, q_format) + stats = graph.tensor_like(stats_bhs1) + + score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, score_mod_tensors) + score_mod_bprop_graph_tensors = ( + _make_cudnn_graph_tensor_dict(graph, score_mod_bprop_tensors) + if score_mod_bprop is not None + else {} + ) + wrapped_score_mod = _wrap_score_mod(score_mod, score_mod_graph_tensors) + wrapped_score_mod_bprop = _wrap_score_mod(score_mod_bprop, score_mod_bprop_graph_tensors) + + dq_layer = torch.empty_like(query_layer) + dk_layer = torch.empty_like(key_layer) + dv_layer = torch.empty_like(value_layer) + dq_dim, dq_stride = _bhsd_dim_stride(dq_layer, q_format) + dk_dim, dk_stride = _bhsd_dim_stride(dk_layer, kv_format) + dv_dim, dv_stride = _bhsd_dim_stride(dv_layer, kv_format) + dq, dk, dv = graph.sdpa_backward( + name="te_score_mod_sdpa_backward", + q=q, + k=k, + v=v, + o=output, + dO=d_output, + stats=stats, + attn_scale=attn_scale, + use_causal_mask=False, + score_mod=wrapped_score_mod, + score_mod_bprop=wrapped_score_mod_bprop, + use_deterministic_algorithm=deterministic, + ) + dq.set_output(True).set_dim(dq_dim).set_stride(dq_stride) + dk.set_output(True).set_dim(dk_dim).set_stride(dk_stride) + dv.set_output(True).set_dim(dv_dim).set_stride(dv_stride) + + workspace_size = _finalize_cudnn_graph(graph) + return _CudnnScoreModBwdGraphEntry( + graph=graph, + q=q, + k=k, + v=v, + output=output, + d_output=d_output, + stats=stats, + dq=dq, + dk=dk, + dv=dv, + score_mod_graph_tensors=score_mod_graph_tensors, + score_mod_bprop_graph_tensors=score_mod_bprop_graph_tensors, + workspace_size=workspace_size, + ) + + +def _get_cudnn_score_mod_bwd_graph( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + output_layer: torch.Tensor, + d_out: torch.Tensor, + stats_bhs1: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], + deterministic: bool, +) -> _CudnnScoreModBwdGraphEntry: + """Return a cached cuDNN frontend graph for score_mod bprop.""" + key = _cudnn_score_mod_bwd_cache_key( + query_layer, + key_layer, + value_layer, + output_layer, + d_out, + stats_bhs1, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + deterministic, + ) + entry = _cudnn_score_mod_graph_cache.get(key) + if entry is None: + entry = _build_cudnn_score_mod_bwd_graph( + query_layer, + key_layer, + value_layer, + output_layer, + d_out, + stats_bhs1, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + deterministic, + ) + _cudnn_score_mod_graph_cache[key] = entry + return entry + + class FusedAttentionWithScoreModFunc(torch.autograd.Function): """cuDNN frontend Python SDPA path with score_mod callback support.""" @@ -1370,52 +1765,47 @@ def forward( ) -> torch.Tensor: # pylint: disable=missing-function-docstring q_bhsd_dim, _ = _bhsd_dim_stride(query_layer, q_format) - - cudnn, graph = _build_cudnn_pygraph(query_layer.dtype, query_layer.device) - q = _bhsd_graph_tensor(graph, query_layer, q_format) - k = _bhsd_graph_tensor(graph, key_layer, kv_format) - v = _bhsd_graph_tensor(graph, value_layer, kv_format) - - score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, score_mod_tensors) - wrapped_score_mod = _wrap_score_mod(score_mod, score_mod_graph_tensors) - output_shape = (*query_layer.shape[:-1], value_layer.shape[-1]) output_layer = torch.empty(output_shape, device=query_layer.device, dtype=query_layer.dtype) - output_dim, output_stride = _bhsd_dim_stride(output_layer, q_format) - output, stats = graph.sdpa( - name="te_score_mod_sdpa", - q=q, - k=k, - v=v, - generate_stats=is_training, - attn_scale=attn_scale, - use_causal_mask=False, - score_mod=wrapped_score_mod, - ) - output.set_output(True).set_dim(output_dim).set_stride(output_stride) - - variant_pack = { - q: query_layer, - k: key_layer, - v: value_layer, - output: output_layer, - } if is_training: stats_bhs1 = torch.empty( (*q_bhsd_dim[:-1], 1), device=query_layer.device, dtype=torch.float32, ) - stats.set_output(True).set_dim(stats_bhs1.size()).set_stride( - stats_bhs1.stride() - ).set_data_type(cudnn.data_type.FLOAT) - variant_pack[stats] = stats_bhs1 else: stats_bhs1 = None - for name, graph_tensor in score_mod_graph_tensors.items(): + + entry = _get_cudnn_score_mod_fwd_graph( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats_bhs1, + ) + variant_pack = { + entry.q: query_layer, + entry.k: key_layer, + entry.v: value_layer, + entry.output: output_layer, + } + if is_training: + variant_pack[entry.stats] = stats_bhs1 + for name, graph_tensor in entry.score_mod_graph_tensors.items(): variant_pack[graph_tensor] = score_mod_tensors[name] - _build_and_run_cudnn_graph(graph, variant_pack, query_layer.device) + _execute_cudnn_graph( + entry.graph, + variant_pack, + entry.workspace_size, + query_layer.device, + ) ctx.is_training = is_training ctx.q_format = q_format @@ -1444,66 +1834,47 @@ def backward(ctx, d_out: torch.Tensor): query_layer, key_layer, value_layer, output_layer, stats_bhs1 = ctx.saved_tensors d_out = d_out.contiguous() - cudnn, graph = _build_cudnn_pygraph(query_layer.dtype, query_layer.device) - q = _bhsd_graph_tensor(graph, query_layer, ctx.q_format) - k = _bhsd_graph_tensor(graph, key_layer, ctx.kv_format) - v = _bhsd_graph_tensor(graph, value_layer, ctx.kv_format) - output = _bhsd_graph_tensor(graph, output_layer, ctx.q_format) - d_output = _bhsd_graph_tensor(graph, d_out, ctx.q_format) - stats = graph.tensor_like(stats_bhs1) - - score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, ctx.score_mod_tensors) - score_mod_bprop_graph_tensors = ( - _make_cudnn_graph_tensor_dict(graph, ctx.score_mod_bprop_tensors) - if ctx.score_mod_bprop is not None - else {} - ) - wrapped_score_mod = _wrap_score_mod(ctx.score_mod, score_mod_graph_tensors) - wrapped_score_mod_bprop = _wrap_score_mod( - ctx.score_mod_bprop, score_mod_bprop_graph_tensors - ) - dq_layer = torch.empty_like(query_layer) dk_layer = torch.empty_like(key_layer) dv_layer = torch.empty_like(value_layer) - dq_dim, dq_stride = _bhsd_dim_stride(dq_layer, ctx.q_format) - dk_dim, dk_stride = _bhsd_dim_stride(dk_layer, ctx.kv_format) - dv_dim, dv_stride = _bhsd_dim_stride(dv_layer, ctx.kv_format) - dq, dk, dv = graph.sdpa_backward( - name="te_score_mod_sdpa_backward", - q=q, - k=k, - v=v, - o=output, - dO=d_output, - stats=stats, - attn_scale=ctx.attn_scale, - use_causal_mask=False, - score_mod=wrapped_score_mod, - score_mod_bprop=wrapped_score_mod_bprop, - use_deterministic_algorithm=ctx.deterministic, + entry = _get_cudnn_score_mod_bwd_graph( + query_layer, + key_layer, + value_layer, + output_layer, + d_out, + stats_bhs1, + ctx.q_format, + ctx.kv_format, + ctx.attn_scale, + ctx.score_mod, + ctx.score_mod_bprop, + ctx.score_mod_tensors, + ctx.score_mod_bprop_tensors, + ctx.deterministic, ) - dq.set_output(True).set_dim(dq_dim).set_stride(dq_stride) - dk.set_output(True).set_dim(dk_dim).set_stride(dk_stride) - dv.set_output(True).set_dim(dv_dim).set_stride(dv_stride) - variant_pack = { - q: query_layer, - k: key_layer, - v: value_layer, - output: output_layer, - d_output: d_out, - stats: stats_bhs1, - dq: dq_layer, - dk: dk_layer, - dv: dv_layer, + entry.q: query_layer, + entry.k: key_layer, + entry.v: value_layer, + entry.output: output_layer, + entry.d_output: d_out, + entry.stats: stats_bhs1, + entry.dq: dq_layer, + entry.dk: dk_layer, + entry.dv: dv_layer, } - for name, graph_tensor in score_mod_graph_tensors.items(): + for name, graph_tensor in entry.score_mod_graph_tensors.items(): variant_pack[graph_tensor] = ctx.score_mod_tensors[name] - for name, graph_tensor in score_mod_bprop_graph_tensors.items(): + for name, graph_tensor in entry.score_mod_bprop_graph_tensors.items(): variant_pack[graph_tensor] = ctx.score_mod_bprop_tensors[name] - _build_and_run_cudnn_graph(graph, variant_pack, query_layer.device) + _execute_cudnn_graph( + entry.graph, + variant_pack, + entry.workspace_size, + query_layer.device, + ) return ( None, From ac4c60d03e6192ef4e5fd0f1cd1aebeea83c2791 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Tue, 12 May 2026 22:40:47 +0000 Subject: [PATCH 7/8] Fix score_mod cache edge cases Signed-off-by: Vladimir Cherepanov --- tests/pytorch/attention/test_attention.py | 45 +++++++++++++++++++ .../dot_product_attention/backends.py | 40 +++++++++++++---- .../dot_product_attention.py | 4 +- 3 files changed, 80 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index a8f1811dc0..c88634fd61 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1662,6 +1662,51 @@ def fake_build( assert len(build_entries) == 2 +def test_score_mod_tensors_are_version_checked_for_backward(monkeypatch): + """In-place score_mod tensor updates before backward should be rejected.""" + + class FakeEntry: + graph = object() + q = object() + k = object() + v = object() + output = object() + stats = object() + score_mod_graph_tensors = {"softcap": object()} + workspace_size = 1 + + def fake_execute(graph, variant_pack, workspace_size, device): + del graph, variant_pack, workspace_size, device + + q, k, v, _, _ = _score_mod_cache_cpu_inputs() + q = q.requires_grad_() + k = k.requires_grad_() + v = v.requires_grad_() + softcap = torch.tensor(0.8, dtype=torch.float32) + + monkeypatch.setattr(dpa_backends, "_get_cudnn_score_mod_fwd_graph", lambda *args: FakeEntry()) + monkeypatch.setattr(dpa_backends, "_execute_cudnn_graph", fake_execute) + + out = dpa_backends.FusedAttentionWithScoreModFunc.apply( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + None, + {"softcap": softcap}, + None, + False, + ) + softcap.add_(1.0) + + with pytest.raises(RuntimeError, match="modified by an inplace operation"): + out.sum().backward() + + def _relative_position_bias(config, dtype): """Materialize score + (q_idx - kv_idx) as post-scale attention bias.""" q_idx = torch.arange(config.max_seqlen_q, dtype=torch.float32, device="cuda").view(1, 1, -1, 1) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 66a49bbc64..2bc8e596d3 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1765,6 +1765,8 @@ def forward( ) -> torch.Tensor: # pylint: disable=missing-function-docstring q_bhsd_dim, _ = _bhsd_dim_stride(query_layer, q_format) + score_mod_tensors = dict(score_mod_tensors or {}) + score_mod_bprop_tensors = dict(score_mod_bprop_tensors or {}) output_shape = (*query_layer.shape[:-1], value_layer.shape[-1]) output_layer = torch.empty(output_shape, device=query_layer.device, dtype=query_layer.dtype) if is_training: @@ -1813,11 +1815,21 @@ def forward( ctx.attn_scale = attn_scale ctx.score_mod = score_mod ctx.score_mod_bprop = score_mod_bprop - ctx.score_mod_tensors = dict(score_mod_tensors or {}) - ctx.score_mod_bprop_tensors = dict(score_mod_bprop_tensors or {}) + ctx.score_mod_tensor_names = tuple(score_mod_tensors.keys()) + ctx.score_mod_bprop_tensor_names = tuple(score_mod_bprop_tensors.keys()) ctx.deterministic = deterministic if is_training: - ctx.save_for_backward(query_layer, key_layer, value_layer, output_layer, stats_bhs1) + # save_for_backward records version counters without copying tensor data. + # This catches in-place score_mod tensor updates before backward. + ctx.save_for_backward( + query_layer, + key_layer, + value_layer, + output_layer, + stats_bhs1, + *score_mod_tensors.values(), + *score_mod_bprop_tensors.values(), + ) else: ctx.save_for_backward(query_layer, key_layer, value_layer, output_layer) @@ -1831,7 +1843,15 @@ def backward(ctx, d_out: torch.Tensor): "score_mod backward requires DotProductAttention to be in training mode." ) - query_layer, key_layer, value_layer, output_layer, stats_bhs1 = ctx.saved_tensors + saved_tensors = ctx.saved_tensors + query_layer, key_layer, value_layer, output_layer, stats_bhs1 = saved_tensors[:5] + score_mod_tensors_end = 5 + len(ctx.score_mod_tensor_names) + score_mod_tensors = dict( + zip(ctx.score_mod_tensor_names, saved_tensors[5:score_mod_tensors_end]) + ) + score_mod_bprop_tensors = dict( + zip(ctx.score_mod_bprop_tensor_names, saved_tensors[score_mod_tensors_end:]) + ) d_out = d_out.contiguous() dq_layer = torch.empty_like(query_layer) @@ -1849,8 +1869,8 @@ def backward(ctx, d_out: torch.Tensor): ctx.attn_scale, ctx.score_mod, ctx.score_mod_bprop, - ctx.score_mod_tensors, - ctx.score_mod_bprop_tensors, + score_mod_tensors, + score_mod_bprop_tensors, ctx.deterministic, ) variant_pack = { @@ -1865,9 +1885,9 @@ def backward(ctx, d_out: torch.Tensor): entry.dv: dv_layer, } for name, graph_tensor in entry.score_mod_graph_tensors.items(): - variant_pack[graph_tensor] = ctx.score_mod_tensors[name] + variant_pack[graph_tensor] = score_mod_tensors[name] for name, graph_tensor in entry.score_mod_bprop_graph_tensors.items(): - variant_pack[graph_tensor] = ctx.score_mod_bprop_tensors[name] + variant_pack[graph_tensor] = score_mod_bprop_tensors[name] _execute_cudnn_graph( entry.graph, @@ -2726,6 +2746,10 @@ def forward( assert ( not fp8 ), "score_mod is not supported with FP8 FusedAttention!" + assert not fp8_output, "score_mod is not supported with fp8_output!" + assert ( + not self.return_max_logit + ), "score_mod is not supported with return_max_logit!" assert ( type(query_layer) is torch.Tensor and type(key_layer) is torch.Tensor diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 7a7745d7c2..b887ed50a1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1249,6 +1249,9 @@ def forward( seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) + if score_mod is not None: + assert inference_params is None, "score_mod is not supported with KV caching!" + # update KV cache and retrieve saved tokens from cache for inference if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" @@ -1456,7 +1459,6 @@ def forward( assert ( not context_parallel ), "score_mod is not supported with context parallelism!" - assert inference_params is None, "score_mod is not supported with KV caching!" assert qkv_format != "thd", "score_mod is not supported with qkv_format='thd'!" assert ( not user_supplied_seqlens From 6446825aeb7870c7088a1c7724fa101b151de61e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 May 2026 03:19:08 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../dot_product_attention/backends.py | 12 ++--- .../dot_product_attention.py | 45 +++++++++---------- 2 files changed, 25 insertions(+), 32 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 2bc8e596d3..fbb55250ef 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -2740,16 +2740,10 @@ def forward( ) if score_mod is not None: - assert ( - not context_parallel - ), "score_mod is not supported with context parallelism!" - assert ( - not fp8 - ), "score_mod is not supported with FP8 FusedAttention!" + assert not context_parallel, "score_mod is not supported with context parallelism!" + assert not fp8, "score_mod is not supported with FP8 FusedAttention!" assert not fp8_output, "score_mod is not supported with fp8_output!" - assert ( - not self.return_max_logit - ), "score_mod is not supported with return_max_logit!" + assert not self.return_max_logit, "score_mod is not supported with return_max_logit!" assert ( type(query_layer) is torch.Tensor and type(key_layer) is torch.Tensor diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index b887ed50a1..95a0b53b4a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1435,15 +1435,18 @@ def forward( if score_mod is None: assert score_mod_bprop is None, "score_mod_bprop requires score_mod!" assert score_mod_tensors is None, "score_mod_tensors requires score_mod!" - assert score_mod_bprop_tensors is None, "score_mod_bprop_tensors requires score_mod!" + assert ( + score_mod_bprop_tensors is None + ), "score_mod_bprop_tensors requires score_mod!" else: assert callable(score_mod), "score_mod must be callable!" - assert ( - score_mod_bprop is None or callable(score_mod_bprop) + assert score_mod_bprop is None or callable( + score_mod_bprop ), "score_mod_bprop must be callable when provided!" - assert ( - query_layer.dtype in [torch.float16, torch.bfloat16] - ), "score_mod only supports FP16 and BF16 tensors!" + assert query_layer.dtype in [ + torch.float16, + torch.bfloat16, + ], "score_mod only supports FP16 and BF16 tensors!" assert ( key_layer.dtype == query_layer.dtype and value_layer.dtype == query_layer.dtype ), "score_mod requires Q, K and V tensors to have the same dtype!" @@ -1452,24 +1455,21 @@ def forward( and type(key_layer) is torch.Tensor and type(value_layer) is torch.Tensor ), "score_mod only supports unquantized torch.Tensor Q, K and V inputs!" - assert ( - not self.fp8 - ), "score_mod is not supported with FP8 DotProductAttention!" + assert not self.fp8, "score_mod is not supported with FP8 DotProductAttention!" assert not fp8_output, "score_mod is not supported with fp8_output!" - assert ( - not context_parallel - ), "score_mod is not supported with context parallelism!" + assert not context_parallel, "score_mod is not supported with context parallelism!" assert qkv_format != "thd", "score_mod is not supported with qkv_format='thd'!" assert ( not user_supplied_seqlens ), "score_mod is mutually exclusive with explicit sequence length metadata!" assert not pad_between_seqs, "score_mod is not supported with pad_between_seqs!" - assert attention_mask is None, "score_mod is mutually exclusive with attention_mask!" assert ( - attn_mask_type == "no_mask" - ), "score_mod requires attn_mask_type='no_mask'!" - assert ( - window_size is None or window_size == (-1, -1) + attention_mask is None + ), "score_mod is mutually exclusive with attention_mask!" + assert attn_mask_type == "no_mask", "score_mod requires attn_mask_type='no_mask'!" + assert window_size is None or window_size == ( + -1, + -1, ), "score_mod is mutually exclusive with sliding window attention!" assert ( core_attention_bias_type == "no_bias" and core_attention_bias is None @@ -1491,13 +1491,12 @@ def forward( not is_graph_capturing() ), "score_mod is not supported with CUDA graph capture!" assert num_splits == 1, "score_mod is not supported with num_splits != 1!" - assert ( - q_format in ["sbhd", "bshd"] and kv_format in ["sbhd", "bshd"] - ), "score_mod only supports SBHD/BSHD QKV formats!" + assert q_format in ["sbhd", "bshd"] and kv_format in [ + "sbhd", + "bshd", + ], "score_mod only supports SBHD/BSHD QKV formats!" if score_mod_tensors is not None: - assert isinstance( - score_mod_tensors, dict - ), "score_mod_tensors must be a dict!" + assert isinstance(score_mod_tensors, dict), "score_mod_tensors must be a dict!" assert all( isinstance(k, str) and isinstance(v, torch.Tensor) for k, v in score_mod_tensors.items()