From f967a26797c77f408e016638c2a02b45e67e04f9 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Tue, 12 May 2026 05:15:43 +0000 Subject: [PATCH 1/3] Add JAX fused attention score_mod support Signed-off-by: Vladimir Cherepanov --- tests/jax/test_fused_attn.py | 359 ++++++++++- transformer_engine/jax/attention.py | 183 +++++- .../jax/cpp_extensions/attention.py | 574 +++++++++++++++++- transformer_engine/jax/csrc/extensions.h | 11 + .../jax/csrc/extensions/attention.cpp | 224 +++++++ .../jax/csrc/extensions/pybind.cpp | 7 + 6 files changed, 1354 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 1fb0108068..f12435fe5b 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. """Tests for fused attention""" +import importlib.util import os from enum import Enum, auto from dataclasses import dataclass, field @@ -40,7 +41,7 @@ CPStrategy, ReorderStrategy, ) -from transformer_engine.jax.cpp_extensions import FusedAttnHelper +from transformer_engine.jax.cpp_extensions import FusedAttnHelper, make_fused_attn_score_mod_config from transformer_engine_jax import ( NVTE_Fused_Attn_Backend, get_cudnn_version, @@ -54,6 +55,149 @@ _deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) +def _has_cudnn_frontend_python(): + return importlib.util.find_spec("cudnn") is not None + + +def _score_mod_causal(graph, score, tensors): + import cudnn # pylint: disable=import-outside-toplevel + + row_index = graph.gen_index( + input=score, + axis=2, + compute_data_type=cudnn.data_type.INT32, + ) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = graph.gen_index( + input=score, + axis=3, + compute_data_type=cudnn.data_type.INT32, + ) + col_index.set_data_type(cudnn.data_type.INT32) + keep = graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + keep.set_data_type(cudnn.data_type.BOOLEAN) + return graph.binary_select(input0=score, input1=tensors["neg_inf"], mask=keep) + + +def _score_mod_causal_bprop(graph, dscore, tensors): + import cudnn # pylint: disable=import-outside-toplevel + + row_index = graph.gen_index( + input=dscore, + axis=2, + compute_data_type=cudnn.data_type.INT32, + ) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = graph.gen_index( + input=dscore, + axis=3, + compute_data_type=cudnn.data_type.INT32, + ) + col_index.set_data_type(cudnn.data_type.INT32) + keep = graph.cmp_ge( + input=row_index, + comparison=col_index, + compute_data_type=cudnn.data_type.BOOLEAN, + ) + keep.set_data_type(cudnn.data_type.BOOLEAN) + return graph.binary_select(input0=dscore, input1=tensors["zero"], mask=keep) + + +def _score_mod_relative_position(graph, score, tensors): + import cudnn # pylint: disable=import-outside-toplevel + + row_index = graph.gen_index( + input=score, + axis=2, + compute_data_type=cudnn.data_type.INT32, + ) + row_index.set_data_type(cudnn.data_type.INT32) + col_index = graph.gen_index( + input=score, + axis=3, + compute_data_type=cudnn.data_type.INT32, + ) + col_index.set_data_type(cudnn.data_type.INT32) + relative_position = graph.sub( + a=row_index, + b=col_index, + compute_data_type=cudnn.data_type.FLOAT, + ) + relative_position.set_data_type(cudnn.data_type.FLOAT) + return graph.add( + a=score, + b=relative_position, + compute_data_type=cudnn.data_type.FLOAT, + ) + + +class _ScoreModSoftcap: + """cuDNN frontend score_mod implementing softcapping.""" + + def __init__(self): + self.before_tanh_activation = None + + def forward(self, graph, score, tensors): + import cudnn # pylint: disable=import-outside-toplevel + + self.before_tanh_activation = graph.div( + a=score, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + self.before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) + tanh_out = graph.tanh(input=self.before_tanh_activation) + tanh_out.set_data_type(cudnn.data_type.FLOAT) + return graph.mul( + a=tanh_out, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + + def backward(self, graph, dscore, tensors): + import cudnn # pylint: disable=import-outside-toplevel + + d_tanh_out = graph.mul( + a=dscore, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + d_tanh_out.set_data_type(cudnn.data_type.FLOAT) + d_before_tanh_activation = graph.tanh_backward( + loss=d_tanh_out, + input=self.before_tanh_activation, + compute_data_type=cudnn.data_type.FLOAT, + ) + d_before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) + return graph.div( + a=d_before_tanh_activation, + b=tensors["softcap"], + compute_data_type=cudnn.data_type.FLOAT, + ) + + +def _reference_attention( + query, key, value, scale, *, causal=False, relative_position=False, softcap=None +): + scores = jnp.einsum("bqhd,bkhd->bhqk", query, key).astype(jnp.float32) * scale + if causal: + q_pos = jnp.arange(query.shape[1])[:, None] + kv_pos = jnp.arange(key.shape[1])[None, :] + scores = jnp.where(q_pos >= kv_pos, scores, -1e9) + if relative_position: + q_pos = jnp.arange(query.shape[1], dtype=jnp.float32)[:, None] + kv_pos = jnp.arange(key.shape[1], dtype=jnp.float32)[None, :] + scores = scores + q_pos - kv_pos + if softcap is not None: + scores = softcap * jnp.tanh(scores / softcap) + probs = jax.nn.softmax(scores, axis=-1) + return jnp.einsum("bhqk,bkhd->bqhd", probs, value).astype(query.dtype) + + @pytest.fixture(autouse=True, scope="module") def init(): """ @@ -138,6 +282,219 @@ def general_dot_product_attention( return context +def _require_cudnn_frontend_score_mod(): + cudnn = pytest.importorskip("cudnn", reason="cuDNN Python frontend is required for score_mod") + version = tuple(int(part) for part in cudnn.backend_version_string().split(".")[:2]) + if version < (9, 6): + pytest.skip("cuDNN score_mod SDPA requires cuDNN frontend 9.6 or newer") + + +def _identity_score_mod(_graph, score, _tensors): + return score + + +def test_fused_attn_score_mod_validation_rejects_masks_without_cudnn_frontend(): + q = jax.ShapeDtypeStruct((1, 16, 1, 128), jnp.float16) + k = jax.ShapeDtypeStruct((1, 16, 1, 128), jnp.float16) + v = jax.ShapeDtypeStruct((1, 16, 1, 128), jnp.float16) + + with pytest.raises(ValueError, match="mutually exclusive with attention masks"): + fused_attn( + (q, k, v), + None, + None, + None, + AttnBiasType.NO_BIAS, + AttnMaskType.CAUSAL_MASK, + QKVLayout.BSHD_BSHD_BSHD, + AttnSoftmaxType.VANILLA_SOFTMAX, + 1.0, + 0.0, + True, + score_mod=_identity_score_mod, + ) + + +def test_fused_attn_score_mod_config_splits_tensors_and_pass_by_value_scalars(): + tensor = jnp.ones((1, 1, 1, 1), dtype=jnp.float32) + + config, tensor_operands, bprop_tensor_operands = make_fused_attn_score_mod_config( + _identity_score_mod, + None, + {"tensor": tensor, "neg_inf": -1e9}, + None, + 0.125, + True, + ) + + assert config.score_mod_tensor_names == ("tensor",) + assert len(tensor_operands) == 1 + assert tensor_operands[0].shape == (1, 1, 1, 1) + assert len(bprop_tensor_operands) == 0 + assert len(config.score_mod_scalars) == 1 + assert config.score_mod_scalars[0].name == "neg_inf" + assert config.score_mod_scalars[0].dtype == "float32" + assert len(config.score_mod_scalars[0].value) == np.dtype(np.float32).itemsize + + +@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") +def test_fused_attn_score_mod_relative_position_optional_bprop(): + _require_cudnn_frontend_score_mod() + + key = jax.random.key(0) + q_key, k_key, v_key = jax.random.split(key, 3) + q = (0.125 * jax.random.normal(q_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( + jnp.float16 + ) + k = (0.125 * jax.random.normal(k_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( + jnp.float16 + ) + v = (0.125 * jax.random.normal(v_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( + jnp.float16 + ) + scale = 1.0 / sqrt(q.shape[-1]) + + def score_mod_loss(query, key_, value): + out = fused_attn( + (query, key_, value), + None, + None, + None, + AttnBiasType.NO_BIAS, + AttnMaskType.NO_MASK, + QKVLayout.BSHD_BSHD_BSHD, + AttnSoftmaxType.VANILLA_SOFTMAX, + scale, + 0.0, + True, + score_mod=_score_mod_relative_position, + ) + return jnp.sum(out.astype(jnp.float32)), out + + def ref_loss(query, key_, value): + out = _reference_attention(query, key_, value, scale, relative_position=True) + return jnp.sum(out.astype(jnp.float32)), out + + (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( + score_mod_loss, argnums=(0, 1, 2), has_aux=True + )(q, k, v) + (ref_value, ref_out), ref_grads = value_and_grad( + ref_loss, argnums=(0, 1, 2), has_aux=True + )(q, k, v) + + assert_allclose(score_mod_out, ref_out, rtol=5e-2, atol=5e-2) + assert_allclose(score_mod_value, ref_value, rtol=5e-2, atol=5e-2) + for grad, ref_grad in zip(score_mod_grads, ref_grads): + assert_allclose(grad, ref_grad, rtol=5e-2, atol=5e-2) + + +@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") +def test_fused_attn_score_mod_causal_with_bprop(): + _require_cudnn_frontend_score_mod() + + key = jax.random.key(1) + q_key, k_key, v_key = jax.random.split(key, 3) + q = (0.125 * jax.random.normal(q_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( + jnp.float16 + ) + k = (0.125 * jax.random.normal(k_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( + jnp.float16 + ) + v = (0.125 * jax.random.normal(v_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( + jnp.float16 + ) + scale = 1.0 / sqrt(q.shape[-1]) + + def score_mod_loss(query, key_, value): + out = fused_attn( + (query, key_, value), + None, + None, + None, + AttnBiasType.NO_BIAS, + AttnMaskType.NO_MASK, + QKVLayout.BSHD_BSHD_BSHD, + AttnSoftmaxType.VANILLA_SOFTMAX, + scale, + 0.0, + True, + score_mod=_score_mod_causal, + score_mod_bprop=_score_mod_causal_bprop, + score_mod_tensors={"neg_inf": -1e9}, + score_mod_bprop_tensors={"zero": 0.0}, + ) + return jnp.sum(out.astype(jnp.float32)), out + + def ref_loss(query, key_, value): + out = _reference_attention(query, key_, value, scale, causal=True) + return jnp.sum(out.astype(jnp.float32)), out + + (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( + score_mod_loss, argnums=(0, 1, 2), has_aux=True + )(q, k, v) + (ref_value, ref_out), ref_grads = value_and_grad( + ref_loss, argnums=(0, 1, 2), has_aux=True + )(q, k, v) + + assert_allclose(score_mod_out, ref_out, rtol=5e-2, atol=5e-2) + assert_allclose(score_mod_value, ref_value, rtol=5e-2, atol=5e-2) + for grad, ref_grad in zip(score_mod_grads, ref_grads): + assert_allclose(grad, ref_grad, rtol=5e-2, atol=5e-2) + + +@pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") +def test_fused_attn_score_mod_softcap_with_bprop(): + _require_cudnn_frontend_score_mod() + + key = jax.random.key(2) + q_key, k_key, v_key, d_out_key = jax.random.split(key, 4) + q = jax.random.normal(q_key, (1, 16, 2, 64), dtype=jnp.float16) + k = jax.random.normal(k_key, (1, 16, 2, 64), dtype=jnp.float16) + v = (0.1 * jax.random.normal(v_key, (1, 16, 2, 64), dtype=jnp.float16)).astype( + jnp.float16 + ) + d_out = jax.random.normal(d_out_key, (1, 16, 2, 64), dtype=jnp.float16) + scale = 1.0 / sqrt(q.shape[-1]) + softcap = 0.8 + softcap_score_mod = _ScoreModSoftcap() + + def score_mod_loss(query, key_, value): + out = fused_attn( + (query, key_, value), + None, + None, + None, + AttnBiasType.NO_BIAS, + AttnMaskType.NO_MASK, + QKVLayout.BSHD_BSHD_BSHD, + AttnSoftmaxType.VANILLA_SOFTMAX, + scale, + 0.0, + True, + score_mod=softcap_score_mod.forward, + score_mod_bprop=softcap_score_mod.backward, + score_mod_tensors={"softcap": softcap}, + score_mod_bprop_tensors={"softcap": softcap}, + ) + return jnp.sum(out.astype(jnp.float32) * d_out.astype(jnp.float32)), out + + def ref_loss(query, key_, value): + out = _reference_attention(query, key_, value, scale, softcap=softcap) + return jnp.sum(out.astype(jnp.float32) * d_out.astype(jnp.float32)), out + + (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( + score_mod_loss, argnums=(0, 1, 2), has_aux=True + )(q, k, v) + (ref_value, ref_out), ref_grads = value_and_grad( + ref_loss, argnums=(0, 1, 2), has_aux=True + )(q, k, v) + + assert_allclose(score_mod_out, ref_out, rtol=7e-2, atol=7e-2) + assert_allclose(score_mod_value, ref_value, rtol=7e-2, atol=7e-2) + for grad, ref_grad in zip(score_mod_grads, ref_grads): + assert_allclose(grad, ref_grad, rtol=7e-2, atol=7e-2) + + @jax.jit def make_causal_mask( segment_ids_q: ArrayLike, diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index f54a043fd2..c418dce4de 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -5,7 +5,7 @@ from __future__ import annotations from enum import Enum from functools import partial -from typing import Optional, Tuple, Union +from typing import Any, Callable, Mapping, Optional, Tuple, Union import warnings from jax.ad_checkpoint import checkpoint_name @@ -1391,10 +1391,129 @@ def _fused_attn_bwd_rule( _fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule) +@partial(jax.custom_vjp, nondiff_argnums=(3, 4)) +def _fused_attn_score_mod( + qkv: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], + score_mod_tensors: Tuple[jnp.ndarray, ...], + score_mod_bprop_tensors: Tuple[jnp.ndarray, ...], + config, + context_checkpoint_name: str, +): + output, _ = _fused_attn_score_mod_fwd_rule( + qkv, + score_mod_tensors, + score_mod_bprop_tensors, + config, + context_checkpoint_name, + ) + return output + + +def _fused_attn_score_mod_fwd_rule( + qkv, + score_mod_tensors, + score_mod_bprop_tensors, + config, + context_checkpoint_name, +): + output, softmax_stats = tex.fused_attn_score_mod_fwd(qkv, score_mod_tensors, config) + output = checkpoint_name(output, context_checkpoint_name) + softmax_stats = checkpoint_name(softmax_stats, context_checkpoint_name) + return output, (qkv, score_mod_tensors, score_mod_bprop_tensors, output, softmax_stats) + + +def _fused_attn_score_mod_bwd_rule(config, context_checkpoint_name, ctx, dz): + del context_checkpoint_name + qkv, score_mod_tensors, score_mod_bprop_tensors, output, softmax_stats = ctx + grad_qkv = tex.fused_attn_score_mod_bwd( + qkv, + output, + dz, + softmax_stats, + score_mod_tensors, + score_mod_bprop_tensors, + config, + ) + return ( + grad_qkv, + tuple(None for _ in score_mod_tensors), + tuple(None for _ in score_mod_bprop_tensors), + ) + + +_fused_attn_score_mod.defvjp( + _fused_attn_score_mod_fwd_rule, _fused_attn_score_mod_bwd_rule +) + + +def _validate_fused_attn_score_mod( + qkv: Tuple[jnp.ndarray, ...], + bias: Optional[jnp.ndarray], + sequence_descriptor: Optional[SequenceDescriptor], + seed: Optional[jnp.ndarray], + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + qkv_layout: QKVLayout, + softmax_type: AttnSoftmaxType, + dropout_probability: float, + max_segments_per_seq: int, + window_size: Optional[Tuple[int, int]], + context_parallel_strategy: CPStrategy, + context_parallel_causal_load_balanced: bool, + context_parallel_axis: str, + softmax_offset: Optional[jnp.ndarray], + stripe_size: int | None, +): + """Validate arguments for the cuDNN frontend score_mod path.""" + header = "score_mod fused_attn" + if qkv_layout is not QKVLayout.BSHD_BSHD_BSHD: + raise ValueError(f"{header} currently only supports QKVLayout.BSHD_BSHD_BSHD.") + if len(qkv) != 3: + raise ValueError(f"{header} requires separate query, key and value tensors.") + if any(tensor.ndim != 4 for tensor in qkv): + raise ValueError(f"{header} requires rank-4 BSHD query/key/value tensors.") + q, k, v = qkv + if q.dtype != k.dtype or q.dtype != v.dtype: + raise ValueError(f"{header} requires query, key and value to have the same dtype.") + if q.dtype not in (jnp.float16, jnp.bfloat16): + raise ValueError(f"{header} only supports FP16/BF16 query, key and value tensors.") + if q.shape[0] != k.shape[0] or q.shape[0] != v.shape[0]: + raise ValueError(f"{header} requires matching batch dimensions.") + if k.shape[1] != v.shape[1]: + raise ValueError(f"{header} requires key and value sequence lengths to match.") + if k.shape[2] != v.shape[2]: + raise ValueError(f"{header} requires key and value head counts to match.") + if q.shape[3] != k.shape[3]: + raise ValueError(f"{header} requires query/key head dimensions to match.") + + if bias is not None or attn_bias_type is not AttnBiasType.NO_BIAS: + raise ValueError(f"{header} is mutually exclusive with attention bias.") + if sequence_descriptor is not None: + raise ValueError(f"{header} is mutually exclusive with padding/sequence descriptors.") + if seed is not None: + raise ValueError(f"{header} is mutually exclusive with dropout seed.") + if attn_mask_type is not AttnMaskType.NO_MASK: + raise ValueError(f"{header} is mutually exclusive with attention masks.") + if softmax_type is not AttnSoftmaxType.VANILLA_SOFTMAX or softmax_offset is not None: + raise ValueError(f"{header} only supports vanilla softmax without softmax_offset.") + if dropout_probability != 0.0: + raise ValueError(f"{header} is mutually exclusive with dropout.") + if max_segments_per_seq != 1: + raise ValueError(f"{header} is mutually exclusive with packed/ragged sequence metadata.") + if window_size not in (None, (-1, -1)): + raise ValueError(f"{header} is mutually exclusive with sliding-window attention.") + if context_parallel_strategy is not CPStrategy.DEFAULT: + raise ValueError(f"{header} is mutually exclusive with context parallelism.") + if context_parallel_causal_load_balanced or context_parallel_axis: + raise ValueError(f"{header} is mutually exclusive with context parallelism.") + if stripe_size is not None: + raise ValueError(f"{header} is mutually exclusive with striped context parallelism.") + + def fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], - sequence_descriptor: SequenceDescriptor, + sequence_descriptor: Optional[SequenceDescriptor], seed: Optional[jnp.ndarray], attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, @@ -1411,6 +1530,10 @@ def fused_attn( context_checkpoint_name: str = "context", softmax_offset: Optional[jnp.ndarray] = None, stripe_size: int | None = None, + score_mod: Optional[Callable] = None, + score_mod_bprop: Optional[Callable] = None, + score_mod_tensors: Optional[Mapping[str, Any]] = None, + score_mod_bprop_tensors: Optional[Mapping[str, Any]] = None, ): """ Perform cuDNN fused attention. @@ -1453,6 +1576,20 @@ def fused_attn( Currently, a stripe_size > 1 is only supported for CP + THD + Striped + AG, whereas a stripe_size=1 is supported for both, CP + THD + Striped + AG and CP + THD + Striped + P2P(Ring) None indicates no striping strategy + score_mod (Optional[Callable]): Optional cuDNN frontend score modification callback. + The callback is called as `score_mod(graph, score, tensors)` while building a + cuDNN frontend graph. When provided, this path only supports BSHD_BSHD_BSHD + layout and is mutually exclusive with masks, padding, bias, dropout, context + parallelism, sliding windows, and non-vanilla softmax. + score_mod_bprop (Optional[Callable]): Optional score modification backward callback, + called as `score_mod_bprop(graph, dscore, tensors)`. If omitted, cuDNN uses the + default backward behavior for the forward score modification graph. + score_mod_tensors (Optional[Mapping[str, Any]]): Additional tensors or Python/NumPy + scalars made available to `score_mod` through its `tensors` dictionary. Scalars + are represented as cuDNN pass-by-value tensors. Tensor entries are treated as + non-differentiable auxiliary inputs. + score_mod_bprop_tensors (Optional[Mapping[str, Any]]): Additional tensors or + Python/NumPy scalars made available to `score_mod_bprop`. Returns: (jnp.ndarray): The output tensor from the fused attention. @@ -1485,6 +1622,48 @@ def fused_attn( AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK, QKVLayout.T3HD, 0.125, 0, True, 3) """ + if score_mod is None: + if score_mod_bprop is not None: + raise ValueError("score_mod_bprop requires score_mod to be provided.") + if score_mod_tensors is not None: + raise ValueError("score_mod_tensors requires score_mod to be provided.") + if score_mod_bprop_tensors is not None: + raise ValueError("score_mod_bprop_tensors requires score_mod to be provided.") + else: + _validate_fused_attn_score_mod( + qkv, + bias, + sequence_descriptor, + seed, + attn_bias_type, + attn_mask_type, + qkv_layout, + softmax_type, + dropout_probability, + max_segments_per_seq, + window_size, + context_parallel_strategy, + context_parallel_causal_load_balanced, + context_parallel_axis, + softmax_offset, + stripe_size, + ) + config, tensor_operands, bprop_tensor_operands = tex.make_fused_attn_score_mod_config( + score_mod, + score_mod_bprop, + score_mod_tensors, + score_mod_bprop_tensors, + scaling_factor, + is_training, + ) + return _fused_attn_score_mod( + qkv, + tensor_operands, + bprop_tensor_operands, + config, + context_checkpoint_name, + ) + if sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray): warnings.warn( "Pass mask to fused_attn is deprecated, please use SequenceDescriptor instead. " diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 489bfde997..f90efec2fa 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -7,10 +7,11 @@ import warnings from dataclasses import dataclass, replace from functools import partial, reduce -from typing import Optional, Tuple +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple import jax import jax.numpy as jnp +import numpy as np from jax import dtypes, lax, ffi from jax.sharding import PartitionSpec, NamedSharding from jax.experimental.custom_partitioning import SdyShardingRule @@ -54,6 +55,9 @@ "FusedAttnHelper", "fused_attn_fwd", "fused_attn_bwd", + "make_fused_attn_score_mod_config", + "fused_attn_score_mod_fwd", + "fused_attn_score_mod_bwd", ] @@ -267,6 +271,574 @@ def check_seed(self, seed, dropout_probability, is_training): return seed +@dataclass(frozen=True) +class _ScoreModScalarSpec: + """Static pass-by-value scalar used when building a cuDNN frontend graph.""" + + name: str + dtype: str + value: bytes + dim: Tuple[int, ...] = (1, 1, 1, 1) + stride: Tuple[int, ...] = (1, 1, 1, 1) + + +@dataclass(frozen=True) +class _FusedAttnScoreModConfig: + """Static configuration for cuDNN frontend score_mod SDPA graphs.""" + + score_mod: Callable + score_mod_bprop: Optional[Callable] + score_mod_tensor_names: Tuple[str, ...] + score_mod_bprop_tensor_names: Tuple[str, ...] + score_mod_scalars: Tuple[_ScoreModScalarSpec, ...] + score_mod_bprop_scalars: Tuple[_ScoreModScalarSpec, ...] + scaling_factor: float + is_training: bool + deterministic: bool + + def __hash__(self): + return hash( + ( + id(self.score_mod), + id(self.score_mod_bprop) if self.score_mod_bprop is not None else None, + self.score_mod_tensor_names, + self.score_mod_bprop_tensor_names, + self.score_mod_scalars, + self.score_mod_bprop_scalars, + self.scaling_factor, + self.is_training, + self.deterministic, + ) + ) + + def __eq__(self, other): + if not isinstance(other, _FusedAttnScoreModConfig): + return False + return ( + self.score_mod is other.score_mod + and self.score_mod_bprop is other.score_mod_bprop + and self.score_mod_tensor_names == other.score_mod_tensor_names + and self.score_mod_bprop_tensor_names == other.score_mod_bprop_tensor_names + and self.score_mod_scalars == other.score_mod_scalars + and self.score_mod_bprop_scalars == other.score_mod_bprop_scalars + and self.scaling_factor == other.scaling_factor + and self.is_training == other.is_training + and self.deterministic == other.deterministic + ) + + +_SCORE_MOD_UID_Q = 1 +_SCORE_MOD_UID_K = 2 +_SCORE_MOD_UID_V = 3 +_SCORE_MOD_UID_O = 4 +_SCORE_MOD_UID_STATS = 5 +_SCORE_MOD_UID_DO = 6 +_SCORE_MOD_UID_DQ = 7 +_SCORE_MOD_UID_DK = 8 +_SCORE_MOD_UID_DV = 9 +_SCORE_MOD_FWD_TENSOR_UID_BASE = 1000 +_SCORE_MOD_BPROP_TENSOR_UID_BASE = 2000 +_SCORE_MOD_FWD_SCALAR_UID_BASE = 3000 +_SCORE_MOD_BPROP_SCALAR_UID_BASE = 4000 + +_score_mod_graph_cache: Dict[Tuple[Any, ...], Tuple[int, int]] = {} + + +def _row_major_stride(shape: Sequence[int]) -> Tuple[int, ...]: + stride = [] + running = 1 + for dim in reversed(tuple(shape)): + stride.append(running) + running *= dim + return tuple(reversed(stride)) + + +def _bshd_as_bhsd_dim_stride(shape: Sequence[int]) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + if len(shape) != 4: + raise ValueError(f"score_mod requires rank-4 BSHD tensors, got shape={shape}.") + batch, seqlen, heads, head_dim = tuple(shape) + return ( + (batch, heads, seqlen, head_dim), + (seqlen * heads * head_dim, head_dim, heads * head_dim, 1), + ) + + +def _dtype_name(dtype) -> str: + return str(jnp.dtype(dtype)) + + +def _is_array_operand(value: Any) -> bool: + return hasattr(value, "shape") and hasattr(value, "dtype") and not isinstance( + value, (bool, int, float, complex, np.generic) + ) + + +def _scalar_to_spec(name: str, value: Any) -> _ScoreModScalarSpec: + if isinstance(value, bool): + dtype = np.bool_ + elif isinstance(value, int): + dtype = np.int32 + elif isinstance(value, float): + dtype = np.float32 + elif isinstance(value, np.generic): + dtype = value.dtype + else: + scalar = np.asarray(value) + if scalar.shape != (): + raise ValueError( + f"score_mod tensor '{name}' is neither a JAX array nor a scalar pass-by-value." + ) + dtype = scalar.dtype + + scalar = np.full((1, 1, 1, 1), value, dtype=dtype) + return _ScoreModScalarSpec(name=name, dtype=str(scalar.dtype), value=scalar.tobytes()) + + +def _split_score_mod_tensors( + tensors: Optional[Mapping[str, Any]], *, argument_name: str +) -> Tuple[Tuple[str, ...], Tuple[jnp.ndarray, ...], Tuple[_ScoreModScalarSpec, ...]]: + if tensors is None: + return (), (), () + if not isinstance(tensors, Mapping): + raise TypeError(f"{argument_name} must be a mapping from string names to tensors/scalars.") + + names = [] + operands = [] + scalars = [] + for name, value in tensors.items(): + if not isinstance(name, str): + raise TypeError(f"{argument_name} keys must be strings, got {type(name).__name__}.") + if _is_array_operand(value): + if len(value.shape) == 0: + raise ValueError( + f"{argument_name}['{name}'] is a rank-0 array. Use a Python/NumPy scalar " + "for cuDNN pass-by-value scalars, or reshape it to a tensor." + ) + names.append(name) + operands.append(jnp.asarray(value)) + else: + scalars.append(_scalar_to_spec(name, value)) + return tuple(names), tuple(operands), tuple(scalars) + + +def make_fused_attn_score_mod_config( + score_mod: Callable, + score_mod_bprop: Optional[Callable], + score_mod_tensors: Optional[Mapping[str, Any]], + score_mod_bprop_tensors: Optional[Mapping[str, Any]], + scaling_factor: float, + is_training: bool, +) -> Tuple[_FusedAttnScoreModConfig, Tuple[jnp.ndarray, ...], Tuple[jnp.ndarray, ...]]: + """Normalize score_mod operands and create a static graph-build config.""" + if not callable(score_mod): + raise TypeError("score_mod must be callable.") + if score_mod_bprop is not None and not callable(score_mod_bprop): + raise TypeError("score_mod_bprop must be callable when provided.") + if score_mod_bprop is None and score_mod_bprop_tensors: + raise ValueError("score_mod_bprop_tensors requires score_mod_bprop to be provided.") + + tensor_names, tensor_operands, scalars = _split_score_mod_tensors( + score_mod_tensors, argument_name="score_mod_tensors" + ) + bprop_tensor_names, bprop_tensor_operands, bprop_scalars = _split_score_mod_tensors( + score_mod_bprop_tensors, argument_name="score_mod_bprop_tensors" + ) + config = _FusedAttnScoreModConfig( + score_mod=score_mod, + score_mod_bprop=score_mod_bprop, + score_mod_tensor_names=tensor_names, + score_mod_bprop_tensor_names=bprop_tensor_names, + score_mod_scalars=scalars, + score_mod_bprop_scalars=bprop_scalars, + scaling_factor=float(scaling_factor), + is_training=bool(is_training), + deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), + ) + return config, tensor_operands, bprop_tensor_operands + + +def _cudnn_data_type(cudnn, dtype): + dtype = jnp.dtype(dtype) + if dtype == jnp.float16: + return cudnn.data_type.HALF + if dtype == jnp.bfloat16: + return cudnn.data_type.BFLOAT16 + if dtype == jnp.float32: + return cudnn.data_type.FLOAT + if dtype == jnp.float64: + return cudnn.data_type.DOUBLE + if dtype == jnp.int32: + return cudnn.data_type.INT32 + if dtype == jnp.int64: + return cudnn.data_type.INT64 + if dtype == jnp.uint8: + return cudnn.data_type.UINT8 + if dtype == jnp.bool_: + return cudnn.data_type.BOOLEAN + raise ValueError(f"Unsupported score_mod tensor dtype: {dtype}.") + + +def _cudnn_data_type_from_name(cudnn, dtype_name: str): + if dtype_name == "bfloat16": + return cudnn.data_type.BFLOAT16 + return _cudnn_data_type(cudnn, np.dtype(dtype_name)) + + +def _graph_tensor_from_aval(cudnn, graph, name: str, aval, uid: int): + shape = tuple(int(dim) for dim in aval.shape) + return graph.tensor( + name=name, + dim=shape, + stride=_row_major_stride(shape), + data_type=_cudnn_data_type(cudnn, aval.dtype), + uid=uid, + ) + + +def _score_mod_graph_tensors( + cudnn, + graph, + names: Tuple[str, ...], + avals: Sequence[Any], + scalars: Tuple[_ScoreModScalarSpec, ...], + tensor_uid_base: int, + scalar_uid_base: int, +): + graph_tensors = {} + tensor_uids = [] + for index, (name, aval) in enumerate(zip(names, avals)): + uid = tensor_uid_base + index + graph_tensors[name] = _graph_tensor_from_aval(cudnn, graph, name, aval, uid) + tensor_uids.append(uid) + + scalar_uids = [] + scalar_values = [] + for index, scalar in enumerate(scalars): + uid = scalar_uid_base + index + graph_tensors[scalar.name] = graph.tensor( + name=scalar.name, + dim=scalar.dim, + stride=scalar.stride, + is_pass_by_value=True, + data_type=_cudnn_data_type_from_name(cudnn, scalar.dtype), + uid=uid, + ) + scalar_uids.append(uid) + scalar_values.append(scalar.value) + + return graph_tensors, tuple(tensor_uids), tuple(scalar_uids), tuple(scalar_values) + + +def _wrap_score_mod(score_mod: Optional[Callable], graph_tensors: Dict[str, Any]): + if score_mod is None: + return None + + def wrapped_score_mod(sdpa_graph, score_tensor): + return score_mod(sdpa_graph, score_tensor, graph_tensors) + + return wrapped_score_mod + + +def _finalize_score_mod_graph(cudnn, graph) -> int: + graph.validate() + graph.build_operation_graph() + try: + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + except cudnn.cudnnGraphNotSupportedError as exc: + raise RuntimeError(f"cuDNN score_mod SDPA graph is not supported: {exc}") from exc + graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE) + return max(int(graph.get_workspace_size()), 1) + + +def _graph_cache_key( + direction: str, + config: _FusedAttnScoreModConfig, + avals: Sequence[Any], +) -> Tuple[Any, ...]: + return ( + direction, + config, + tuple((tuple(aval.shape), _dtype_name(aval.dtype)) for aval in avals), + ) + + +def _shape_dtype(value) -> jax.ShapeDtypeStruct: + return jax.ShapeDtypeStruct(tuple(value.shape), value.dtype) + + +def _import_cudnn_for_score_mod(): + try: + import cudnn # pylint: disable=import-outside-toplevel + except ImportError as exc: + raise ImportError( + "score_mod fused_attn requires the cuDNN frontend Python package (`cudnn`)." + ) from exc + return cudnn + + +def _build_score_mod_fwd_graph(q_aval, k_aval, v_aval, score_mod_avals, config): + cudnn = _import_cudnn_for_score_mod() + + io_data_type = _cudnn_data_type(cudnn, q_aval.dtype) + graph = cudnn.pygraph( + io_data_type=io_data_type, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + q_dim, q_stride = _bshd_as_bhsd_dim_stride(q_aval.shape) + k_dim, k_stride = _bshd_as_bhsd_dim_stride(k_aval.shape) + v_dim, v_stride = _bshd_as_bhsd_dim_stride(v_aval.shape) + q = graph.tensor( + name="q", dim=q_dim, stride=q_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_Q + ) + k = graph.tensor( + name="k", dim=k_dim, stride=k_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_K + ) + v = graph.tensor( + name="v", dim=v_dim, stride=v_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_V + ) + + score_mod_graph_tensors, tensor_uids, scalar_uids, scalar_values = _score_mod_graph_tensors( + cudnn, + graph, + config.score_mod_tensor_names, + score_mod_avals, + config.score_mod_scalars, + _SCORE_MOD_FWD_TENSOR_UID_BASE, + _SCORE_MOD_FWD_SCALAR_UID_BASE, + ) + + output, stats = graph.sdpa( + name="te_score_mod_sdpa", + q=q, + k=k, + v=v, + generate_stats=config.is_training, + attn_scale=config.scaling_factor, + use_causal_mask=False, + score_mod=_wrap_score_mod(config.score_mod, score_mod_graph_tensors), + ) + + batch, q_seqlen, q_heads, _ = q_aval.shape + _, _, _, v_head_dim = v_aval.shape + output_dim, output_stride = _bshd_as_bhsd_dim_stride((batch, q_seqlen, q_heads, v_head_dim)) + output.set_output(True).set_uid(_SCORE_MOD_UID_O).set_dim(output_dim).set_stride( + output_stride + ) + output.set_data_type(io_data_type) + + output_uids = [_SCORE_MOD_UID_O] + if config.is_training: + stats_shape = (batch, q_heads, q_seqlen, 1) + stats.set_output(True).set_uid(_SCORE_MOD_UID_STATS).set_dim(stats_shape).set_stride( + _row_major_stride(stats_shape) + ) + stats.set_data_type(cudnn.data_type.FLOAT) + output_uids.append(_SCORE_MOD_UID_STATS) + + workspace_size = _finalize_score_mod_graph(cudnn, graph) + graph_id = transformer_engine_jax.register_fused_attn_score_mod_graph( + graph, + [int(uid) for uid in graph._get_variant_pack_uids_sorted()], + [_SCORE_MOD_UID_Q, _SCORE_MOD_UID_K, _SCORE_MOD_UID_V, *tensor_uids], + output_uids, + list(scalar_uids), + list(scalar_values), + ) + return graph_id, workspace_size + + +def _build_score_mod_bwd_graph( + q_aval, + k_aval, + v_aval, + output_aval, + doutput_aval, + stats_aval, + score_mod_avals, + score_mod_bprop_avals, + config, +): + cudnn = _import_cudnn_for_score_mod() + + io_data_type = _cudnn_data_type(cudnn, q_aval.dtype) + graph = cudnn.pygraph( + io_data_type=io_data_type, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + q_dim, q_stride = _bshd_as_bhsd_dim_stride(q_aval.shape) + k_dim, k_stride = _bshd_as_bhsd_dim_stride(k_aval.shape) + v_dim, v_stride = _bshd_as_bhsd_dim_stride(v_aval.shape) + o_dim, o_stride = _bshd_as_bhsd_dim_stride(output_aval.shape) + do_dim, do_stride = _bshd_as_bhsd_dim_stride(doutput_aval.shape) + q = graph.tensor( + name="q", dim=q_dim, stride=q_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_Q + ) + k = graph.tensor( + name="k", dim=k_dim, stride=k_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_K + ) + v = graph.tensor( + name="v", dim=v_dim, stride=v_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_V + ) + output = graph.tensor( + name="o", dim=o_dim, stride=o_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_O + ) + doutput = graph.tensor( + name="dO", dim=do_dim, stride=do_stride, data_type=io_data_type, uid=_SCORE_MOD_UID_DO + ) + stats = graph.tensor( + name="stats", + dim=tuple(int(dim) for dim in stats_aval.shape), + stride=_row_major_stride(stats_aval.shape), + data_type=cudnn.data_type.FLOAT, + uid=_SCORE_MOD_UID_STATS, + ) + + score_mod_graph_tensors, tensor_uids, scalar_uids, scalar_values = _score_mod_graph_tensors( + cudnn, + graph, + config.score_mod_tensor_names, + score_mod_avals, + config.score_mod_scalars, + _SCORE_MOD_FWD_TENSOR_UID_BASE, + _SCORE_MOD_FWD_SCALAR_UID_BASE, + ) + ( + score_mod_bprop_graph_tensors, + bprop_tensor_uids, + bprop_scalar_uids, + bprop_scalar_values, + ) = _score_mod_graph_tensors( + cudnn, + graph, + config.score_mod_bprop_tensor_names, + score_mod_bprop_avals, + config.score_mod_bprop_scalars, + _SCORE_MOD_BPROP_TENSOR_UID_BASE, + _SCORE_MOD_BPROP_SCALAR_UID_BASE, + ) + + dq, dk, dv = graph.sdpa_backward( + name="te_score_mod_sdpa_backward", + q=q, + k=k, + v=v, + o=output, + dO=doutput, + stats=stats, + attn_scale=config.scaling_factor, + use_causal_mask=False, + score_mod=_wrap_score_mod(config.score_mod, score_mod_graph_tensors), + score_mod_bprop=_wrap_score_mod(config.score_mod_bprop, score_mod_bprop_graph_tensors), + use_deterministic_algorithm=config.deterministic, + ) + + dq.set_output(True).set_uid(_SCORE_MOD_UID_DQ).set_dim(q_dim).set_stride(q_stride) + dk.set_output(True).set_uid(_SCORE_MOD_UID_DK).set_dim(k_dim).set_stride(k_stride) + dv.set_output(True).set_uid(_SCORE_MOD_UID_DV).set_dim(v_dim).set_stride(v_stride) + + workspace_size = _finalize_score_mod_graph(cudnn, graph) + graph_id = transformer_engine_jax.register_fused_attn_score_mod_graph( + graph, + [int(uid) for uid in graph._get_variant_pack_uids_sorted()], + [ + _SCORE_MOD_UID_Q, + _SCORE_MOD_UID_K, + _SCORE_MOD_UID_V, + _SCORE_MOD_UID_O, + _SCORE_MOD_UID_DO, + _SCORE_MOD_UID_STATS, + *tensor_uids, + *bprop_tensor_uids, + ], + [_SCORE_MOD_UID_DQ, _SCORE_MOD_UID_DK, _SCORE_MOD_UID_DV], + [*scalar_uids, *bprop_scalar_uids], + [*scalar_values, *bprop_scalar_values], + ) + return graph_id, workspace_size + + +def fused_attn_score_mod_fwd( + qkv: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], + score_mod_tensors: Tuple[jnp.ndarray, ...], + config: _FusedAttnScoreModConfig, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Run cuDNN frontend SDPA forward with a score_mod callback.""" + q, k, v = qkv + q_aval, k_aval, v_aval = map(_shape_dtype, (q, k, v)) + score_mod_avals = tuple(_shape_dtype(arg) for arg in score_mod_tensors) + key = _graph_cache_key("fwd", config, (q_aval, k_aval, v_aval, *score_mod_avals)) + if key not in _score_mod_graph_cache: + _score_mod_graph_cache[key] = _build_score_mod_fwd_graph( + q_aval, k_aval, v_aval, score_mod_avals, config + ) + graph_id, workspace_size = _score_mod_graph_cache[key] + + batch, q_seqlen, q_heads, _ = q.shape + _, _, _, v_head_dim = v.shape + output_shape = jax.ShapeDtypeStruct((batch, q_seqlen, q_heads, v_head_dim), q.dtype) + stats_shape = (batch, q_heads, q_seqlen, 1) if config.is_training else (0,) + stats = jax.ShapeDtypeStruct(stats_shape, jnp.float32) + workspace = jax.ShapeDtypeStruct((workspace_size,), jnp.uint8) + output, softmax_stats, _ = ffi.ffi_call( + "te_fused_attn_score_mod_forward_ffi", + (output_shape, stats, workspace), + )(q, k, v, *score_mod_tensors, graph_id=graph_id) + return output, softmax_stats + + +def fused_attn_score_mod_bwd( + qkv: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], + output: jnp.ndarray, + doutput: jnp.ndarray, + softmax_stats: jnp.ndarray, + score_mod_tensors: Tuple[jnp.ndarray, ...], + score_mod_bprop_tensors: Tuple[jnp.ndarray, ...], + config: _FusedAttnScoreModConfig, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Run cuDNN frontend SDPA backward with score_mod callbacks.""" + if not config.is_training: + raise RuntimeError("score_mod backward requires fused_attn(..., is_training=True).") + + q, k, v = qkv + all_inputs = (q, k, v, output, doutput, softmax_stats, *score_mod_tensors) + all_inputs = (*all_inputs, *score_mod_bprop_tensors) + avals = tuple(_shape_dtype(arg) for arg in all_inputs) + key = _graph_cache_key("bwd", config, avals) + if key not in _score_mod_graph_cache: + _score_mod_graph_cache[key] = _build_score_mod_bwd_graph( + *avals[:6], + avals[6 : 6 + len(score_mod_tensors)], + avals[6 + len(score_mod_tensors) :], + config, + ) + graph_id, workspace_size = _score_mod_graph_cache[key] + + dq = jax.ShapeDtypeStruct(q.shape, q.dtype) + dk = jax.ShapeDtypeStruct(k.shape, k.dtype) + dv = jax.ShapeDtypeStruct(v.shape, v.dtype) + workspace = jax.ShapeDtypeStruct((workspace_size,), jnp.uint8) + dq, dk, dv, _ = ffi.ffi_call( + "te_fused_attn_score_mod_backward_ffi", + (dq, dk, dv, workspace), + )( + q, + k, + v, + output, + doutput, + softmax_stats, + *score_mod_tensors, + *score_mod_bprop_tensors, + graph_id=graph_id, + ) + return dq, dk, dv + + def generate_cu_seqlen(actual_seqlen): """ Generating cumsum seqlen for a batch diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 2ecfedc8a2..f17788a068 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -146,6 +146,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnScoreModForwardHandler); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnScoreModBackwardHandler); + NVTE_Fused_Attn_Backend GetFusedAttnBackend( bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, @@ -169,6 +173,13 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal); +int64_t RegisterFusedAttnScoreModGraph(pybind11::object graph, + const std::vector& user_uids, + const std::vector& input_uids, + const std::vector& output_uids, + const std::vector& scalar_uids, + const std::vector& scalar_values); + // GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmV2Handler); diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index ed136d7b9e..26a3e5e8c0 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -8,6 +8,13 @@ #include "transformer_engine/fused_attn.h" #include "transformer_engine/transformer_engine.h" +#include +#include +#include +#include +#include +#include + namespace transformer_engine { namespace jax { @@ -689,5 +696,222 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI, .Attrs(), FFI_CudaGraph_Traits); +namespace { + +struct ScoreModScalarStorage { + alignas(16) std::array data{}; + size_t size = 0; +}; + +struct ScoreModGraphEntry { + PyObject *py_graph = nullptr; + std::vector user_uids; + std::vector input_uids; + std::vector output_uids; + std::vector scalar_uids; + std::vector scalar_values; +}; + +std::unordered_map> &ScoreModGraphRegistry() { + static std::unordered_map> registry; + return registry; +} + +std::mutex &ScoreModGraphRegistryMutex() { + static std::mutex mutex; + return mutex; +} + +std::atomic &NextScoreModGraphId() { + static std::atomic next_id{1}; + return next_id; +} + +struct ScoreModCudnnHandleCache { + std::unordered_map handles; + + cudnnHandle_t GetHandle() { + int device_id = 0; + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + auto it = handles.find(device_id); + if (it == handles.end()) { + cudnnHandle_t handle = nullptr; + NVTE_CHECK_CUDNN(cudnnCreate(&handle)); + it = handles.emplace(device_id, handle).first; + } + return it->second; + } + + ~ScoreModCudnnHandleCache() { + for (auto &[_, handle] : handles) { + cudnnDestroy(handle); + } + } +}; + +cudnnHandle_t GetScoreModCudnnHandle() { + static thread_local ScoreModCudnnHandleCache cache; + return cache.GetHandle(); +} + +std::shared_ptr GetScoreModGraphEntry(int64_t graph_id) { + std::lock_guard lock(ScoreModGraphRegistryMutex()); + auto ®istry = ScoreModGraphRegistry(); + auto it = registry.find(graph_id); + NVTE_CHECK(it != registry.end(), "Unknown cuDNN score_mod graph id: ", graph_id); + return it->second; +} + +Error_Type ExecuteScoreModGraph(cudaStream_t stream, int64_t graph_id, + const std::vector &input_ptrs, + const std::vector &output_ptrs, void *workspace) { + auto entry = GetScoreModGraphEntry(graph_id); + NVTE_CHECK(input_ptrs.size() == entry->input_uids.size(), + "cuDNN score_mod graph expected ", entry->input_uids.size(), " inputs but got ", + input_ptrs.size()); + NVTE_CHECK(output_ptrs.size() >= entry->output_uids.size(), + "cuDNN score_mod graph expected at least ", entry->output_uids.size(), + " outputs but got ", output_ptrs.size()); + + std::unordered_map variant_pack; + for (size_t i = 0; i < entry->input_uids.size(); ++i) { + variant_pack.emplace(entry->input_uids[i], input_ptrs[i]); + } + for (size_t i = 0; i < entry->output_uids.size(); ++i) { + variant_pack.emplace(entry->output_uids[i], output_ptrs[i]); + } + for (size_t i = 0; i < entry->scalar_uids.size(); ++i) { + variant_pack.emplace(entry->scalar_uids[i], entry->scalar_values[i].data.data()); + } + + std::vector user_ptrs; + user_ptrs.reserve(entry->user_uids.size()); + for (const auto uid : entry->user_uids) { + auto it = variant_pack.find(uid); + NVTE_CHECK(it != variant_pack.end(), "cuDNN score_mod graph variant pack is missing UID ", + uid); + user_ptrs.push_back(reinterpret_cast(it->second)); + } + + auto handle = GetScoreModCudnnHandle(); + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + { + pybind11::gil_scoped_acquire gil; + try { + auto graph = pybind11::reinterpret_borrow(entry->py_graph); + graph.attr("_execute_with_ptrs")(user_ptrs, reinterpret_cast(workspace), + reinterpret_cast(handle)); + } catch (const pybind11::error_already_set &exc) { + NVTE_ERROR("cuDNN score_mod SDPA graph execution failed: ", exc.what()); + } + } + return ffi_with_cuda_error_check(); +} + +void AppendRemainingBuffers(Variadic_Buffer_Type args, std::vector *ptrs) { + ptrs->reserve(ptrs->size() + args.size()); + for (size_t i = 0; i < args.size(); ++i) { + auto maybe_buf = args.get(i); + NVTE_CHECK(!maybe_buf.has_error(), "Failed to decode variadic score_mod input buffer."); + ptrs->push_back(maybe_buf.value().untyped_data()); + } +} + +} // namespace + +int64_t RegisterFusedAttnScoreModGraph(pybind11::object graph, + const std::vector &user_uids, + const std::vector &input_uids, + const std::vector &output_uids, + const std::vector &scalar_uids, + const std::vector &scalar_values) { + NVTE_CHECK(!graph.is_none(), "Cannot register an empty cuDNN score_mod graph."); + NVTE_CHECK(!user_uids.empty(), "Cannot register a cuDNN score_mod graph without variant UIDs."); + NVTE_CHECK(scalar_uids.size() == scalar_values.size(), + "Mismatched score_mod scalar uid/value counts."); + + auto entry = std::make_shared(); + entry->py_graph = graph.ptr(); + Py_INCREF(entry->py_graph); + entry->user_uids = user_uids; + entry->input_uids = input_uids; + entry->output_uids = output_uids; + entry->scalar_uids = scalar_uids; + entry->scalar_values.reserve(scalar_values.size()); + for (const auto &value : scalar_values) { + NVTE_CHECK(value.size() <= 16, "score_mod pass-by-value scalars must be at most 16 bytes."); + ScoreModScalarStorage storage; + storage.size = value.size(); + std::copy(value.begin(), value.end(), storage.data.begin()); + entry->scalar_values.push_back(storage); + } + + const int64_t graph_id = NextScoreModGraphId().fetch_add(1); + { + std::lock_guard lock(ScoreModGraphRegistryMutex()); + ScoreModGraphRegistry().emplace(graph_id, std::move(entry)); + } + return graph_id; +} + +Error_Type FusedAttnScoreModForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, + Buffer_Type v_buf, Variadic_Buffer_Type score_mod_args, + Result_Type output_buf, Result_Type stats_buf, + Result_Type workspace_buf, Dictionary attrs) { + int64_t graph_id = get_attr_value(attrs, "graph_id"); + std::vector input_ptrs = {q_buf.untyped_data(), k_buf.untyped_data(), + v_buf.untyped_data()}; + AppendRemainingBuffers(score_mod_args, &input_ptrs); + + std::vector output_ptrs = {output_buf->untyped_data(), stats_buf->untyped_data()}; + return ExecuteScoreModGraph(stream, graph_id, input_ptrs, output_ptrs, + workspace_buf->untyped_data()); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnScoreModForwardHandler, FusedAttnScoreModForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // q + .Arg() // k + .Arg() // v + .RemainingArgs() // score_mod tensor operands + .Ret() // output + .Ret() // stats + .Ret() // workspace + .Attrs()); + +Error_Type FusedAttnScoreModBackwardFFI( + cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf, + Buffer_Type output_buf, Buffer_Type doutput_buf, Buffer_Type stats_buf, + Variadic_Buffer_Type score_mod_args, Result_Type dq_buf, Result_Type dk_buf, + Result_Type dv_buf, Result_Type workspace_buf, Dictionary attrs) { + int64_t graph_id = get_attr_value(attrs, "graph_id"); + std::vector input_ptrs = {q_buf.untyped_data(), k_buf.untyped_data(), + v_buf.untyped_data(), output_buf.untyped_data(), + doutput_buf.untyped_data(), stats_buf.untyped_data()}; + AppendRemainingBuffers(score_mod_args, &input_ptrs); + + std::vector output_ptrs = {dq_buf->untyped_data(), dk_buf->untyped_data(), + dv_buf->untyped_data()}; + return ExecuteScoreModGraph(stream, graph_id, input_ptrs, output_ptrs, + workspace_buf->untyped_data()); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnScoreModBackwardHandler, FusedAttnScoreModBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // q + .Arg() // k + .Arg() // v + .Arg() // output + .Arg() // doutput + .Arg() // stats + .RemainingArgs() // score_mod tensor operands + .Ret() // dq + .Ret() // dk + .Ret() // dv + .Ret() // workspace + .Attrs()); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 70d0403b3e..bdb4507323 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -64,6 +64,12 @@ pybind11::dict Registrations() { dict["te_fused_attn_backward_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler)); + dict["te_fused_attn_score_mod_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(FusedAttnScoreModForwardHandler)); + dict["te_fused_attn_score_mod_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(FusedAttnScoreModBackwardHandler)); // GEMM dict["te_gemm_ffi"] = @@ -121,6 +127,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_norm_bwd_workspace_sizes", &GetNormBackwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); + m.def("register_fused_attn_score_mod_graph", &RegisterFusedAttnScoreModGraph); m.def("get_topk_workspace_sizes", &GetTopkWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); From 6b0532824215979ba7ef2dff64aae5ec637ef984 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Tue, 12 May 2026 05:38:04 +0000 Subject: [PATCH 2/3] Stabilize score_mod callback cache keys Signed-off-by: Vladimir Cherepanov --- tests/jax/test_fused_attn.py | 41 +++++++++++++++++++ .../jax/cpp_extensions/attention.py | 23 +++++++++-- 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index f12435fe5b..001fa049ce 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -337,6 +337,47 @@ def test_fused_attn_score_mod_config_splits_tensors_and_pass_by_value_scalars(): assert len(config.score_mod_scalars[0].value) == np.dtype(np.float32).itemsize +def test_fused_attn_score_mod_config_stabilizes_bound_method_cache_keys(): + softcap_score_mod = _ScoreModSoftcap() + first_forward = softcap_score_mod.forward + second_forward = softcap_score_mod.forward + first_backward = softcap_score_mod.backward + second_backward = softcap_score_mod.backward + + assert first_forward is not second_forward + assert first_backward is not second_backward + + config_1, _, _ = make_fused_attn_score_mod_config( + first_forward, + first_backward, + {"softcap": 0.8}, + {"softcap": 0.8}, + 0.125, + True, + ) + config_2, _, _ = make_fused_attn_score_mod_config( + second_forward, + second_backward, + {"softcap": 0.8}, + {"softcap": 0.8}, + 0.125, + True, + ) + other_softcap_score_mod = _ScoreModSoftcap() + config_3, _, _ = make_fused_attn_score_mod_config( + other_softcap_score_mod.forward, + other_softcap_score_mod.backward, + {"softcap": 0.8}, + {"softcap": 0.8}, + 0.125, + True, + ) + + assert config_1 == config_2 + assert hash(config_1) == hash(config_2) + assert config_1 != config_3 + + @pytest.mark.skipif(not _has_cudnn_frontend_python(), reason="cuDNN Python frontend is required") def test_fused_attn_score_mod_relative_position_optional_bprop(): _require_cudnn_frontend_score_mod() diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index f90efec2fa..6b4d9f8025 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -282,12 +282,25 @@ class _ScoreModScalarSpec: stride: Tuple[int, ...] = (1, 1, 1, 1) +def _score_mod_callback_cache_key(callback: Optional[Callable]) -> Optional[Tuple[Any, ...]]: + """Return a stable cache key for callbacks that may be bound methods.""" + if callback is None: + return None + self_obj = getattr(callback, "__self__", None) + func = getattr(callback, "__func__", None) + if self_obj is not None and func is not None: + return ("bound_method", id(self_obj), id(func)) + return ("callable", id(callback)) + + @dataclass(frozen=True) class _FusedAttnScoreModConfig: """Static configuration for cuDNN frontend score_mod SDPA graphs.""" score_mod: Callable score_mod_bprop: Optional[Callable] + score_mod_key: Tuple[Any, ...] + score_mod_bprop_key: Optional[Tuple[Any, ...]] score_mod_tensor_names: Tuple[str, ...] score_mod_bprop_tensor_names: Tuple[str, ...] score_mod_scalars: Tuple[_ScoreModScalarSpec, ...] @@ -299,8 +312,8 @@ class _FusedAttnScoreModConfig: def __hash__(self): return hash( ( - id(self.score_mod), - id(self.score_mod_bprop) if self.score_mod_bprop is not None else None, + self.score_mod_key, + self.score_mod_bprop_key, self.score_mod_tensor_names, self.score_mod_bprop_tensor_names, self.score_mod_scalars, @@ -315,8 +328,8 @@ def __eq__(self, other): if not isinstance(other, _FusedAttnScoreModConfig): return False return ( - self.score_mod is other.score_mod - and self.score_mod_bprop is other.score_mod_bprop + self.score_mod_key == other.score_mod_key + and self.score_mod_bprop_key == other.score_mod_bprop_key and self.score_mod_tensor_names == other.score_mod_tensor_names and self.score_mod_bprop_tensor_names == other.score_mod_bprop_tensor_names and self.score_mod_scalars == other.score_mod_scalars @@ -446,6 +459,8 @@ def make_fused_attn_score_mod_config( config = _FusedAttnScoreModConfig( score_mod=score_mod, score_mod_bprop=score_mod_bprop, + score_mod_key=_score_mod_callback_cache_key(score_mod), + score_mod_bprop_key=_score_mod_callback_cache_key(score_mod_bprop), score_mod_tensor_names=tensor_names, score_mod_bprop_tensor_names=bprop_tensor_names, score_mod_scalars=scalars, From 1a9635201ffaeb9463956867cc7cfd903eab620b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 May 2026 03:19:59 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_fused_attn.py | 46 +++++++------------ transformer_engine/jax/attention.py | 4 +- .../jax/cpp_extensions/attention.py | 10 ++-- .../jax/csrc/extensions/attention.cpp | 37 ++++++++------- 4 files changed, 40 insertions(+), 57 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 001fa049ce..1a7658fb8a 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -384,15 +384,9 @@ def test_fused_attn_score_mod_relative_position_optional_bprop(): key = jax.random.key(0) q_key, k_key, v_key = jax.random.split(key, 3) - q = (0.125 * jax.random.normal(q_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( - jnp.float16 - ) - k = (0.125 * jax.random.normal(k_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( - jnp.float16 - ) - v = (0.125 * jax.random.normal(v_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( - jnp.float16 - ) + q = (0.125 * jax.random.normal(q_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + k = (0.125 * jax.random.normal(k_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + v = (0.125 * jax.random.normal(v_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) scale = 1.0 / sqrt(q.shape[-1]) def score_mod_loss(query, key_, value): @@ -419,9 +413,9 @@ def ref_loss(query, key_, value): (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( score_mod_loss, argnums=(0, 1, 2), has_aux=True )(q, k, v) - (ref_value, ref_out), ref_grads = value_and_grad( - ref_loss, argnums=(0, 1, 2), has_aux=True - )(q, k, v) + (ref_value, ref_out), ref_grads = value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)( + q, k, v + ) assert_allclose(score_mod_out, ref_out, rtol=5e-2, atol=5e-2) assert_allclose(score_mod_value, ref_value, rtol=5e-2, atol=5e-2) @@ -435,15 +429,9 @@ def test_fused_attn_score_mod_causal_with_bprop(): key = jax.random.key(1) q_key, k_key, v_key = jax.random.split(key, 3) - q = (0.125 * jax.random.normal(q_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( - jnp.float16 - ) - k = (0.125 * jax.random.normal(k_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( - jnp.float16 - ) - v = (0.125 * jax.random.normal(v_key, (1, 64, 2, 128), dtype=jnp.float16)).astype( - jnp.float16 - ) + q = (0.125 * jax.random.normal(q_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + k = (0.125 * jax.random.normal(k_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) + v = (0.125 * jax.random.normal(v_key, (1, 64, 2, 128), dtype=jnp.float16)).astype(jnp.float16) scale = 1.0 / sqrt(q.shape[-1]) def score_mod_loss(query, key_, value): @@ -473,9 +461,9 @@ def ref_loss(query, key_, value): (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( score_mod_loss, argnums=(0, 1, 2), has_aux=True )(q, k, v) - (ref_value, ref_out), ref_grads = value_and_grad( - ref_loss, argnums=(0, 1, 2), has_aux=True - )(q, k, v) + (ref_value, ref_out), ref_grads = value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)( + q, k, v + ) assert_allclose(score_mod_out, ref_out, rtol=5e-2, atol=5e-2) assert_allclose(score_mod_value, ref_value, rtol=5e-2, atol=5e-2) @@ -491,9 +479,7 @@ def test_fused_attn_score_mod_softcap_with_bprop(): q_key, k_key, v_key, d_out_key = jax.random.split(key, 4) q = jax.random.normal(q_key, (1, 16, 2, 64), dtype=jnp.float16) k = jax.random.normal(k_key, (1, 16, 2, 64), dtype=jnp.float16) - v = (0.1 * jax.random.normal(v_key, (1, 16, 2, 64), dtype=jnp.float16)).astype( - jnp.float16 - ) + v = (0.1 * jax.random.normal(v_key, (1, 16, 2, 64), dtype=jnp.float16)).astype(jnp.float16) d_out = jax.random.normal(d_out_key, (1, 16, 2, 64), dtype=jnp.float16) scale = 1.0 / sqrt(q.shape[-1]) softcap = 0.8 @@ -526,9 +512,9 @@ def ref_loss(query, key_, value): (score_mod_value, score_mod_out), score_mod_grads = value_and_grad( score_mod_loss, argnums=(0, 1, 2), has_aux=True )(q, k, v) - (ref_value, ref_out), ref_grads = value_and_grad( - ref_loss, argnums=(0, 1, 2), has_aux=True - )(q, k, v) + (ref_value, ref_out), ref_grads = value_and_grad(ref_loss, argnums=(0, 1, 2), has_aux=True)( + q, k, v + ) assert_allclose(score_mod_out, ref_out, rtol=7e-2, atol=7e-2) assert_allclose(score_mod_value, ref_value, rtol=7e-2, atol=7e-2) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index c418dce4de..adbbcd02fa 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -1441,9 +1441,7 @@ def _fused_attn_score_mod_bwd_rule(config, context_checkpoint_name, ctx, dz): ) -_fused_attn_score_mod.defvjp( - _fused_attn_score_mod_fwd_rule, _fused_attn_score_mod_bwd_rule -) +_fused_attn_score_mod.defvjp(_fused_attn_score_mod_fwd_rule, _fused_attn_score_mod_bwd_rule) def _validate_fused_attn_score_mod( diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 6b4d9f8025..fa9d9f1467 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -381,8 +381,10 @@ def _dtype_name(dtype) -> str: def _is_array_operand(value: Any) -> bool: - return hasattr(value, "shape") and hasattr(value, "dtype") and not isinstance( - value, (bool, int, float, complex, np.generic) + return ( + hasattr(value, "shape") + and hasattr(value, "dtype") + and not isinstance(value, (bool, int, float, complex, np.generic)) ) @@ -639,9 +641,7 @@ def _build_score_mod_fwd_graph(q_aval, k_aval, v_aval, score_mod_avals, config): batch, q_seqlen, q_heads, _ = q_aval.shape _, _, _, v_head_dim = v_aval.shape output_dim, output_stride = _bshd_as_bhsd_dim_stride((batch, q_seqlen, q_heads, v_head_dim)) - output.set_output(True).set_uid(_SCORE_MOD_UID_O).set_dim(output_dim).set_stride( - output_stride - ) + output.set_output(True).set_uid(_SCORE_MOD_UID_O).set_dim(output_dim).set_stride(output_stride) output.set_data_type(io_data_type) output_uids = [_SCORE_MOD_UID_O] diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 26a3e5e8c0..665ac40277 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -4,10 +4,6 @@ * See LICENSE for license information. ************************************************************************/ -#include "../extensions.h" -#include "transformer_engine/fused_attn.h" -#include "transformer_engine/transformer_engine.h" - #include #include #include @@ -15,6 +11,10 @@ #include #include +#include "../extensions.h" +#include "transformer_engine/fused_attn.h" +#include "transformer_engine/transformer_engine.h" + namespace transformer_engine { namespace jax { @@ -766,9 +766,8 @@ Error_Type ExecuteScoreModGraph(cudaStream_t stream, int64_t graph_id, const std::vector &input_ptrs, const std::vector &output_ptrs, void *workspace) { auto entry = GetScoreModGraphEntry(graph_id); - NVTE_CHECK(input_ptrs.size() == entry->input_uids.size(), - "cuDNN score_mod graph expected ", entry->input_uids.size(), " inputs but got ", - input_ptrs.size()); + NVTE_CHECK(input_ptrs.size() == entry->input_uids.size(), "cuDNN score_mod graph expected ", + entry->input_uids.size(), " inputs but got ", input_ptrs.size()); NVTE_CHECK(output_ptrs.size() >= entry->output_uids.size(), "cuDNN score_mod graph expected at least ", entry->output_uids.size(), " outputs but got ", output_ptrs.size()); @@ -788,8 +787,7 @@ Error_Type ExecuteScoreModGraph(cudaStream_t stream, int64_t graph_id, user_ptrs.reserve(entry->user_uids.size()); for (const auto uid : entry->user_uids) { auto it = variant_pack.find(uid); - NVTE_CHECK(it != variant_pack.end(), "cuDNN score_mod graph variant pack is missing UID ", - uid); + NVTE_CHECK(it != variant_pack.end(), "cuDNN score_mod graph variant pack is missing UID ", uid); user_ptrs.push_back(reinterpret_cast(it->second)); } @@ -860,7 +858,7 @@ Error_Type FusedAttnScoreModForwardFFI(cudaStream_t stream, Buffer_Type q_buf, B Result_Type workspace_buf, Dictionary attrs) { int64_t graph_id = get_attr_value(attrs, "graph_id"); std::vector input_ptrs = {q_buf.untyped_data(), k_buf.untyped_data(), - v_buf.untyped_data()}; + v_buf.untyped_data()}; AppendRemainingBuffers(score_mod_args, &input_ptrs); std::vector output_ptrs = {output_buf->untyped_data(), stats_buf->untyped_data()}; @@ -880,19 +878,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnScoreModForwardHandler, FusedAttnScoreMod .Ret() // workspace .Attrs()); -Error_Type FusedAttnScoreModBackwardFFI( - cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf, - Buffer_Type output_buf, Buffer_Type doutput_buf, Buffer_Type stats_buf, - Variadic_Buffer_Type score_mod_args, Result_Type dq_buf, Result_Type dk_buf, - Result_Type dv_buf, Result_Type workspace_buf, Dictionary attrs) { +Error_Type FusedAttnScoreModBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, + Buffer_Type v_buf, Buffer_Type output_buf, + Buffer_Type doutput_buf, Buffer_Type stats_buf, + Variadic_Buffer_Type score_mod_args, Result_Type dq_buf, + Result_Type dk_buf, Result_Type dv_buf, + Result_Type workspace_buf, Dictionary attrs) { int64_t graph_id = get_attr_value(attrs, "graph_id"); - std::vector input_ptrs = {q_buf.untyped_data(), k_buf.untyped_data(), - v_buf.untyped_data(), output_buf.untyped_data(), - doutput_buf.untyped_data(), stats_buf.untyped_data()}; + std::vector input_ptrs = {q_buf.untyped_data(), k_buf.untyped_data(), + v_buf.untyped_data(), output_buf.untyped_data(), + doutput_buf.untyped_data(), stats_buf.untyped_data()}; AppendRemainingBuffers(score_mod_args, &input_ptrs); std::vector output_ptrs = {dq_buf->untyped_data(), dk_buf->untyped_data(), - dv_buf->untyped_data()}; + dv_buf->untyped_data()}; return ExecuteScoreModGraph(stream, graph_id, input_ptrs, output_ptrs, workspace_buf->untyped_data()); }