diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 32ea1694ee..c88634fd61 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -26,6 +26,7 @@ from transformer_engine.pytorch.attention.dot_product_attention import ( _attention_backends, ) +import transformer_engine.pytorch.attention.dot_product_attention.backends as dpa_backends from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils, check_set_window_size, @@ -1390,6 +1391,628 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: return out, max_logit, (None, None, None, d_softmax_offset) +def _score_mod_causal(score_mod_graph, score_tensor, tensors): + """cuDNN frontend score_mod implementing top-left causal masking.""" + import cudnn # pylint: disable=import-outside-toplevel + + row_index = score_mod_graph.gen_index(input=score_tensor, axis=2) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = score_mod_graph.gen_index(input=score_tensor, axis=3) + col_index.set_data_type(cudnn.data_type.INT32) + keep = score_mod_graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + keep.set_data_type(cudnn.data_type.BOOLEAN) + return score_mod_graph.binary_select( + input0=score_tensor, + input1=tensors["neg_inf"], + mask=keep, + ) + + +def _score_mod_causal_bprop(score_mod_graph, dP_tensor, tensors): + """cuDNN frontend score_mod_bprop implementing top-left causal masking.""" + import cudnn # pylint: disable=import-outside-toplevel + + row_index = score_mod_graph.gen_index(input=dP_tensor, axis=2) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = score_mod_graph.gen_index(input=dP_tensor, axis=3) + col_index.set_data_type(cudnn.data_type.INT32) + keep = score_mod_graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + keep.set_data_type(cudnn.data_type.BOOLEAN) + return score_mod_graph.binary_select( + input0=dP_tensor, + input1=tensors["zero"], + mask=keep, + ) + + +def _score_mod_relative_position(score_mod_graph, score_tensor, _tensors): + """cuDNN frontend score_mod adding relative position bias.""" + import cudnn # pylint: disable=import-outside-toplevel + + row_index = score_mod_graph.gen_index(input=score_tensor, axis=2) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = score_mod_graph.gen_index(input=score_tensor, axis=3) + col_index.set_data_type(cudnn.data_type.INT32) + relative_position = score_mod_graph.sub( + a=row_index, + b=col_index, + compute_data_type=cudnn.data_type.FLOAT, + ) + relative_position.set_data_type(cudnn.data_type.FLOAT) + return score_mod_graph.add( + a=score_tensor, + b=relative_position, + compute_data_type=cudnn.data_type.FLOAT, + ) + + +def _score_mod_identity_bprop(_score_mod_graph, dP_tensor, _tensors): + """cuDNN frontend score_mod_bprop for score_mods with unit score derivative.""" + return dP_tensor + + +class _ScoreModSoftcap: + """cuDNN frontend score_mod implementing softcapping.""" + + def __init__(self): + self.before_tanh_activation = None + + def forward(self, score_mod_graph, score_tensor, tensors): + """Apply softcap * tanh(score / softcap).""" + import cudnn # pylint: disable=import-outside-toplevel + + self.before_tanh_activation = score_mod_graph.div( + a=score_tensor, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + self.before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) + tanh_out = score_mod_graph.tanh(input=self.before_tanh_activation) + tanh_out.set_data_type(cudnn.data_type.FLOAT) + return score_mod_graph.mul( + a=tanh_out, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + + def backward(self, score_mod_graph, dP_tensor, tensors): + """Apply softcap derivative to dP.""" + import cudnn # pylint: disable=import-outside-toplevel + + d_tanh_out = score_mod_graph.mul( + a=dP_tensor, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + d_tanh_out.set_data_type(cudnn.data_type.FLOAT) + d_before_tanh_activation = score_mod_graph.tanh_backward( + loss=d_tanh_out, + input=self.before_tanh_activation, + compute_data_type=cudnn.data_type.FLOAT, + ) + d_before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) + return score_mod_graph.div( + a=d_before_tanh_activation, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + + +def _score_mod_cache_cpu_inputs(): + """Small CPU tensors for score_mod cache-key tests.""" + q = torch.empty((2, 4, 3, 8), dtype=torch.float16) + k = torch.empty((2, 4, 3, 8), dtype=torch.float16) + v = torch.empty((2, 4, 3, 8), dtype=torch.float16) + o = torch.empty((2, 4, 3, 8), dtype=torch.float16) + stats = torch.empty((2, 3, 4, 1), dtype=torch.float32) + return q, k, v, o, stats + + +def test_score_mod_cache_bound_method_key_stable(): + """Bound method keys should be stable across repeated attribute access.""" + softcap = _ScoreModSoftcap() + key_0 = dpa_backends._score_mod_callback_cache_key(softcap.forward) + key_1 = dpa_backends._score_mod_callback_cache_key(softcap.forward) + other_key = dpa_backends._score_mod_callback_cache_key(_ScoreModSoftcap().forward) + + assert key_0 == key_1 + assert key_0 != other_key + + +def test_score_mod_cache_key_ignores_pass_by_value_values(): + """Scalar CPU tensor values are runtime inputs, not execution-plan metadata.""" + q, k, v, o, stats = _score_mod_cache_cpu_inputs() + key_0 = dpa_backends._cudnn_score_mod_fwd_cache_key( + True, + q, + k, + v, + o, + stats, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor(0.8, dtype=torch.float32)}, + ) + key_1 = dpa_backends._cudnn_score_mod_fwd_cache_key( + True, + q, + k, + v, + o, + stats, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor(1.2, dtype=torch.float32)}, + ) + key_2 = dpa_backends._cudnn_score_mod_fwd_cache_key( + True, + q, + k, + v, + o, + stats, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor([0.8], dtype=torch.float32)}, + ) + + assert key_0 == key_1 + assert key_0 != key_2 + + +def test_score_mod_cache_fwd_reuses_graph_for_pass_by_value_changes(monkeypatch): + """Fprop graph cache should reuse entries when only scalar CPU tensor values change.""" + q, k, v, o, stats = _score_mod_cache_cpu_inputs() + cache = dpa_backends._cudnn_score_mod_graph_cache + saved_cache = dict(cache) + build_entries = [] + + def fake_build( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats_bhs1, + ): + del ( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats_bhs1, + ) + entry = object() + build_entries.append(entry) + return entry + + monkeypatch.setattr(dpa_backends, "_build_cudnn_score_mod_fwd_graph", fake_build) + try: + cache.clear() + entry_0 = dpa_backends._get_cudnn_score_mod_fwd_graph( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor(0.8, dtype=torch.float32)}, + o, + stats, + ) + entry_1 = dpa_backends._get_cudnn_score_mod_fwd_graph( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor(1.2, dtype=torch.float32)}, + o, + stats, + ) + entry_2 = dpa_backends._get_cudnn_score_mod_fwd_graph( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + {"softcap": torch.tensor([0.8], dtype=torch.float32)}, + o, + stats, + ) + finally: + cache.clear() + cache.update(saved_cache) + + assert entry_0 is entry_1 + assert entry_2 is not entry_0 + assert len(build_entries) == 2 + + +def test_score_mod_tensors_are_version_checked_for_backward(monkeypatch): + """In-place score_mod tensor updates before backward should be rejected.""" + + class FakeEntry: + graph = object() + q = object() + k = object() + v = object() + output = object() + stats = object() + score_mod_graph_tensors = {"softcap": object()} + workspace_size = 1 + + def fake_execute(graph, variant_pack, workspace_size, device): + del graph, variant_pack, workspace_size, device + + q, k, v, _, _ = _score_mod_cache_cpu_inputs() + q = q.requires_grad_() + k = k.requires_grad_() + v = v.requires_grad_() + softcap = torch.tensor(0.8, dtype=torch.float32) + + monkeypatch.setattr(dpa_backends, "_get_cudnn_score_mod_fwd_graph", lambda *args: FakeEntry()) + monkeypatch.setattr(dpa_backends, "_execute_cudnn_graph", fake_execute) + + out = dpa_backends.FusedAttentionWithScoreModFunc.apply( + True, + q, + k, + v, + "bshd", + "bshd", + 1.0, + _score_mod_causal, + None, + {"softcap": softcap}, + None, + False, + ) + softcap.add_(1.0) + + with pytest.raises(RuntimeError, match="modified by an inplace operation"): + out.sum().backward() + + +def _relative_position_bias(config, dtype): + """Materialize score + (q_idx - kv_idx) as post-scale attention bias.""" + q_idx = torch.arange(config.max_seqlen_q, dtype=torch.float32, device="cuda").view(1, 1, -1, 1) + kv_idx = torch.arange(config.max_seqlen_kv, dtype=torch.float32, device="cuda").view( + 1, 1, 1, -1 + ) + return (q_idx - kv_idx).to(dtype).expand(1, config.num_heads, -1, -1).contiguous() + + +def _to_bhsd(tensor, qkv_format): + """Convert SBHD/BSHD test tensors to logical BHSD.""" + if qkv_format == "sbhd": + return tensor.permute(1, 2, 0, 3) + return tensor.permute(0, 2, 1, 3) + + +def _from_bhsd(tensor, qkv_format): + """Convert logical BHSD test tensors to SBHD/BSHD.""" + if qkv_format == "sbhd": + return tensor.permute(2, 0, 1, 3).contiguous() + return tensor.permute(0, 2, 1, 3).contiguous() + + +def _pytorch_softcap_attention(q, k, v, qkv_format, softmax_scale, softcap): + """PyTorch reference for softcapped scaled dot-product attention.""" + q_bhsd = _to_bhsd(q, qkv_format).float() + k_bhsd = _to_bhsd(k, qkv_format).float() + v_bhsd = _to_bhsd(v, qkv_format).float() + scores = torch.matmul(q_bhsd, k_bhsd.transpose(-2, -1)) * softmax_scale + scores = softcap * torch.tanh(scores / softcap) + probs = torch.softmax(scores, dim=-1) + out = _from_bhsd(torch.matmul(probs, v_bhsd), qkv_format).to(v.dtype) + return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1]) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") +@pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd"]) +@pytest.mark.parametrize("scalar_loss", [False, True]) +def test_dot_product_attention_score_mod(dtype, qkv_format, scalar_loss): + """Compare score_mod causal masking against standard cuDNN causal attention.""" + try: + import cudnn # pylint: disable=unused-import,import-outside-toplevel + except ImportError: + pytest.skip("cuDNN Python frontend is required for score_mod attention.") + + reset_rng_states() + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + config = ModelConfig(2, 64, 4, 64, attn_mask_type="no_mask") + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", + ) + if not available_backends[1] or not fused_attn_backends: + pytest.skip("FusedAttention is not available for this score_mod configuration.") + + if qkv_format == "sbhd": + q_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + else: + q_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + + q = (0.1 * torch.randn(q_shape, dtype=dtype, device="cuda")).requires_grad_() + k = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + v = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + q_ref, k_ref, v_ref = [x.detach().clone().requires_grad_() for x in (q, k, v)] + + flex_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="no_mask", + layer_number=1, + ).to(dtype=dtype, device="cuda") + ref_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="causal", + layer_number=1, + ).to(dtype=dtype, device="cuda") + + out = flex_attn( + q, + k, + v, + qkv_format=qkv_format, + attn_mask_type="no_mask", + score_mod=_score_mod_causal, + score_mod_bprop=_score_mod_causal_bprop, + score_mod_tensors={"neg_inf": torch.full((1, 1, 1, 1), -1e9)}, + score_mod_bprop_tensors={"zero": torch.full((1, 1, 1, 1), 0.0)}, + ) + out_ref = ref_attn( + q_ref, + k_ref, + v_ref, + qkv_format=qkv_format, + attn_mask_type="causal", + ) + + if scalar_loss: + out.sum().backward() + out_ref.sum().backward() + else: + d_out = torch.randn_like(out) + out.backward(d_out) + out_ref.backward(d_out) + + tols = dict(atol=5e-2, rtol=5e-2) + torch.testing.assert_close(out, out_ref, **tols) + torch.testing.assert_close(q.grad, q_ref.grad, **tols) + torch.testing.assert_close(k.grad, k_ref.grad, **tols) + torch.testing.assert_close(v.grad, v_ref.grad, **tols) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") +@pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd"]) +def test_dot_product_attention_score_mod_softcap(dtype, qkv_format): + """Compare softcap score_mod against PyTorch math attention.""" + try: + import cudnn # pylint: disable=unused-import,import-outside-toplevel + except ImportError: + pytest.skip("cuDNN Python frontend is required for score_mod attention.") + + reset_rng_states() + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + config = ModelConfig(2, 16, 4, 64, attn_mask_type="no_mask") + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", + ) + if not available_backends[1] or not fused_attn_backends: + pytest.skip("FusedAttention is not available for this softcap configuration.") + + if qkv_format == "sbhd": + q_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + else: + q_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + + q = torch.randn(q_shape, dtype=dtype, device="cuda").requires_grad_() + k = torch.randn(kv_shape, dtype=dtype, device="cuda").requires_grad_() + v = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + q_ref, k_ref, v_ref = [x.detach().clone().requires_grad_() for x in (q, k, v)] + + softcap = 0.8 + softcap_tensor = torch.full((1, 1, 1, 1), softcap) + softcap_score_mod = _ScoreModSoftcap() + flex_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="no_mask", + layer_number=1, + ).to(dtype=dtype, device="cuda") + + out = flex_attn( + q, + k, + v, + qkv_format=qkv_format, + attn_mask_type="no_mask", + score_mod=softcap_score_mod.forward, + score_mod_bprop=softcap_score_mod.backward, + score_mod_tensors={"softcap": softcap_tensor}, + score_mod_bprop_tensors={"softcap": softcap_tensor}, + ) + out_ref = _pytorch_softcap_attention( + q_ref, + k_ref, + v_ref, + qkv_format, + 1.0 / config.head_dim_qk**0.5, + softcap, + ) + + d_out = torch.randn_like(out) + out.backward(d_out) + out_ref.backward(d_out) + + tols = dict(atol=7e-2, rtol=7e-2) + torch.testing.assert_close(out, out_ref, **tols) + torch.testing.assert_close(q.grad, q_ref.grad, **tols) + torch.testing.assert_close(k.grad, k_ref.grad, **tols) + torch.testing.assert_close(v.grad, v_ref.grad, **tols) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") +@pytest.mark.skipif(get_cudnn_version() < (9, 6, 0), reason="cuDNN 9.6.0+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd"]) +def test_dot_product_attention_score_mod_relative_position(dtype, qkv_format): + """Compare relative-position score_mod against materialized post-scale bias.""" + try: + import cudnn # pylint: disable=unused-import,import-outside-toplevel + except ImportError: + pytest.skip("cuDNN Python frontend is required for score_mod attention.") + + reset_rng_states() + + config = ModelConfig(2, 16, 4, 64, attn_mask_type="no_mask") + bias_config = ModelConfig( + config.batch_size, + config.max_seqlen_q, + config.num_heads, + config.head_dim_qk, + attn_mask_type="no_mask", + attn_bias_type="post_scale_bias", + bias_shape="1hss", + ) + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", + ) + bias_available_backends, _, bias_fused_attn_backends = get_available_attention_backends( + bias_config, + qkv_dtype=dtype, + qkv_layout=f"{qkv_format}_{qkv_format}_{qkv_format}", + ) + if ( + not available_backends[1] + or not fused_attn_backends + or not bias_available_backends[1] + or not bias_fused_attn_backends + ): + pytest.skip("FusedAttention is not available for this relative-position configuration.") + + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + if qkv_format == "sbhd": + q_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + else: + q_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim_qk) + kv_shape = q_shape + + q = (0.1 * torch.randn(q_shape, dtype=dtype, device="cuda")).requires_grad_() + k = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + v = (0.1 * torch.randn(kv_shape, dtype=dtype, device="cuda")).requires_grad_() + q_ref, k_ref, v_ref = [x.detach().clone().requires_grad_() for x in (q, k, v)] + + flex_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="no_mask", + layer_number=1, + ).to(dtype=dtype, device="cuda") + ref_attn = DotProductAttention( + config.num_heads, + config.head_dim_qk, + qkv_format=qkv_format, + attn_mask_type="no_mask", + layer_number=1, + ).to(dtype=dtype, device="cuda") + + out = flex_attn( + q, + k, + v, + qkv_format=qkv_format, + attn_mask_type="no_mask", + score_mod=_score_mod_relative_position, + score_mod_bprop=_score_mod_identity_bprop, + ) + out_ref = ref_attn( + q_ref, + k_ref, + v_ref, + qkv_format=qkv_format, + attn_mask_type="no_mask", + core_attention_bias_type="post_scale_bias", + core_attention_bias=_relative_position_bias(config, dtype), + ) + + d_out = torch.randn_like(out) + out.backward(d_out) + out_ref.backward(d_out) + + tols = dict(atol=5e-2, rtol=5e-2) + torch.testing.assert_close(out, out_ref, **tols) + torch.testing.assert_close(q.grad, q_ref.grad, **tols) + torch.testing.assert_close(k.grad, k_ref.grad, **tols) + torch.testing.assert_close(v.grad, v_ref.grad, **tols) + + model_configs_te_layer = { # test: ModelConfig(b, sq, hq, dqk) "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"), diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 79ebbd4afa..fbb55250ef 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -4,6 +4,7 @@ """Attention Backends.""" from contextlib import nullcontext +from dataclasses import dataclass from importlib.metadata import version as get_pkg_version from importlib.metadata import PackageNotFoundError import os @@ -89,6 +90,8 @@ _flash_attn_bwd = None _flash_attn_varlen_fwd = None _flash_attn_varlen_bwd = None +_cudnn_score_mod_handles: Dict[torch.device, Any] = {} +_cudnn_score_mod_graph_cache: Dict[Tuple[Any, ...], Any] = {} # Try to import Flash Attention v2 try: @@ -1244,6 +1247,671 @@ def convert_to_torch_float8(tensor, dtype): return output.contiguous() +def _bhsd_dim_stride( + tensor: torch.Tensor, tensor_format: str +) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + """Describe an SBHD/BSHD tensor as cuDNN frontend's logical BHSD format.""" + if tensor_format == "sbhd": + return ( + (tensor.shape[1], tensor.shape[2], tensor.shape[0], tensor.shape[3]), + (tensor.stride(1), tensor.stride(2), tensor.stride(0), tensor.stride(3)), + ) + if tensor_format == "bshd": + return ( + (tensor.shape[0], tensor.shape[2], tensor.shape[1], tensor.shape[3]), + (tensor.stride(0), tensor.stride(2), tensor.stride(1), tensor.stride(3)), + ) + raise ValueError(f"score_mod only supports SBHD/BSHD tensor formats, got {tensor_format}.") + + +def _bhsd_graph_tensor(graph, tensor: torch.Tensor, tensor_format: str): + """Create a cuDNN graph tensor with BHSD dims and TE-layout strides.""" + dim, stride = _bhsd_dim_stride(tensor, tensor_format) + return graph.tensor(dim=dim, stride=stride, data_type=tensor.dtype) + + +def _score_mod_callback_cache_key(callback: Optional[Callable]) -> Optional[Tuple[Any, ...]]: + """Create a stable cache key for a score_mod callable.""" + if callback is None: + return None + self_obj = getattr(callback, "__self__", None) + func_obj = getattr(callback, "__func__", None) + if self_obj is not None and func_obj is not None: + return ("bound_method", id(self_obj), id(func_obj)) + return ("callable", id(callback)) + + +def _score_mod_device_key(device: torch.device) -> Tuple[Any, ...]: + """Normalize a tensor device for graph cache keys.""" + if device.type == "cuda": + index = device.index + if index is None: + index = torch.cuda.current_device() + return (device.type, index) + return (device.type, device.index) + + +def _score_mod_tensor_metadata(tensor: torch.Tensor) -> Tuple[Any, ...]: + """Describe tensor metadata that can affect cuDNN graph construction.""" + return ( + tuple(tensor.size()), + tuple(tensor.stride()), + tensor.dtype, + _score_mod_device_key(tensor.device), + ) + + +def _score_mod_tensor_dict_metadata( + tensors: Optional[Dict[str, torch.Tensor]], +) -> Tuple[Tuple[str, Tuple[Any, ...]], ...]: + """Describe score_mod tensor parameters without including their values.""" + if tensors is None: + return () + return tuple((name, _score_mod_tensor_metadata(tensor)) for name, tensor in tensors.items()) + + +def _score_mod_bhsd_tensor_metadata(tensor: torch.Tensor, tensor_format: str) -> Tuple[Any, ...]: + """Describe an SBHD/BSHD runtime tensor as a cuDNN BHSD graph tensor.""" + dim, stride = _bhsd_dim_stride(tensor, tensor_format) + return (dim, stride, tensor.dtype, _score_mod_device_key(tensor.device)) + + +def _make_cudnn_graph_tensor_dict(graph, tensors: Optional[Dict[str, torch.Tensor]]): + """Create cuDNN graph tensors matching runtime tensors.""" + if tensors is None: + return {} + return {name: graph.tensor_like(tensor) for name, tensor in tensors.items()} + + +def _wrap_score_mod(score_mod: Optional[Callable], graph_tensors: Dict[str, Any]): + """Adapt TE's score_mod signature to cuDNN frontend's two-argument callback.""" + if score_mod is None: + return None + + def _wrapped_score_mod(sdpa_graph, score_tensor): + return score_mod(sdpa_graph, score_tensor, graph_tensors) + + return _wrapped_score_mod + + +def _get_cudnn_current_stream_handle(cudnn, device: torch.device): + """Return a cuDNN handle for device, bound to PyTorch's current stream.""" + if device.type != "cuda": + raise ValueError(f"score_mod only supports CUDA tensors, got device {device}.") + if device.index is None: + device = torch.device("cuda", torch.cuda.current_device()) + + handle = _cudnn_score_mod_handles.get(device) + with torch.cuda.device(device): + if handle is None: + handle = cudnn.create_handle() + _cudnn_score_mod_handles[device] = handle + + stream = torch.cuda.current_stream(device).cuda_stream + cudnn.set_stream(handle=handle, stream=stream) + return handle + + +def _build_cudnn_pygraph(dtype: torch.dtype, device: torch.device): + """Create a cuDNN frontend Python graph for F16/BF16 SDPA.""" + import cudnn # pylint: disable=import-outside-toplevel + + if dtype == torch.float16: + io_data_type = cudnn.data_type.HALF + elif dtype == torch.bfloat16: + io_data_type = cudnn.data_type.BFLOAT16 + else: + raise ValueError(f"score_mod only supports FP16/BF16 tensors, got {dtype}.") + + graph = cudnn.pygraph( + io_data_type=io_data_type, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + handle=_get_cudnn_current_stream_handle(cudnn, device), + ) + return cudnn, graph + + +@dataclass +class _CudnnScoreModFwdGraphEntry: + """Cached cuDNN frontend graph and graph tensor handles for score_mod fprop.""" + + graph: Any + q: Any + k: Any + v: Any + output: Any + stats: Optional[Any] + score_mod_graph_tensors: Dict[str, Any] + workspace_size: int + + +@dataclass +class _CudnnScoreModBwdGraphEntry: + """Cached cuDNN frontend graph and graph tensor handles for score_mod bprop.""" + + graph: Any + q: Any + k: Any + v: Any + output: Any + d_output: Any + stats: Any + dq: Any + dk: Any + dv: Any + score_mod_graph_tensors: Dict[str, Any] + score_mod_bprop_graph_tensors: Dict[str, Any] + workspace_size: int + + +def _finalize_cudnn_graph(graph) -> int: + """Build a cuDNN frontend Python graph and return its workspace size.""" + import cudnn # pylint: disable=import-outside-toplevel + + graph.validate() + graph.build_operation_graph() + try: + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + except cudnn.cudnnGraphNotSupportedError as exc: + raise RuntimeError(f"cuDNN score_mod SDPA graph is not supported: {exc}") from exc + graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE) + return max(graph.get_workspace_size(), 1) + + +def _execute_cudnn_graph( + graph, + variant_pack: Dict[Any, torch.Tensor], + workspace_size: int, + device: torch.device, +): + """Execute a built cuDNN frontend Python graph.""" + import cudnn # pylint: disable=import-outside-toplevel + + if device.type == "cuda" and device.index is None: + device = torch.device("cuda", torch.cuda.current_device()) + workspace = torch.empty( + workspace_size, + device=device, + dtype=torch.uint8, + ) + graph.execute( + variant_pack, + workspace, + handle=_get_cudnn_current_stream_handle(cudnn, device), + ) + + +def _cudnn_score_mod_fwd_cache_key( + is_training: bool, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + output_layer: torch.Tensor, + stats_bhs1: Optional[torch.Tensor], + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_tensors: Optional[Dict[str, torch.Tensor]], +) -> Tuple[Any, ...]: + """Cache key for score_mod fprop execution plans.""" + return ( + "fwd", + is_training, + q_format, + kv_format, + attn_scale, + _score_mod_callback_cache_key(score_mod), + _score_mod_bhsd_tensor_metadata(query_layer, q_format), + _score_mod_bhsd_tensor_metadata(key_layer, kv_format), + _score_mod_bhsd_tensor_metadata(value_layer, kv_format), + _score_mod_bhsd_tensor_metadata(output_layer, q_format), + _score_mod_tensor_metadata(stats_bhs1) if stats_bhs1 is not None else None, + _score_mod_tensor_dict_metadata(score_mod_tensors), + ) + + +def _cudnn_score_mod_bwd_cache_key( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + output_layer: torch.Tensor, + d_out: torch.Tensor, + stats_bhs1: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], + deterministic: bool, +) -> Tuple[Any, ...]: + """Cache key for score_mod bprop execution plans.""" + return ( + "bwd", + q_format, + kv_format, + attn_scale, + deterministic, + _score_mod_callback_cache_key(score_mod), + _score_mod_callback_cache_key(score_mod_bprop), + _score_mod_bhsd_tensor_metadata(query_layer, q_format), + _score_mod_bhsd_tensor_metadata(key_layer, kv_format), + _score_mod_bhsd_tensor_metadata(value_layer, kv_format), + _score_mod_bhsd_tensor_metadata(output_layer, q_format), + _score_mod_bhsd_tensor_metadata(d_out, q_format), + _score_mod_tensor_metadata(stats_bhs1), + _score_mod_tensor_dict_metadata(score_mod_tensors), + _score_mod_tensor_dict_metadata(score_mod_bprop_tensors), + ) + + +def _build_cudnn_score_mod_fwd_graph( + is_training: bool, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + output_layer: torch.Tensor, + stats_bhs1: Optional[torch.Tensor], +) -> _CudnnScoreModFwdGraphEntry: + """Build a cached cuDNN frontend graph for score_mod fprop.""" + import cudnn # pylint: disable=import-outside-toplevel + + _, graph = _build_cudnn_pygraph(query_layer.dtype, query_layer.device) + q = _bhsd_graph_tensor(graph, query_layer, q_format) + k = _bhsd_graph_tensor(graph, key_layer, kv_format) + v = _bhsd_graph_tensor(graph, value_layer, kv_format) + + score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, score_mod_tensors) + wrapped_score_mod = _wrap_score_mod(score_mod, score_mod_graph_tensors) + + output_dim, output_stride = _bhsd_dim_stride(output_layer, q_format) + output, stats = graph.sdpa( + name="te_score_mod_sdpa", + q=q, + k=k, + v=v, + generate_stats=is_training, + attn_scale=attn_scale, + use_causal_mask=False, + score_mod=wrapped_score_mod, + ) + output.set_output(True).set_dim(output_dim).set_stride(output_stride) + + if is_training: + assert stats_bhs1 is not None + stats.set_output(True).set_dim(stats_bhs1.size()).set_stride( + stats_bhs1.stride() + ).set_data_type(cudnn.data_type.FLOAT) + else: + stats = None + + workspace_size = _finalize_cudnn_graph(graph) + return _CudnnScoreModFwdGraphEntry( + graph=graph, + q=q, + k=k, + v=v, + output=output, + stats=stats, + score_mod_graph_tensors=score_mod_graph_tensors, + workspace_size=workspace_size, + ) + + +def _get_cudnn_score_mod_fwd_graph( + is_training: bool, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + output_layer: torch.Tensor, + stats_bhs1: Optional[torch.Tensor], +) -> _CudnnScoreModFwdGraphEntry: + """Return a cached cuDNN frontend graph for score_mod fprop.""" + key = _cudnn_score_mod_fwd_cache_key( + is_training, + query_layer, + key_layer, + value_layer, + output_layer, + stats_bhs1, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + ) + entry = _cudnn_score_mod_graph_cache.get(key) + if entry is None: + entry = _build_cudnn_score_mod_fwd_graph( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats_bhs1, + ) + _cudnn_score_mod_graph_cache[key] = entry + return entry + + +def _build_cudnn_score_mod_bwd_graph( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + output_layer: torch.Tensor, + d_out: torch.Tensor, + stats_bhs1: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], + deterministic: bool, +) -> _CudnnScoreModBwdGraphEntry: + """Build a cached cuDNN frontend graph for score_mod bprop.""" + _, graph = _build_cudnn_pygraph(query_layer.dtype, query_layer.device) + q = _bhsd_graph_tensor(graph, query_layer, q_format) + k = _bhsd_graph_tensor(graph, key_layer, kv_format) + v = _bhsd_graph_tensor(graph, value_layer, kv_format) + output = _bhsd_graph_tensor(graph, output_layer, q_format) + d_output = _bhsd_graph_tensor(graph, d_out, q_format) + stats = graph.tensor_like(stats_bhs1) + + score_mod_graph_tensors = _make_cudnn_graph_tensor_dict(graph, score_mod_tensors) + score_mod_bprop_graph_tensors = ( + _make_cudnn_graph_tensor_dict(graph, score_mod_bprop_tensors) + if score_mod_bprop is not None + else {} + ) + wrapped_score_mod = _wrap_score_mod(score_mod, score_mod_graph_tensors) + wrapped_score_mod_bprop = _wrap_score_mod(score_mod_bprop, score_mod_bprop_graph_tensors) + + dq_layer = torch.empty_like(query_layer) + dk_layer = torch.empty_like(key_layer) + dv_layer = torch.empty_like(value_layer) + dq_dim, dq_stride = _bhsd_dim_stride(dq_layer, q_format) + dk_dim, dk_stride = _bhsd_dim_stride(dk_layer, kv_format) + dv_dim, dv_stride = _bhsd_dim_stride(dv_layer, kv_format) + dq, dk, dv = graph.sdpa_backward( + name="te_score_mod_sdpa_backward", + q=q, + k=k, + v=v, + o=output, + dO=d_output, + stats=stats, + attn_scale=attn_scale, + use_causal_mask=False, + score_mod=wrapped_score_mod, + score_mod_bprop=wrapped_score_mod_bprop, + use_deterministic_algorithm=deterministic, + ) + dq.set_output(True).set_dim(dq_dim).set_stride(dq_stride) + dk.set_output(True).set_dim(dk_dim).set_stride(dk_stride) + dv.set_output(True).set_dim(dv_dim).set_stride(dv_stride) + + workspace_size = _finalize_cudnn_graph(graph) + return _CudnnScoreModBwdGraphEntry( + graph=graph, + q=q, + k=k, + v=v, + output=output, + d_output=d_output, + stats=stats, + dq=dq, + dk=dk, + dv=dv, + score_mod_graph_tensors=score_mod_graph_tensors, + score_mod_bprop_graph_tensors=score_mod_bprop_graph_tensors, + workspace_size=workspace_size, + ) + + +def _get_cudnn_score_mod_bwd_graph( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + output_layer: torch.Tensor, + d_out: torch.Tensor, + stats_bhs1: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], + deterministic: bool, +) -> _CudnnScoreModBwdGraphEntry: + """Return a cached cuDNN frontend graph for score_mod bprop.""" + key = _cudnn_score_mod_bwd_cache_key( + query_layer, + key_layer, + value_layer, + output_layer, + d_out, + stats_bhs1, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + deterministic, + ) + entry = _cudnn_score_mod_graph_cache.get(key) + if entry is None: + entry = _build_cudnn_score_mod_bwd_graph( + query_layer, + key_layer, + value_layer, + output_layer, + d_out, + stats_bhs1, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + deterministic, + ) + _cudnn_score_mod_graph_cache[key] = entry + return entry + + +class FusedAttentionWithScoreModFunc(torch.autograd.Function): + """cuDNN frontend Python SDPA path with score_mod callback support.""" + + @staticmethod + def forward( + ctx, + is_training: bool, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + q_format: str, + kv_format: str, + attn_scale: float, + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Dict[str, torch.Tensor]], + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], + deterministic: bool, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + q_bhsd_dim, _ = _bhsd_dim_stride(query_layer, q_format) + score_mod_tensors = dict(score_mod_tensors or {}) + score_mod_bprop_tensors = dict(score_mod_bprop_tensors or {}) + output_shape = (*query_layer.shape[:-1], value_layer.shape[-1]) + output_layer = torch.empty(output_shape, device=query_layer.device, dtype=query_layer.dtype) + if is_training: + stats_bhs1 = torch.empty( + (*q_bhsd_dim[:-1], 1), + device=query_layer.device, + dtype=torch.float32, + ) + else: + stats_bhs1 = None + + entry = _get_cudnn_score_mod_fwd_graph( + is_training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + attn_scale, + score_mod, + score_mod_tensors, + output_layer, + stats_bhs1, + ) + variant_pack = { + entry.q: query_layer, + entry.k: key_layer, + entry.v: value_layer, + entry.output: output_layer, + } + if is_training: + variant_pack[entry.stats] = stats_bhs1 + for name, graph_tensor in entry.score_mod_graph_tensors.items(): + variant_pack[graph_tensor] = score_mod_tensors[name] + + _execute_cudnn_graph( + entry.graph, + variant_pack, + entry.workspace_size, + query_layer.device, + ) + + ctx.is_training = is_training + ctx.q_format = q_format + ctx.kv_format = kv_format + ctx.attn_scale = attn_scale + ctx.score_mod = score_mod + ctx.score_mod_bprop = score_mod_bprop + ctx.score_mod_tensor_names = tuple(score_mod_tensors.keys()) + ctx.score_mod_bprop_tensor_names = tuple(score_mod_bprop_tensors.keys()) + ctx.deterministic = deterministic + if is_training: + # save_for_backward records version counters without copying tensor data. + # This catches in-place score_mod tensor updates before backward. + ctx.save_for_backward( + query_layer, + key_layer, + value_layer, + output_layer, + stats_bhs1, + *score_mod_tensors.values(), + *score_mod_bprop_tensors.values(), + ) + else: + ctx.save_for_backward(query_layer, key_layer, value_layer, output_layer) + + return output_layer + + @staticmethod + def backward(ctx, d_out: torch.Tensor): + # pylint: disable=missing-function-docstring + if not ctx.is_training: + raise RuntimeError( + "score_mod backward requires DotProductAttention to be in training mode." + ) + + saved_tensors = ctx.saved_tensors + query_layer, key_layer, value_layer, output_layer, stats_bhs1 = saved_tensors[:5] + score_mod_tensors_end = 5 + len(ctx.score_mod_tensor_names) + score_mod_tensors = dict( + zip(ctx.score_mod_tensor_names, saved_tensors[5:score_mod_tensors_end]) + ) + score_mod_bprop_tensors = dict( + zip(ctx.score_mod_bprop_tensor_names, saved_tensors[score_mod_tensors_end:]) + ) + d_out = d_out.contiguous() + + dq_layer = torch.empty_like(query_layer) + dk_layer = torch.empty_like(key_layer) + dv_layer = torch.empty_like(value_layer) + entry = _get_cudnn_score_mod_bwd_graph( + query_layer, + key_layer, + value_layer, + output_layer, + d_out, + stats_bhs1, + ctx.q_format, + ctx.kv_format, + ctx.attn_scale, + ctx.score_mod, + ctx.score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + ctx.deterministic, + ) + variant_pack = { + entry.q: query_layer, + entry.k: key_layer, + entry.v: value_layer, + entry.output: output_layer, + entry.d_output: d_out, + entry.stats: stats_bhs1, + entry.dq: dq_layer, + entry.dk: dk_layer, + entry.dv: dv_layer, + } + for name, graph_tensor in entry.score_mod_graph_tensors.items(): + variant_pack[graph_tensor] = score_mod_tensors[name] + for name, graph_tensor in entry.score_mod_bprop_graph_tensors.items(): + variant_pack[graph_tensor] = score_mod_bprop_tensors[name] + + _execute_cudnn_graph( + entry.graph, + variant_pack, + entry.workspace_size, + query_layer.device, + ) + + return ( + None, + dq_layer, + dk_layer, + dv_layer, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + class FusedAttnFunc(torch.autograd.Function): """FusedAttention forward and backward implementation""" @@ -1945,6 +2613,10 @@ def forward( inference_params: Optional[InferenceParams] = None, softmax_offset: torch.Tensor = None, fp8_output: bool = False, + score_mod: Optional[Callable] = None, + score_mod_bprop: Optional[Callable] = None, + score_mod_tensors: Optional[Dict[str, torch.Tensor]] = None, + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: """fused attention fprop""" assert ( @@ -2067,7 +2739,41 @@ def forward( cp_group[0] if cp_comm_type == "a2a+p2p" else cp_group ) - if context_parallel: + if score_mod is not None: + assert not context_parallel, "score_mod is not supported with context parallelism!" + assert not fp8, "score_mod is not supported with FP8 FusedAttention!" + assert not fp8_output, "score_mod is not supported with fp8_output!" + assert not self.return_max_logit, "score_mod is not supported with return_max_logit!" + assert ( + type(query_layer) is torch.Tensor + and type(key_layer) is torch.Tensor + and type(value_layer) is torch.Tensor + ), "score_mod only supports unquantized torch.Tensor Q, K and V inputs!" + assert ( + fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen + ), "score_mod requires the F16/BF16 cuDNN fused attention backend!" + assert ( + attn_mask_type == "no_mask" + and core_attention_bias_type == "no_bias" + and core_attention_bias is None + and self.softmax_type == "vanilla" + and self.attention_dropout == 0.0 + ), "score_mod is mutually exclusive with masks, bias, sink attention and dropout!" + output = FusedAttentionWithScoreModFunc.apply( + self.training, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + self.softmax_scale, + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + self.deterministic, + ) + elif context_parallel: assert ( fp8 or fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 17e9a337a4..95a0b53b4a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -892,6 +892,10 @@ def forward( pad_between_seqs: Optional[bool] = None, fp8_output: Optional[bool] = False, num_splits: Optional[int] = 1, + score_mod: Optional[Callable] = None, + score_mod_bprop: Optional[Callable] = None, + score_mod_tensors: Optional[Dict[str, torch.Tensor]] = None, + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: r""" Dot Product Attention Layer. @@ -1080,6 +1084,18 @@ def forward( Optional split control for FlashAttention-3 only. When set, this value is forwarded to the FA3 backend to control internal kernel splitting behavior for non-context-parallel cases. It is ignored for other backends and when context parallelism is enabled. + score_mod: Optional[Callable], default = None + cuDNN frontend score modification callback. This is a cuDNN-only path and is mutually + exclusive with masks, bias, ALiBi, sink attention, dropout, FP8, context parallelism, + THD format, KV caching, and return_max_logit. The callback signature is + ``score_mod(graph, score, tensors) -> score``. + score_mod_bprop: Optional[Callable], default = None + Optional cuDNN frontend callback for the backward pass of score_mod. The callback + signature is ``score_mod_bprop(graph, dP, tensors) -> dP``. + score_mod_tensors: Optional[Dict[str, torch.Tensor]], default = None + Runtime tensors exposed to score_mod as cuDNN graph tensors. + score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]], default = None + Runtime tensors exposed to score_mod_bprop as cuDNN graph tensors. """ with self.prepare_forward_ctx( @@ -1088,6 +1104,13 @@ def forward( allow_non_contiguous=True, allow_different_data_and_param_types=self.softmax_type != "vanilla", ) as query_layer: + user_supplied_seqlens = ( + cu_seqlens_q is not None + or cu_seqlens_kv is not None + or cu_seqlens_q_padded is not None + or cu_seqlens_kv_padded is not None + ) + # checks for RNG if self.rng_states_tracker is not None and is_graph_capturing(): assert isinstance( @@ -1226,6 +1249,9 @@ def forward( seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) + if score_mod is not None: + assert inference_params is None, "score_mod is not supported with KV caching!" + # update KV cache and retrieve saved tokens from cache for inference if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" @@ -1406,6 +1432,84 @@ def forward( else: pad_between_seqs = False + if score_mod is None: + assert score_mod_bprop is None, "score_mod_bprop requires score_mod!" + assert score_mod_tensors is None, "score_mod_tensors requires score_mod!" + assert ( + score_mod_bprop_tensors is None + ), "score_mod_bprop_tensors requires score_mod!" + else: + assert callable(score_mod), "score_mod must be callable!" + assert score_mod_bprop is None or callable( + score_mod_bprop + ), "score_mod_bprop must be callable when provided!" + assert query_layer.dtype in [ + torch.float16, + torch.bfloat16, + ], "score_mod only supports FP16 and BF16 tensors!" + assert ( + key_layer.dtype == query_layer.dtype and value_layer.dtype == query_layer.dtype + ), "score_mod requires Q, K and V tensors to have the same dtype!" + assert ( + type(query_layer) is torch.Tensor + and type(key_layer) is torch.Tensor + and type(value_layer) is torch.Tensor + ), "score_mod only supports unquantized torch.Tensor Q, K and V inputs!" + assert not self.fp8, "score_mod is not supported with FP8 DotProductAttention!" + assert not fp8_output, "score_mod is not supported with fp8_output!" + assert not context_parallel, "score_mod is not supported with context parallelism!" + assert qkv_format != "thd", "score_mod is not supported with qkv_format='thd'!" + assert ( + not user_supplied_seqlens + ), "score_mod is mutually exclusive with explicit sequence length metadata!" + assert not pad_between_seqs, "score_mod is not supported with pad_between_seqs!" + assert ( + attention_mask is None + ), "score_mod is mutually exclusive with attention_mask!" + assert attn_mask_type == "no_mask", "score_mod requires attn_mask_type='no_mask'!" + assert window_size is None or window_size == ( + -1, + -1, + ), "score_mod is mutually exclusive with sliding window attention!" + assert ( + core_attention_bias_type == "no_bias" and core_attention_bias is None + ), "score_mod is mutually exclusive with attention bias!" + assert alibi_slopes is None, "score_mod is mutually exclusive with ALiBi!" + assert ( + self.softmax_type == "vanilla" + ), "score_mod is mutually exclusive with sink attention!" + assert ( + self.attention_dropout == 0.0 + ), "score_mod is not supported with attention dropout!" + assert ( + not self.return_max_logit + ), "score_mod is not supported with return_max_logit!" + assert ( + not checkpoint_core_attention + ), "score_mod is not supported with checkpoint_core_attention!" + assert ( + not is_graph_capturing() + ), "score_mod is not supported with CUDA graph capture!" + assert num_splits == 1, "score_mod is not supported with num_splits != 1!" + assert q_format in ["sbhd", "bshd"] and kv_format in [ + "sbhd", + "bshd", + ], "score_mod only supports SBHD/BSHD QKV formats!" + if score_mod_tensors is not None: + assert isinstance(score_mod_tensors, dict), "score_mod_tensors must be a dict!" + assert all( + isinstance(k, str) and isinstance(v, torch.Tensor) + for k, v in score_mod_tensors.items() + ), "score_mod_tensors must map string names to torch.Tensor instances!" + if score_mod_bprop_tensors is not None: + assert isinstance( + score_mod_bprop_tensors, dict + ), "score_mod_bprop_tensors must be a dict!" + assert all( + isinstance(k, str) and isinstance(v, torch.Tensor) + for k, v in score_mod_bprop_tensors.items() + ), "score_mod_bprop_tensors must map string names to torch.Tensor instances!" + # gather attention params for get_attention_backend attention_params = dpa_utils.AttentionParams( qkv_type=type(query_layer), @@ -1443,7 +1547,39 @@ def forward( num_splits=num_splits, ) global _attention_backends - if is_in_onnx_export_mode(): + if score_mod is not None: + use_flash_attention = False + flash_attention_backend = None + use_fused_attention = True + use_unfused_attention = False + q_type = dpa_utils.TE_DType[query_layer.dtype] + fused_attention_backend = tex.get_fused_attn_backend( + self.training, + q_type, + q_type, + dpa_utils.QKVLayout["bshd_bshd_bshd"], + dpa_utils.AttnBiasType["no_bias"], + dpa_utils.AttnMaskType["no_mask"], + dpa_utils.SoftmaxType["vanilla"], + 0.0, + num_attention_heads, + num_gqa_groups, + max_seqlen_q, + max_seqlen_kv, + head_dim_qk, + head_dim_v, + -1, + -1, + False, + is_graph_capturing(), + self.deterministic, + ) + if fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend: + raise ValueError( + "score_mod requires a cuDNN FusedAttention backend, but no fused " + "attention backend supports the provided inputs." + ) + elif is_in_onnx_export_mode(): # We do not want to call get_attention_backend() in ONNX mode # and we want to avoid using any global variables like _attention_backends. use_flash_attention = False @@ -1619,6 +1755,10 @@ def forward( inference_params=inference_params, softmax_offset=softmax_offset, fp8_output=fp8_output, + score_mod=score_mod, + score_mod_bprop=score_mod_bprop, + score_mod_tensors=score_mod_tensors, + score_mod_bprop_tensors=score_mod_bprop_tensors, ) if use_unfused_attention: