diff --git a/ci/jax.sh b/ci/jax.sh index 4804ecff3..0cb348ddb 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -57,8 +57,9 @@ run_test_config() { export NVTE_JAX_UNITTEST_LEVEL=L0 # this env variable controls parameters set for some tests run_default_fa 1 test_custom_call_compute.py run_default_fa 1 test_functions.py - run 1 test_fused_attn.py - NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py # Using FAv2 for forward and backward pass + run 1 test_fused_attn.py -k 'not test_fused_attn_small_seq_explicit_api' # skip smallseq in normal flow + XLA_FLAGS='--xla_gpu_enable_command_buffer=' run 1 test_fused_attn.py -k 'test_fused_attn_small_seq_explicit_api' # explicit small-seq API; + NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py -k 'not test_fused_attn_small_seq_explicit_api' # Using FAv2 for forward and backward pass run_default_fa 1 test_layer.py # it effectively always uses unfused attention run_default_fa 1 test_sanity_import.py run_default_fa 1 test_softmax.py diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 8639af79b..83d2c0006 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -9,6 +9,7 @@ from functools import partial from math import sqrt from typing import Tuple, Optional, Dict +import os import random import jax @@ -24,7 +25,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from jax.typing import ArrayLike, DTypeLike -from transformer_engine.jax.cpp_extensions.misc import is_hip_extension +from transformer_engine.jax.cpp_extensions.misc import get_xla_flag, is_hip_extension from transformer_engine.jax import autocast from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.attention import ( @@ -36,6 +37,7 @@ reorder_causal_load_balancing, inverse_reorder_causal_load_balancing, fused_attn, + fused_attn_small_seq, run_length_fill, make_swa_mask, SequenceDescriptor, @@ -314,6 +316,34 @@ def customcall_fused_dpa( qkv_args, bias, sequence_descriptor, dropout_rng, softmax_offset=softmax_offset, **kwargs ).astype(query.dtype) +def customcall_small_seq_dpa( + query, + key, + value, + bias, + sequence_descriptor, + dropout_rng, + **kwargs, +): + """TE ROCm small-seq explicit API (separate Q, K, V only).""" + small_kwargs = { + "attn_bias_type": kwargs["attn_bias_type"], + "attn_mask_type": kwargs["attn_mask_type"], + "scaling_factor": kwargs["scaling_factor"], + "dropout_probability": kwargs["dropout_probability"], + "is_training": kwargs["is_training"], + "qkv_layout": kwargs["qkv_layout"], + "max_segments_per_seq": kwargs["max_segments_per_seq"], + "window_size": kwargs.get("window_size"), + } + return fused_attn_small_seq( + (query, key, value), + bias, + sequence_descriptor, + dropout_rng, + **small_kwargs, + ).astype(query.dtype) + class BiasShape(Enum): """ @@ -369,6 +399,9 @@ class FusedAttnRunner: cp_strategy: CPStrategy = CPStrategy.DEFAULT cp_load_balanced: bool = True + # THD segment layout for ROCm small-seq explicit API tests + use_small_seq_thd_setup: bool = False + # dictionary of expected collective comm bytes coll_count_ref: Optional[Dict[str, int]] = None @@ -384,6 +417,10 @@ def __post_init__(self): # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases. def _get_max_segments_per_sequence(self): if self.qkv_layout.is_thd(): + # Small-seq explicit API tests use fixed segment counts; skip +1 slack used + # by generic fused THD to probe runtime_segments < max_segments on cuDNN. + if self.use_small_seq_thd_setup: + return self.num_segments_per_seq if 90400 <= get_cudnn_version() < 90500: return self.num_segments_per_seq else: @@ -497,8 +534,78 @@ def _check_configs(self): "the F16_arbitrary_seqlen backend." ) + def _setup_thd_segments_small_seq(self, generate_random_segment_ids): + """ + Build THD segment descriptors for ROCm small-seq cross-attention tests. + + Uses num_segments_per_seq = max_seqlen_q for both Q and KV. For Q: if max_seqlen_q == 1, + uses a fixed layout (one token per batch, cu_seqlens [0,1,...,batch_size]); otherwise + generates random segments. For KV: always generates random segments. + """ + num_segments_per_seq = self.max_seqlen_q + if self.max_seqlen_q == 1: + # Q: deterministic - one segment of length 1 per batch -> cu_seqlen [0,1,2,...,batch_size] + segment_ids_q = jnp.ones((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) + segment_pos_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) + pad_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) + seqlens_q = jnp.ones((self.batch_size, 1), dtype=jnp.int32) + offsets_q = jnp.concatenate( + [ + jnp.arange(self.batch_size, dtype=jnp.int32)[:, None], + jnp.full((self.batch_size, 1), -1, dtype=jnp.int32), + ], + axis=1, + ) + else: + segment_ids_q, segment_pos_q, pad_q = generate_random_segment_ids( + self.batch_size, self.max_seqlen_q, num_segments_per_seq, seed=42 + ) + # Compute seqlens/offsets directly instead of using get_seqlens_and_offsets. + # get_seqlens_and_offsets uses bincount(length=max_seqlen) which cannot capture + # segment IDs equal to max_seqlen (when num_segments == max_seqlen_q, segment + # IDs range from 1 to max_seqlen_q). The missing segment plus the appended + # sentinel causes _fix_len_take in impl() to leak entries across batches. + # Since each Q segment has exactly 1 token (max_segment_size = max_seqlen_q // + # num_segments_per_seq = 1), we build seqlens as all-ones with no sentinels. + seqlens_q = jnp.ones((self.batch_size, num_segments_per_seq), dtype=jnp.int32) + offsets_q = jnp.concatenate( + [ + jnp.tile( + jnp.arange(num_segments_per_seq, dtype=jnp.int32)[None, :], + (self.batch_size, 1), + ), + jnp.full((self.batch_size, 1), -1, dtype=jnp.int32), + ], + axis=1, + ) + + + min_segment_len = None if self.window_size is None else seqlens_q + segment_ids_kv, segment_pos_kv, pad_kv = generate_random_segment_ids( + self.batch_size, + self.max_seqlen_kv, + num_segments_per_seq, + seed=2024, + min_segment_len=min_segment_len, + ) + seqlens_kv, offsets_kv = get_seqlens_and_offsets(segment_ids_kv) + return ( + num_segments_per_seq, + segment_ids_q, + segment_pos_q, + pad_q, + seqlens_q, + offsets_q, + segment_ids_kv, + segment_pos_kv, + pad_kv, + seqlens_kv, + offsets_kv, + ) + def _setup_inputs(self): - self._check_configs() + if not self.use_small_seq_thd_setup: + self._check_configs() # Create a mesh for distributed tests self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape) @@ -625,30 +732,45 @@ def generate_random_segment_ids( return segment_ids, segment_pos, segment_pad if self.qkv_layout.is_thd(): - self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( - self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 - ) - self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) - # TODO(rewang): record only self attention and find the reason of cross attention - if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv: - self.segment_ids_kv = self.segment_ids_q - self.segment_pos_kv = self.segment_pos_q - self.pad_kv = self.pad_q - else: - # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support - min_segment_len = None - if ( - self.window_size is not None or self.attn_mask_type.is_bottom_right() - ): # SWA or BRCM requires kv_len >= q_len - min_segment_len = self.seqlens_q - self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( - self.batch_size, - self.max_seqlen_kv, + if self.use_small_seq_thd_setup: + ( self.num_segments_per_seq, - seed=2024, - min_segment_len=min_segment_len, + self.segment_ids_q, + self.segment_pos_q, + self.pad_q, + self.seqlens_q, + self.offsets_q, + self.segment_ids_kv, + self.segment_pos_kv, + self.pad_kv, + self.seqlens_kv, + self.offsets_kv, + ) = self._setup_thd_segments_small_seq(generate_random_segment_ids) + else: + self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( + self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) - self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv) + self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) + # TODO(rewang): record only self attention and find the reason of cross attention + if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv: + self.segment_ids_kv = self.segment_ids_q + self.segment_pos_kv = self.segment_pos_q + self.pad_kv = self.pad_q + else: + # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support + min_segment_len = None + if ( + self.window_size is not None or self.attn_mask_type.is_bottom_right() + ): # SWA or BRCM requires kv_len >= q_len + min_segment_len = self.seqlens_q + self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( + self.batch_size, + self.max_seqlen_kv, + self.num_segments_per_seq, + seed=2024, + min_segment_len=min_segment_len, + ) + self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv) else: self.segment_ids_q, self.pad_q = gen_valid( self.batch_size, self.max_seqlen_q, pad_ratio @@ -1092,6 +1214,163 @@ def check_dqkv(primitive, reference, pad, idx): target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text() assert_equal_collectives(target_hlo, self.coll_count_ref) + def test_backward_small_seq_api(self): + """Backward test using fused_attn_small_seq (explicit ROCm API), not generic fused_attn.""" + self._setup_inputs() + + def grad_func(func, *args, cp_reverse_out=False, **kwargs): + gradient_multiplier = self.max_seqlen_q * self.num_heads_q + if self.attn_mask_type.is_causal(): + gradient_multiplier /= 10 + if not cp_reverse_out: + ret_valid = jnp.where( + self.pad_q[..., jnp.newaxis, jnp.newaxis], + 0, + func(*args, **kwargs), + ) + else: + ret_valid = jnp.where( + self.pad_q[..., jnp.newaxis, jnp.newaxis], + 0, + self.cp_inverse_reorder_fn(func(*args, **kwargs)), + ) + return ( + jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier + ).astype(self.dtype) + + customcall_args = [ + jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), + jax.device_put(self.bias, self.bias_sharding), + jax.device_put(self.sequence_desciptor, self.seq_desc_sharding), + jax.device_put(self.dropout_rng, self.dropout_rng_sharding), + ] + kwargs_small = { + "attn_bias_type": self.attn_bias_type, + "attn_mask_type": self.attn_mask_type, + "scaling_factor": self.scaling_factor, + "dropout_probability": self.dropout_prob, + "is_training": self.is_training, + "qkv_layout": self.qkv_layout, + "max_segments_per_seq": self._get_max_segments_per_sequence(), + "window_size": self.window_size, + } + + if self.bias_shape == BiasShape._1HSS: + arg_nums = (0, 1, 2, 3) + grad_shardings = ( + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + ) + else: + arg_nums = (0, 1, 2) + grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding) + + jitted_primitive = jit( + value_and_grad( + lambda q, k, v, bias, sequence_descriptor, dropout_rng: grad_func( + customcall_small_seq_dpa, + q, + k, + v, + bias, + sequence_descriptor, + dropout_rng, + cp_reverse_out=True, + **kwargs_small, + ), + arg_nums, + ), + in_shardings=( + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + self.seq_desc_sharding, + self.dropout_rng_sharding, + ), + out_shardings=(None, grad_shardings), + ) + kwargs_ref = { + "attn_bias_type": self.attn_bias_type, + "attn_mask_type": self.attn_mask_type, + "softmax_type": self.softmax_type, + "scaling_factor": self.scaling_factor, + "dropout_probability": self.dropout_prob, + "is_training": self.is_training, + "qkv_layout": self.qkv_layout, + "max_segments_per_seq": self._get_max_segments_per_sequence(), + "window_size": self.window_size, + "context_parallel_strategy": self.cp_strategy, + "context_parallel_causal_load_balanced": self.cp_load_balanced, + } + jitted_reference = jit( + value_and_grad( + lambda q, k, v, bias, softmax_offset, mask, dropout_rng: grad_func( + jax_dpa, + q, + k, + v, + bias, + softmax_offset, + mask, + dropout_rng, + **kwargs_ref, + ), + arg_nums, + ), + in_shardings=( + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + self.softmax_offset_sharding, + self.mask_sharding, + self.dropout_rng_sharding, + ), + out_shardings=(None, grad_shardings), + ) + + with self.mesh, autocast(mesh_resource=self.mesh_resource): + primitive_out, primitive_dgrad = jitted_primitive(*customcall_args) + + reference_out, reference_dgrad = jitted_reference( + self.q, + self.k, + self.v, + self.bias, + self.softmax_offset, + self.mask, + self.dropout_rng, + ) + + if self.dropout_prob > 0.0: + return + + assert_allclose(primitive_out, reference_out, dtype=self.dtype) + + def check_dqkv(primitive, reference, pad, idx): + primitive_valid, primitive_invalid, reference_valid, reference_invalid = ( + _split_valid_and_invalid(primitive, reference, pad) + ) + assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype) + assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype) + assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) + + primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3] + reference_dq, reference_dk, reference_dv = reference_dgrad[:3] + + primitive_dq = self.cp_inverse_reorder_fn(primitive_dq) + primitive_dk = self.cp_inverse_reorder_fn(primitive_dk) + primitive_dv = self.cp_inverse_reorder_fn(primitive_dv) + + check_dqkv(primitive_dq, reference_dq, self.pad_q, 0) + check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1) + check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2) + @pytest.mark.parametrize( "attn_mask_type", @@ -1507,3 +1786,68 @@ def test_jax_new_rng(): ) runner = FusedAttnRunner(**kwargs) runner.test_forward() + + +# ROCm small-seq varlen tests (explicit fused_attn_small_seq API). +@pytest.fixture +def xla_gpu_graph_disabled(): + if get_xla_flag("--xla_gpu_enable_command_buffer", default=None) != "": + pytest.skip( + "Test must set XLA_FLAGS with --xla_gpu_enable_command_buffer= " + ) + yield + +@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16], ids=["BF16", "FP16"]) +@pytest.mark.parametrize( + "b, s_q, s_kv, h_q, h_kv, d_qk, d_v", + [ + pytest.param(4000, 1, 2, 16, 16, 128, 128, id="4000-1-2-16-16-128-128"), + pytest.param(4000, 1, 4, 16, 16, 128, 128, id="4000-1-4-16-16-128-128"), + pytest.param(4000, 1, 6, 16, 16, 128, 128, id="4000-1-6-16-16-128-128"), + pytest.param(4000, 1, 8, 16, 16, 128, 128, id="4000-1-8-16-16-128-128"), + pytest.param(4000, 1, 12, 16, 16, 128, 128, id="4000-1-12-16-16-128-128"), + pytest.param(4000, 1, 16, 16, 16, 128, 128, id="4000-1-16-16-16-128-128"), + pytest.param(4000, 1, 2, 16, 16, 256, 256, id="4000-1-2-16-16-256-256"), + pytest.param(4000, 1, 4, 16, 16, 256, 256, id="4000-1-4-16-16-256-256"), + pytest.param(4000, 1, 6, 16, 16, 256, 256, id="4000-1-6-16-16-256-256"), + pytest.param(4000, 1, 8, 16, 16, 256, 256, id="4000-1-8-16-16-256-256"), + pytest.param(4000, 1, 12, 16, 16, 256, 256, id="4000-1-12-16-16-256-256"), + pytest.param(4000, 1, 16, 16, 16, 256, 256, id="4000-1-16-16-16-256-256"), + pytest.param(4000, 1, 2, 16, 16, 512, 512, id="4000-1-2-16-16-512-512"), + pytest.param(4000, 1, 4, 16, 16, 512, 512, id="4000-1-4-16-16-512-512"), + pytest.param(2048, 2, 4, 16, 16, 128, 128, id="seqpack-2048-2-4-16-16-128-128"), + pytest.param(2, 4096, 8192, 16, 16, 128, 128, id="seqpack-2-4096-8192-16-16-128-128"), + ], +) +@pytest.mark.skipif( + not is_hip_extension(), reason="Small-seq explicit API only available on AMD hardware" +) +def test_fused_attn_small_seq_explicit_api( + xla_gpu_graph_disabled, b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype +): + """ + Test nvte_fused_attn_small_seq / fused_attn_small_seq: THD_THD_THD varlen cross-attention + (head dims 128, 256, or 512). + """ + runner = FusedAttnRunner( + batch_size=b, + max_seqlen_q=s_q, + max_seqlen_kv=s_kv, + num_heads_q=h_q, + num_heads_kv=h_kv, + head_dim_qk=d_qk, + head_dim_v=d_v, + attn_bias_type=AttnBiasType.NO_BIAS, + attn_mask_type=AttnMaskType.PADDING_MASK, + softmax_type=AttnSoftmaxType.VANILLA_SOFTMAX, + dropout_prob=0.0, + use_old_rng=True, + dtype=dtype, + is_training=True, + qkv_layout=QKVLayout.THD_THD_THD, + bias_shape=BiasShape._B1SS, + window_size=None, + seq_desc_format=SeqDescFormat.Seqlens, + use_small_seq_thd_setup=True, + ) + runner.test_backward_small_seq_api() diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 2ec51746d..8019d0719 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -262,6 +262,7 @@ else() list(APPEND transformer_engine_cuda_sources fused_attn_rocm/fused_attn_aotriton.cpp fused_attn_rocm/fused_attn_ck.cpp + fused_attn_rocm/fused_attn_small_seq.cpp fused_attn_rocm/utils.cpp) endif() diff --git a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp index fae97d468..47528c020 100644 --- a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp +++ b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp @@ -168,6 +168,12 @@ hipError_t ck_attn_varlen_bwd( int how_v3_bf16_cvt, hipStream_t stream); +uint64_t get_runtime_max_seqlen(uint64_t b, + const void* cu_seqlen_ptr, + const void* cu_seqlen_padded_ptr, + void* workspace, + hipStream_t stream); + }//namespace ck_fused_attn #endif // CK_FUSED_ATTN_H diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index e787b31c8..e25865aa4 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -4,12 +4,14 @@ * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ +#include #include #include #include #include "transformer_engine/fused_attn.h" #include "fused_attn_aotriton.h" #include "fused_attn_ck.h" +#include "fused_attn_small_seq.h" #include "../common.h" #include "utils.h" @@ -894,6 +896,218 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } } +bool nvte_is_small_seq_attn_supported( + NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, + size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right) { + return transformer_engine::fused_attn_rocm::is_small_seq_attn_supported( + q_dtype, kv_dtype, qkv_layout, bias_type, attn_mask_type, dropout, num_attn_heads, + num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, + window_size_right); +} + +size_t nvte_fused_attn_small_seq_bwd_workspace_size(size_t batch, size_t attn_heads, + size_t max_seqlen_kv, NVTEDType dtype) { + return transformer_engine::fused_attn_rocm::fused_attn_small_seq_bwd_workspace_size( + batch, attn_heads, max_seqlen_kv, static_cast(dtype)); +} + +namespace { + +// Validate runtime max s_q == 1 and max s_kv in [2, 16]; returns runtime max KV length. +size_t nvte_assert_small_seq_runtime_max_seqlen(uint64_t b, const void *dev_ptr_cu_seqlens_q, + const void *dev_ptr_cu_seqlens_kv, void *workspace, + size_t workspace_bytes, const char *log_tag, + cudaStream_t stream) { + constexpr size_t runtime_seqlen_bytes = sizeof(uint64_t); + NVTE_CHECK(workspace_bytes >= runtime_seqlen_bytes, log_tag, + "workspace too small to compute runtime max seqlen (need at least ", runtime_seqlen_bytes, + " bytes)."); + const size_t runtime_s_q = static_cast(ck_fused_attn::get_runtime_max_seqlen( + b, dev_ptr_cu_seqlens_q, nullptr, workspace, stream)); + const size_t runtime_s_kv = static_cast(ck_fused_attn::get_runtime_max_seqlen( + b, dev_ptr_cu_seqlens_kv, nullptr, workspace, stream)); + if (const char *env_ck = std::getenv("NVTE_LOG_CK_CONFIG"); + env_ck != nullptr && std::string(env_ck) == "1") { + std::cout << std::endl << log_tag << "b=" << b << ", runtime_max_seqlen_q=" << runtime_s_q + << ", runtime_max_seqlen_kv=" << runtime_s_kv << std::endl; + } + NVTE_CHECK(runtime_s_q == 1 && runtime_s_kv >= 2 && runtime_s_kv <= 16, log_tag, + "small-seq requires runtime s_q==1 and s_kv in [2,16]; got runtime_s_q=", runtime_s_q, + ", runtime_s_kv=", runtime_s_kv, "."); + return runtime_s_kv; +} + +} // namespace + +void nvte_fused_attn_small_seq_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, + size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_attn_small_seq_fwd); + using namespace transformer_engine; + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_K = convertNVTETensorCheck(K); + const Tensor *input_V = convertNVTETensorCheck(V); + convertNVTETensorCheck(Bias); + convertNVTETensorCheck(S); + Tensor *output_O = convertNVTETensorCheck(O); + Tensor *wkspace = convertNVTETensorCheck(workspace); + + auto ndim = input_Q->data.shape.size(); + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + size_t h_q = input_Q->data.shape[ndim - 2]; + size_t h_kv = input_K->data.shape[ndim - 2]; + size_t d_qk = input_Q->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim - 1]; + + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_K->data.dtype); + + log_fused_attn_config(__FUNCTION__, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, + dropout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, + window_size_left, window_size_right); + + std::tie(window_size_left, window_size_right) = + check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); + + NVTE_CHECK( + fused_attn_rocm::is_small_seq_attn_supported( + Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right), + "nvte_fused_attn_small_seq_fwd: configuration not supported for small-seq path."); + + Tensor *softmax_aux_tensor = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + void *attn_weights_buf = softmax_aux_tensor->data.dptr; + + if (wkspace->data.dptr == nullptr) { + wkspace->data.shape = {sizeof(uint64_t)}; + wkspace->data.dtype = DType::kByte; + return; + } + + void *dev_ptr_seed = input_rng_state->data.dptr; + void *dev_ptr_offset = + reinterpret_cast(reinterpret_cast(input_rng_state->data.dptr) + 1); + + size_t workspace_bytes = 1; + for (size_t i = 0; i < wkspace->data.shape.size(); ++i) { + workspace_bytes *= wkspace->data.shape[i]; + } + workspace_bytes *= fused_attn_rocm::nvte_dtype_size(wkspace->data.dtype); + + const size_t runtime_max_seqlen_kv = nvte_assert_small_seq_runtime_max_seqlen( + static_cast(b), input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, + wkspace->data.dptr, workspace_bytes, "attn_fwd(small-seq kernel): ", stream); + + fused_attn_rocm::fused_attn_small_seq_fwd( + b, h_q, h_kv, runtime_max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout, + input_Q->data.dptr, input_K->data.dptr, input_V->data.dptr, output_O->data.dptr, + attn_weights_buf, input_cu_seqlens_kv->data.dptr, input_cu_seqlens_kv_padded->data.dptr, + dev_ptr_seed, dev_ptr_offset, input_Q->data.dtype, wkspace->data.dptr, &workspace_bytes, + stream); + (void)page_table_k; + (void)page_table_v; +} + +void nvte_fused_attn_small_seq_bwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, + const NVTETensor dO, const NVTETensor S, NVTETensor dP, + const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, NVTETensor dV, + NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_attn_small_seq_bwd); + (void)deterministic; + using namespace transformer_engine; + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); + const Tensor *input_Q = convertNVTETensorCheck(Q); + const Tensor *input_K = convertNVTETensorCheck(K); + const Tensor *input_V = convertNVTETensorCheck(V); + const Tensor *input_O = convertNVTETensorCheck(O); + const Tensor *input_dO = convertNVTETensorCheck(dO); + convertNVTETensorCheck(S); + convertNVTETensorCheck(dP); + Tensor *output_dQ = convertNVTETensorCheck(dQ); + Tensor *output_dK = convertNVTETensorCheck(dK); + Tensor *output_dV = convertNVTETensorCheck(dV); + convertNVTETensorCheck(dBias); + Tensor *wkspace = convertNVTETensorCheck(workspace); + + const Tensor *attn_weights_tensor = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + + auto ndim = input_Q->data.shape.size(); + size_t b = input_cu_seqlens_q->data.shape[0] - 1; + size_t h_q = input_Q->data.shape[ndim - 2]; + size_t h_kv = input_K->data.shape[ndim - 2]; + size_t d_qk = input_Q->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim - 1]; + + const NVTEDType Q_type = static_cast(input_Q->data.dtype); + const NVTEDType KV_type = static_cast(input_K->data.dtype); + + log_fused_attn_config(__FUNCTION__, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, + dropout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, + window_size_left, window_size_right); + + std::tie(window_size_left, window_size_right) = + check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); + + NVTE_CHECK( + fused_attn_rocm::is_small_seq_attn_supported( + Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right), + "nvte_fused_attn_small_seq_bwd: configuration not supported for small-seq path."); + + size_t req_bytes = fused_attn_rocm::fused_attn_small_seq_bwd_workspace_size( + b, h_q, max_seqlen_kv, input_Q->data.dtype); + + if (wkspace->data.dptr == nullptr) { + wkspace->data.shape = {req_bytes}; + wkspace->data.dtype = DType::kByte; + return; + } + + size_t workspace_bytes = 1; + for (size_t i = 0; i < wkspace->data.shape.size(); ++i) { + workspace_bytes *= wkspace->data.shape[i]; + } + workspace_bytes *= fused_attn_rocm::nvte_dtype_size(wkspace->data.dtype); + NVTE_CHECK(workspace_bytes >= req_bytes, "nvte_fused_attn_small_seq_bwd: workspace too small."); + + const size_t runtime_max_seqlen_kv = nvte_assert_small_seq_runtime_max_seqlen( + static_cast(b), input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, + wkspace->data.dptr, workspace_bytes, "attn_bwd(small-seq kernel): ", stream); + + fused_attn_rocm::fused_attn_small_seq_bwd( + b, h_q, h_kv, runtime_max_seqlen_kv, d_qk, d_v, attn_scale, dropout, input_Q->data.dptr, + input_K->data.dptr, input_V->data.dptr, input_O->data.dptr, input_dO->data.dptr, + attn_weights_tensor->data.dptr, output_dQ->data.dptr, output_dK->data.dptr, + output_dV->data.dptr, input_cu_seqlens_kv->data.dptr, + input_cu_seqlens_kv_padded->data.dptr, input_Q->data.dtype, wkspace->data.dptr, + &workspace_bytes, stream); +} + uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t max_batch_size, cudaStream_t stream) { NVTE_API_CALL(nvte_get_runtime_num_segments); diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 5d4302db3..6c9ae383a 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -1796,7 +1796,7 @@ void fused_attn_ck_fwd( size_t max_tokens_q = std::accumulate((input_Q->data).shape.begin(), (input_Q->data).shape.end(), static_cast(1), std::multiplies())/h_q/d_qk; size_t max_tokens_kv = std::accumulate((input_K->data).shape.begin(), (input_K->data).shape.end(), static_cast(1), std::multiplies())/h_kv/d_qk; - bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; + bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; if (Aux_CTX_Tensors->size == 0) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; @@ -1851,7 +1851,6 @@ void fused_attn_ck_fwd( bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - fused_attn_ck_fwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, bias_b, bias_h, max_tokens_q, max_tokens_kv, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.cpp new file mode 100644 index 000000000..469279bb3 --- /dev/null +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.cpp @@ -0,0 +1,1121 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +/*! \file fused_attn_small_seq.cpp + * \brief small-seq (varlen) attention: seq_q=1, max_seqlen_kv<=16, THD only. + */ + +#include +#include +#include + +#include +#include +#include +#include + +#include "../common.h" +#include "../util/cuda_runtime.h" +#include "fused_attn_small_seq.h" +#include "utils.h" + +// Macros to avoid repeating dispatch switch cases for max_seqlen_kv in [2, 16]. +// T, bi, hi, d_qk and the pointer/scale args must be in scope where these are used. +#define SMALLSEQ_DISPATCH_FWD_CASE(N) \ + case N: \ + dispatch_fwd(bi, hi, Q_ptr, K_ptr, V_ptr, dropout_mask, dropout, \ + sqr_dk_scale, O_ptr, attn_workspace, cu_kv, cu_kv_p, \ + d_qk, stream); \ + break; +#define SMALLSEQ_DISPATCH_BWD_CASE(N) \ + case N: \ + dispatch_bwd(bi, hi, Q_ptr, K_ptr, V_ptr, dO_ptr, attn_ptr, \ + dropout_mask, dropout, sqr_dk_scale, dQ_ptr, dK_ptr, \ + dV_ptr, workspace_ptr, cu_kv, cu_kv_p, d_qk, stream); \ + break; + +namespace transformer_engine { +namespace fused_attn_rocm { + +enum class CausalMaskType { DISABLE = 0, TOP_LEFT = 1, BOTTOM_RIGHT = 2 }; + +template +struct SmallSeqConfig { + static constexpr int seq_q = 1; + static constexpr int max_seq_kv = MAX_SEQ_KV; + static constexpr int head_dim = HEAD_DIM; + static constexpr int step2_block_size = STEP2_BLOCK_SIZE; + static constexpr bool enable_dropout_mask = ENABLE_DROPOUT_MASK; + static constexpr CausalMaskType mask_type = MASK_TYPE; +}; + +/* MAX_SEQ_KV and HEAD_DIM are compile-time so kernels can use fixed stack arrays + * (e.g. float results[max_seq_kv], T attn[max_seq_kv]) and constexpr grid/block + * sizes. This matches varlen_attn/attn_fwd.cpp (FmhaKernelConfig<..., MAX_SEQ_KV, HEAD_DIM>). + * Dispatch supports head_dim 128, 256, 512 (d_qk == d_v). */ + +// ----- Forward kernels (with runtime batch_size, head_num) ----- + +template +__global__ void compute_scores_kernel(const T* Q, + const T* K, + T* scores, + float scale, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + int batch_size, + int head_num) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = 64; + constexpr int thread_block_size = 64; + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * thread_block_size * tasks_per_block; + int thread_id = threadIdx.x; + + for (int task = 0; task < tasks_per_block; task++) { + int cur_batch_idx = base_block_offset + task * thread_block_size + thread_id; + int batch_idx = cur_batch_idx / (seq_q * head_num); + int seq_head_idx = cur_batch_idx % (seq_q * head_num); + int seq_idx = seq_head_idx / head_num; + int head_idx = seq_head_idx % head_num; + + if (batch_idx >= batch_size) + continue; + + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + int kv_offset = cu_seqlens_kv_padded[batch_idx]; + + float results[max_seq_kv]; + T fetch_Q[block_k]; + T fetch_K[block_k]; + T* Q_ptr = (T*)&Q[(batch_idx * seq_q * head_num + seq_idx * head_num + head_idx) * head_dim]; + T* K_ptr = (T*)&K[(kv_offset * head_num + head_idx) * head_dim]; + T* score_ptr = (T*)&scores[cur_batch_idx * max_seq_kv]; + uint4 ls_dwordx4_tmp_var; + for (int i = 0; i < seq_kv; i++) + results[i] = 0.0f; + for (int dim_offset = 0; dim_offset < head_dim; dim_offset += block_k) { + if constexpr (std::is_same::value || std::is_same::value) { + for (int k = 0; k < block_k / 8; k++) { + ls_dwordx4_tmp_var = *((uint4*)&Q_ptr[dim_offset + k * 8]); + fetch_Q[k * 8 + 0] = ((T*)&ls_dwordx4_tmp_var.x)[0]; + fetch_Q[k * 8 + 1] = ((T*)&ls_dwordx4_tmp_var.x)[1]; + fetch_Q[k * 8 + 2] = ((T*)&ls_dwordx4_tmp_var.y)[0]; + fetch_Q[k * 8 + 3] = ((T*)&ls_dwordx4_tmp_var.y)[1]; + fetch_Q[k * 8 + 4] = ((T*)&ls_dwordx4_tmp_var.z)[0]; + fetch_Q[k * 8 + 5] = ((T*)&ls_dwordx4_tmp_var.z)[1]; + fetch_Q[k * 8 + 6] = ((T*)&ls_dwordx4_tmp_var.w)[0]; + fetch_Q[k * 8 + 7] = ((T*)&ls_dwordx4_tmp_var.w)[1]; + } + for (int kv_idx = 0; kv_idx < seq_kv; kv_idx++) { + for (int k = 0; k < block_k / 8; k++) { + ls_dwordx4_tmp_var = + *((uint4*)&K_ptr[kv_idx * head_num * head_dim + dim_offset + k * 8]); + fetch_K[k * 8 + 0] = ((T*)&ls_dwordx4_tmp_var.x)[0]; + fetch_K[k * 8 + 1] = ((T*)&ls_dwordx4_tmp_var.x)[1]; + fetch_K[k * 8 + 2] = ((T*)&ls_dwordx4_tmp_var.y)[0]; + fetch_K[k * 8 + 3] = ((T*)&ls_dwordx4_tmp_var.y)[1]; + fetch_K[k * 8 + 4] = ((T*)&ls_dwordx4_tmp_var.z)[0]; + fetch_K[k * 8 + 5] = ((T*)&ls_dwordx4_tmp_var.z)[1]; + fetch_K[k * 8 + 6] = ((T*)&ls_dwordx4_tmp_var.w)[0]; + fetch_K[k * 8 + 7] = ((T*)&ls_dwordx4_tmp_var.w)[1]; + } +#pragma unroll + for (int k = 0; k < block_k; k++) + results[kv_idx] += static_cast(fetch_Q[k]) * static_cast(fetch_K[k]); + } + } else { + for (int k = 0; k < block_k / 4; k++) { + ls_dwordx4_tmp_var = *((uint4*)&Q_ptr[dim_offset + k * 4]); + fetch_Q[k * 4 + 0] = *((T*)&ls_dwordx4_tmp_var.x); + fetch_Q[k * 4 + 1] = *((T*)&ls_dwordx4_tmp_var.y); + fetch_Q[k * 4 + 2] = *((T*)&ls_dwordx4_tmp_var.z); + fetch_Q[k * 4 + 3] = *((T*)&ls_dwordx4_tmp_var.w); + } + for (int kv_idx = 0; kv_idx < seq_kv; kv_idx++) { + for (int k = 0; k < block_k / 4; k++) { + ls_dwordx4_tmp_var = + *((uint4*)&K_ptr[kv_idx * head_num * head_dim + dim_offset + k * 4]); + fetch_K[k * 4 + 0] = *((T*)&ls_dwordx4_tmp_var.x); + fetch_K[k * 4 + 1] = *((T*)&ls_dwordx4_tmp_var.y); + fetch_K[k * 4 + 2] = *((T*)&ls_dwordx4_tmp_var.z); + fetch_K[k * 4 + 3] = *((T*)&ls_dwordx4_tmp_var.w); + } +#pragma unroll + for (int k = 0; k < block_k; k++) + results[kv_idx] += fetch_Q[k] * fetch_K[k]; + } + } + } + for (int i = 0; i < seq_kv; i++) + score_ptr[i] = T(results[i] * scale); + for (int i = seq_kv; i < max_seq_kv; i++) + score_ptr[i] = T(-1e9f); + } +} + +template +__global__ void apply_mask_and_softmax_kernel(T* scores, + const T* dropout_mask, + float dropout_scale, + const int* cu_seqlens_kv, + int batch_size, + int head_num) +{ + const uint32_t block_id = blockIdx.x; + const uint32_t thread_id = threadIdx.x; + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int block_size = Config::step2_block_size; + constexpr int per_score_size = seq_q * max_seq_kv; + constexpr int valid_thread_range = block_size / per_score_size * per_score_size; + const uint32_t cur_block_offset = block_id * valid_thread_range + thread_id; + const uint32_t total_elt = static_cast(batch_size) * head_num * seq_q * max_seq_kv; + bool is_tail = block_id * valid_thread_range + block_size >= total_elt; + int real_row_num = + is_tail ? (total_elt - block_id * valid_thread_range) / max_seq_kv + : valid_thread_range / max_seq_kv; + + if (cur_block_offset < total_elt && thread_id < valid_thread_range) { + __shared__ T tmp_scores[valid_thread_range]; + constexpr int row_num = valid_thread_range / max_seq_kv; + __shared__ T row_max[row_num]; + __shared__ T row_sum[row_num]; + + int global_row_idx = cur_block_offset / max_seq_kv; + int batch_idx = global_row_idx / (seq_q * head_num); + int k_idx = cur_block_offset % max_seq_kv; + + int seq_kv = (batch_idx < batch_size) + ? (cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]) + : max_seq_kv; + + T score_value = scores[cur_block_offset]; + tmp_scores[thread_id] = score_value; + + if constexpr (Config::mask_type == CausalMaskType::TOP_LEFT) { + int q_idx = (cur_block_offset % (seq_q * max_seq_kv)) / max_seq_kv; + if (k_idx > q_idx || k_idx >= seq_kv) + tmp_scores[thread_id] = T(-1e9f); + } else if constexpr (Config::mask_type == CausalMaskType::BOTTOM_RIGHT) { + int q_idx = (cur_block_offset % (seq_q * max_seq_kv)) / max_seq_kv; + if (k_idx < q_idx || k_idx >= seq_kv) + tmp_scores[thread_id] = T(-1e9f); + } else { + if (k_idx >= seq_kv) + tmp_scores[thread_id] = T(-1e9f); + } + __syncthreads(); + + if (thread_id < real_row_num) { + T max_val = T(-1e9f); +#pragma unroll + for (int i = 0; i < max_seq_kv; i++) + max_val = fmaxf(static_cast(max_val), + static_cast(tmp_scores[thread_id * max_seq_kv + i])); + row_max[thread_id] = max_val; + } + __syncthreads(); + + T exp_val = T(expf(static_cast(tmp_scores[thread_id] - + row_max[thread_id / max_seq_kv]))); + tmp_scores[thread_id] = exp_val; + __syncthreads(); + + if (thread_id < real_row_num) { + T sum = T(0.0f); +#pragma unroll + for (int i = 0; i < max_seq_kv; i++) + sum += tmp_scores[thread_id * max_seq_kv + i]; + row_sum[thread_id] = sum; + } + __syncthreads(); + + T attn_weight = tmp_scores[thread_id] / row_sum[thread_id / max_seq_kv]; + if constexpr (Config::enable_dropout_mask) { + attn_weight = attn_weight * dropout_mask[cur_block_offset] * dropout_scale; + } + scores[cur_block_offset] = attn_weight; + } +} + +template +__global__ void compute_output_kernel(const T* attn_weights, + const T* V, + T* O, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + int batch_size, + int head_num) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = BLOCK_K; + constexpr int dwordx4_load_elt = 16 / sizeof(T); + constexpr int warp_size = 64; + constexpr int process_head_per_warp = warp_size / (head_dim / block_k); + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * process_head_per_warp * tasks_per_block; + int thread_id = threadIdx.x; + int thread_batch_offset = thread_id / (head_dim / block_k); + int thread_head_offset = thread_id % (head_dim / block_k) * block_k; + + uint4 load_dwordx4_tmp_var[block_k / dwordx4_load_elt], + store_dwordx4_tmp_var[block_k / dwordx4_load_elt]; + T attn[max_seq_kv]; + + for (int task = 0; task < tasks_per_block; task++) { + int block_batch_head_idx = base_block_offset + task * process_head_per_warp; + int cur_idx = block_batch_head_idx + thread_batch_offset; + + int batch_idx = cur_idx / (seq_q * head_num); + int seq_head_idx = cur_idx % (seq_q * head_num); + int seq_q_idx = seq_head_idx / head_num; + int head_idx = seq_head_idx % head_num; + + if (batch_idx >= batch_size) + continue; + + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + int kv_offset = cu_seqlens_kv_padded[batch_idx]; + +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + store_dwordx4_tmp_var[i].x = 0; + store_dwordx4_tmp_var[i].y = 0; + store_dwordx4_tmp_var[i].z = 0; + store_dwordx4_tmp_var[i].w = 0; + } +#pragma unroll + for (int i = 0; i < max_seq_kv; i++) + attn[i] = attn_weights[cur_idx * max_seq_kv + i]; + for (int j = 0; j < seq_kv; j++) { +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + load_dwordx4_tmp_var[i] = + *((uint4*)&V[((kv_offset + j) * head_num + head_idx) * head_dim + thread_head_offset + + i * dwordx4_load_elt]); + } +#pragma unroll + for (int b = 0; b < block_k; b++) + ((T*)&store_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt] += + attn[j] * + ((T*)&load_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt]; + } +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) + *((uint4*)&O[(batch_idx * seq_q * head_num + seq_q_idx * head_num + head_idx) * head_dim + + thread_head_offset + i * dwordx4_load_elt]) = store_dwordx4_tmp_var[i]; + } +} + +// ----- Forward launcher ----- + +template +void run_attn_fwd_impl(int b, + int head_num, + const T* Q, + const T* K, + const T* V, + const T* dropout_mask, + float dropout_p, + float sqr_dk_scale, + T* O, + T* workspace, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + hipStream_t stream) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int warp_size = 64; + + int merge_bs = b * head_num; + float scale = sqr_dk_scale; + float dropout_scale = (dropout_p > 0.0f) ? (1.0f / (1.0f - dropout_p)) : 1.0f; + + constexpr int kernel1_threads = 64; + dim3 block(kernel1_threads); + dim3 grid((merge_bs + kernel1_threads - 1) / kernel1_threads); + compute_scores_kernel<<>>( + Q, K, workspace, scale, cu_seqlens_kv, cu_seqlens_kv_padded, b, head_num); + + constexpr int work_thread_num = + Config::step2_block_size / (seq_q * max_seq_kv) * (seq_q * max_seq_kv); + dim3 grid2((merge_bs * seq_q * max_seq_kv + work_thread_num - 1) / work_thread_num); + dim3 block2(Config::step2_block_size); + apply_mask_and_softmax_kernel<<>>( + workspace, dropout_mask, dropout_scale, cu_seqlens_kv, b, head_num); + + constexpr int kernel3_block_k = 8; + constexpr int kernel3_threads = 64; + constexpr int process_head_per_warp = warp_size / (head_dim / kernel3_block_k); + + dim3 block3(kernel3_threads); + dim3 grid3((merge_bs / process_head_per_warp + 2 - 1) / 2); + compute_output_kernel<<>>( + workspace, V, O, cu_seqlens_kv, cu_seqlens_kv_padded, b, head_num); +} + +// ----- Backward kernels (with runtime batch_size, head_num) ----- + +template +__global__ void compute_grad_v_kernel(const T* attn_weights, + const T* grad_O, + T* grad_V, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + int batch_size, + int head_num) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = BLOCK_K; + constexpr int dwordx4_load_elt = 16 / sizeof(T); + constexpr int warp_size = 64; + constexpr int process_head_per_warp = warp_size / (head_dim / block_k); + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * process_head_per_warp * tasks_per_block; + int thread_id = threadIdx.x; + int thread_batch_offset = thread_id / (head_dim / block_k); + int thread_head_offset = thread_id % (head_dim / block_k) * block_k; + + uint4 load_dwordx4_tmp_var[block_k / dwordx4_load_elt]; + T attn[max_seq_kv]; + + for (int task = 0; task < tasks_per_block; task++) { + int block_batch_head_idx = base_block_offset + task * process_head_per_warp; + int cur_idx = block_batch_head_idx + thread_batch_offset; + + int batch_idx = cur_idx / (seq_q * head_num); + int seq_head_idx = cur_idx % (seq_q * head_num); + int seq_q_idx = seq_head_idx / head_num; + int head_idx = seq_head_idx % head_num; + + if (batch_idx >= batch_size) + continue; + + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + +#pragma unroll + for (int i = 0; i < max_seq_kv; i++) + attn[i] = attn_weights[cur_idx * max_seq_kv + i]; + + for (int j = 0; j < seq_kv; j++) { + uint4 store_dwordx4_tmp_var[block_k / dwordx4_load_elt]; +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + store_dwordx4_tmp_var[i].x = 0; + store_dwordx4_tmp_var[i].y = 0; + store_dwordx4_tmp_var[i].z = 0; + store_dwordx4_tmp_var[i].w = 0; + } + +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + load_dwordx4_tmp_var[i] = + *((uint4*)&grad_O[(batch_idx * seq_q * head_num + seq_q_idx * head_num + head_idx) * + head_dim + + thread_head_offset + i * dwordx4_load_elt]); + } + +#pragma unroll + for (int b = 0; b < block_k; b++) { + ((T*)&store_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt] += + attn[j] * + ((T*)&load_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt]; + } + +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + int grad_v_idx = (cu_seqlens_kv_padded[batch_idx] + j) * head_num * head_dim + + head_idx * head_dim + thread_head_offset + i * dwordx4_load_elt; + *((uint4*)&grad_V[grad_v_idx]) = store_dwordx4_tmp_var[i]; + } + } + } +} + +template +__global__ void compute_grad_attn_kernel(const T* grad_O, + const T* V, + T* grad_attn, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + int batch_size, + int head_num) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = 64; + constexpr int thread_block_size = 64; + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * thread_block_size * tasks_per_block; + int thread_id = threadIdx.x; + + for (int task = 0; task < tasks_per_block; task++) { + int cur_batch_idx = base_block_offset + task * thread_block_size + thread_id; + int batch_idx = cur_batch_idx / (seq_q * head_num); + int seq_head_idx = cur_batch_idx % (seq_q * head_num); + int seq_idx = seq_head_idx / head_num; + int head_idx = seq_head_idx % head_num; + + if (batch_idx >= batch_size) + continue; + + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + + float results[max_seq_kv]; + T fetch_grad_O[block_k]; + T fetch_V[block_k]; + + T* grad_O_ptr = (T*)&grad_O[(batch_idx * seq_q * head_num + seq_idx * head_num + head_idx) * + head_dim]; + + const T* V_base = + &V[cu_seqlens_kv_padded[batch_idx] * head_num * head_dim + head_idx * head_dim]; + int V_stride = head_num * head_dim; + + T* grad_attn_ptr = (T*)&grad_attn[cur_batch_idx * max_seq_kv]; + + uint4 ls_dwordx4_tmp_var; + + for (int i = 0; i < seq_kv; i++) + results[i] = 0.0f; + + for (int dim_offset = 0; dim_offset < head_dim; dim_offset += block_k) { + if constexpr (std::is_same::value || std::is_same::value) { + for (int k = 0; k < block_k / 8; k++) { + ls_dwordx4_tmp_var = *((uint4*)&grad_O_ptr[dim_offset + k * 8]); + fetch_grad_O[k * 8 + 0] = ((T*)&ls_dwordx4_tmp_var.x)[0]; + fetch_grad_O[k * 8 + 1] = ((T*)&ls_dwordx4_tmp_var.x)[1]; + fetch_grad_O[k * 8 + 2] = ((T*)&ls_dwordx4_tmp_var.y)[0]; + fetch_grad_O[k * 8 + 3] = ((T*)&ls_dwordx4_tmp_var.y)[1]; + fetch_grad_O[k * 8 + 4] = ((T*)&ls_dwordx4_tmp_var.z)[0]; + fetch_grad_O[k * 8 + 5] = ((T*)&ls_dwordx4_tmp_var.z)[1]; + fetch_grad_O[k * 8 + 6] = ((T*)&ls_dwordx4_tmp_var.w)[0]; + fetch_grad_O[k * 8 + 7] = ((T*)&ls_dwordx4_tmp_var.w)[1]; + } + for (int kv_idx = 0; kv_idx < seq_kv; kv_idx++) { + for (int k = 0; k < block_k / 8; k++) { + ls_dwordx4_tmp_var = + *((uint4*)&V_base[kv_idx * V_stride + dim_offset + k * 8]); + fetch_V[k * 8 + 0] = ((T*)&ls_dwordx4_tmp_var.x)[0]; + fetch_V[k * 8 + 1] = ((T*)&ls_dwordx4_tmp_var.x)[1]; + fetch_V[k * 8 + 2] = ((T*)&ls_dwordx4_tmp_var.y)[0]; + fetch_V[k * 8 + 3] = ((T*)&ls_dwordx4_tmp_var.y)[1]; + fetch_V[k * 8 + 4] = ((T*)&ls_dwordx4_tmp_var.z)[0]; + fetch_V[k * 8 + 5] = ((T*)&ls_dwordx4_tmp_var.z)[1]; + fetch_V[k * 8 + 6] = ((T*)&ls_dwordx4_tmp_var.w)[0]; + fetch_V[k * 8 + 7] = ((T*)&ls_dwordx4_tmp_var.w)[1]; + } +#pragma unroll + for (int k = 0; k < block_k; k++) + results[kv_idx] += + static_cast(fetch_grad_O[k]) * static_cast(fetch_V[k]); + } + } else { + for (int k = 0; k < block_k / 4; k++) { + ls_dwordx4_tmp_var = *((uint4*)&grad_O_ptr[dim_offset + k * 4]); + fetch_grad_O[k * 4 + 0] = *((T*)&ls_dwordx4_tmp_var.x); + fetch_grad_O[k * 4 + 1] = *((T*)&ls_dwordx4_tmp_var.y); + fetch_grad_O[k * 4 + 2] = *((T*)&ls_dwordx4_tmp_var.z); + fetch_grad_O[k * 4 + 3] = *((T*)&ls_dwordx4_tmp_var.w); + } + for (int kv_idx = 0; kv_idx < seq_kv; kv_idx++) { + for (int k = 0; k < block_k / 4; k++) { + ls_dwordx4_tmp_var = + *((uint4*)&V_base[kv_idx * V_stride + dim_offset + k * 4]); + fetch_V[k * 4 + 0] = *((T*)&ls_dwordx4_tmp_var.x); + fetch_V[k * 4 + 1] = *((T*)&ls_dwordx4_tmp_var.y); + fetch_V[k * 4 + 2] = *((T*)&ls_dwordx4_tmp_var.z); + fetch_V[k * 4 + 3] = *((T*)&ls_dwordx4_tmp_var.w); + } +#pragma unroll + for (int k = 0; k < block_k; k++) + results[kv_idx] += fetch_grad_O[k] * fetch_V[k]; + } + } + } + for (int i = 0; i < seq_kv; i++) + grad_attn_ptr[i] = T(results[i]); + for (int i = seq_kv; i < max_seq_kv; i++) + grad_attn_ptr[i] = T(0.0f); + } +} + +template +__global__ void softmax_backward_kernel(const T* attn_weights, + const T* dropout_mask, + T* grad_attn, + float dropout_scale, + const int* cu_seqlens_kv, + int batch_size, + int head_num) +{ + const uint32_t block_id = blockIdx.x; + const uint32_t thread_id = threadIdx.x; + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int block_size = Config::step2_block_size; + constexpr int per_grad_attn_size = seq_q * max_seq_kv; + constexpr int valid_thread_range = block_size / per_grad_attn_size * per_grad_attn_size; + const uint32_t cur_block_offset = block_id * valid_thread_range + thread_id; + const uint32_t total_elt = static_cast(batch_size) * head_num * seq_q * max_seq_kv; + bool is_tail = block_id * valid_thread_range + block_size >= total_elt; + int real_row_num = + is_tail ? (total_elt - block_id * valid_thread_range) / max_seq_kv + : valid_thread_range / max_seq_kv; + + if (cur_block_offset < total_elt && thread_id < valid_thread_range) { + __shared__ T tmp_grad_score[valid_thread_range]; + constexpr int row_num = valid_thread_range / max_seq_kv; + __shared__ T reduce_grad_score[row_num]; + + int global_row_idx = cur_block_offset / max_seq_kv; + int batch_idx = global_row_idx / (seq_q * head_num); + int k_idx = cur_block_offset % max_seq_kv; + + int seq_kv = (batch_idx < batch_size) + ? (cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]) + : max_seq_kv; + + T grad_attn_value = grad_attn[cur_block_offset]; + if constexpr (Config::enable_dropout_mask) + grad_attn_value = grad_attn_value * dropout_mask[cur_block_offset] * dropout_scale; + T attn_weight = attn_weights[cur_block_offset]; + T grad_score = grad_attn_value * attn_weight; + tmp_grad_score[thread_id] = grad_score; + __syncthreads(); + + if (thread_id < real_row_num) { + T sum = T(0.0f); +#pragma unroll + for (int i = 0; i < max_seq_kv; i++) + sum += tmp_grad_score[thread_id * max_seq_kv + i]; + reduce_grad_score[thread_id] = sum; + } + __syncthreads(); + + grad_score -= attn_weight * reduce_grad_score[thread_id / max_seq_kv]; + + if constexpr (Config::mask_type == CausalMaskType::TOP_LEFT) { + int q_idx = (cur_block_offset % (seq_q * max_seq_kv)) / max_seq_kv; + if (k_idx > q_idx || k_idx >= seq_kv) + grad_score = T(0.0f); + } else if constexpr (Config::mask_type == CausalMaskType::BOTTOM_RIGHT) { + int q_idx = (cur_block_offset % (seq_q * max_seq_kv)) / max_seq_kv; + if (k_idx < q_idx || k_idx >= seq_kv) + grad_score = T(0.0f); + } else { + if (k_idx >= seq_kv) + grad_score = T(0.0f); + } + grad_attn[cur_block_offset] = grad_score; + } +} + +template +__global__ void compute_grad_qk_kernel(const T* grad_scores, + const T* Q, + const T* K, + T* grad_Q, + T* grad_K, + float scale, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + int batch_size, + int head_num) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = BLOCK_K; + constexpr int dwordx4_load_elt = 16 / sizeof(T); + constexpr int warp_size = 64; + constexpr int process_head_per_warp = warp_size / (head_dim / block_k); + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * process_head_per_warp * tasks_per_block; + int thread_id = threadIdx.x; + int thread_batch_offset = thread_id / (head_dim / block_k); + int thread_head_offset = thread_id % (head_dim / block_k) * block_k; + + uint4 load_dwordx4_tmp_var[block_k / dwordx4_load_elt]; + T grad_score_vals[max_seq_kv]; + + for (int task = 0; task < tasks_per_block; task++) { + int block_batch_head_idx = base_block_offset + task * process_head_per_warp; + int cur_idx = block_batch_head_idx + thread_batch_offset; + + int batch_idx = cur_idx / (seq_q * head_num); + int seq_head_idx = cur_idx % (seq_q * head_num); + int seq_q_idx = seq_head_idx / head_num; + int head_idx = seq_head_idx % head_num; + + if (batch_idx >= batch_size) + continue; + + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + +#pragma unroll + for (int i = 0; i < max_seq_kv; i++) + grad_score_vals[i] = grad_scores[cur_idx * max_seq_kv + i]; + + uint4 store_dwordx4_tmp_var[block_k / dwordx4_load_elt]; +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + store_dwordx4_tmp_var[i].x = 0; + store_dwordx4_tmp_var[i].y = 0; + store_dwordx4_tmp_var[i].z = 0; + store_dwordx4_tmp_var[i].w = 0; + } + for (int j = 0; j < seq_kv; j++) { +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + int k_idx = (cu_seqlens_kv_padded[batch_idx] + j) * head_num * head_dim + + head_idx * head_dim + thread_head_offset + i * dwordx4_load_elt; + load_dwordx4_tmp_var[i] = *((uint4*)&K[k_idx]); + } +#pragma unroll + for (int b = 0; b < block_k; b++) { + ((T*)&store_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt] += + grad_score_vals[j] * + ((T*)&load_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt]; + } + } +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + T* grad_Q_ptr = &grad_Q[(batch_idx * seq_q * head_num + seq_q_idx * head_num + head_idx) * + head_dim + + thread_head_offset + i * dwordx4_load_elt]; + for (int b = 0; b < dwordx4_load_elt; b++) + grad_Q_ptr[b] = ((T*)&store_dwordx4_tmp_var[i])[b] * T(scale); + } +#pragma unroll + for (int i = 0; i < block_k / dwordx4_load_elt; i++) { + load_dwordx4_tmp_var[i] = + *((uint4*)&Q[(batch_idx * seq_q * head_num + seq_q_idx * head_num + head_idx) * head_dim + + thread_head_offset + i * dwordx4_load_elt]); + } + for (int j = 0; j < seq_kv; j++) { +#pragma unroll + for (int b = 0; b < block_k; b++) { + T val = grad_score_vals[j] * + ((T*)&load_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt] * + T(scale); + int grad_k_idx = (cu_seqlens_kv_padded[batch_idx] + j) * head_num * head_dim + + head_idx * head_dim + thread_head_offset + b; + grad_K[grad_k_idx] = val; + } + } + } +} + +template +void run_attn_bwd_impl(int b, + int head_num, + const T* Q, + const T* K, + const T* V, + const T* grad_O, + const T* attn_weights, + const T* dropout_mask, + float dropout_p, + float sqr_dk_scale, + T* grad_Q, + T* grad_K, + T* grad_V, + T* workspace, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + hipStream_t stream) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int warp_size = 64; + + int merge_bs = b * head_num; + float scale = sqr_dk_scale; + float dropout_scale = (dropout_p > 0.0f) ? (1.0f / (1.0f - dropout_p)) : 1.0f; + + dim3 block(warp_size); + constexpr int tasks_per_block_v = 16; + dim3 grid_v((b * seq_q * head_num + tasks_per_block_v - 1) / tasks_per_block_v); + compute_grad_v_kernel<<>>( + attn_weights, grad_O, grad_V, cu_seqlens_kv, cu_seqlens_kv_padded, b, head_num); + + constexpr int tasks_per_block_attn = 16; + constexpr int process_head_per_warp = warp_size / (head_dim / 64); + dim3 grid_grad_attn((b * seq_q * head_num + tasks_per_block_attn * process_head_per_warp - 1) / + (tasks_per_block_attn * process_head_per_warp)); + compute_grad_attn_kernel<<>>( + grad_O, V, workspace, cu_seqlens_kv, cu_seqlens_kv_padded, b, head_num); + + constexpr int work_thread_num = + Config::step2_block_size / (seq_q * max_seq_kv) * (seq_q * max_seq_kv); + dim3 grid_softmax((merge_bs * seq_q * max_seq_kv + work_thread_num - 1) / work_thread_num); + dim3 block_softmax(Config::step2_block_size); + softmax_backward_kernel<<>>( + attn_weights, dropout_mask, workspace, dropout_scale, cu_seqlens_kv, b, head_num); + + constexpr int tasks_per_block_qk = 4; + dim3 grid_qk((b * seq_q * head_num + tasks_per_block_qk - 1) / tasks_per_block_qk); + compute_grad_qk_kernel<<>>( + workspace, Q, K, grad_Q, grad_K, scale, cu_seqlens_kv, cu_seqlens_kv_padded, b, head_num); +} + +size_t fused_attn_small_seq_bwd_workspace_size(size_t b, + size_t h_q, + size_t max_seqlen_kv, + DType dtype) { + constexpr size_t elt_size = 2u; // BF16 and FP16 are 2 bytes + return b * h_q * 1 * std::min(max_seqlen_kv, size_t(16)) * elt_size; +} + +template +static void dispatch_fwd(int b, int h_q, const T* Q, const T* K, const T* V, const T* dropout_mask, + float dropout, float scale, T* O, T* workspace, const int* cu_kv, + const int* cu_kv_p, size_t d_qk, hipStream_t stream) { + switch (d_qk) { + case 128: + run_attn_fwd_impl>( + b, h_q, Q, K, V, dropout_mask, dropout, scale, O, workspace, cu_kv, cu_kv_p, stream); + break; + case 256: + run_attn_fwd_impl>( + b, h_q, Q, K, V, dropout_mask, dropout, scale, O, workspace, cu_kv, cu_kv_p, stream); + break; + case 512: + run_attn_fwd_impl>( + b, h_q, Q, K, V, dropout_mask, dropout, scale, O, workspace, cu_kv, cu_kv_p, stream); + break; + default: + NVTE_ERROR( + "Unsupported head dimension (d_qk) for small-seq attention: must be 128, 256, or " + "512."); + } +} + +template +static void dispatch_bwd(int b, int h_q, const T* Q, const T* K, const T* V, const T* grad_O, + const T* attn_weights, const T* dropout_mask, float dropout, float scale, + T* grad_Q, T* grad_K, T* grad_V, T* workspace, const int* cu_kv, + const int* cu_kv_p, size_t d_qk, hipStream_t stream) { + switch (d_qk) { + case 128: + run_attn_bwd_impl>( + b, h_q, Q, K, V, grad_O, attn_weights, dropout_mask, dropout, scale, + grad_Q, grad_K, grad_V, workspace, cu_kv, cu_kv_p, stream); + break; + case 256: + run_attn_bwd_impl>( + b, h_q, Q, K, V, grad_O, attn_weights, dropout_mask, dropout, scale, + grad_Q, grad_K, grad_V, workspace, cu_kv, cu_kv_p, stream); + break; + case 512: + run_attn_bwd_impl>( + b, h_q, Q, K, V, grad_O, attn_weights, dropout_mask, dropout, scale, + grad_Q, grad_K, grad_V, workspace, cu_kv, cu_kv_p, stream); + break; + default: + NVTE_ERROR( + "Unsupported head dimension (d_qk) for small-seq attention: must be 128, 256, or " + "512."); + } +} + +void fused_attn_small_seq_fwd(size_t b, + size_t h_q, + size_t h_kv, + size_t max_seqlen_kv, + size_t d_qk, + size_t d_v, + bool is_training, + float attn_scale, + float dropout, + const void* devPtrQ, + const void* devPtrK, + const void* devPtrV, + void* devPtrO, + void* attn_weights_buffer, + const void* devPtrCuSeqlensKV, + const void* devPtrSeqOffsetsKV, + const void* rng_seed, + const void* rng_offset, + DType qkv_dtype, + void* workspace, + size_t* workspace_size, + cudaStream_t stream) +{ + const char* nvte_smallseq = std::getenv("NVTE_LOG_CK_CONFIG"); + if (nvte_smallseq && std::string(nvte_smallseq) == "1") { + std::cout << std::endl << "attn_fwd(small-seq kernel): "; + std::cout << "b: " << b << ", "; + std::cout << "h_q: " << h_q << ", "; + std::cout << "h_kv: " << h_kv << ", "; + std::cout << "max_seqlen_kv: " << max_seqlen_kv << ", "; + std::cout << "d_qk: " << d_qk << ", "; + std::cout << "d_v: " << d_v << ", "; + std::cout << "is_training: " << is_training << ", "; + std::cout << "attn_scale: " << attn_scale << ", "; + std::cout << "dropout: " << dropout << ", "; + std::cout << "qkv_dtype: " + << (qkv_dtype == DType::kBFloat16 ? "BF16" : qkv_dtype == DType::kFloat16 ? "FP16" : "?") + << std::endl; + } + + float sqr_dk_scale = attn_scale; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(qkv_dtype, T, + const T* Q_ptr = static_cast(devPtrQ); + const T* K_ptr = static_cast(devPtrK); + const T* V_ptr = static_cast(devPtrV); + T* O_ptr = static_cast(devPtrO); + T* attn_workspace = static_cast(attn_weights_buffer); + const int* cu_kv = static_cast(devPtrCuSeqlensKV); + const int* cu_kv_p = static_cast(devPtrSeqOffsetsKV); + const T* dropout_mask = nullptr; + int bi = static_cast(b); + int hi = static_cast(h_q); + + switch (max_seqlen_kv) { + SMALLSEQ_DISPATCH_FWD_CASE(2) + SMALLSEQ_DISPATCH_FWD_CASE(3) + SMALLSEQ_DISPATCH_FWD_CASE(4) + SMALLSEQ_DISPATCH_FWD_CASE(5) + SMALLSEQ_DISPATCH_FWD_CASE(6) + SMALLSEQ_DISPATCH_FWD_CASE(7) + SMALLSEQ_DISPATCH_FWD_CASE(8) + SMALLSEQ_DISPATCH_FWD_CASE(9) + SMALLSEQ_DISPATCH_FWD_CASE(10) + SMALLSEQ_DISPATCH_FWD_CASE(11) + SMALLSEQ_DISPATCH_FWD_CASE(12) + SMALLSEQ_DISPATCH_FWD_CASE(13) + SMALLSEQ_DISPATCH_FWD_CASE(14) + SMALLSEQ_DISPATCH_FWD_CASE(15) + SMALLSEQ_DISPATCH_FWD_CASE(16) + default: + NVTE_ERROR("Unsupported max_seqlen_kv for small-seq: max_seqlen_kv <= 16."); + } + ); + +} + +void fused_attn_small_seq_bwd(size_t b, + size_t h_q, + size_t h_kv, + size_t max_seqlen_kv, + size_t d_qk, + size_t d_v, + float attn_scale, + float dropout, + const void* devPtrQ, + const void* devPtrK, + const void* devPtrV, + const void* devPtrO, + const void* devPtrdO, + const void* attn_weights, + void* devPtrdQ, + void* devPtrdK, + void* devPtrdV, + const void* devPtrCuSeqlensKV, + const void* devPtrSeqOffsetsKV, + DType qkv_dtype, + void* workspace, + size_t* workspace_size, + cudaStream_t stream) +{ + const char* nvte_smallseq = std::getenv("NVTE_LOG_CK_CONFIG"); + if (nvte_smallseq && std::string(nvte_smallseq) == "1") { + std::cout << std::endl << "attn_bwd(ck small-seq kernel): "; + std::cout << "b: " << b << ", "; + std::cout << "h_q: " << h_q << ", "; + std::cout << "h_kv: " << h_kv << ", "; + std::cout << "max_seqlen_kv: " << max_seqlen_kv << ", "; + std::cout << "d_qk: " << d_qk << ", "; + std::cout << "d_v: " << d_v << ", "; + std::cout << "attn_scale: " << attn_scale << ", "; + std::cout << "dropout: " << dropout << ", "; + std::cout << "qkv_dtype: " + << (qkv_dtype == DType::kBFloat16 ? "BF16" : qkv_dtype == DType::kFloat16 ? "FP16" : "?") + << std::endl; + } + + float sqr_dk_scale = attn_scale; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(qkv_dtype, T, + const T* Q_ptr = static_cast(devPtrQ); + const T* K_ptr = static_cast(devPtrK); + const T* V_ptr = static_cast(devPtrV); + const T* O_ptr = static_cast(devPtrO); + const T* dO_ptr = static_cast(devPtrdO); + const T* attn_ptr = static_cast(attn_weights); + T* dQ_ptr = static_cast(devPtrdQ); + T* dK_ptr = static_cast(devPtrdK); + T* dV_ptr = static_cast(devPtrdV); + T* workspace_ptr = static_cast(workspace); + const int* cu_kv = static_cast(devPtrCuSeqlensKV); + const int* cu_kv_p = static_cast(devPtrSeqOffsetsKV); + const T* dropout_mask = nullptr; + int bi = static_cast(b); + int hi = static_cast(h_q); + + switch (max_seqlen_kv) { + SMALLSEQ_DISPATCH_BWD_CASE(2) + SMALLSEQ_DISPATCH_BWD_CASE(3) + SMALLSEQ_DISPATCH_BWD_CASE(4) + SMALLSEQ_DISPATCH_BWD_CASE(5) + SMALLSEQ_DISPATCH_BWD_CASE(6) + SMALLSEQ_DISPATCH_BWD_CASE(7) + SMALLSEQ_DISPATCH_BWD_CASE(8) + SMALLSEQ_DISPATCH_BWD_CASE(9) + SMALLSEQ_DISPATCH_BWD_CASE(10) + SMALLSEQ_DISPATCH_BWD_CASE(11) + SMALLSEQ_DISPATCH_BWD_CASE(12) + SMALLSEQ_DISPATCH_BWD_CASE(13) + SMALLSEQ_DISPATCH_BWD_CASE(14) + SMALLSEQ_DISPATCH_BWD_CASE(15) + SMALLSEQ_DISPATCH_BWD_CASE(16) + default: + NVTE_ERROR("Unsupported max_seqlen_kv for small-seq: max_seqlen_kv <= 16."); + } + ); +} + +bool is_small_seq_attn_supported( + NVTEDType q_dtype, + NVTEDType kv_dtype, + NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + float dropout, + size_t num_attn_heads, + size_t num_gqa_groups, + size_t max_seqlen_q, + size_t max_seqlen_kv, + size_t head_dim_qk, + size_t head_dim_v, + int64_t window_size_left, + int64_t window_size_right) { + bool log = false; + if (const char* env_p = std::getenv("NVTE_LOG_CK_CONFIG")) { + log = (env_p != nullptr && std::string(env_p) == "1"); + } + + // CK small-seq path does not enable GQA/MQA yet + if (num_gqa_groups == 0 || num_attn_heads != num_gqa_groups) { + if (log) { + std::cout << "small-seq: GQA/MQA not supported; require num_attn_heads == num_kv_heads" + << std::endl; + } + return false; + } + + if (q_dtype != kv_dtype || + !(q_dtype == NVTEDType::kNVTEFloat16 || q_dtype == NVTEDType::kNVTEBFloat16)) { + if (log) { + std::cout << "small-seq: Q/K/V must be FP16 or BF16 and match" << std::endl; + } + return false; + } + + if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) { + if (log) { + std::cout << "small-seq: bias not supported" << std::endl; + } + return false; + } + + if (dropout != 0.f) { + if (log) { + std::cout << "small-seq: dropout not supported in kernel path" << std::endl; + } + return false; + } + + if (qkv_layout != NVTE_QKV_Layout::NVTE_THD_THD_THD) { + if (log) { + std::cout << "small-seq: layout must be NVTE_THD_THD_THD" << std::endl; + } + return false; + } + + // max_seqlen_q / max_seqlen_kv are compile-time or padded upper bounds; actual varlen + // lengths are in cu_seqlens / offsets. Do not reject here based on those maxima — callers + // must ensure runtime lengths and the launch-time max_seqlen_kv passed to fwd/bwd match + // the small-seq kernel contract. + (void)max_seqlen_q; + (void)max_seqlen_kv; + + if (head_dim_qk != head_dim_v) { + if (log) { + std::cout << "small-seq: head_dim_qk and head_dim_v must match" << std::endl; + } + return false; + } + if (head_dim_qk != 128 && head_dim_qk != 256 && head_dim_qk != 512) { + if (log) { + std::cout << "small-seq: head_dim_qk must be 128, 256, or 512" << std::endl; + } + return false; + } + + bool is_causal = + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); + if (is_causal) { + if (log) { + std::cout << "small-seq: causal mask types not supported" << std::endl; + } + return false; + } + + if (attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK) { + if (log) { + std::cout << "small-seq: only NO_MASK or PADDING_MASK supported" << std::endl; + } + return false; + } + + // Small-seq kernels never consume window bounds + // Full attention only: (-1,-1), or (-1,0) which nvte_fused_attn_small_seq_* normalizes via + // check_set_window_size (NO_MASK / PADDING_MASK). is_small_seq_attn_supported may run without + // that pass (e.g. JAX workspace sizing), so treat (-1,0) the same here. + const bool full_window = + (window_size_left == -1 && window_size_right == -1) || + (window_size_left == -1 && window_size_right == 0); + if (!full_window) { + if (log) { + std::cout << "small-seq: sliding/local window attention not supported" << std::endl; + } + return false; + } + + return true; +} + +} // namespace fused_attn_rocm +} // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.h b/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.h new file mode 100644 index 000000000..c9e2cf134 --- /dev/null +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.h @@ -0,0 +1,98 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +/*! \file fused_attn_small_seq.h + * \brief Small-seq (varlen) attention for ROCm: seq_q=1, max_seqlen_kv<=16, THD only. + */ + +#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALL_SEQ_H_ +#define TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALL_SEQ_H_ + +#include +#include "transformer_engine/fused_attn.h" + +namespace transformer_engine { +namespace fused_attn_rocm { + +bool is_small_seq_attn_supported( + NVTEDType q_dtype, + NVTEDType kv_dtype, + NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, + float dropout, + size_t num_attn_heads, + size_t num_gqa_groups, + size_t max_seqlen_q, + size_t max_seqlen_kv, + size_t head_dim_qk, + size_t head_dim_v, + int64_t window_size_left, + int64_t window_size_right); + +/** Workspace size in bytes for small-seq backward path */ +size_t fused_attn_small_seq_bwd_workspace_size(size_t b, + size_t h_q, + size_t max_seqlen_kv, + DType dtype); + +/** Forward: Q,K,V -> O; attention weights written to attn_weights_buffer (same as output_S). + * attn_weights_buffer is also used as internal workspace (scores then overwritten by attn + * weights). */ +void fused_attn_small_seq_fwd(size_t b, + size_t h_q, + size_t h_kv, + size_t max_seqlen_kv, + size_t d_qk, + size_t d_v, + bool is_training, + float attn_scale, + float dropout, + const void* devPtrQ, + const void* devPtrK, + const void* devPtrV, + void* devPtrO, + void* attn_weights_buffer, + const void* devPtrCuSeqlensKV, + const void* devPtrSeqOffsetsKV, + const void* rng_seed, + const void* rng_offset, + DType qkv_dtype, + void* workspace, + size_t* workspace_size, + cudaStream_t stream); + +/** Backward: dO, O, attn_weights -> dQ, dK, dV. attn_weights is the buffer from forward + * (output_S). workspace must be at least fused_attn_small_seq_bwd_workspace_size. + * max_seqlen_kv is the runtime max KV length when invoked from nvte_fused_attn_small_seq_bwd. */ +void fused_attn_small_seq_bwd(size_t b, + size_t h_q, + size_t h_kv, + size_t max_seqlen_kv, + size_t d_qk, + size_t d_v, + float attn_scale, + float dropout, + const void* devPtrQ, + const void* devPtrK, + const void* devPtrV, + const void* devPtrO, + const void* devPtrdO, + const void* attn_weights, + void* devPtrdQ, + void* devPtrdK, + void* devPtrdV, + const void* devPtrCuSeqlensKV, + const void* devPtrSeqOffsetsKV, + DType qkv_dtype, + void* workspace, + size_t* workspace_size, + cudaStream_t stream); + +} // namespace fused_attn_rocm +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALL_SEQ_H_ diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 726cc4e47..0c98b5971 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -697,6 +697,43 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream); +#ifdef __HIP_PLATFORM_AMD__ +bool nvte_is_small_seq_attn_supported( + NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, + size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right); + +/*! \brief Device workspace bytes required for small-seq backward. */ +size_t nvte_fused_attn_small_seq_bwd_workspace_size(size_t batch, size_t attn_heads, + size_t max_seqlen_kv, NVTEDType dtype); + +/*! \brief ROCm small-seq forward: separate Q, K, V; same tensor roles as nvte_fused_attn_fwd. */ +void nvte_fused_attn_small_seq_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, + size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, + cudaStream_t stream); + +/*! \brief ROCm small-seq backward. */ +void nvte_fused_attn_small_seq_bwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, + const NVTETensor dO, const NVTETensor S, NVTETensor dP, + const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, NVTETensor dV, + NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, + cudaStream_t stream); +#endif // __HIP_PLATFORM_AMD__ + /*! \brief Update the RNG state with the seed and calculated offset. * * \warning This API is **experimental** and subject to change. diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 21db296c3..45d2319f6 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -1257,6 +1257,130 @@ def _fused_attn_bwd_rule( _fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule) +@partial( + jax.custom_vjp, + nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12), +) +def _fused_attn_small_seq( + qkv: Tuple[jnp.ndarray, ...], + bias: Optional[jnp.ndarray], + sequence_descriptor: SequenceDescriptor, + seed: Optional[jnp.ndarray], + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + qkv_layout: QKVLayout, + scaling_factor: float, + dropout_probability: float, + is_training: bool, + max_segments_per_seq: int, + window_size: Optional[Tuple[int, int]], + context_checkpoint_name: str = "context", +): + output, _ = _fused_attn_small_seq_fwd_rule( + qkv, + bias, + sequence_descriptor, + seed, + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + max_segments_per_seq, + window_size, + context_checkpoint_name=context_checkpoint_name, + ) + return output + + +def _fused_attn_small_seq_fwd_rule( + qkv, + bias, + sequence_descriptor, + seed, + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + max_segments_per_seq, + window_size, + context_checkpoint_name, +): + output, softmax_aux, rng_state = tex.fused_attn_small_seq_fwd( + qkv, + bias, + sequence_descriptor, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + max_segments_per_seq=max_segments_per_seq, + window_size=window_size, + ) + output = checkpoint_name(output, context_checkpoint_name) + softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name) + rng_state = checkpoint_name(rng_state, context_checkpoint_name) + return output, ( + qkv, + bias, + sequence_descriptor, + softmax_aux, + rng_state, + output, + ) + + +def _fused_attn_small_seq_bwd_rule( + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + max_segments_per_seq, + window_size, + context_checkpoint_name, + ctx, + dz, +): + del context_checkpoint_name + (qkv, bias, sequence_descriptor, softmax_aux, rng_state, output) = ctx + grad_qkv, grad_bias = tex.fused_attn_small_seq_bwd( + qkv, + bias, + softmax_aux, + rng_state, + output, + dz, + sequence_descriptor, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + max_segments_per_seq=max_segments_per_seq, + window_size=window_size, + ) + if attn_bias_type == AttnBiasType.NO_BIAS: + grad_bias = None + return ( + grad_qkv, + grad_bias, + None, + None, + ) + + +_fused_attn_small_seq.defvjp(_fused_attn_small_seq_fwd_rule, _fused_attn_small_seq_bwd_rule) + + def fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], @@ -1399,3 +1523,46 @@ def fused_attn( stripe_size=stripe_size, ) return output + + +def fused_attn_small_seq( + qkv: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], + bias: Optional[jnp.ndarray], + sequence_descriptor: SequenceDescriptor, + seed: Optional[jnp.ndarray], + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + qkv_layout: QKVLayout, + scaling_factor: float, + dropout_probability: float, + is_training: bool, + max_segments_per_seq: int = 1, + window_size: Optional[Tuple[int, int]] = None, + context_checkpoint_name: str = "context", +): + """ + ROCm small-sequence varlen cross-attention (explicit backend API). + + This entry point calls the dedicated small-seq kernels (not generic fused-attn backend + selection). Intended for THD layouts with compile-time head dimensions 128 / 256 / 512 (``d_qk == d_v``), FP16/BF16, + and ``AttnBiasType.NO_BIAS``. + Context parallelism is not supported. + """ + if sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray): + raise ValueError("fused_attn_small_seq requires a SequenceDescriptor.") + + return _fused_attn_small_seq( + qkv, + bias, + sequence_descriptor, + seed, + attn_bias_type, + attn_mask_type, + qkv_layout, + scaling_factor, + dropout_probability, + is_training, + max_segments_per_seq, + window_size, + context_checkpoint_name=context_checkpoint_name, + ) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 4d669bc46..1a2a4b5de 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -59,6 +59,8 @@ "FusedAttnHelper", "fused_attn_fwd", "fused_attn_bwd", + "fused_attn_small_seq_fwd", + "fused_attn_small_seq_bwd", ] @@ -373,11 +375,13 @@ def abstract( elif backend == NVTE_Fused_Attn_Backend.NVTE_CK: if config.qkv_layout.is_thd(): softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1) + softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) - softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) + softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: raise ValueError(f"Unsupported {backend=}") + softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with @@ -784,9 +788,800 @@ def abstract( *, config, ): - """ - Fused attention bwd abstract - """ + """ + Fused attention bwd abstract + """ + del softmax_aux_aval, rng_state_aval, output_aval + + q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) + k_dtype = dtypes.canonicalize_dtype(k_aval.dtype) + v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) + bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) + doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype) + assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype + assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype + + ( + batch_shape, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + qk_head_dim, + v_head_dim, + ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) + + if config.attn_bias_type == AttnBiasType.NO_BIAS: + bias_batch = bias_heads = 0 + else: + *bias_batch_shape, bias_heads, _, _ = bias_aval.shape + bias_batch = reduce(operator.mul, bias_batch_shape) + + deterministic = not FusedAttnHelper.is_non_deterministic_allowed() + + input_batch = reduce(operator.mul, batch_shape) + wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes( + input_batch, + bias_batch, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + bias_heads, + qk_head_dim, + v_head_dim, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type.value, + config.attn_mask_type.value, + config.softmax_type.value, + config.qkv_layout.value, + jax_dtype_to_te_dtype(q_aval.dtype), + config.is_training, + deterministic, + config.max_segments_per_seq, + config.window_size[0], + config.window_size[1], + ) + + dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) + dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype) + dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype) + dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) + wkspace_aval = q_aval.update( + shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype) + ) + + # Validate incoming softmax_offset shape and dtype + assert ( + softmax_offset_aval.dtype == jnp.float32 + ), f"Incorrect softmax_offset dtype: {softmax_offset_aval.dtype}, expected: {jnp.float32}" + if config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX: + assert softmax_offset_aval.shape == (1, attn_heads, 1, 1), ( + f"Incorrect softmax_offset shape for {config.softmax_type}:" + f" {softmax_offset_aval.shape}, expected: (1, {attn_heads}, 1, 1)" + ) + else: + assert softmax_offset_aval.shape == (0,), ( + f"Incorrect softmax_offset shape for {config.softmax_type}:" + f" {softmax_offset_aval.shape}, expected: (0,)" + ) + + if config.softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX: + dsoftmax_offset_aval = q_aval.update( + shape=softmax_offset_aval.shape, dtype=softmax_offset_aval.dtype + ) + else: + dsoftmax_offset_aval = q_aval.update(shape=(1, attn_heads, 1, 1), dtype=jnp.float32) + + return dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval, wkspace_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + Fused attention fwd outer primitive abstract + """ + dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval, _ = ( + FusedAttnBwdPrimitive.abstract(*args, **kwargs) + ) + return dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval + + @staticmethod + def lowering( + ctx, + q, + k, + v, + bias, + softmax_offset, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, + *, + config, + ): + """ + Fused attention bwd lowering rules + """ + q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in + + ( + batch_shape, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + qk_head_dim, + v_head_dim, + ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) + + input_batch = reduce(operator.mul, batch_shape) + + if config.attn_bias_type == AttnBiasType.NO_BIAS: + bias_batch = bias_heads = 0 + else: + *bias_batch_shape, bias_heads, _, _ = bias_aval.shape + bias_batch = reduce(operator.mul, bias_batch_shape) + + if config.cp_striped_window_size is not None: + window_size_left = config.cp_striped_window_size[0] + window_size_right = config.cp_striped_window_size[1] + else: + window_size_left = config.window_size[0] + window_size_right = config.window_size[1] + + return ffi.ffi_lowering(FusedAttnBwdPrimitive.name)( + ctx, + q, + k, + v, + bias, + softmax_offset, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering + input_batch=input_batch, + bias_batch=bias_batch, + q_max_seqlen=q_max_seqlen, + kv_max_seqlen=kv_max_seqlen, + attn_heads=attn_heads, + num_gqa_groups=num_gqa_groups, + bias_heads=bias_heads, + qk_head_dim=qk_head_dim, + v_head_dim=v_head_dim, + max_segments_per_seq=config.max_segments_per_seq, + scaling_factor=float(config.scaling_factor), + dropout_probability=float(config.dropout_probability), + bias_type=int(config.attn_bias_type.value), + mask_type=int(config.attn_mask_type.value), + qkv_layout=int(config.qkv_layout.value), + is_training=config.is_training, + deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), + window_size_left=window_size_left, + window_size_right=window_size_right, + softmax_type=int(config.softmax_type.value), + ) + + @staticmethod + def impl( + q, + k, + v, + bias, + softmax_offset, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config, + ): + assert FusedAttnBwdPrimitive.inner_primitive is not None + + sequence_descriptor = SequenceDescriptor( + seqlens=(q_seqlen, kv_seqlen), + seq_offsets=(q_seq_offsets, k_seq_offsets), + segment_ids=(_q_segment_ids, _kv_segment_ids), + segment_pos=(_q_segment_pos, _kv_segment_pos), + ) + + (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = ( + sequence_descriptor.get_seqlens_and_offsets( + config.attn_mask_type, + config.qkv_layout, + config.window_size, + config.max_segments_per_seq, + ) + ) + + if config.qkv_layout.is_thd(): + + def _fix_len_take(x, condition, fill_value=-1): + x_shape = x.shape + x = x.flatten() + size = x.size + indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0] + # TODO(rewang): try indices_are_sorted + y = jnp.take(x, indices, fill_value=fill_value) + return jnp.reshape(y, x_shape) + + def convert_to_2d(offsets, batch, max_seqlen): + offsets_2d = jnp.where( + offsets >= 0, + offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis], + offsets, + ) + return offsets_2d + + batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval( + q, k, v, config.qkv_layout + ) + assert len(batch) == 1 + kv_batch = q_batch = batch[0] + + # Gather valid q_seqlen, which is greater than 0 + # cuDNN version < 9.3.0: + # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]] + # cuDNN version >= 9.3.0, which supports act_seqlen = 0 + # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]] + if get_cudnn_version() >= (9, 3, 0): + fill_value = 0 + else: + fill_value = -1 + q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value) + kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value) + + # Flatten the offset calculation + # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]] + q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen) + k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen) + + # Gather valid q_seq_offsets, which is greater and equal to 0 + # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]] + # And set the unused position to max size (batch * max_seqlen) + # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]] + q_seq_offsets = _fix_len_take( + q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen + ) + k_seq_offsets = _fix_len_take( + k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen + ) + + q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) + kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) + + dq, dk, dv, dbias, dsoftmax_offset, _ = FusedAttnBwdPrimitive.inner_primitive.bind( + q, + k, + v, + bias, + softmax_offset, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config=config, + ) + return dq, dk, dv, dbias, dsoftmax_offset + + @staticmethod + def batcher(batched_args, batch_dims, *, config): + check_valid_batch_dims(batch_dims) + assert FusedAttnBwdPrimitive.outer_primitive is not None + q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims + + out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim + return ( + FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config), + out_bdims, + ) + + @staticmethod + def infer_sharding_from_operands(config, mesh, arg_infos, result_infos): + del config, result_infos + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + softmax_offset_spec = get_padded_spec(arg_infos[4]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec)) + return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding, dsoftmax_offset_sharding) + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + del result_infos + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + softmax_offset_spec = get_padded_spec(arg_infos[4]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec)) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[-1] = arg_shardings[-3] + arg_shardings[-2] = arg_shardings[-4] + arg_shardings = tuple(arg_shardings) + out_shardings = ( + dq_sharding, + dk_sharding, + dv_sharding, + dbias_sharding, + dsoftmax_offset_sharding, + ) + + def sharded_impl( + q, + k, + v, + bias, + softmax_offset, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + ): + local_dq, local_dk, local_dv, local_dbias, local_dsoftmax_offset = ( + FusedAttnBwdPrimitive.impl( + q, + k, + v, + bias, + softmax_offset, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config=config, + ) + ) + global_dbias = local_dbias + if config.attn_bias_type is not AttnBiasType.NO_BIAS: + global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) + + global_dsoftmax_offset = local_dsoftmax_offset + if config.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX: + global_dsoftmax_offset = all_reduce_sum_along_dp_fsdp(local_dsoftmax_offset, mesh) + + return local_dq, local_dk, local_dv, global_dbias, global_dsoftmax_offset + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(config, mesh, value_types, result_types): + if version.parse(jax.__version__) < version.parse("0.5.0"): + raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") + del config, mesh + # Keep in sync with `infer_sharding_from_operands`. + input_spec = tuple((f"…{x}",) for x in range(len(value_types))) + output_spec = tuple((f"…{x}",) for x in range(len(result_types))) + return SdyShardingRule(input_spec, output_spec) + + +register_primitive(FusedAttnBwdPrimitive) + + +class SmallSeqAttnFwdPrimitive(BasePrimitive): + """ROCm small-sequence cross-attention forward""" + + name = "te_small_seq_attn_forward_ffi" + multiple_results = True + impl_static_args = (13,) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + q_aval, + k_aval, + v_aval, + bias_aval, + seed_aval, + q_seqlen_or_cu_seqlen_aval, + kv_seqlen_or_cu_seqlen_aval, + _q_seq_offsets, + _k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + *, + config: _FusedAttnConfig, + ): + if not is_hip_extension(): + raise ValueError( + "Small-seq attention requires Transformer Engine built for ROCm (HIP extension)." + ) + if not hasattr(transformer_engine_jax, "get_small_seq_attn_fwd_workspace_sizes"): + raise ValueError( + "Small-seq workspace helpers are unavailable; use a ROCm build of transformer_engine_jax." + ) + if config.qkv_layout != QKVLayout.THD_THD_THD: + raise ValueError(f"Small-seq requires QKVLayout.THD_THD_THD, got {config.qkv_layout}") + if config.attn_bias_type != AttnBiasType.NO_BIAS: + raise ValueError("Small-seq does not support attention bias.") + if config.dropout_probability != 0: + raise ValueError("Small-seq kernel path does not support dropout.") + + q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) + k_dtype = dtypes.canonicalize_dtype(k_aval.dtype) + v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) + bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) + assert q_dtype == k_dtype == v_dtype == bias_dtype + assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype + + ( + batch_shape, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + q_head_dim, + v_head_dim, + ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) + + if q_head_dim != v_head_dim or q_head_dim not in (128, 256, 512): + raise ValueError( + "Small-seq requires matching q/v head dims in {128, 256, 512}, " + f"got qk={q_head_dim} v={v_head_dim}" + ) + + output_shape = (*batch_shape, q_max_seqlen, attn_heads, v_head_dim) + out_aval = q_aval.update(shape=output_shape, dtype=q_dtype) + + kv_eff = min(kv_max_seqlen, 16) + softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_eff) + softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=q_dtype) + + checker = _FusedAttnRNGStateChecker() + seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype) + assert seed_dtype == checker.rng_state_dtype + rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) + rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) + + bias_batch = bias_heads = 0 + input_batch = reduce(operator.mul, batch_shape) + wkspace_info = transformer_engine_jax.get_small_seq_attn_fwd_workspace_sizes( + input_batch, + bias_batch, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + bias_heads, + q_head_dim, + v_head_dim, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type.value, + config.attn_mask_type.value, + config.qkv_layout.value, + jax_dtype_to_te_dtype(q_aval.dtype), + config.is_training, + config.max_segments_per_seq, + config.window_size[0], + config.window_size[1], + ) + wkspace_aval = q_aval.update( + shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) + ) + + return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + out_aval, softmax_aux_aval, rng_state_aval, _ = SmallSeqAttnFwdPrimitive.abstract( + *args, **kwargs + ) + return out_aval, softmax_aux_aval, rng_state_aval + + @staticmethod + def lowering( + ctx, + q, + k, + v, + bias, + seed, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + *, + config: _FusedAttnConfig, + ): + q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in + ( + batch_shape, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + q_head_dim, + v_head_dim, + ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) + input_batch = reduce(operator.mul, batch_shape) + bias_batch = bias_heads = 0 + window_size_left = config.window_size[0] + window_size_right = config.window_size[1] + + return ffi.ffi_lowering(SmallSeqAttnFwdPrimitive.name)( + ctx, + q, + k, + v, + bias, + seed, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + input_batch=input_batch, + bias_batch=bias_batch, + q_max_seqlen=q_max_seqlen, + kv_max_seqlen=kv_max_seqlen, + attn_heads=attn_heads, + num_gqa_groups=num_gqa_groups, + bias_heads=bias_heads, + qk_head_dim=q_head_dim, + v_head_dim=v_head_dim, + max_segments_per_seq=config.max_segments_per_seq, + scaling_factor=float(config.scaling_factor), + dropout_probability=float(config.dropout_probability), + bias_type=int(config.attn_bias_type.value), + mask_type=int(config.attn_mask_type.value), + qkv_layout=int(config.qkv_layout.value), + is_training=config.is_training, + deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), + window_size_left=window_size_left, + window_size_right=window_size_right, + ) + + @staticmethod + def impl( + q, + k, + v, + bias, + seed, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config: _FusedAttnConfig, + ): + assert SmallSeqAttnFwdPrimitive.inner_primitive is not None + sequence_descriptor = SequenceDescriptor( + seqlens=(q_seqlen, kv_seqlen), + seq_offsets=(q_seq_offsets, k_seq_offsets), + segment_ids=(_q_segment_ids, _kv_segment_ids), + segment_pos=(_q_segment_pos, _kv_segment_pos), + ) + (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = ( + sequence_descriptor.get_seqlens_and_offsets( + config.attn_mask_type, + config.qkv_layout, + config.window_size, + config.max_segments_per_seq, + ) + ) + if config.qkv_layout.is_thd(): + + def _fix_len_take(x, condition, fill_value=-1): + x_shape = x.shape + x = x.flatten() + size = x.size + indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0] + y = jnp.take(x, indices, fill_value=fill_value) + return jnp.reshape(y, x_shape) + + def convert_to_2d(offsets, batch, max_seqlen): + offsets_2d = jnp.where( + offsets >= 0, + offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis], + offsets, + ) + return offsets_2d + + batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval( + q, k, v, config.qkv_layout + ) + assert len(batch) == 1 + kv_batch = q_batch = batch[0] + if get_cudnn_version() >= (9, 3, 0): + fill_value = 0 + else: + fill_value = -1 + q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value) + kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value) + q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen) + k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen) + q_seq_offsets = _fix_len_take( + q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen + ) + k_seq_offsets = _fix_len_take( + k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen + ) + + q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) + kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) + + output, softmax_aux, rng_state, _ = SmallSeqAttnFwdPrimitive.inner_primitive.bind( + q, + k, + v, + bias, + seed, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config=config, + ) + return output, softmax_aux, rng_state + + @staticmethod + def batcher(batched_args, batch_dims, *, config): + check_valid_batch_dims(batch_dims) + assert SmallSeqAttnFwdPrimitive.outer_primitive is not None + q_bdim, _, _, _, seed_bdim, *_ = batch_dims + out_bdims = q_bdim, q_bdim, seed_bdim + return ( + SmallSeqAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config), + out_bdims, + ) + + @staticmethod + def infer_sharding_from_operands(config, mesh, arg_infos, result_infos): + del config, result_infos + q_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) + ) + rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) + return (out_sharding, softmax_aux_sharding, rng_state_sharding) + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + out_sharding = result_infos[0].sharding + softmax_aux_sharding = result_infos[1].sharding + rng_state_sharding = seed_sharding = NamedSharding( + mesh, PartitionSpec(get_all_mesh_axes(), None) + ) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[4] = seed_sharding + arg_shardings[-1] = arg_shardings[-3] + arg_shardings[-2] = arg_shardings[-4] + arg_shardings = tuple(arg_shardings) + out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + impl = partial(SmallSeqAttnFwdPrimitive.impl, config=config) + return mesh, impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(config, mesh, value_types, result_types): + if version.parse(jax.__version__) < version.parse("0.5.0"): + raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") + del mesh, result_types + input_spec = [(f"…{x}",) for x in range(len(value_types))] + rng_sharding = (f"…{len(value_types)}",) + input_spec[0] = ("…0", "seqlen", "head", "hidden") + out_sharding = ("…0", "seqlen", "head", "hidden") + softmax_aux_sharding = ("…0", "head", "seqlen", "i") + return SdyShardingRule( + tuple(input_spec), (out_sharding, softmax_aux_sharding, rng_sharding) + ) + + +class SmallSeqAttnBwdPrimitive(BasePrimitive): + """ROCm small-sequence cross-attention backward.""" + + name = "te_small_seq_attn_backward_ffi" + multiple_results = True + impl_static_args = (16,) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + q_aval, + k_aval, + v_aval, + bias_aval, + softmax_aux_aval, + rng_state_aval, + output_aval, + doutput_aval, + q_seqlen_or_cu_seqlen_aval, + kv_seqlen_or_cu_seqlen_aval, + _q_seq_offsets, + _k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + *, + config, + ): + if not is_hip_extension(): + raise ValueError( + "Small-seq attention requires Transformer Engine built for ROCm (HIP extension)." + ) + if not hasattr(transformer_engine_jax, "get_small_seq_attn_bwd_workspace_sizes"): + raise ValueError( + "Small-seq workspace helpers are unavailable; use a ROCm build of transformer_engine_jax." + ) del softmax_aux_aval, rng_state_aval, output_aval q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) @@ -798,7 +1593,7 @@ def abstract( assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype ( - batch_shape, + _, # logical batch_shape; small-seq NVTE batch is cu_seqlens_q.shape[0]-1 q_max_seqlen, kv_max_seqlen, attn_heads, @@ -807,17 +1602,14 @@ def abstract( v_head_dim, ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) - if config.attn_bias_type == AttnBiasType.NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - + bias_batch = bias_heads = 0 deterministic = not FusedAttnHelper.is_non_deterministic_allowed() - - input_batch = reduce(operator.mul, batch_shape) - wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes( - input_batch, + # NVTE uses b = cu_seqlens_q.shape[0] - 1 (one packed segment per slot), not + # reduce(batch_shape). E.g. seqpack with max_seqlen_q>1 yields cu length + # batch*segments+1 while Q still has leading logical batch only. + small_seq_workspace_batch = q_seqlen_or_cu_seqlen_aval.shape[0] - 1 + wkspace_shape, wkspace_dtype = transformer_engine_jax.get_small_seq_attn_bwd_workspace_sizes( + small_seq_workspace_batch, bias_batch, q_max_seqlen, kv_max_seqlen, @@ -830,7 +1622,6 @@ def abstract( config.dropout_probability, config.attn_bias_type.value, config.attn_mask_type.value, - config.softmax_type.value, config.qkv_layout.value, jax_dtype_to_te_dtype(q_aval.dtype), config.is_training, @@ -847,40 +1638,14 @@ def abstract( wkspace_aval = q_aval.update( shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype) ) - - # Validate incoming softmax_offset shape and dtype - assert ( - softmax_offset_aval.dtype == jnp.float32 - ), f"Incorrect softmax_offset dtype: {softmax_offset_aval.dtype}, expected: {jnp.float32}" - if config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX: - assert softmax_offset_aval.shape == (1, attn_heads, 1, 1), ( - f"Incorrect softmax_offset shape for {config.softmax_type}:" - f" {softmax_offset_aval.shape}, expected: (1, {attn_heads}, 1, 1)" - ) - else: - assert softmax_offset_aval.shape == (0,), ( - f"Incorrect softmax_offset shape for {config.softmax_type}:" - f" {softmax_offset_aval.shape}, expected: (0,)" - ) - - if config.softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX: - dsoftmax_offset_aval = q_aval.update( - shape=softmax_offset_aval.shape, dtype=softmax_offset_aval.dtype - ) - else: - dsoftmax_offset_aval = q_aval.update(shape=(1, attn_heads, 1, 1), dtype=jnp.float32) - - return dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval, wkspace_aval + return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): - """ - Fused attention fwd outer primitive abstract - """ - dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval, _ = ( - FusedAttnBwdPrimitive.abstract(*args, **kwargs) + dq_aval, dk_aval, dv_aval, dbias_aval, _ = SmallSeqAttnBwdPrimitive.abstract( + *args, **kwargs ) - return dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval + return dq_aval, dk_aval, dv_aval, dbias_aval @staticmethod def lowering( @@ -889,7 +1654,6 @@ def lowering( k, v, bias, - softmax_offset, softmax_aux, rng_state, output, @@ -905,11 +1669,7 @@ def lowering( *, config, ): - """ - Fused attention bwd lowering rules - """ q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in - ( batch_shape, q_max_seqlen, @@ -919,29 +1679,17 @@ def lowering( qk_head_dim, v_head_dim, ) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) - input_batch = reduce(operator.mul, batch_shape) + bias_batch = bias_heads = 0 + window_size_left = config.window_size[0] + window_size_right = config.window_size[1] - if config.attn_bias_type == AttnBiasType.NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - - if config.cp_striped_window_size is not None: - window_size_left = config.cp_striped_window_size[0] - window_size_right = config.cp_striped_window_size[1] - else: - window_size_left = config.window_size[0] - window_size_right = config.window_size[1] - - return ffi.ffi_lowering(FusedAttnBwdPrimitive.name)( + return ffi.ffi_lowering(SmallSeqAttnBwdPrimitive.name)( ctx, q, k, v, bias, - softmax_offset, softmax_aux, rng_state, output, @@ -953,7 +1701,7 @@ def lowering( q_segment_ids, kv_segment_ids, q_segment_pos, - kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering + kv_segment_pos, input_batch=input_batch, bias_batch=bias_batch, q_max_seqlen=q_max_seqlen, @@ -973,7 +1721,6 @@ def lowering( deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), window_size_left=window_size_left, window_size_right=window_size_right, - softmax_type=int(config.softmax_type.value), ) @staticmethod @@ -982,7 +1729,6 @@ def impl( k, v, bias, - softmax_offset, softmax_aux, rng_state, output, @@ -997,15 +1743,13 @@ def impl( _kv_segment_pos, config, ): - assert FusedAttnBwdPrimitive.inner_primitive is not None - + assert SmallSeqAttnBwdPrimitive.inner_primitive is not None sequence_descriptor = SequenceDescriptor( seqlens=(q_seqlen, kv_seqlen), seq_offsets=(q_seq_offsets, k_seq_offsets), segment_ids=(_q_segment_ids, _kv_segment_ids), segment_pos=(_q_segment_pos, _kv_segment_pos), ) - (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = ( sequence_descriptor.get_seqlens_and_offsets( config.attn_mask_type, @@ -1014,7 +1758,6 @@ def impl( config.max_segments_per_seq, ) ) - if config.qkv_layout.is_thd(): def _fix_len_take(x, condition, fill_value=-1): @@ -1022,7 +1765,6 @@ def _fix_len_take(x, condition, fill_value=-1): x = x.flatten() size = x.size indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0] - # TODO(rewang): try indices_are_sorted y = jnp.take(x, indices, fill_value=fill_value) return jnp.reshape(y, x_shape) @@ -1039,28 +1781,14 @@ def convert_to_2d(offsets, batch, max_seqlen): ) assert len(batch) == 1 kv_batch = q_batch = batch[0] - - # Gather valid q_seqlen, which is greater than 0 - # cuDNN version < 9.3.0: - # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]] - # cuDNN version >= 9.3.0, which supports act_seqlen = 0 - # [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]] if get_cudnn_version() >= (9, 3, 0): fill_value = 0 else: fill_value = -1 q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value) kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value) - - # Flatten the offset calculation - # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]] q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen) k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen) - - # Gather valid q_seq_offsets, which is greater and equal to 0 - # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]] - # And set the unused position to max size (batch * max_seqlen) - # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]] q_seq_offsets = _fix_len_take( q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen ) @@ -1071,12 +1799,11 @@ def convert_to_2d(offsets, batch, max_seqlen): q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) - dq, dk, dv, dbias, dsoftmax_offset, _ = FusedAttnBwdPrimitive.inner_primitive.bind( + dq, dk, dv, dbias, _ = SmallSeqAttnBwdPrimitive.inner_primitive.bind( q, k, v, bias, - softmax_offset, softmax_aux, rng_state, output, @@ -1091,17 +1818,16 @@ def convert_to_2d(offsets, batch, max_seqlen): _kv_segment_pos, config=config, ) - return dq, dk, dv, dbias, dsoftmax_offset + return dq, dk, dv, dbias @staticmethod def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) - assert FusedAttnBwdPrimitive.outer_primitive is not None - q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims - - out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim + assert SmallSeqAttnBwdPrimitive.outer_primitive is not None + q_bdim, k_bdim, v_bdim, *_ = batch_dims + out_bdims = q_bdim, k_bdim, v_bdim, q_bdim return ( - FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config), + SmallSeqAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config), out_bdims, ) @@ -1112,13 +1838,11 @@ def infer_sharding_from_operands(config, mesh, arg_infos, result_infos): k_spec = get_padded_spec(arg_infos[1]) v_spec = get_padded_spec(arg_infos[2]) bias_spec = get_padded_spec(arg_infos[3]) - softmax_offset_spec = get_padded_spec(arg_infos[4]) dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec)) - return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding, dsoftmax_offset_sharding) + return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) @staticmethod def partition(config, mesh, arg_infos, result_infos): @@ -1127,30 +1851,21 @@ def partition(config, mesh, arg_infos, result_infos): k_spec = get_padded_spec(arg_infos[1]) v_spec = get_padded_spec(arg_infos[2]) bias_spec = get_padded_spec(arg_infos[3]) - softmax_offset_spec = get_padded_spec(arg_infos[4]) dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec)) arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings[-1] = arg_shardings[-3] arg_shardings[-2] = arg_shardings[-4] arg_shardings = tuple(arg_shardings) - out_shardings = ( - dq_sharding, - dk_sharding, - dv_sharding, - dbias_sharding, - dsoftmax_offset_sharding, - ) + out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) def sharded_impl( q, k, v, bias, - softmax_offset, softmax_aux, rng_state, output, @@ -1164,37 +1879,25 @@ def sharded_impl( _q_segment_pos, _kv_segment_pos, ): - local_dq, local_dk, local_dv, local_dbias, local_dsoftmax_offset = ( - FusedAttnBwdPrimitive.impl( - q, - k, - v, - bias, - softmax_offset, - softmax_aux, - rng_state, - output, - doutput, - q_cu_seqlen, - kv_cu_seqlen, - q_seq_offsets, - k_seq_offsets, - _q_segment_ids, - _kv_segment_ids, - _q_segment_pos, - _kv_segment_pos, - config=config, - ) + return SmallSeqAttnBwdPrimitive.impl( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config=config, ) - global_dbias = local_dbias - if config.attn_bias_type is not AttnBiasType.NO_BIAS: - global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) - - global_dsoftmax_offset = local_dsoftmax_offset - if config.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX: - global_dsoftmax_offset = all_reduce_sum_along_dp_fsdp(local_dsoftmax_offset, mesh) - - return local_dq, local_dk, local_dv, global_dbias, global_dsoftmax_offset return mesh, sharded_impl, out_shardings, arg_shardings @@ -1203,13 +1906,13 @@ def shardy_sharding_rule(config, mesh, value_types, result_types): if version.parse(jax.__version__) < version.parse("0.5.0"): raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") del config, mesh - # Keep in sync with `infer_sharding_from_operands`. input_spec = tuple((f"…{x}",) for x in range(len(value_types))) output_spec = tuple((f"…{x}",) for x in range(len(result_types))) return SdyShardingRule(input_spec, output_spec) -register_primitive(FusedAttnBwdPrimitive) +register_primitive(SmallSeqAttnFwdPrimitive) +register_primitive(SmallSeqAttnBwdPrimitive) def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contiguous: bool): @@ -3641,3 +4344,127 @@ def fused_attn_bwd( config=fused_config, ) return tuple(qkv_grads[: len(qkv)]), bias_grad, softmax_offset_grad + + +def fused_attn_small_seq_fwd( + qkv: Tuple[jnp.ndarray, ...], + bias: Optional[jnp.ndarray], + sequence_descriptor: SequenceDescriptor, + seed: Optional[jnp.ndarray], + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + qkv_layout: QKVLayout, + scaling_factor: float, + dropout_probability: float, + is_training: bool, + max_segments_per_seq: int, + window_size: Optional[Tuple[int, int]] = None, + softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX, + stripe_size: int | None = None, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """ + Forward pass for ROCm small-sequence cross-attention (explicit TE API). + """ + if not is_hip_extension(): + raise RuntimeError( + "fused_attn_small_seq_fwd requires Transformer Engine built for ROCm (HIP)." + ) + if qkv_layout != QKVLayout.THD_THD_THD: + raise ValueError( + f"fused_attn_small_seq_fwd requires QKVLayout.THD_THD_THD, got {qkv_layout}" + ) + if len(qkv) != 3: + raise ValueError(f"fused_attn_small_seq_fwd expects qkv=(q, k, v), got len={len(qkv)}") + + seed = _FusedAttnRNGStateChecker().check_seed(seed, dropout_probability, is_training) + if attn_bias_type == AttnBiasType.NO_BIAS: + assert bias is None + bias = jnp.zeros(0, dtype=qkv[0].dtype) + else: + raise ValueError("fused_attn_small_seq_fwd only supports AttnBiasType.NO_BIAS") + + fused_config = _FusedAttnConfig( + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + softmax_type=softmax_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + max_segments_per_seq=max_segments_per_seq, + window_size=(-1, -1) if window_size is None else window_size, + context_parallel_load_balanced=False, + cp_axis="", + cp_striped_window_size=None, + stripe_size=stripe_size, + ) + + seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) + output, softmax_aux, rng_state = SmallSeqAttnFwdPrimitive.outer_primitive.bind( + *qkv, + bias, + seed, + *seq_desc_flatten, + config=fused_config, + ) + rng_state = with_sharding_constraint(rng_state, PartitionSpec(get_all_mesh_axes(), None)) + return (output, softmax_aux, rng_state) + + +def fused_attn_small_seq_bwd( + qkv: Tuple[jnp.ndarray, ...], + bias: Optional[jnp.ndarray], + softmax_aux: jnp.ndarray, + rng_state: jnp.ndarray, + output: jnp.ndarray, + doutput: jnp.ndarray, + sequence_descriptor: SequenceDescriptor, + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + qkv_layout: QKVLayout, + scaling_factor: float, + dropout_probability: float, + is_training: bool, + max_segments_per_seq: int, + window_size: Optional[Tuple[int, int]] = None, + softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX, + stripe_size: int | None = None, +): + if not is_hip_extension(): + raise RuntimeError( + "fused_attn_small_seq_bwd requires Transformer Engine built for ROCm (HIP)." + ) + if len(qkv) != 3: + raise ValueError(f"fused_attn_small_seq_bwd expects qkv=(q, k, v), got len={len(qkv)}") + if attn_bias_type == AttnBiasType.NO_BIAS: + assert bias is None + bias = jnp.zeros(0, dtype=qkv[0].dtype) + + fused_config = _FusedAttnConfig( + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + softmax_type=softmax_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + max_segments_per_seq=max_segments_per_seq, + window_size=(-1, -1) if window_size is None else window_size, + context_parallel_load_balanced=False, + cp_axis="", + cp_striped_window_size=None, + stripe_size=stripe_size, + ) + + seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) + dq, dk, dv, _dbias = SmallSeqAttnBwdPrimitive.outer_primitive.bind( + *qkv, + bias, + softmax_aux, + rng_state, + output, + doutput, + *seq_desc_flatten, + config=fused_config, + ) + return (dq, dk, dv), None diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 8c2798c68..6108327ff 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -137,6 +137,27 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right); +#ifdef USE_ROCM +XLA_FFI_DECLARE_HANDLER_SYMBOL(SmallSeqAttnForwardHandler); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(SmallSeqAttnBackwardHandler); + +pybind11::tuple GetSmallSeqAttnForwardWorkspaceSizes( + size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, + size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right); + +pybind11::tuple GetSmallSeqAttnBackwardWorkspaceSizes( + size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, + bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, + int64_t window_size_right); +#endif // USE_ROCM + // GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler); diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 41347a85e..edcdc081d 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -351,6 +351,185 @@ static void FusedAttnForwardImpl( nvte_tensor_pack_destroy(&aux_output_tensors); } +#ifdef USE_ROCM +void PrepareSmallSeqAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch, + const size_t attn_heads, const size_t q_max_seqlen, + const size_t kv_max_seqlen, DType dtype, + void *softmax_buf, void *rng_state_buf) { + tensor_pack->size = 2; + NVTETensor &softmax_aux = tensor_pack->tensors[0]; + NVTEBasicTensor softmax_aux_data; + softmax_aux_data.data_ptr = softmax_buf; + softmax_aux_data.shape.ndim = 4; + softmax_aux_data.shape.data[0] = input_batch; + softmax_aux_data.shape.data[1] = attn_heads; + softmax_aux_data.shape.data[2] = q_max_seqlen; + softmax_aux_data.shape.data[3] = std::min(kv_max_seqlen, size_t{16}); + softmax_aux_data.dtype = static_cast(dtype); + nvte_set_tensor_param(&softmax_aux, kNVTERowwiseData, &softmax_aux_data); + + NVTETensor &rng_state_aux = tensor_pack->tensors[1]; + NVTEBasicTensor rng_state_aux_data; + rng_state_aux_data.data_ptr = rng_state_buf; + rng_state_aux_data.shape = {}; + rng_state_aux_data.shape.ndim = 2; + rng_state_aux_data.dtype = static_cast(DType::kInt64); + nvte_set_tensor_param(&rng_state_aux, kNVTERowwiseData, &rng_state_aux_data); +} + +void PrepareSmallSeqAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch, + const size_t attn_heads, const size_t q_max_seqlen, + const size_t kv_max_seqlen, DType dtype, void *softmax_buf, + void *rng_state_buf) { + PrepareSmallSeqAttnForwardAuxTensors(tensor_pack, input_batch, attn_heads, q_max_seqlen, + kv_max_seqlen, dtype, softmax_buf, rng_state_buf); +} + +pybind11::tuple GetSmallSeqAttnForwardWorkspaceSizes( + size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, + size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { + (void)scaling_factor; + (void)is_training; + (void)max_segments_per_seq; + NVTE_CHECK( + nvte_is_small_seq_attn_supported( + static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, + mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, + qk_head_dim, v_head_dim, window_size_left, window_size_right), + "GetSmallSeqAttnForwardWorkspaceSizes: configuration not supported."); + NVTE_CHECK(bias_batch == 0 && bias_heads == 0, + "GetSmallSeqAttnForwardWorkspaceSizes: bias not supported for small-seq."); + // At least 8 bytes: nvte_fused_attn_small_seq_fwd uses the start of workspace for + // ck_fused_attn::get_runtime_max_seqlen + constexpr size_t k_small_seq_runtime_probe_bytes = 8; + TensorWrapper query_workspace_tensor(nullptr, std::vector{k_small_seq_runtime_probe_bytes}, + DType::kByte); + return pybind11::make_tuple(MakeShapeVector(query_workspace_tensor.shape()), + query_workspace_tensor.dtype()); +} + +static void SmallSeqAttnForwardImpl( + cudaStream_t stream, void *q, void *k, void *v, void *bias, void *seed, void *q_cu_seqlens, + void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *output, void *softmax_aux, + void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, + size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, + size_t qk_head_dim, size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size, + float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, + bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) { + (void)deterministic; + FUSED_ATTN_IMPL_COMMON_BLOCK; + NVTE_CHECK(layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD && is_ragged, + "SmallSeqAttnForwardImpl: requires THD separate Q, K, V."); + NVTE_CHECK(bias_batch == 0 && bias_heads == 0, + "SmallSeqAttnForwardImpl: bias not supported for small-seq."); + + auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); + + if (is_ragged) { + auto output_size = input_batch * q_max_seqlen * attn_heads * v_head_dim; + (void)cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream); + size_t kv_eff = std::min(kv_max_seqlen, size_t{16}); + auto softmax_aux_elems = input_batch * q_max_seqlen * attn_heads * kv_eff; + (void)cudaMemsetAsync(softmax_aux, 0, softmax_aux_elems * typeToSize(dtype), stream); + } + + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + auto o_shape = std::vector{input_batch * q_max_seqlen, attn_heads, v_head_dim}; + auto o_tensor = TensorWrapper(output, o_shape, dtype); + auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); + + nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, + NVTE_Fused_Attn_Backend::NVTE_CK, stream); + + NVTETensorPack aux_output_tensors; + nvte_tensor_pack_create(&aux_output_tensors); + PrepareSmallSeqAttnForwardAuxTensors(&aux_output_tensors, input_batch, attn_heads, q_max_seqlen, + kv_max_seqlen, dtype, softmax_aux, rng_state); + + auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector{1}, DType::kInt32); + + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto k_tensor = TensorWrapper(k, k_shape, dtype); + auto v_tensor = TensorWrapper(v, v_shape, dtype); + + nvte_fused_attn_small_seq_fwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), + o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), + dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), + q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); + + nvte_tensor_pack_destroy(&aux_output_tensors); +} + +static void SmallSeqAttnBackwardImpl( + cudaStream_t stream, void *q, void *k, void *v, void *bias, void *softmax_aux, void *rng_state, + void *output, void *doutput, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets, + void *k_seq_offsets, void *dq, void *dk, void *dv, void *dbias, void *workspace, + size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, + float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, + bool deterministic, int64_t window_size_left, int64_t window_size_right) { + FUSED_ATTN_IMPL_COMMON_BLOCK; + NVTE_CHECK(layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD && is_ragged, + "SmallSeqAttnBackwardImpl: requires THD separate Q, K, V."); + NVTE_CHECK(bias_batch == 0 && bias_heads == 0, + "SmallSeqAttnBackwardImpl: bias not supported for small-seq."); + + auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, v_head_dim}; + auto output_tensor = TensorWrapper(output, output_shape, dtype); + auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); + + NVTETensorPack aux_input_tensors; + nvte_tensor_pack_create(&aux_input_tensors); + PrepareSmallSeqAttnBackwardAuxTensors(&aux_input_tensors, input_batch, attn_heads, q_max_seqlen, + kv_max_seqlen, dtype, softmax_aux, rng_state); + + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto k_tensor = TensorWrapper(k, k_shape, dtype); + auto v_tensor = TensorWrapper(v, v_shape, dtype); + auto dq_tensor = TensorWrapper(dq, q_shape, dtype); + auto dk_tensor = TensorWrapper(dk, k_shape, dtype); + auto dv_tensor = TensorWrapper(dv, v_shape, dtype); + + if (is_ragged) { + (void)cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), + stream); + (void)cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), + stream); + (void)cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), + stream); + } + + nvte_fused_attn_small_seq_bwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), s_tensor.data(), s_tensor.data(), &aux_input_tensors, + dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), + k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, + deterministic, workspace_tensor.data(), stream); + + nvte_tensor_pack_destroy(&aux_input_tensors); + (void)is_training; +} +#endif // USE_ROCM + #define FUSED_ATTN_FFI_GET_ATTRS \ size_t input_batch = get_attr_value(attrs, "input_batch"); \ size_t bias_batch = get_attr_value(attrs, "bias_batch"); \ @@ -501,9 +680,38 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( nvte_tensor_pack_destroy(&aux_input_tensors); auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); + return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); } +#ifdef USE_ROCM +pybind11::tuple GetSmallSeqAttnBackwardWorkspaceSizes( + size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim, + size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, + bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, + int64_t window_size_right) { + (void)scaling_factor; + (void)is_training; + (void)deterministic; + (void)max_segments_per_seq; + NVTE_CHECK( + nvte_is_small_seq_attn_supported( + static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, + mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, + qk_head_dim, v_head_dim, window_size_left, window_size_right), + "GetSmallSeqAttnBackwardWorkspaceSizes: configuration not supported."); + NVTE_CHECK(bias_batch == 0 && bias_heads == 0, + "GetSmallSeqAttnBackwardWorkspaceSizes: bias not supported for small-seq."); + size_t bwd_bytes = nvte_fused_attn_small_seq_bwd_workspace_size( + input_batch, attn_heads, kv_max_seqlen, static_cast(dtype)); + TensorWrapper query_workspace_tensor(nullptr, std::vector{bwd_bytes}, DType::kByte); + return pybind11::make_tuple(MakeShapeVector(query_workspace_tensor.shape()), + query_workspace_tensor.dtype()); +} +#endif // USE_ROCM + static void FusedAttnBackwardImpl( cudaStream_t stream, void *q, void *k, void *v, void *bias, void *softmax_offset, void *softmax_aux, void *rng_state, void *output, void *doutput, void *q_cu_seqlens, @@ -680,5 +888,100 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI, .Attrs(), FFI_CudaGraph_Traits); +#ifdef USE_ROCM +Error_Type SmallSeqAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, + Buffer_Type v_buf, Buffer_Type bias_buf, Buffer_Type seed_buf, + Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, + Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, + Variadic_Buffer_Type _unused_args, Result_Type output_buf, + Result_Type softmax_aux_buf, Result_Type rng_state_buf, + Result_Type workspace_buf, Dictionary attrs) { + FUSED_ATTN_FFI_GET_ATTRS; + + SmallSeqAttnForwardImpl( + stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), + bias_buf.untyped_data(), seed_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), + kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, + is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, output_buf->untyped_data(), + softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), workspace_buf->untyped_data(), + input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, + qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor, + dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype, is_training, + deterministic, window_size_left, window_size_right); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(SmallSeqAttnForwardHandler, SmallSeqAttnForwardFFI, + FFI::Bind() + .Ctx() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .RemainingArgs() + .Ret() + .Ret() + .Ret() + .Ret() + .Attrs(), + FFI_CudaGraph_Traits); + +Error_Type SmallSeqAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, + Buffer_Type v_buf, Buffer_Type bias_buf, + Buffer_Type softmax_aux_buf, Buffer_Type rng_state_buf, + Buffer_Type output_buf, Buffer_Type doutput_buf, + Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, + Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, + Variadic_Buffer_Type _unused_args, Result_Type dq_buf, + Result_Type dk_buf, Result_Type dv_buf, Result_Type dbias_buf, + Result_Type workspace_buf, Dictionary attrs) { + FUSED_ATTN_FFI_GET_ATTRS; + + SmallSeqAttnBackwardImpl( + stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), + bias_buf.untyped_data(), softmax_aux_buf.untyped_data(), rng_state_buf.untyped_data(), + output_buf.untyped_data(), doutput_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), + kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, + is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, dq_buf->untyped_data(), + dk_buf->untyped_data(), dv_buf->untyped_data(), dbias_buf->untyped_data(), + workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, + attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim, max_segments_per_seq, + wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, + wkspace_dtype, is_training, deterministic, window_size_left, window_size_right); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(SmallSeqAttnBackwardHandler, SmallSeqAttnBackwardFFI, + FFI::Bind() + .Ctx() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .RemainingArgs() + .Ret() + .Ret() + .Ret() + .Ret() + .Ret() + .Attrs(), + FFI_CudaGraph_Traits); +#endif // USE_ROCM + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 0c56bb088..05b3054b2 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -92,6 +92,8 @@ pybind11::dict Registrations() { // Attention dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler); dict["te_fused_attn_backward_ffi"] = EncapsulateFFI(FusedAttnBackwardHandler); + dict["te_small_seq_attn_forward_ffi"] = EncapsulateFFI(SmallSeqAttnForwardHandler); + dict["te_small_seq_attn_backward_ffi"] = EncapsulateFFI(SmallSeqAttnBackwardHandler); dict["te_gemm_ffi"] = EncapsulateFFI(GemmHandler); dict["te_grouped_gemm_ffi"] = EncapsulateFFI(GroupedGemmHandler); @@ -117,6 +119,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_norm_bwd_workspace_sizes", &GetNormBackwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); +#ifdef USE_ROCM + m.def("get_small_seq_attn_fwd_workspace_sizes", &GetSmallSeqAttnForwardWorkspaceSizes); + m.def("get_small_seq_attn_bwd_workspace_sizes", &GetSmallSeqAttnBackwardWorkspaceSizes); +#endif m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); #ifndef USE_ROCM