[PyTorch] Support for cuDNN-backed flex attention#2984
Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds a Python-only
Confidence Score: 3/5Safe to review but not to merge without addressing the callback cache-key correctness issue. The new FusedAttentionWithScoreModFunc caches compiled cuDNN graphs under a key derived from id() of the score_mod callable. For bound-method score_mods whose instances vary graph topology, a garbage-collected instance can be reallocated at the same address as a new instance of the same class, producing an identical cache key and silently returning a graph built for a different computation. No exception is raised; the wrong attention output is returned and back-propagated. transformer_engine/pytorch/attention/dot_product_attention/backends.py — specifically _score_mod_callback_cache_key and the module-level _cudnn_score_mod_graph_cache. Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant DPA as DotProductAttention.forward
participant FA as FusedAttention.forward
participant Func as FusedAttentionWithScoreModFunc
participant Cache as _cudnn_score_mod_graph_cache
participant cuDNN as cuDNN Python Frontend
User->>DPA: "forward(q, k, v, score_mod=fn, ...)"
DPA->>DPA: validate score_mod constraints
DPA->>DPA: get_fused_attn_backend (availability check)
DPA->>FA: "forward(..., score_mod=fn)"
FA->>FA: assert no FP8 / CP / dropout / masks
FA->>Func: apply(is_training, q, k, v, fmt, scale, score_mod, ...)
Func->>Func: allocate output_layer, stats_bhs1
Func->>Cache: _get_cudnn_score_mod_fwd_graph(key)
alt cache miss
Cache->>cuDNN: pygraph(dtype, device, handle)
cuDNN->>Func: call score_mod(graph, score_tensor, tensors)
cuDNN->>cuDNN: "sdpa(..., score_mod=wrapped)"
cuDNN->>cuDNN: validate / build_operation_graph / build_plans
Cache-->>Func: _CudnnScoreModFwdGraphEntry
else cache hit
Cache-->>Func: _CudnnScoreModFwdGraphEntry
end
Func->>cuDNN: graph.execute(variant_pack, workspace, handle)
cuDNN-->>Func: output_layer filled
Func->>Func: "ctx.save_for_backward(q, k, v, out, stats, *mod_tensors)"
Func-->>User: output_layer
User->>Func: backward(d_out)
Func->>Cache: _get_cudnn_score_mod_bwd_graph(key)
alt cache miss
Cache->>cuDNN: pygraph(dtype, device, handle)
cuDNN->>Func: call score_mod + score_mod_bprop
cuDNN->>cuDNN: sdpa_backward(...)
Cache-->>Func: _CudnnScoreModBwdGraphEntry
else cache hit
Cache-->>Func: _CudnnScoreModBwdGraphEntry
end
Func->>cuDNN: graph.execute(variant_pack, workspace, handle)
cuDNN-->>Func: dq, dk, dv filled
Func-->>User: (None, dq, dk, dv, None x8)
|
| def _score_mod_callback_cache_key(callback: Optional[Callable]) -> Optional[Tuple[Any, ...]]: | ||
| """Create a stable cache key for a score_mod callable.""" | ||
| if callback is None: | ||
| return None | ||
| self_obj = getattr(callback, "__self__", None) | ||
| func_obj = getattr(callback, "__func__", None) | ||
| if self_obj is not None and func_obj is not None: | ||
| return ("bound_method", id(self_obj), id(func_obj)) | ||
| return ("callable", id(callback)) |
There was a problem hiding this comment.
id()-based cache key is unsafe for parameterized bound-method score_mods
id(self_obj) identifies a Python object by its memory address. When a bound-method instance is garbage-collected, Python may immediately reuse that memory for a new instance. If the new instance belongs to the same class (same id(func_obj)), the cache key is identical, so _get_cudnn_score_mod_fwd_graph returns the old compiled graph even though the new instance might construct a structurally different computation — e.g., a score_mod class whose forward loops self.n_layers times. The wrong graph is executed without any error, silently producing incorrect attention outputs.
For stateless module-level functions this is fine (they're never GC'd), but any stateful class-based score_mod where different instances produce different graph topologies can hit this bug in long-running programs. Consider using type(self_obj) and a per-class sequence counter, or requiring callers to provide an explicit cache key.
| _flash_attn_varlen_fwd = None | ||
| _flash_attn_varlen_bwd = None |
There was a problem hiding this comment.
Unbounded module-level graph cache will grow indefinitely
_cudnn_score_mod_graph_cache is a plain dict with no eviction policy. Cache keys encode tensor shapes, strides, dtype, and device, so every new (batch, seq, heads, dim) combination — extremely common in training with variable-length sequences or multi-task workloads — inserts a permanent entry. Each cached cuDNN graph holds compiled CUDA kernels and associated state, which can be several tens of MB. Over a long training run this will silently consume increasing GPU/CPU memory. Consider a bounded LRU cache (e.g., functools.lru_cache or a collections.OrderedDict with a size cap).
| fused_attention_backend = tex.get_fused_attn_backend( | ||
| self.training, | ||
| q_type, | ||
| q_type, | ||
| dpa_utils.QKVLayout["bshd_bshd_bshd"], | ||
| dpa_utils.AttnBiasType["no_bias"], | ||
| dpa_utils.AttnMaskType["no_mask"], | ||
| dpa_utils.SoftmaxType["vanilla"], |
There was a problem hiding this comment.
get_fused_attn_backend availability check always uses bshd_bshd_bshd regardless of actual format
The score_mod path hard-codes dpa_utils.QKVLayout["bshd_bshd_bshd"] for the backend probe, even when the user passes qkv_format="sbhd". The result is only used to gate on NVTE_No_Backend, so in practice it likely works today because backend availability for a given dtype is layout-independent. However, if a future cuDNN version makes SBHD/BSHD support diverge, this probe would give a false-positive (accepts sbhd even though no backend supports it) or false-negative (rejects sbhd when it is actually supported). Using the real layout for the probe would make the check self-documenting and future-proof.
Description
This PR introduces an alternative, Python-only code path for the FusedAttention backend for PyTorch.
The user can specify score_mod and score_mod_bprop functions, which get routed to the corresponding parameters of the sdpa and sdpa_backward calls to cuDNN FE.
Fixes # (issue)
Type of change
Changes
Checklist: