Skip to content

[JAX] Support for cuDNN-backed flex attention#2985

Open
vcherepanov-nv wants to merge 3 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-jax
Open

[JAX] Support for cuDNN-backed flex attention#2985
vcherepanov-nv wants to merge 3 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-jax

Conversation

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator

Description

This PR introduces an alternative code path for the FusedAttention backend for JAX.
The user can specify score_mod and score_mod_bprop functions, which get routed to the corresponding parameters of the sdpa and sdpa_backward calls to cuDNN FE.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • A new code path for FusedAttention backend, when score_mod (and the related parameters) is specified
  • Tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 13, 2026

Greptile Summary

This PR adds a cuDNN frontend score_mod code path to the JAX FusedAttention backend, letting users pass arbitrary score modification callbacks (e.g. causal masking, relative position bias, softcapping) that are wired into cudnn.pygraph.sdpa / sdpa_backward at JAX trace time.

  • A new _fused_attn_score_mod custom_vjp function handles forward/backward dispatch; make_fused_attn_score_mod_config normalises callbacks and tensors/scalars into a hashable static config used for JAX tracing and graph caching.
  • C++ FFI handlers (FusedAttnScoreModForwardHandler / BackwardHandler) execute the pre-built cuDNN graphs via a global ScoreModGraphRegistry; the Python graph object's refcount is bumped with Py_INCREF at registration but never decremented, because ScoreModGraphEntry has no destructor — this leaks every registered graph for the process lifetime.

Confidence Score: 3/5

The new score_mod path introduces a Python reference leak in the C++ registry that will silently accumulate cuDNN graph objects for the lifetime of any long-running process; this needs to be fixed before widespread deployment.

Every call to register_fused_attn_score_mod_graph increments the Python object refcount (Py_INCREF) but ScoreModGraphEntry has no destructor and the registry never evicts entries, so cuDNN Python graph objects are never released. In a training loop sweeping many attention shapes this continuously grows GPU/CPU memory. The GIL-under-CUDA-FFI pattern also needs documentation before use in multi-threaded environments.

transformer_engine/jax/csrc/extensions/attention.cpp is the most critical — the ScoreModGraphEntry struct and RegisterFusedAttnScoreModGraph function both need attention. transformer_engine/jax/cpp_extensions/attention.py also warrants a second look for the id()-based cache key logic and unbounded graph cache growth.

Important Files Changed

Filename Overview
transformer_engine/jax/csrc/extensions/attention.cpp Adds C++ FFI handlers and a global registry for cuDNN score_mod graphs; contains a Python reference leak (Py_INCREF without Py_DECREF in ScoreModGraphEntry) and acquires the GIL during CUDA FFI execution.
transformer_engine/jax/cpp_extensions/attention.py Introduces cuDNN graph building, caching, and FFI dispatch for score_mod SDPA; id()-based cache keys carry a theoretical false-hit risk after GC, and the Python/C++ graph caches grow without bound.
transformer_engine/jax/attention.py Wires score_mod into fused_attn via a new custom_vjp function; validation is thorough and the early-return path correctly bypasses the legacy attention flow.
transformer_engine/jax/csrc/extensions.h Adds forward/backward handler declarations and RegisterFusedAttnScoreModGraph signature; straightforward header additions.
transformer_engine/jax/csrc/extensions/pybind.cpp Registers the two new FFI handlers and exposes register_fused_attn_score_mod_graph to Python; no issues found.
tests/jax/test_fused_attn.py Adds unit tests covering causal masking, relative position bias, and softcap via score_mod, plus config validation and cache-key stability tests; good coverage of the happy paths.

Sequence Diagram

sequenceDiagram
    participant User as User (Python)
    participant FA as fused_attn()
    participant Cfg as make_fused_attn_score_mod_config()
    participant VJP as _fused_attn_score_mod (custom_vjp)
    participant PyExt as cpp_extensions/attention.py
    participant CppReg as RegisterFusedAttnScoreModGraph (C++)
    participant FFI as XLA FFI (C++)
    participant CuDNN as cuDNN Frontend (Python)

    User->>FA: "score_mod=fn, score_mod_tensors={...}"
    FA->>FA: _validate_fused_attn_score_mod()
    FA->>Cfg: make_fused_attn_score_mod_config(score_mod, ...)
    Cfg-->>FA: config, tensor_operands, bprop_tensor_operands
    FA->>VJP: _fused_attn_score_mod(qkv, tensors, bprop_tensors, config, ...)

    Note over VJP,PyExt: Forward pass (JAX trace time)
    VJP->>PyExt: fused_attn_score_mod_fwd(qkv, tensors, config)
    PyExt->>PyExt: check _score_mod_graph_cache
    alt cache miss
        PyExt->>CuDNN: build cudnn.pygraph, call score_mod callback
        CuDNN-->>PyExt: compiled graph
        PyExt->>CppReg: register_fused_attn_score_mod_graph(graph, uids, ...)
        CppReg->>CppReg: Py_INCREF(py_graph) no DECREF
        CppReg-->>PyExt: graph_id
        PyExt->>PyExt: "cache[key] = (graph_id, workspace_size)"
    end
    PyExt->>FFI: "ffi.ffi_call(te_fused_attn_score_mod_forward_ffi, graph_id=graph_id)"
    FFI->>CuDNN: _execute_with_ptrs() GIL acquired
    CuDNN-->>FFI: output, stats
    FFI-->>VJP: output, softmax_stats

    Note over VJP,PyExt: Backward pass
    VJP->>PyExt: fused_attn_score_mod_bwd(qkv, output, dz, stats, tensors, bprop_tensors, config)
    PyExt->>PyExt: check _score_mod_graph_cache (bwd key)
    alt cache miss
        PyExt->>CuDNN: build sdpa_backward graph, call score_mod + score_mod_bprop callbacks
        PyExt->>CppReg: register_fused_attn_score_mod_graph(...)
        CppReg-->>PyExt: graph_id
    end
    PyExt->>FFI: "ffi.ffi_call(te_fused_attn_score_mod_backward_ffi, graph_id=graph_id)"
    FFI->>CuDNN: _execute_with_ptrs() GIL acquired
    CuDNN-->>FFI: dq, dk, dv
    FFI-->>User: gradients
Loading

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +706 to +713
struct ScoreModGraphEntry {
PyObject *py_graph = nullptr;
std::vector<int64_t> user_uids;
std::vector<int64_t> input_uids;
std::vector<int64_t> output_uids;
std::vector<int64_t> scalar_uids;
std::vector<ScoreModScalarStorage> scalar_values;
};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Python reference leak: Py_INCREF without a matching Py_DECREF

ScoreModGraphEntry stores a raw PyObject* and its refcount is bumped at registration (Py_INCREF(entry->py_graph) at line 833), but the struct has no destructor to call Py_DECREF. Because ScoreModGraphRegistry never removes entries either, every cuDNN Python graph object registered here is permanently immortalised — it will never be collected by Python's GC regardless of what the call site does. Over many different attention shapes or graph configurations this accumulates silently. The fix is to add a destructor that acquires the GIL and calls Py_DECREF, or to store a pybind11::object (which manages the refcount automatically) and ensure destruction always happens under the GIL.

Comment on lines +684 to +692
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)

q_dim, q_stride = _bshd_as_bhsd_dim_stride(q_aval.shape)
k_dim, k_stride = _bshd_as_bhsd_dim_stride(k_aval.shape)
v_dim, v_stride = _bshd_as_bhsd_dim_stride(v_aval.shape)
o_dim, o_stride = _bshd_as_bhsd_dim_stride(output_aval.shape)
do_dim, do_stride = _bshd_as_bhsd_dim_stride(doutput_aval.shape)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 id()-based cache keys can produce false cache hits after GC

_score_mod_callback_cache_key builds its key from id(self_obj) and id(func). Python recycles object addresses after GC, so if a callback instance is collected and a new object (of a different class or with different graph logic) is allocated at the same address, the new config will compare equal to the old one under __eq__. JAX's nondiff-argnum caching then reuses the traced function and graph built for the original callback, silently executing the wrong cuDNN graph. The risk is low for long-lived module-level functions but real for short-lived class instances. Anchoring the key to a non-id stable identifier (e.g., a weakref plus explicit id, or requiring callers to supply an explicit stable key) would eliminate the ambiguity.

Comment on lines +765 to +807
Error_Type ExecuteScoreModGraph(cudaStream_t stream, int64_t graph_id,
const std::vector<void *> &input_ptrs,
const std::vector<void *> &output_ptrs, void *workspace) {
auto entry = GetScoreModGraphEntry(graph_id);
NVTE_CHECK(input_ptrs.size() == entry->input_uids.size(), "cuDNN score_mod graph expected ",
entry->input_uids.size(), " inputs but got ", input_ptrs.size());
NVTE_CHECK(output_ptrs.size() >= entry->output_uids.size(),
"cuDNN score_mod graph expected at least ", entry->output_uids.size(),
" outputs but got ", output_ptrs.size());

std::unordered_map<int64_t, void *> variant_pack;
for (size_t i = 0; i < entry->input_uids.size(); ++i) {
variant_pack.emplace(entry->input_uids[i], input_ptrs[i]);
}
for (size_t i = 0; i < entry->output_uids.size(); ++i) {
variant_pack.emplace(entry->output_uids[i], output_ptrs[i]);
}
for (size_t i = 0; i < entry->scalar_uids.size(); ++i) {
variant_pack.emplace(entry->scalar_uids[i], entry->scalar_values[i].data.data());
}

std::vector<std::intptr_t> user_ptrs;
user_ptrs.reserve(entry->user_uids.size());
for (const auto uid : entry->user_uids) {
auto it = variant_pack.find(uid);
NVTE_CHECK(it != variant_pack.end(), "cuDNN score_mod graph variant pack is missing UID ", uid);
user_ptrs.push_back(reinterpret_cast<std::intptr_t>(it->second));
}

auto handle = GetScoreModCudnnHandle();
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
{
pybind11::gil_scoped_acquire gil;
try {
auto graph = pybind11::reinterpret_borrow<pybind11::object>(entry->py_graph);
graph.attr("_execute_with_ptrs")(user_ptrs, reinterpret_cast<std::intptr_t>(workspace),
reinterpret_cast<std::intptr_t>(handle));
} catch (const pybind11::error_already_set &exc) {
NVTE_ERROR("cuDNN score_mod SDPA graph execution failed: ", exc.what());
}
}
return ffi_with_cuda_error_check();
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 GIL held across a CUDA FFI call boundary

ExecuteScoreModGraph acquires pybind11::gil_scoped_acquire while the CUDA stream is live and calls a Python method (_execute_with_ptrs) synchronously. Any other Python thread that holds the GIL and is waiting on CUDA work will deadlock. More broadly, acquiring the GIL inside an XLA/JAX FFI handler — which JAX may dispatch from a non-Python thread — creates a locking inversion risk. This is by-design if cuDNN's Python frontend has no C-level execution path, but the limitation should be documented and the possibility of multi-threaded JAX dispatch should be explicitly considered.

_SCORE_MOD_UID_DQ = 7
_SCORE_MOD_UID_DK = 8
_SCORE_MOD_UID_DV = 9
_SCORE_MOD_FWD_TENSOR_UID_BASE = 1000
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _score_mod_graph_cache and C++ registry grow without bound

_score_mod_graph_cache is a module-level dict that accumulates (graph_id, workspace_size) entries for every unique (direction, config, aval-tuple) seen during tracing, and the C++ ScoreModGraphRegistry holds the corresponding cuDNN graph objects forever. Each entry keeps a Python cuDNN graph alive (and, due to the missing Py_DECREF noted separately, prevents GC). In long-running services or evaluation loops that sweep over many shapes/dtypes, this leads to unbounded cuDNN graph memory accumulation. An LRU eviction strategy or an explicit graph-release API paired with cache invalidation would contain the growth.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant