[JAX] Support for cuDNN-backed flex attention#2985
Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds a cuDNN frontend
Confidence Score: 3/5The new score_mod path introduces a Python reference leak in the C++ registry that will silently accumulate cuDNN graph objects for the lifetime of any long-running process; this needs to be fixed before widespread deployment. Every call to register_fused_attn_score_mod_graph increments the Python object refcount (Py_INCREF) but ScoreModGraphEntry has no destructor and the registry never evicts entries, so cuDNN Python graph objects are never released. In a training loop sweeping many attention shapes this continuously grows GPU/CPU memory. The GIL-under-CUDA-FFI pattern also needs documentation before use in multi-threaded environments. transformer_engine/jax/csrc/extensions/attention.cpp is the most critical — the ScoreModGraphEntry struct and RegisterFusedAttnScoreModGraph function both need attention. transformer_engine/jax/cpp_extensions/attention.py also warrants a second look for the id()-based cache key logic and unbounded graph cache growth. Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User (Python)
participant FA as fused_attn()
participant Cfg as make_fused_attn_score_mod_config()
participant VJP as _fused_attn_score_mod (custom_vjp)
participant PyExt as cpp_extensions/attention.py
participant CppReg as RegisterFusedAttnScoreModGraph (C++)
participant FFI as XLA FFI (C++)
participant CuDNN as cuDNN Frontend (Python)
User->>FA: "score_mod=fn, score_mod_tensors={...}"
FA->>FA: _validate_fused_attn_score_mod()
FA->>Cfg: make_fused_attn_score_mod_config(score_mod, ...)
Cfg-->>FA: config, tensor_operands, bprop_tensor_operands
FA->>VJP: _fused_attn_score_mod(qkv, tensors, bprop_tensors, config, ...)
Note over VJP,PyExt: Forward pass (JAX trace time)
VJP->>PyExt: fused_attn_score_mod_fwd(qkv, tensors, config)
PyExt->>PyExt: check _score_mod_graph_cache
alt cache miss
PyExt->>CuDNN: build cudnn.pygraph, call score_mod callback
CuDNN-->>PyExt: compiled graph
PyExt->>CppReg: register_fused_attn_score_mod_graph(graph, uids, ...)
CppReg->>CppReg: Py_INCREF(py_graph) no DECREF
CppReg-->>PyExt: graph_id
PyExt->>PyExt: "cache[key] = (graph_id, workspace_size)"
end
PyExt->>FFI: "ffi.ffi_call(te_fused_attn_score_mod_forward_ffi, graph_id=graph_id)"
FFI->>CuDNN: _execute_with_ptrs() GIL acquired
CuDNN-->>FFI: output, stats
FFI-->>VJP: output, softmax_stats
Note over VJP,PyExt: Backward pass
VJP->>PyExt: fused_attn_score_mod_bwd(qkv, output, dz, stats, tensors, bprop_tensors, config)
PyExt->>PyExt: check _score_mod_graph_cache (bwd key)
alt cache miss
PyExt->>CuDNN: build sdpa_backward graph, call score_mod + score_mod_bprop callbacks
PyExt->>CppReg: register_fused_attn_score_mod_graph(...)
CppReg-->>PyExt: graph_id
end
PyExt->>FFI: "ffi.ffi_call(te_fused_attn_score_mod_backward_ffi, graph_id=graph_id)"
FFI->>CuDNN: _execute_with_ptrs() GIL acquired
CuDNN-->>FFI: dq, dk, dv
FFI-->>User: gradients
Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| struct ScoreModGraphEntry { | ||
| PyObject *py_graph = nullptr; | ||
| std::vector<int64_t> user_uids; | ||
| std::vector<int64_t> input_uids; | ||
| std::vector<int64_t> output_uids; | ||
| std::vector<int64_t> scalar_uids; | ||
| std::vector<ScoreModScalarStorage> scalar_values; | ||
| }; |
There was a problem hiding this comment.
Python reference leak:
Py_INCREF without a matching Py_DECREF
ScoreModGraphEntry stores a raw PyObject* and its refcount is bumped at registration (Py_INCREF(entry->py_graph) at line 833), but the struct has no destructor to call Py_DECREF. Because ScoreModGraphRegistry never removes entries either, every cuDNN Python graph object registered here is permanently immortalised — it will never be collected by Python's GC regardless of what the call site does. Over many different attention shapes or graph configurations this accumulates silently. The fix is to add a destructor that acquires the GIL and calls Py_DECREF, or to store a pybind11::object (which manages the refcount automatically) and ensure destruction always happens under the GIL.
| intermediate_data_type=cudnn.data_type.FLOAT, | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
|
|
||
| q_dim, q_stride = _bshd_as_bhsd_dim_stride(q_aval.shape) | ||
| k_dim, k_stride = _bshd_as_bhsd_dim_stride(k_aval.shape) | ||
| v_dim, v_stride = _bshd_as_bhsd_dim_stride(v_aval.shape) | ||
| o_dim, o_stride = _bshd_as_bhsd_dim_stride(output_aval.shape) | ||
| do_dim, do_stride = _bshd_as_bhsd_dim_stride(doutput_aval.shape) |
There was a problem hiding this comment.
id()-based cache keys can produce false cache hits after GC
_score_mod_callback_cache_key builds its key from id(self_obj) and id(func). Python recycles object addresses after GC, so if a callback instance is collected and a new object (of a different class or with different graph logic) is allocated at the same address, the new config will compare equal to the old one under __eq__. JAX's nondiff-argnum caching then reuses the traced function and graph built for the original callback, silently executing the wrong cuDNN graph. The risk is low for long-lived module-level functions but real for short-lived class instances. Anchoring the key to a non-id stable identifier (e.g., a weakref plus explicit id, or requiring callers to supply an explicit stable key) would eliminate the ambiguity.
| Error_Type ExecuteScoreModGraph(cudaStream_t stream, int64_t graph_id, | ||
| const std::vector<void *> &input_ptrs, | ||
| const std::vector<void *> &output_ptrs, void *workspace) { | ||
| auto entry = GetScoreModGraphEntry(graph_id); | ||
| NVTE_CHECK(input_ptrs.size() == entry->input_uids.size(), "cuDNN score_mod graph expected ", | ||
| entry->input_uids.size(), " inputs but got ", input_ptrs.size()); | ||
| NVTE_CHECK(output_ptrs.size() >= entry->output_uids.size(), | ||
| "cuDNN score_mod graph expected at least ", entry->output_uids.size(), | ||
| " outputs but got ", output_ptrs.size()); | ||
|
|
||
| std::unordered_map<int64_t, void *> variant_pack; | ||
| for (size_t i = 0; i < entry->input_uids.size(); ++i) { | ||
| variant_pack.emplace(entry->input_uids[i], input_ptrs[i]); | ||
| } | ||
| for (size_t i = 0; i < entry->output_uids.size(); ++i) { | ||
| variant_pack.emplace(entry->output_uids[i], output_ptrs[i]); | ||
| } | ||
| for (size_t i = 0; i < entry->scalar_uids.size(); ++i) { | ||
| variant_pack.emplace(entry->scalar_uids[i], entry->scalar_values[i].data.data()); | ||
| } | ||
|
|
||
| std::vector<std::intptr_t> user_ptrs; | ||
| user_ptrs.reserve(entry->user_uids.size()); | ||
| for (const auto uid : entry->user_uids) { | ||
| auto it = variant_pack.find(uid); | ||
| NVTE_CHECK(it != variant_pack.end(), "cuDNN score_mod graph variant pack is missing UID ", uid); | ||
| user_ptrs.push_back(reinterpret_cast<std::intptr_t>(it->second)); | ||
| } | ||
|
|
||
| auto handle = GetScoreModCudnnHandle(); | ||
| NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); | ||
| { | ||
| pybind11::gil_scoped_acquire gil; | ||
| try { | ||
| auto graph = pybind11::reinterpret_borrow<pybind11::object>(entry->py_graph); | ||
| graph.attr("_execute_with_ptrs")(user_ptrs, reinterpret_cast<std::intptr_t>(workspace), | ||
| reinterpret_cast<std::intptr_t>(handle)); | ||
| } catch (const pybind11::error_already_set &exc) { | ||
| NVTE_ERROR("cuDNN score_mod SDPA graph execution failed: ", exc.what()); | ||
| } | ||
| } | ||
| return ffi_with_cuda_error_check(); | ||
| } |
There was a problem hiding this comment.
GIL held across a CUDA FFI call boundary
ExecuteScoreModGraph acquires pybind11::gil_scoped_acquire while the CUDA stream is live and calls a Python method (_execute_with_ptrs) synchronously. Any other Python thread that holds the GIL and is waiting on CUDA work will deadlock. More broadly, acquiring the GIL inside an XLA/JAX FFI handler — which JAX may dispatch from a non-Python thread — creates a locking inversion risk. This is by-design if cuDNN's Python frontend has no C-level execution path, but the limitation should be documented and the possibility of multi-threaded JAX dispatch should be explicitly considered.
| _SCORE_MOD_UID_DQ = 7 | ||
| _SCORE_MOD_UID_DK = 8 | ||
| _SCORE_MOD_UID_DV = 9 | ||
| _SCORE_MOD_FWD_TENSOR_UID_BASE = 1000 |
There was a problem hiding this comment.
_score_mod_graph_cache and C++ registry grow without bound
_score_mod_graph_cache is a module-level dict that accumulates (graph_id, workspace_size) entries for every unique (direction, config, aval-tuple) seen during tracing, and the C++ ScoreModGraphRegistry holds the corresponding cuDNN graph objects forever. Each entry keeps a Python cuDNN graph alive (and, due to the missing Py_DECREF noted separately, prevents GC). In long-running services or evaluation loops that sweep over many shapes/dtypes, this leads to unbounded cuDNN graph memory accumulation. An LRU eviction strategy or an explicit graph-release API paired with cache invalidation would contain the growth.
Description
This PR introduces an alternative code path for the FusedAttention backend for JAX.
The user can specify score_mod and score_mod_bprop functions, which get routed to the corresponding parameters of the sdpa and sdpa_backward calls to cuDNN FE.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: