diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 0258951d4..b372d3987 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 0258951d4d512f4714eb1574496f4d57669b1b93 +Subproject commit b372d39879d44c91a8d5b342022e74802b6a8da2 diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index d5e1cb291..d8b698973 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.12.0.dev0 +2.12.0 diff --git a/build_tools/hipify/custom_map.json b/build_tools/hipify/custom_map.json index 5fc3cded0..6525731f5 100644 --- a/build_tools/hipify/custom_map.json +++ b/build_tools/hipify/custom_map.json @@ -11,7 +11,6 @@ "__nv_fp8_e5m2" : "te_hip_fp8_e5m2", "__nv_fp8_e4m3" : "te_hip_fp8_e4m3", "cuda::getCurrentCUDAStream" : "hip::getCurrentHIPStreamMasqueradingAsCUDA", - "at::cuda::CUDAGuard" : "at::hip::HIPGuardMasqueradingAsCUDA", "__nv_fp4_e2m1" : "__hip_fp4_e2m1", "__nv_fp4x2_e2m1" : "__hip_fp4x2_e2m1", "__nv_fp4x4_e2m1" : "__hip_fp4x4_e2m1", diff --git a/examples/jax/datasets.txt b/examples/jax/datasets.txt new file mode 100644 index 000000000..fd3f5bc41 --- /dev/null +++ b/examples/jax/datasets.txt @@ -0,0 +1,3 @@ +# Datasets used by TE encoder tests. Pull these to pre-emptively cache datasets +ylecun/mnist +nyu-mll/glue \ No newline at end of file diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index ee9ce130a..3453e35d2 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -29,6 +29,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" mkdir -p "$XML_LOG_DIR" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_with_determinism.xml $TE_PATH/tests/jax/test_fused_attn.py -k "TestFusedAttnWithDeterminism" || test_fail "tests/jax/test_fused_attn.py" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 21eed2836..a13dfada7 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -45,6 +45,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_e python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh index b3a520e12..6f9ff54e4 100644 --- a/qa/L1_pytorch_onnx_unittest/test.sh +++ b/qa/L1_pytorch_onnx_unittest/test.sh @@ -6,4 +6,5 @@ : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py +# NVTE_UnfusedDPA_Emulate_FP8=1 enables FP8 attention emulation when no native backend is available +NVTE_UnfusedDPA_Emulate_FP8=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 4e42fad92..66244444f 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -71,12 +71,16 @@ std::vector create_transpose(const InputType* const input, const size } // Compute the global encode scale factor for a given global amax -float compute_global_encode_scaling_factor_FP4(const float global_amax) { +float compute_global_encode_scaling_factor_FP4(const float global_amax, const bool use_fast_math) { constexpr float fp8_max = 448.0f; // 448.0f; constexpr float fp4_max = 6.0f; // 6.0f; float global_encode_scale = fp8_max * fp4_max / global_amax; - // If scale is infinity, return max value of float32 - global_encode_scale = fminf(global_encode_scale, Numeric_Traits::maxNorm); + // If scale is infinity, return the max normalized value + const float max_norm_clamp = use_fast_math + ? Numeric_Traits::maxNorm + : Numeric_Traits::maxNorm; + + global_encode_scale = fminf(global_encode_scale, max_norm_clamp); // If global amax is 0 or infinity, return 1 if (global_amax == 0.0f || global_encode_scale == 0.0f) { return 1.0f; @@ -93,10 +97,11 @@ void quantize_nvfp4_1d(float (*OP)(const float), const size_t rows, const size_t cols, const size_t scales_stride, - const float global_amax) { + const float global_amax, + const bool use_fast_math) { // Compute a global encoding/decoding scaling factor for all S_dec_b - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); constexpr size_t block_size_X = 16; const size_t blocks_X = divide_round_up(cols, block_size_X); @@ -131,14 +136,20 @@ void quantize_nvfp4_1d(float (*OP)(const float), const float S_dec_b = block_amax / 6.0f; // Scale & Store per-block decoding scaling factor - const float S_dec_b_fp8 = S_dec_b * S_enc; + const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + const float S_dec_b_fp32 = static_cast(S_dec_b_fp8); // Compute "correct" per-block encoding scaling factor - const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8; + const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32; const size_t scale_idx = i * scales_stride + block_X; - scales[scale_idx] = static_cast(S_dec_b_fp8); - const float scale_reciprocal = S_enc_b_fp8; + scales[scale_idx] = S_dec_b_fp8; + + float scale_reciprocal = S_enc_b_fp8; + if (use_fast_math) { + // Numerical truncation to match GPU implementation, if mixed precision FMA instruction is used + scale_reciprocal = static_cast(static_cast(scale_reciprocal)); + } for (size_t j = j_min; j < j_max; j += 2) { const int idx_pair = (i * cols + j) / 2; @@ -153,7 +164,7 @@ void quantize_nvfp4_1d(float (*OP)(const float), fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair); output[idx_pair] = casted_to_e2m1_pair; - // const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair); + const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair); } } } @@ -166,9 +177,10 @@ void compute_2d_mathematical_scales(float (*OP)(const float), const size_t rows, const size_t cols, const float global_amax, - std::vector>& math_scales) { + std::vector>& math_scales, + const bool use_fast_math) { - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; const size_t blocks_Y = divide_round_up(rows, block_size_Y); @@ -212,13 +224,14 @@ void quantize_nvfp4_2d(float (*OP)(const float), const size_t rows, const size_t cols, const size_t scales_stride, - const float global_amax) { + const float global_amax, + const bool use_fast_math) { // Step 1: Compute mathematical 8x8 scaling factors std::vector> math_scales; - compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math); - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; const size_t blocks_Y = divide_round_up(rows, block_size_Y); @@ -299,11 +312,12 @@ void quantize_nvfp4(float (*OP)(const float), const size_t cols, const size_t scales_stride, const float global_amax, + const bool use_fast_math, const bool use_2d_quantization = false) { if (use_2d_quantization) { - quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax); + quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math); } else { - quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax); + quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math); } } @@ -319,6 +333,7 @@ void compute_ref(float (*OP)(const float), const size_t cols, const size_t scales_stride, const size_t scales_stride_t, + const bool use_fast_math, const bool use_2d_quantization = false) { std::vector input_t = create_transpose(input, rows, cols); @@ -326,7 +341,7 @@ void compute_ref(float (*OP)(const float), if (use_2d_quantization) { // Step 1: Compute mathematical 8×8 scaling factors std::vector> math_scales; - compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; @@ -353,12 +368,16 @@ void compute_ref(float (*OP)(const float), // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d // (This part processes the actual FP4 data using the mathematical scaling factors) - quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax); // scales already filled - quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax); // scales_t already filled + quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax, + use_fast_math); // scales already filled + quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax, + use_fast_math); // scales_t already filled } else { - quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_2d_quantization); - quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, use_2d_quantization); + quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax, + use_fast_math, use_2d_quantization); + quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, + use_fast_math, use_2d_quantization); } } @@ -366,6 +385,8 @@ void compare_nvfp4_tensors(const std::string& name, const fp4e2m1 *test_data, const fp4e2m1 *ref_data, const int rows, const int cols, double atol = 1e-5, double rtol = 1e-8) { + constexpr int max_mismatches_to_print = 3; + std::vector mismatch_messages; size_t total_mismatches = 0; @@ -379,29 +400,30 @@ void compare_nvfp4_tensors(const std::string& name, const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); +#ifdef __HIP_PLATFORM_AMD__ bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); - /* For Float32 the floating point comparison is enough to error out */ - bool assertion = false; - if (mismatch && !assertion) { - /* Check if it is just a failure of round to nearest choosing different - side of the real value */ + if (mismatch) { + // Check if it is just a failure of round to nearest choosing different + // side of the real value const double mean = (t + r) / 2; const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); const double cast_mean_p = static_cast(static_cast(mean_p)); const double cast_mean_m = static_cast(static_cast(mean_m)); - assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + mismatch = !(cast_mean_m == std::min(t, r) && cast_mean_p == std::max(t, r)); } - if (assertion) { +#else + const bool mismatch = fabs(t - r) > (atol + fabs(r) * rtol); +#endif + if (mismatch) { total_mismatches++; - std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " + - std::to_string(t) + " vs " + std::to_string(r) + - " (abs_diff: " + std::to_string(fabs(t - r)) + - ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")"; - mismatch_messages.push_back(msg); - // Optional: limit number of detailed messages to avoid overwhelming output - if (mismatch_messages.size() <= 100) { + if (total_mismatches <= max_mismatches_to_print) { + std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " + + std::to_string(t) + " vs " + std::to_string(r) + + " (abs_diff: " + std::to_string(fabs(t - r)) + + ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")"; + mismatch_messages.push_back(msg); std::cout << "Error in tensor " << name << ": " << msg << std::endl; } } @@ -417,8 +439,9 @@ void compare_nvfp4_tensors(const std::string& name, std::cout << "STATUS: FAILED for output" << std::endl; std::cout << "Total mismatches found: " << total_mismatches << std::endl; std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl; - if (mismatch_messages.size() > 100) { - std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl; + if (mismatch_messages.size() > max_mismatches_to_print) { + std::cout << "... and " << (mismatch_messages.size() - max_mismatches_to_print) + << " more mismatches (showing first " << max_mismatches_to_print << ")" << std::endl; } std::cout << "============================" << std::endl; @@ -536,7 +559,8 @@ void compareResults_nvfp4(const Tensor &test, template void performTest(float (*OP)(const float), - const std::vector& shape) { + const std::vector& shape, + const bool use_fast_math) { using namespace test; DType itype = TypeInfo::dtype; @@ -608,15 +632,16 @@ void performTest(float (*OP)(const float), cols, scales_stride, scales_stride_t, + use_fast_math, use_2d_quantization); - - QuantizationConfigWrapper quant_config; - // Initialize stochastic rounding Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); rng_state.rowwise_cpu_dptr()[0] = 123; // rng_seed rng_state.rowwise_cpu_dptr()[1] = 321; // rng_sequence rng_state.from_cpu(); + + QuantizationConfigWrapper quant_config; + quant_config.set_use_fast_math(use_fast_math); #ifdef __HIP_PLATFORM_AMD__ quant_config.set_stochastic_rounding(use_stochastic_rounding); #else @@ -651,8 +676,8 @@ void performTest(float (*OP)(const float), } ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - const double atol = 0.05; - const double rtol = 0.1; + const double atol = 1.0E-6; + const double rtol = 1.0E-6; // Set dump_data=true to enable dumping tensor data to files for analysis compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false); @@ -717,7 +742,8 @@ std::vector Activation_types = { class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam , - transformer_engine::DType>> {}; + transformer_engine::DType, + bool>> {}; TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { #ifndef __HIP_PLATFORM_AMD__ @@ -733,6 +759,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const ActivationType Act_type = std::get<0>(GetParam()); const auto tensor_dims = std::get<1>(GetParam()); const DType input_type = std::get<2>(GetParam()); + const bool use_fast_math = std::get<3>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -750,7 +777,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { } TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - performTest(OP, tensor_dims); + performTest(OP, tensor_dims, use_fast_math); ); } @@ -772,7 +799,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(tensor_dims), - ::testing::Values(DType::kBFloat16)), + ::testing::Values(DType::kBFloat16), + ::testing::Values(false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)); const auto& shape = std::get<1>(info.param); @@ -780,5 +808,8 @@ INSTANTIATE_TEST_SUITE_P( name += "X" + std::to_string(s); } name += "X" + test::typeName(std::get<2>(info.param)); + if (std::get<3>(info.param)) { + name += "X_FAST_SCALING"; + } return name; }); diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 8639af79b..26f5514d3 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -4,6 +4,7 @@ # # See LICENSE for license information. """Tests for fused attention""" +import os from enum import Enum, auto from dataclasses import dataclass, field from functools import partial @@ -52,6 +53,9 @@ from distributed_test_base import assert_equal_collectives from utils import assert_allclose, print_debug_tensor_stats +# Get determinism +_deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + @pytest.fixture(autouse=True, scope="module") def init(): @@ -417,16 +421,24 @@ def _check_configs(self): pytest.skip( "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" ) - # TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support - if ( - get_device_compute_capability(0) >= 100 - and self.dropout_prob == 0.1 - and self.attn_bias_type is not AttnBiasType.NO_BIAS - and not is_hip_extension() - ): - pytest.skip( - "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" - ) + if not is_hip_extension() and get_device_compute_capability(0) >= 100 and self.is_training: + if FusedAttnHelper.is_non_deterministic_allowed() and ( + (self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS) + or get_cudnn_version() < 90700 + ): + pytest.skip( + "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with" + " dropout" + ) + if not FusedAttnHelper.is_non_deterministic_allowed() and ( + self.dropout_prob != 0.0 + or self.attn_bias_type != AttnBiasType.NO_BIAS + or get_cudnn_version() < 91801 + ): + pytest.skip( + "For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or" + " dropout" + ) # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate(): @@ -1346,6 +1358,7 @@ def check_dqkv(primitive, reference, pad, idx): pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"), ], ) +@pytest.mark.skipif(_deterministic, reason="Test non-determinism only") class TestFusedAttn: """ Fused attention tester @@ -1507,3 +1520,183 @@ def test_jax_new_rng(): ) runner = FusedAttnRunner(**kwargs) runner.test_forward() + + + +@pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), + pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"), + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"), + pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"), + pytest.param( + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT" + ), + ], +) +@pytest.mark.parametrize( + "softmax_type", + [ + pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"), + ], +) +@pytest.mark.parametrize( + "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout", + [ + # large data size + fp16 + cross attn + gqa + diff hidden v dim + qkv separate + pytest.param( + 2, + 1024, + 2048, + 12, + 6, + 128, + 64, + jnp.bfloat16, + QKVLayout.BSHD_BSHD_BSHD, + id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-SEPARATE", + ), + pytest.param( + 2, + 1024, + 2048, + 12, + 6, + 128, + 64, + jnp.bfloat16, + QKVLayout.THD_THD_THD, + id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE", + ), + ], +) +@pytest.mark.parametrize( + "dropout_prob", + [ + pytest.param(0.0, id="DROP_0.0"), + ], +) +@pytest.mark.parametrize( + "swa", + [ + pytest.param(False, id="NO_SWA"), + ], +) +@pytest.mark.parametrize( + "seq_desc_format", + [ + pytest.param(SeqDescFormat.Seqlens, id="Seqlens"), + ], +) +@pytest.mark.skipif(not _deterministic, reason="Test determinism only") +class TestFusedAttnWithDeterminism: + """ + Fused attention tester with determinism + """ + + @staticmethod + @pytest.mark.parametrize( + "is_training", + [ + pytest.param(True, id="TRAINING"), + ], + ) + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + ], + ) + def _test_forward( + b, + s_q, + s_kv, + h_q, + h_kv, + d_qk, + d_v, + attn_bias_type, + attn_mask_type, + softmax_type, + dropout_prob, + dtype, + is_training, + qkv_layout, + bias_shape, + swa, + seq_desc_format, + ): + """ + Test forward with parameterized configs + This test is not intended to run automatically during CI as it is time-consuming + It is kept for development and debugging + """ + TestFusedAttn._test_forward( + b, + s_q, + s_kv, + h_q, + h_kv, + d_qk, + d_v, + attn_bias_type, + attn_mask_type, + softmax_type, + dropout_prob, + dtype, + is_training, + qkv_layout, + bias_shape, + swa, + seq_desc_format, + ) + + @staticmethod + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + ], + ) + def test_backward( + b, + s_q, + s_kv, + h_q, + h_kv, + d_qk, + d_v, + attn_bias_type, + attn_mask_type, + softmax_type, + dropout_prob, + dtype, + qkv_layout, + bias_shape, + swa, + seq_desc_format, + ): + """ + Test backward with parameterized configs + """ + TestFusedAttn.test_backward( + b, + s_q, + s_kv, + h_q, + h_kv, + d_qk, + d_v, + attn_bias_type, + attn_mask_type, + softmax_type, + dropout_prob, + dtype, + qkv_layout, + bias_shape, + swa, + seq_desc_format, + ) diff --git a/tests/jax/test_permutation.py b/tests/jax/test_permutation.py index 5bb59c6ed..138a81724 100644 --- a/tests/jax/test_permutation.py +++ b/tests/jax/test_permutation.py @@ -23,7 +23,7 @@ (128, 5, 128, 3), (1024, 8, 128, 8), (4096, 32, 1280, 2), - (4096, 256, 4096, 6), + (4096, 64, 4096, 6), ] DISPATCH_COMBINE_CASES = { "L0": ALL_DISPATCH_COMBINE_CASES[0:2], @@ -44,7 +44,7 @@ (128, 5, 128, 3, 8), (1024, 8, 128, 8, 16), (4096, 32, 1280, 2, 128), - (4096, 256, 4096, 6, 16), + (4096, 64, 4096, 6, 16), ] DISPATCH_COMBINE_PADDING_CASES = { "L0": ALL_DISPATCH_COMBINE_PADDING_CASES[0:2], diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 5eff52c45..7e2513aab 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -76,6 +76,14 @@ f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}" ) + +# Get determinism +_deterministic = ( + not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + or torch.are_deterministic_algorithms_enabled() +) + + # Reset RNG seed and states seed = 1234 reset_rng_states() @@ -223,6 +231,7 @@ def test_dot_product_attention( if config.window_size == (-1, -1) and swa: config.window_size = [2, 2] + config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] if qkv_format == "thd" and "padding" not in config.attn_mask_type: @@ -238,8 +247,10 @@ def test_dot_product_attention( qkv_layout=qkv_layout, pad_between_seqs=pad_between_seqs, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not fused_attn_supported: is_training = False available_backends, _, fused_attn_backends = get_available_attention_backends( @@ -248,6 +259,7 @@ def test_dot_product_attention( qkv_layout=qkv_layout, pad_between_seqs=pad_between_seqs, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends @@ -529,6 +541,15 @@ def test_dpa_softmax(dtype, model_configs, model): ) +@pytest.mark.skipif(get_cudnn_version() < (9, 18, 0), reason="cuDNN 9.18.0+ is required.") +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_configs", [model_configs_softmax]) +@pytest.mark.parametrize("model", model_configs_softmax.keys()) +def test_dpa_softmax_thd(dtype, model_configs, model): + """Test DotProductAttention module with different softmax types""" + test_dot_product_attention(dtype, model_configs, model, True, True, "thd_thd_thd", False, False) + + model_configs_mla = { # test: ModelConfig(b, sq, hq, dqk) "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), @@ -792,9 +813,10 @@ def test_dpa_bias_shapes(dtype, model_configs, model): @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_swa]) @pytest.mark.parametrize("model", model_configs_swa.keys()) -def test_dpa_sliding_window(dtype, model_configs, model): +@pytest.mark.parametrize("qkv_layout", ["thd_thd_thd", "sbhd_sbhd_sbhd"]) +def test_dpa_sliding_window(dtype, model_configs, model, qkv_layout): """Test DotProductAttention module with sliding window attention""" - test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False) + test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, True, False) model_configs_alibi_slopes = { @@ -1018,11 +1040,14 @@ def _run_dot_product_attention( reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True # Create seqlens @@ -1424,6 +1449,7 @@ def test_transformer_layer( qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd") ), is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported: @@ -1437,6 +1463,7 @@ def test_transformer_layer( else qkv_format.replace("hd", "3hd") ), is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends @@ -1625,10 +1652,13 @@ def _run_transformer_layer( reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True # Create input tensor @@ -1822,6 +1852,7 @@ def test_dpa_fp8_extra_state(model, dtype): qkv_dtype=torch.float8_e4m3fn, qkv_layout="sb3hd", is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported and not flash_attn_supported: @@ -2013,6 +2044,7 @@ def test_mha_fp8_vs_f16( fp8=True, fp8_meta=fp8_meta, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends if flash_attn_supported + fused_attn_supported_fp8 < 1: @@ -2024,6 +2056,7 @@ def test_mha_fp8_vs_f16( qkv_dtype=dtype, qkv_layout=qkv_format.replace("hd", "h3d"), is_training=is_training, + deterministic=_deterministic, ) _, fused_attn_supported_f16, _ = available_backends if not fused_attn_supported_f16: @@ -2032,6 +2065,7 @@ def test_mha_fp8_vs_f16( if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16( @@ -2041,6 +2075,7 @@ def test_mha_fp8_vs_f16( if fused_attn_supported_fp8: os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( @@ -2050,6 +2085,7 @@ def test_mha_fp8_vs_f16( if fused_attn_supported_f16: os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( @@ -2263,6 +2299,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal fp8=True, fp8_meta=fp8_meta, is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if flash_attn_supported + fused_attn_supported < 1: @@ -2273,6 +2310,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal qkv_dtype=dtype, qkv_layout=qkv_layout, is_training=is_training, + deterministic=_deterministic, ) _, fused_attn_supported, _ = available_backends if not fused_attn_supported: @@ -2283,6 +2321,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)") flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( @@ -2292,6 +2331,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal if unfused_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)") unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( @@ -2300,6 +2340,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( @@ -2308,6 +2349,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if config.dropout_p == 0.0: # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") @@ -2563,13 +2605,16 @@ def test_custom_mha_fp8_vs_f16(dtype, model): qkv_dtype=torch.float8_e4m3fn, qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd", is_training=is_training, + deterministic=_deterministic, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not (fused_attn_backends and unfused_attn_supported): pytest.skip("Not enough backends to run this test with.") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention") - unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention") + unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16( + dtype, config, "UnfusedDotProductAttention" + ) atol = 5e-1 rtol = 5e-1 @@ -2602,10 +2647,13 @@ def _run_custom_mha_fp8(dtype, config, backend): reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True inp = 0.0001 * torch.randint( @@ -2656,10 +2704,13 @@ def _run_ref_mha_f16(dtype, config, backend): os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + if backend == "UnfusedDotProductAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True inp = torch.load("qkv.pt").to(device="cuda") @@ -2947,7 +2998,7 @@ def forward( cu_seqlens, max_s, ) -> torch.Tensor: - with self.prepare_forward(inp, num_gemms=3) as inp: + with self.prepare_forward_ctx(inp, num_gemms=3) as inp: out = _custom_mha_fp8.apply( inp, self.qkv_weight, diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 3b7573b9a..9d7a7e3ba 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -155,7 +155,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), # MHA "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA - "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA + "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA "cp_2_2": ModelConfig( @@ -171,7 +171,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias" ), # GQA "cp_2_4": ModelConfig( - 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) + 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512) ), # GQA "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA @@ -195,7 +195,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: - configs = ["cp_1_0", "cp_1_1", "cp_1_4", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] + configs = [ + "cp_1_0", + "cp_1_1", + "cp_1_4", + "cp_2_0", + "cp_2_2", + "cp_2_4", + "cp_3_2", + "cp_4_2", + ] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] qkv_formats = ["sbhd", "thd"] @@ -293,9 +302,14 @@ def test_cp_with_fused_attention( pytest.skip( "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!" ) - if config.softmax_type != "vanilla" and qkv_format == "thd": + if ( + get_cudnn_version() < (9, 18, 0) + and config.softmax_type != "vanilla" + and qkv_format == "thd" + ): pytest.skip( - "CP implementation does not support qkv_format=thd for non-vanilla softmax types!" + "Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for" + " non-vanilla softmax types!" ) dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} diff --git a/tests/pytorch/debug/test_sanity.py b/tests/pytorch/debug/test_sanity.py index aee5474e7..2bc4b3559 100644 --- a/tests/pytorch/debug/test_sanity.py +++ b/tests/pytorch/debug/test_sanity.py @@ -30,10 +30,17 @@ stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range] start_step : 0 end_step: 1 +""", + "log_fp8": """log_fp8: + layers: + layer_types: [linear] + enabled: + True + transformer_engine: LogFp8TensorStats: enabled: True tensors: [activation, gradient, weight] - stats: [underflows, overflows] + stats: [underflows%] start_step : 0 end_step: 1 """, @@ -46,22 +53,26 @@ FakeQuant: enabled: True gemms: [fprop, dgrad, wgrad] + tensors: [activation, weight, gradient] quant_format: FP8E5M2 """, } +# Configs that require FP8 to be enabled +fp8_required_configs = {"log_fp8"} + def _get_model(model_key): if model_key == "linear": - return te.Linear(D, D) + return te.Linear(D, D, name="layer") if model_key == "layernorm_linear": - return te.LayerNormLinear(D, D) + return te.LayerNormLinear(D, D, name="layer") if model_key == "layernorm_mlp": - return te.LayerNormMLP(D, D, D) + return te.LayerNormMLP(D, D, D, name="layer") if model_key == "mha_attention": - return te.MultiheadAttention(D, H) + return te.MultiheadAttention(D, H, name="layer") if model_key == "transformer_layer": - return te.TransformerLayer(D, D, H) + return te.TransformerLayer(D, D, H, name="layer") def _run_forward_backward(model, fp8): @@ -95,4 +106,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir): def test_sanity_debug(model_key, fp8, config_key, feature_dirs): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + if not fp8 and config_key in fp8_required_configs: + pytest.skip(f"Config '{config_key}' requires FP8") _run_test(model_key, fp8, configs[config_key], feature_dirs) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 50cd150c4..9aea3bc27 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -713,6 +713,14 @@ def test_export_layernorm_mlp_activation(seed_default_rng, activation): _test_export_layernorm_mlp(activation=activation) +# Quantization recipes with fp8_dpa=True for attention emulation export test +dpa_quantization_recipes = [None] # None = no quantization +if fp8_available: + dpa_quantization_recipes.append(recipe.DelayedScaling(fp8_dpa=True)) + dpa_quantization_recipes.append(recipe.Float8CurrentScaling(fp8_dpa=True)) + + +@pytest.mark.parametrize("fp8_recipe", dpa_quantization_recipes) @pytest.mark.parametrize( "precision, use_mask, attn_mask_type", [ @@ -730,6 +738,7 @@ def test_export_core_attention( precision: torch.dtype, use_mask: bool, attn_mask_type: str, + fp8_recipe: recipe.Recipe, ): # Set dimensions (these are arbitrary). seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64) @@ -749,22 +758,25 @@ def test_export_core_attention( mask_str = get_attn_mask_str(use_mask, attn_mask_type) high_prec_str = dtype2str(precision) - fname = f"te.core_attention{mask_str}{high_prec_str}.onnx" + fp8_str = "_fp8_dpa" if fp8_recipe is not None else "" + fname = f"te.core_attention{fp8_str}{mask_str}{high_prec_str}.onnx" + + is_fp8 = fp8_recipe is not None model = te.attention.DotProductAttention( num_attention_heads=num_attention_heads, kv_channels=kv_channels, - attention_dropout=0.5, qkv_format=qkv_format, attn_mask_type=attn_mask_type, ).to(device="cuda") - do_export(model, inp, fname, input_names=input_names, fp8_recipe=None) - te_outputs = te_infer(model, inp, is_fp8=False, fp8_recipe=None) + do_export(model, inp, fname, input_names=input_names, fp8_recipe=fp8_recipe) + te_outputs = te_infer(model, inp, is_fp8=is_fp8, fp8_recipe=fp8_recipe) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) if precision in (torch.bfloat16,): return + atol = 5e-1 if is_fp8 else 1e-2 validate_result( - fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs + fname, inp, model, is_fp8=True, atol=atol, input_names=input_names, te_outputs=te_outputs ) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index c0398b801..70738440e 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -395,11 +395,11 @@ def test(): backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} if AttentionLogging._is_logging_setup is False: AttentionLogging.setup_logging() - with logging_context(highest_level=AttentionLogging._log_level): - for i in range(3): - os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) - _attention_backends["backend_selection_requires_update"] = True - available_backends, flash_attention_backend, fused_attention_backend = test() - if fused_attention_backend == FusedAttnBackend[backends[i]]: - fused_attn_backends.append(fused_attention_backend) + + for i in range(3): + os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) + _attention_backends["backend_selection_requires_update"] = True + available_backends, flash_attention_backend, fused_attention_backend = test() + if fused_attention_backend == FusedAttnBackend[backends[i]]: + fused_attn_backends.append(fused_attention_backend) return available_backends, flash_attention_backend, fused_attn_backends diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index 3f82798f9..f26abeb90 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -39,6 +39,14 @@ inline bool dimensions_supported_by_TMA(const Tensor *const t) { return cols % alignment_requirement == 0; } +#ifndef __HIP_PLATFORM_AMD__ +__device__ __forceinline__ unsigned char *align_smem_ptr_per_TMA_requirements(unsigned char *p) { + size_t addr = reinterpret_cast(p); + addr = (addr + TMA_SHMEM_ALIGNMENT - 1) & ~(TMA_SHMEM_ALIGNMENT - 1); + return reinterpret_cast(addr); +} +#endif //#ifndef __HIP_PLATFORM_AMD__ + namespace kernel { constexpr size_t THREADS_PER_BLOCK = 256; diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 5da9cc5a5..61c6ba9ce 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -21,6 +21,7 @@ #include "../../util/ptx.cuh" #include "../../utils.cuh" #include "core_nvfp4.cuh" +#include "specialized/quantize_transpose_nvfp4_tuned_1D.cuh" namespace transformer_engine { namespace dispatch { @@ -1159,6 +1160,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, #if FP4_TYPE_SUPPORTED using namespace quantize_transpose_kernel; using namespace ptx; + bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to @@ -1166,6 +1168,11 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, // TODO(Frank): Is there a better way to do this? bool return_transpose = output->has_columnwise_data(); + // if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { + // quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); + // return; + // } + constexpr bool COMPUTE_ACTIVATIONS = false; using ParamOP = Empty; constexpr float (*OP)(float, const ParamOP &) = nullptr; diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh new file mode 100644 index 000000000..411900168 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -0,0 +1,789 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_transpose_nvfp4_tuned_1D.cuh + * \brief Tuned kernel to cast to NVFP4 and transpose. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_ + +#include +#include +#include +#include + +#include "../../../common.h" +#include "../../../util/math.h" +#include "../../../util/ptx.cuh" +#include "../../../utils.cuh" +#include "../core_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +namespace quantize_transpose_tuned_kernel { + +using namespace quantization_and_transposition_SF; +using namespace core; +using namespace ptx; + +#if FP4_TYPE_SUPPORTED + +struct TunableConfig { + static constexpr int CHUNK_DIM_Y = 128; + static constexpr int CHUNK_DIM_X = 128; + static constexpr int PREFETCH_STAGES = 1; + static constexpr bool PERSISTENT = false; +}; + +constexpr int SCALE_DIM = 16; // NVFP4 block (x16 elts) +constexpr int THREADS_NUM = 128; +constexpr int ELTS_PER_THREAD = 16; +constexpr int TILE_DIM_Y = 64; +constexpr int TILE_DIM_X = 64; + +static_assert(ELTS_PER_THREAD == SCALE_DIM && "Hardcoded and fixed parameter\0"); + +static_assert((THREADS_NUM * ELTS_PER_THREAD <= TILE_DIM_Y * TILE_DIM_X) && + "Unbalanced threads workload\0"); + +static_assert((TunableConfig::CHUNK_DIM_Y % TILE_DIM_Y == 0) && + "Chunk size Y must be evenly divisible by the tile size Y\0"); +static_assert((TunableConfig::CHUNK_DIM_X % TILE_DIM_X == 0) && + "Chunk size X must be evenly divisible by the tile size X\0"); + +static_assert((TILE_DIM_Y % SCALE_DIM == 0) && + "Tile size Y must be evenly divisible by the scale dim\0"); +static_assert((TILE_DIM_X % SCALE_DIM == 0) && + "Tile size X must be evenly divisible by the scale dim\0"); + +constexpr int TILES_Y = TunableConfig::CHUNK_DIM_Y / TILE_DIM_Y; +constexpr int TILES_X = TunableConfig::CHUNK_DIM_X / TILE_DIM_X; + +constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD; + +constexpr int SCALES_PER_CHUNK_Y = TunableConfig::CHUNK_DIM_Y / SCALE_DIM; +constexpr int SCALES_PER_CHUNK_X = TunableConfig::CHUNK_DIM_X / SCALE_DIM; + +constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; +constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; + +constexpr int STAGES_Y = TILES_Y; +constexpr int STAGES_X = TILES_X; +constexpr int STAGES = STAGES_Y * STAGES_X; + +constexpr int BUFFS_NUM = TunableConfig::PREFETCH_STAGES + 1; +constexpr int BUFFS_NUM_IN = BUFFS_NUM; +constexpr int BUFFS_NUM_OUT = BUFFS_NUM; +constexpr int BUFFS_NUM_OUT_TR = 2; +constexpr int BUFF_DIM_Y = TILE_DIM_Y; +constexpr int BUFF_DIM_X = TILE_DIM_X; +constexpr int BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X; +constexpr int BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM; + +// Input buffer (BF16) +constexpr int BUFF_IN_DIM_Y = BUFF_DIM_Y; +constexpr int BUFF_IN_DIM_X = BUFF_DIM_X; +constexpr int BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; +constexpr int BUFF_IN_ELTS_NUM = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; + +// Output buffer (NVFP4) +constexpr int BUFF_OUT_DIM_Y = BUFF_DIM_Y; +constexpr int BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8; +constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; + +// Output transpose buffer (NVFP4) +constexpr int BUFF_OUT_TR_DIM_Y = BUFF_DIM_X; +constexpr int BUFF_OUT_TR_DIM_X = (BUFF_DIM_Y * 4) / 8; +constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X; + +// Manual swizzling parameters to reduce SHMEM bank conflicts +constexpr int PACK_SIZE = 8; +constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE; + +constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; +constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; + +constexpr int THREADS_X_TR = TILE_DIM_X / 2; +constexpr int THREADS_Y_TR = THREADS_NUM / THREADS_X_TR; + +constexpr int ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; +constexpr int ITERATIONS_TR = SCALES_PER_TILE_Y / THREADS_Y_TR; +static_assert(ITERATIONS_TR >= 1 && "Number of transpose iterations should be >=1\0"); +static_assert((SCALES_PER_TILE_Y % THREADS_Y_TR == 0) && + "Partial transpose iterations are not supported\0"); + +constexpr int BUFF_OUT_IT_OFFSET = BUFF_OUT_TR_DIM_X / ITERATIONS_TR / STAGES; + +static_assert(BUFF_DIM_Y >= SCALE_DIM && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); +static_assert(TunableConfig::CHUNK_DIM_Y >= BUFF_DIM_Y); +static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD; + +using IType = bf16; +using IType2 = typename ptx::FPx2; +using IType3D = IType[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X]; +using IType2x3D = IType2[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2]; +using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X]; +using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_TR_DIM_Y][BUFF_OUT_TR_DIM_X]; +using ScalesType2D = nvfp4_scale_t[TunableConfig::CHUNK_DIM_Y][SCALES_PER_CHUNK_X]; +using ScalesTypeTr2D = nvfp4_scale_t[TunableConfig::CHUNK_DIM_X][SCALES_PER_CHUNK_Y]; +using RNG_t = typename transformer_engine::curanddx::detail::philox4x32_native_state<10>; + +template +struct SCALING_COEFFICIENT_TYPE {}; +template <> +struct SCALING_COEFFICIENT_TYPE { + using type = float; +}; +template <> +struct SCALING_COEFFICIENT_TYPE { + using type = bf16; +}; + +__device__ __forceinline__ float get_amax_of_pair(const IType2 pair) { + return static_cast(__hmax(__habs(pair.x), __habs(pair.y))); +} + +// Compute "correct" per-block encoding scaling factor +template +__device__ __forceinline__ SF_TYPE +compute_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, const float S_enc) { + constexpr float float_max = detail::TypeExtrema::max; + const float scale_rcp = fminf(S_enc / static_cast(S_dec_block), float_max); + return static_cast(scale_rcp); +} + +template +__device__ __forceinline__ void colwise_scaling(const IType *__restrict__ sIn_ptr, + fp4e2m1x2 *__restrict__ sOut_tr_ptr, + nvfp4_scale_t *__restrict__ sSFcolwise_ptr, + const float S_enc_colwise, const int stage_Y, + const int stage_X, const int buff_in, + const int buff_out_tr, RNG_t &rng, + uint4 &random_uint4, int &rnd_idx) { + using scaling_coeff_type = typename SCALING_COEFFICIENT_TYPE::type; + + const auto &sIn2x = *reinterpret_cast(sIn_ptr); + auto &sOut_tr = *reinterpret_cast(sOut_tr_ptr); + auto &sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); + + const int warp = threadIdx.x / THREADS_PER_WARP; + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + + const int tid_Y_colwise = (thread_lane % 4 + warp) % 4; + const int tid_X_colwise = thread_lane; + + const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM; + const int thread_offset_X_colwise = tid_X_colwise * 2; + + const int in_thread_offset_Y = thread_offset_Y_colwise; + const int in_thread_offset_X = thread_offset_X_colwise / 2; + + const int out_tr_thread_offset_Y = thread_offset_X_colwise; + const int out_tr_thread_offset_X = thread_offset_Y_colwise / 2; + + const int scale_tr_offset_Y = (stage_X * TILE_DIM_X) + 2 * tid_X_colwise; + const int scale_tr_offset_X = (stage_Y * SCALES_PER_TILE_Y) + tid_Y_colwise; + + __align__(8) IType rIn[2][SCALE_DIM]; + // Read (cache) a pair of input elements (S2R). Find NVFP4-block AMAX + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const IType2 elt_pair = + ptx::ld_shared_b32(&sIn2x[buff_in][in_thread_offset_Y + i][in_thread_offset_X]); + rIn[0][i] = elt_pair.x; + rIn[1][i] = elt_pair.y; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, elt_pair); + } + const float block_amax[2] = {static_cast(__habs(thread_amax_2x.x)), + static_cast(__habs(thread_amax_2x.y))}; +#pragma unroll + for (int w = 0; w < 2; ++w) { + const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax[w], S_enc_colwise); + + // Store scaling factors to SMEM buffer (R2S) + sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = S_dec_b_fp8; + + const scaling_coeff_type SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_colwise); + + // Scale elements + __align__(8) uint32_t rOut[SCALE_DIM / 8]; +#pragma unroll + for (int e = 0; e < SCALE_DIM / 8; ++e) { + const uint64_t elts03 = *reinterpret_cast(&rIn[w][8 * e]); + const uint64_t elts47 = *reinterpret_cast(&rIn[w][8 * e + 4]); + if constexpr (USE_STOCHASTIC_ROUNDING) { + const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); + const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); + rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding( + elts03, elts47, SFcoefficient, rbits03, rbits47); + } else { + rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, + SFcoefficient); + } + } + uint64_t &out_pack_16x = *reinterpret_cast(rOut); + ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X], + out_pack_16x); + } +} + +template +__device__ __forceinline__ void rowwise_scaling(const IType *__restrict__ sIn_ptr, + fp4e2m1x2 *__restrict__ sOut_ptr, + nvfp4_scale_t *__restrict__ sSFrowwise_ptr, + const float S_enc_rowwise, const int stage_Y, + const int stage_X, const int buff_in, + const int buff_out, RNG_t &rng, uint4 &random_uint4, + int &rnd_idx) { + using scaling_coeff_type = typename SCALING_COEFFICIENT_TYPE::type; + + const auto &sIn = *reinterpret_cast(sIn_ptr); + auto &sOut = *reinterpret_cast(sOut_ptr); + auto &sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); + + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * ELTS_PER_THREAD; + + const int SF_thread_offset_rowwise_Y = tid_Y_rowwise; + const int SF_thread_offset_rowwise_X = tid_X_rowwise / THREADS_PER_SCALE_ROWWISE; + + const bool SF_storing_thread = (tid_X_rowwise % THREADS_PER_SCALE_ROWWISE == 0); + + const int stage_rowwise_scales_offset_Y = SF_thread_offset_rowwise_Y + stage_Y * TILE_DIM_Y; + const int stage_rowwise_scales_offset_X = + SF_thread_offset_rowwise_X + stage_X * SCALES_PER_TILE_X; +#pragma unroll + for (int it = 0; it < ITERATIONS_NORMAL; ++it) { + const int it_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + __align__(16) IType2 rIn[WAVES][PACK_SIZE / 2]; + + // Read (cache) input elements (S2R). Find NVFP4-block AMAX + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + + // Load elements + __uint128_t &elts_8x = *reinterpret_cast<__uint128_t *>(&rIn[w]); + elts_8x = ptx::ld_shared_b128(&sIn[buff_in][it_offset_Y_rowwise][swizzled_thread_idx]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, rIn[w][e]); + } + } + const float block_amax = get_amax_of_pair(thread_amax_2x); + + const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + const scaling_coeff_type SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise); + + // Store scaling factors to SMEM buffer (R2S) + if (SF_storing_thread) { + const int scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = stage_rowwise_scales_offset_X; + sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8; + } + +// Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const uint64_t elts03 = *reinterpret_cast(&rIn[w][0]); + const uint64_t elts47 = *reinterpret_cast(&rIn[w][2]); + + uint32_t out_x8; + if constexpr (USE_STOCHASTIC_ROUNDING) { + const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx); + const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx); + out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding( + elts03, elts47, SFcoefficient, rbits03, rbits47); + } else { + out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest(elts03, elts47, + SFcoefficient); + } + + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD; + const int swizzled_idx = (swizzled_group_idx + thread_offset_X_rowwise) / 2; + ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], out_x8); + } + } +} + +template +__global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D_kernel( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, nvfp4_scale_t *const scales_ptr, + nvfp4_scale_t *const scales_t_ptr, const float *noop, const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, const size_t cols, + const size_t scale_stride, const size_t scale_stride_t, const size_t *rng_state) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + const size_t rng_sequence = + threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + RNG_t rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; + // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + int rnd_idx = 0; + + const bool leading_thread = (threadIdx.x == 0); + + constexpr int buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr int buff_elems_total_in = BUFFS_NUM_IN * buff_elems; + + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); + + constexpr int in_mem = buff_size_aligned_in; + + constexpr int out_mem_rowwise_data = buff_size_aligned_out; + constexpr int out_mem_colwise_data = RETURN_TRANSPOSE ? buff_size_aligned_out_t : 0; + constexpr int out_mem_rowwise_scales = DIVUP_TO_MULTIPLE( + TunableConfig::CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char *dshmem = common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + + IType *sIn_ptr = reinterpret_cast(dshmem); + fp4e2m1x2 *sOut_ptr = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *sOut_tr_ptr = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + auto &sIn = *reinterpret_cast(sIn_ptr); + auto &sOut = *reinterpret_cast(sOut_ptr); + auto &sOut_tr = *reinterpret_cast(sOut_tr_ptr); + + nvfp4_scale_t *sSFrowwise_ptr = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *sSFcolwise_ptr = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + + auto &sSFrowwise = *reinterpret_cast(sSFrowwise_ptr); + auto &sSFcolwise = *reinterpret_cast(sSFcolwise_ptr); + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + // Compute a global encoding/decoding scaling factors for all S_dec_b + const float S_enc_rowwise = + (amax_rowwise_ptr == nullptr) + ? 1.0f + : core::compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + + const float S_enc_colwise = + (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : core::compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + + __shared__ uint64_t workID_mbar; + __shared__ __uint128_t workID_response; + constexpr uint32_t workID_response_size = sizeof(workID_response); + static_assert(workID_response_size == 16); + + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + + // Coordinates of the first chunk (CTA) to process + int32_t ctaid_X = blockIdx.x; + int32_t ctaid_Y = blockIdx.y; + + // Initialize shared memory barriers with the number of threads participating in them + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); + } + ptx::mbarrier_init(&workID_mbar, 1); + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + bool job_finished = false; + int buff_in = 0; + int buff_out = 0; + int buff_out_tr = 0; + int IN_buff_readable_parity[BUFFS_NUM] = {0, 0}; + int ctaid_parity = 0; + +// Prefetch input data only when processing the first chunk, +// which enables the one-iteration overlap throughout the entire kernel life +#pragma unroll + for (int stage = 0; stage < TunableConfig::PREFETCH_STAGES; ++stage) { + const int buff_in = stage; + const int stage_Y = stage / STAGES_X; + const int stage_X = stage % STAGES_X; + + const int stage_offset_Y = stage_Y * TILE_DIM_Y; + const int stage_offset_X = stage_X * TILE_DIM_X; + + const int block_offset_Y = ctaid_Y * TunableConfig::CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * TunableConfig::CHUNK_DIM_X; + + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X + stage_offset_X; + + uint64_t *barrier = &IN_buff_readable_mbar[buff_in]; + if (leading_thread) { + uint64_t *dst = reinterpret_cast(&sIn[buff_in]); + const uint64_t *src = reinterpret_cast(&tensor_map_input); + + // Arrive on the barrier and tell how many bytes are expected to come in + ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); + + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y, + barrier); + } + } + + while (!job_finished) { + const int block_offset_Y = ctaid_Y * TunableConfig::CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * TunableConfig::CHUNK_DIM_X; + + const int block_offset_Y_tr = ctaid_X * TunableConfig::CHUNK_DIM_X; + const int block_offset_X_tr = ctaid_Y * TunableConfig::CHUNK_DIM_Y; + + const int chunk_rows = rows - block_offset_Y; + const int chunk_cols = cols - block_offset_X; + + const int scales_block_offset_Y_rowwise = ctaid_Y * TunableConfig::CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = ctaid_X * SCALES_PER_CHUNK_X; + const int scales_block_offset_Y_tr = ctaid_X * TunableConfig::CHUNK_DIM_X; + const int scales_block_offset_X_tr = ctaid_Y * SCALES_PER_CHUNK_Y; + + if constexpr (TunableConfig::PERSISTENT) { + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(&workID_mbar, workID_response_size); + ptx::try_cancel_cta(&workID_mbar, &workID_response); + } + } + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int stage_Y = stage / STAGES_X; + const int stage_X = stage % STAGES_X; + + const int stage_offset_Y = stage_Y * TILE_DIM_Y; + const int stage_offset_X = stage_X * TILE_DIM_X; + + if (stage == STAGES - TunableConfig::PREFETCH_STAGES) { + if constexpr (TunableConfig::PERSISTENT) { + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&workID_mbar, ctaid_parity); + ptx::get_cancelled_cta_id_2D(&workID_response, ctaid_X, ctaid_Y); + ctaid_parity ^= 1; + } else { + ctaid_X = -1; + ctaid_Y = -1; + } + if (ctaid_X == -1 && ctaid_Y == -1) { + job_finished = true; + } + } + + // Prefetch next stage Input data + if (!job_finished || (stage < STAGES - TunableConfig::PREFETCH_STAGES)) { + const int next_prefetch_buff = (buff_in + TunableConfig::PREFETCH_STAGES) % BUFFS_NUM; + const int next_prefetch_stage = (stage + TunableConfig::PREFETCH_STAGES) % STAGES; + const int next_prefetch_stage_Y = next_prefetch_stage / STAGES_X; + const int next_prefetch_stage_X = next_prefetch_stage % STAGES_X; + + const int next_prefetch_stage_offset_Y = next_prefetch_stage_Y * TILE_DIM_Y; + const int next_prefetch_stage_offset_X = next_prefetch_stage_X * TILE_DIM_X; + + // Offsets change, because coordinates of the next "to-be-prefetched" CTA do also chage + const int block_offset_Y = ctaid_Y * TunableConfig::CHUNK_DIM_Y; + const int block_offset_X = ctaid_X * TunableConfig::CHUNK_DIM_X; + + const int global_offset_Y = block_offset_Y + next_prefetch_stage_offset_Y; + const int global_offset_X = block_offset_X + next_prefetch_stage_offset_X; + + uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; + if (leading_thread) { + uint64_t *dst = reinterpret_cast(&sIn[next_prefetch_buff]); + const uint64_t *src = reinterpret_cast(&tensor_map_input); + + // Arrive on the barrier and tell how many bytes are expected to come in + ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); + + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y, + barrier); + } + ptx::fence_proxy_async_shared_cta(); + } + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + + // Wait for TMA transfer to have finished reading shared memory + // I.e. the OUT buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read(); + + // NVFP4 Quantization + rowwise_scaling( + sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, stage_Y, stage_X, buff_in, buff_out, + rng, random_uint4, rnd_idx); + + if constexpr (RETURN_TRANSPOSE) { + colwise_scaling( + sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, S_enc_colwise, stage_Y, stage_X, buff_in, + buff_out_tr, rng, random_uint4, rnd_idx); + } + + // Wait for shared memory writes to be visible to TMA engine + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine + + // Initiate TMA transfer to copy shared memory to global memory + if (leading_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X + stage_offset_X; + const int global_offset_Y_tr = block_offset_Y_tr + stage_offset_X; + const int global_offset_X_tr = block_offset_X_tr + stage_offset_Y; + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, + global_offset_Y, reinterpret_cast(&sOut[buff_out])); + + if constexpr (RETURN_TRANSPOSE) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), global_offset_X_tr, + global_offset_Y_tr, reinterpret_cast(&sOut_tr[buff_out_tr])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation + ptx::cp_async_bulk_commit_group(); + } + + buff_in = (buff_in + 1) % BUFFS_NUM_IN; + buff_out = (buff_out + 1) % BUFFS_NUM_OUT; + buff_out_tr = (buff_out_tr + 1) % BUFFS_NUM_OUT_TR; + } // end of stages + + // Vectorized store of scaling factors (S2G) + { + // Rowwise + { + using ScalesVec = Vec; + // number of scales in X dimension of this chunk + const int count = min(SCALES_PER_CHUNK_X, chunk_cols / SCALE_DIM); + + for (size_t row = threadIdx.x; row < TunableConfig::CHUNK_DIM_Y; row += THREADS_NUM) { + const size_t row_global = scales_block_offset_Y_rowwise + row; + if (row_global < rows) { + ScalesVec &scales_vec = *reinterpret_cast(sSFrowwise[row]); + const size_t scale_idx_global = + row_global * scale_stride + scales_block_offset_X_rowwise; + scales_vec.store_to_elts(&scales_ptr[scale_idx_global], 0, count); + } + } + } + + // Colwise + if constexpr (RETURN_TRANSPOSE) { + using ScalesVec = Vec; + // number of scales in Y dimension of this chunk + const int count = min(SCALES_PER_CHUNK_Y, chunk_rows / SCALE_DIM); + + for (size_t row_tr = threadIdx.x; row_tr < TunableConfig::CHUNK_DIM_X; + row_tr += THREADS_NUM) { + const size_t row_tr_global = scales_block_offset_Y_tr + row_tr; + if (row_tr_global < cols) { + ScalesVec &scales_vec = *reinterpret_cast(sSFcolwise[row_tr]); + const size_t scale_idx_global = + row_tr_global * scale_stride_t + scales_block_offset_X_tr; + scales_vec.store_to_elts(&scales_t_ptr[scale_idx_global], 0, count); + } + } + } + + if (!job_finished) { + // Ensures all reads from SFs buffer have completed and it's ready to be reused + __syncthreads(); + } + } + } + + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); + } + ptx::mbarrier_invalid(&workID_mbar); + } +#else + NVTE_DEVICE_ERROR("sm_100 or higher is required."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +#endif // FP4_TYPE_SUPPORTED +} // namespace quantize_transpose_tuned_kernel + +inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, Tensor *output, + const QuantizationConfig *quant_config, + cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace quantize_transpose_tuned_kernel; + using namespace ptx; + + const bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; + const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; + + // If transposed output is allocated, return the transposed data + // Otherwise, it's not necesary to return the transposed data. + const bool return_transpose = output->has_columnwise_data(); + + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", false); + + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + if (return_transpose) { + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Transposed output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Transposed scaling tensor must be allocated"); + } + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + NVTE_CHECK(rows % 32 == 0, + "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA + NVTE_CHECK(cols % 32 == 0, + "Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA + + const int blocks_Y = DIVUP(rows, static_cast(TunableConfig::CHUNK_DIM_Y)); + const int blocks_X = DIVUP(cols, static_cast(TunableConfig::CHUNK_DIM_X)); + const dim3 grid(blocks_X, blocks_Y); + const int block_size = THREADS_NUM; + + const size_t scale_stride = output->scale_inv.shape[1]; + const size_t scale_stride_transpose = + return_transpose ? output->columnwise_scale_inv.shape[1] : 0; + + nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); + nvfp4_scale_t *const scales_transpose_ptr = + reinterpret_cast(output->columnwise_scale_inv.dptr); + + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); + const float *const amax_colwise_ptr = + reinterpret_cast(output->columnwise_amax.dptr); + + const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr; + const size_t *rng_state = nullptr; + if (rng_state_tensor != nullptr) { + Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor); + NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); + } + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + alignas(64) CUtensorMap tensor_map_output_transpose{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + 4); + if (return_transpose) { + create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows, + BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4); + } + + constexpr int buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr int buff_elems_total_in = BUFFS_NUM_IN * buff_elems; + constexpr int buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_aligned_out_t = + DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT); + + constexpr int buff_size_scales = DIVUP_TO_MULTIPLE( + TunableConfig::CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + constexpr int buff_size_scales_transpose = DIVUP_TO_MULTIPLE( + TunableConfig::CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT); + + const int in_mem = buff_size_aligned_in; + + const int out_data_mem = buff_size_aligned_out; + const int out_data_transpose_mem = return_transpose ? buff_size_aligned_out_t : 0; + const int out_scales_mem = buff_size_scales; + const int out_scales_transpose_mem = return_transpose ? buff_size_scales_transpose : 0; + + const int out_mem = out_data_mem + out_data_transpose_mem; + + const int dshmem_size = + in_mem + out_mem + out_scales_transpose_mem + out_scales_mem + TMA_SHMEM_ALIGNMENT; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_fast_math, USE_FAST_MATH, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + }););); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // FP4_TYPE_SUPPORTED +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_ diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index fde0d3892..4f8367aac 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -206,7 +206,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph) { + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -406,9 +406,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (window_size_right == -1 || window_size_right == 0)) || // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} (cudnn_runtime_version >= 90200 && - ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + ((window_size_left == -1 && window_size_right == -1 && + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) || + ((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q == max_seqlen_kv)) && max_seqlen_q <= max_seqlen_kv && dropout == 0.0 && @@ -418,12 +420,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} (cudnn_runtime_version >= 90600 && ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && + ((window_size_left >= 0 || window_size_left == -1) && + (window_size_right >= 0 || window_size_right == -1) && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && // TODO(cyang): fix bug for BRCM + cross-attention on sm100 (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && cudnn_runtime_version <= 90700) || cudnn_runtime_version > 90700)))) || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && @@ -440,7 +444,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.13.1+: vanilla, off-by-one, learnable (cudnn_runtime_version >= 91301 || (cudnn_runtime_version < 91301 && - softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX))) { + softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) && + // determinism on Blackwell + // pre-9.18.1: fwd: deterministic; bwd: non-deterministic + // 9.18.1+: fwd: deterministic; bwd: non-deterministic/deterministic + (sm_arch_ < 100 || + (sm_arch_ >= 100 && (!is_training || + (is_training && !deterministic && + (dropout == 0.0 || bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)) || + (is_training && deterministic && cudnn_runtime_version >= 91801 && + dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { @@ -506,16 +519,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // NVTE fused attention FWD with packed QKV // DEPRECATED: This API is deprecated. // Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, - size_t max_seqlen, bool is_training, bool return_max_logit, - bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd_qkvpacked( + const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, + bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); using namespace transformer_engine; @@ -553,7 +564,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit, - cuda_graph); + cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -589,13 +600,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, fused_attn_arbitrary_seqlen_fwd( b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, - input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, - wkspace, stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view, + input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens, input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, + input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( - "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); + "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " + "\n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) @@ -629,8 +641,8 @@ void nvte_fused_attn_bwd_qkvpacked( NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; @@ -669,7 +681,8 @@ void nvte_fused_attn_bwd_qkvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph); + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph, + deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -725,10 +738,11 @@ void nvte_fused_attn_bwd_qkvpacked( fused_attn_arbitrary_seqlen_bwd( b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, &Q_view, - &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, &dQ_view, - &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, - input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); + attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, &Q_view, &K_view, &V_view, input_O, input_dO, input_Bias, + input_SoftmaxOffset, output_S, &dQ_view, &dK_view, &dV_view, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, input_cu_seqlens_padded, + input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " @@ -779,7 +793,8 @@ void nvte_fused_attn_fwd_kvpacked( size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { + int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -855,7 +870,7 @@ void nvte_fused_attn_fwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -891,13 +906,14 @@ void nvte_fused_attn_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, - input_page_table_v, input_rng_state, wkspace, stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view, + input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, + input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( - "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); + "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. " + "\n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) @@ -933,8 +949,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, - cudaStream_t stream) { + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -982,10 +998,10 @@ void nvte_fused_attn_bwd_kvpacked( const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_KV->data.dtype); - NVTE_Fused_Attn_Backend fused_attention_backend = - nvte_get_fused_attn_backend(true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, - softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - d, window_size_left, window_size_right, false, cuda_graph); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false, + cuda_graph, deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -1040,11 +1056,11 @@ void nvte_fused_attn_bwd_kvpacked( fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, - input_Q, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, - output_dQ, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, - wkspace, stream, handle); + bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, input_Q, &K_view, &V_view, input_O, input_dO, + input_Bias, input_SoftmaxOffset, output_S, output_dQ, &dK_view, &dV_view, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.3 is required for BF16/FP16 fused attention " @@ -1094,8 +1110,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -1166,7 +1182,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -1183,13 +1199,14 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, - input_page_table_v, input_rng_state, wkspace, stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V, + input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, + input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( - "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); + "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " + "\n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) @@ -1215,8 +1232,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, - bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -1262,7 +1280,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, - cuda_graph); + cuda_graph, deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -1289,8 +1307,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, - deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias, - input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, + bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO, + input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index d3746fc04..53023361e 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -55,10 +55,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, - void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, - void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, - void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ, + void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, + void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -75,6 +75,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; + bottom_right_diagonal = false; } bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (is_training && dropout_probability != 0.0f); @@ -129,6 +130,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, true, tensorType, cudnn_frontend::DataType_t::NOT_SET, @@ -254,9 +256,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); + fe::DiagonalAlignment_t const &diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_options.set_diagonal_alignment(diagonal_alignment); if (cudnn_runtime_version >= 90200 && window_size_left != -1) { sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_options.set_diagonal_band_right_bound(window_size_right); + } sdpa_options.set_alibi_mask(is_alibi); @@ -542,13 +551,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, - void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, - void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, - void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, - void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, - void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, - void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose, + void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset, + void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, + void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset, + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, + void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, + size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -563,6 +573,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; + bottom_right_diagonal = false; } bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (dropout_probability != 0.0f); @@ -621,6 +632,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, tensorType, cudnn_frontend::DataType_t::NOT_SET, @@ -781,9 +793,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( sdpa_backward_options.set_max_total_seq_len_kv(s_kv); } + fe::DiagonalAlignment_t const &diagonal_alignment = + bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT + : fe::DiagonalAlignment_t::TOP_LEFT; + sdpa_backward_options.set_diagonal_alignment(diagonal_alignment); + if (cudnn_runtime_version >= 90200 && window_size_left != -1) { sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); } + if (cudnn_runtime_version >= 90600 && window_size_right != -1) { + sdpa_backward_options.set_diagonal_band_right_bound(window_size_right); + } if (cudnn_runtime_version >= 90000) { sdpa_backward_options.set_deterministic_algorithm(deterministic); @@ -1044,8 +1064,8 @@ void fused_attn_arbitrary_seqlen_fwd( size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, @@ -1180,11 +1200,11 @@ void fused_attn_arbitrary_seqlen_fwd( max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, - devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, + devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1206,13 +1226,14 @@ void fused_attn_arbitrary_seqlen_bwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, + Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; void *devPtrQ = input_Q->data.dptr; @@ -1273,8 +1294,8 @@ void fused_attn_arbitrary_seqlen_bwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, - devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, + bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, + devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index c34eae4e6..4dd7f3d1d 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -25,8 +25,8 @@ void fused_attn_arbitrary_seqlen_fwd( size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, @@ -37,13 +37,14 @@ void fused_attn_arbitrary_seqlen_bwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, + Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); #endif // CUDNN_VERSION >= 8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 3630041cc..f886ec77f 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1707,6 +1707,7 @@ void fused_attn_fp8_fwd_impl_v1( 0, 0, true, + true, qkv_tensor_type, o_tensor_type, cudnn_frontend::DataType_t::NOT_SET, @@ -2035,6 +2036,7 @@ void fused_attn_fp8_bwd_impl_v1( NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, 0, 0, + true, false, qkv_tensor_type, o_tensor_type, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 7d23bb5c5..fdfc4abe8 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -110,6 +110,7 @@ struct FADescriptor_v1 { NVTE_Softmax_Type softmax_type; std::int64_t window_size_left; std::int64_t window_size_right; + bool bottom_right_diagonal; bool deterministic; cudnn_frontend::DataType_t qkv_tensor_type; cudnn_frontend::DataType_t o_tensor_type; @@ -121,15 +122,16 @@ struct FADescriptor_v1 { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, - window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, - o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) < + window_size_left, window_size_right, bottom_right_diagonal, deterministic, + bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type, + generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, - rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, - rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, - rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); + rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal, + rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, + rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); } }; diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index e787b31c8..2564f059d 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -228,6 +228,29 @@ const std::unordered_map mNVTEMaskTypeStr = { {NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK, "PADDING_CAUSAL_BOTTOM_RIGHT_MASK"}, }; +// True for the two *_BOTTOM_RIGHT_MASK variants, false otherwise. +inline bool implied_bottom_right_diagonal(NVTE_Mask_Type attn_mask_type) { + return attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK; +} + +// The ROCm/AITER fused-attn backend derives mask anchoring solely from the +// NVTE_Mask_Type enum and does not consume `bottom_right_diagonal`. Any +// divergence between the boolean and the alignment implied by the mask type +// would silently produce numerically incorrect attention, so we reject it +// here until AITER plumbs an explicit alignment parameter. +inline void check_bottom_right_diagonal(NVTE_Mask_Type attn_mask_type, + bool bottom_right_diagonal) { + if (bottom_right_diagonal != implied_bottom_right_diagonal(attn_mask_type)) { + NVTE_ERROR( + "ROCm fused attention does not support a `bottom_right_diagonal` value " + "that diverges from the alignment implied by `attn_mask_type`. Use " + "NVTE_CAUSAL_BOTTOM_RIGHT_MASK or NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK " + "for bottom-right alignment, or the corresponding non-bottom-right " + "mask types for top-left alignment."); + } +} + void log_fused_attn_config( const char* func_name, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t batch_size, @@ -279,7 +302,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph) { + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { using namespace transformer_engine; // TODO: Add return_max_logit support @@ -345,18 +368,20 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // NVTE fused attention FWD with packed QKV -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, bool is_training, bool return_max_logit, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - bool cuda_graph, NVTE_Bias_Type bias_type, + bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); using namespace transformer_engine; + check_bottom_right_diagonal(attn_mask_type, bottom_right_diagonal); const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); @@ -394,7 +419,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit, - cuda_graph); + cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd_qkvpacked( @@ -431,10 +456,11 @@ void nvte_fused_attn_bwd_qkvpacked( NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; + check_bottom_right_diagonal(attn_mask_type, bottom_right_diagonal); const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens); const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded); @@ -477,7 +503,8 @@ void nvte_fused_attn_bwd_qkvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph); + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph, + deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)){ @@ -523,10 +550,12 @@ void nvte_fused_attn_fwd_kvpacked( size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { + int64_t window_size_right, bool bottom_right_diagonal, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; + check_bottom_right_diagonal(attn_mask_type, bottom_right_diagonal); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); @@ -570,7 +599,7 @@ void nvte_fused_attn_fwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd_kvpacked( @@ -614,10 +643,11 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, - cudaStream_t stream) { + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; + check_bottom_right_diagonal(attn_mask_type, bottom_right_diagonal); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); @@ -666,7 +696,8 @@ void nvte_fused_attn_bwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, - d, window_size_left, window_size_right, false, cuda_graph); + d, window_size_left, window_size_right, false, cuda_graph, + deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -720,10 +751,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; + check_bottom_right_diagonal(attn_mask_type, bottom_right_diagonal); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); @@ -760,7 +793,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd( @@ -806,10 +839,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; + check_bottom_right_diagonal(attn_mask_type, bottom_right_diagonal); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); @@ -852,7 +887,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, - cuda_graph); + cuda_graph, deterministic); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 726cc4e47..358b4ef97 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -221,13 +221,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] window_size_right Sliding window size (the right half). * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. + * \param[in] deterministic Whether determinism is required or not. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph); + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); /*! \brief Compute dot product attention with packed QKV input. * @@ -288,22 +289,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ [[deprecated( "nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate " "Q, K, V tensors instead.")]] -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, - size_t max_seqlen, bool is_training, bool return_max_logit, - bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd_qkvpacked( + const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, + bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. * @@ -356,6 +356,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. @@ -370,8 +371,8 @@ void nvte_fused_attn_bwd_qkvpacked( NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, - NVTETensor workspace, cudaStream_t stream); + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, + bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with packed KV input. * @@ -439,6 +440,7 @@ void nvte_fused_attn_bwd_qkvpacked( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ @@ -454,7 +456,8 @@ void nvte_fused_attn_fwd_kvpacked( size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); + int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, + cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * @@ -513,6 +516,7 @@ void nvte_fused_attn_fwd_kvpacked( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. @@ -529,8 +533,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, - cudaStream_t stream); + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with separate Q, K and V. * @@ -602,19 +606,23 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd( - const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, - bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -679,6 +687,7 @@ void nvte_fused_attn_fwd( * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). + * \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix. * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. @@ -694,8 +703,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, - bool cuda_graph, NVTETensor workspace, cudaStream_t stream); + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index c724612cb..afb429bbf 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -461,9 +461,9 @@ class TensorAllocator { } void Free(NVTETensor t) { - std::lock_guard lock(mutex); uintptr_t index = reinterpret_cast(t); if (index == 0) return; + std::lock_guard lock(mutex); NVTE_CHECK(index <= memory.size(), "Invalid tensor."); free_list.push_back(index); // Clean up @@ -571,9 +571,9 @@ class GroupedTensorAllocator { } void Free(NVTEGroupedTensor t) { - std::lock_guard lock(mutex); uintptr_t index = reinterpret_cast(t); if (index == 0) return; + std::lock_guard lock(mutex); NVTE_CHECK(index <= memory.size(), "Invalid grouped tensor."); free_list.push_back(index); // Clean up diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index a186ed9d3..36cc8a952 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -137,6 +137,7 @@ constexpr bool is_supported_arch() { #endif //#ifndef __HIP_PLATFORM_AMD__ +#ifndef __HIP_PLATFORM_AMD__ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init __device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) @@ -178,6 +179,18 @@ __device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +__device__ __forceinline__ void mbarrier_arrive_expect_tx_cta_relaxed_shared_cta( + uint64_t *mbar, const uint32_t tx_count) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.expect_tx.relaxed.cta.shared::cta.b64 _, [%0], %1;" ::"r"(mbar_ptr), + "r"(tx_count)); +#else + NVTE_DEVICE_ERROR( + "mbarrier_arrive_expect_tx_cta_relaxed_shared_cta is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + __device__ __forceinline__ void fence_mbarrier_init_release_cluster() { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile("fence.mbarrier_init.release.cluster;"); @@ -257,6 +270,76 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3 #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +__device__ __forceinline__ void mbarrier_wait_parity_acquire_cta_shared_cta(uint64_t *mbar, + uint32_t phase_parity) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile( + "{\n\t" + ".reg .b64 r1; \n\t" + ".reg .pred waitComplete; \n\t" // predicate representing if barrier condition is met + "WAIT: \n\t" // loop around barrier wait + "mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 waitComplete, [%0], %1; \n\t" + "@waitComplete bra DONE; \n\t" // mbarrier conditions are met + "bra WAIT; \n\t" // just a time-out, try again + "DONE: \n\t" + "}\n\t" + : + : "r"(mbar_ptr), "r"(phase_parity) + : "memory"); +#else + NVTE_DEVICE_ERROR("mbarrier_wait_parity_acquire_cta_shared_cta is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__device__ __forceinline__ void try_cancel_cta(uint64_t *mbar, __uint128_t *response_data_ptr) { + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + uint32_t workID_response = __cvta_generic_to_shared(response_data_ptr); + asm volatile( + "clusterlaunchcontrol.try_cancel.async.mbarrier::complete_tx::bytes.multicast::cluster::" + "all.b128 " + "[%0], [%1];" ::"r"(workID_response), + "r"(mbar_ptr)); + } else { + NVTE_DEVICE_ERROR( + "Cluster Launch Control PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } +} + +__device__ __forceinline__ void get_cancelled_cta_id_2D(__uint128_t *response_data_ptr, + int32_t &ctaid_X, int32_t &ctaid_Y) { + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + uint32_t workID_response = __cvta_generic_to_shared(response_data_ptr); + asm volatile( + "{\n\t" + ".reg .s32 x_ctaid; \n\t" + ".reg .s32 y_ctaid; \n\t" + "mov .s32 x_ctaid, -1; \n\t" + "mov .s32 y_ctaid, -1; \n\t" + ".reg.b128 try_cancel_response; \n\t" + "ld.shared.b128 try_cancel_response, [%2]; \n\t" + ".reg .pred P1; \n\t" + "clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 P1, try_cancel_response; \n\t" + "@P1 clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128 {x_ctaid, y_ctaid, _, " + "_}, try_cancel_response; \n\t" + "mov .s32 %0, x_ctaid; \n\t" + "mov .s32 %1, y_ctaid; \n\t" + "}\n\t" + : "=r"(ctaid_X), "=r"(ctaid_Y) + : "r"(workID_response) + : "memory"); + } else { + NVTE_DEVICE_ERROR( + "Cluster Launch Control PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } +} +#endif //#ifndef __HIP_PLATFORM_AMD__ + constexpr uint32_t FP32_MANTISSA_BITS = 23; constexpr uint32_t FP32_EXPONENT_BIAS = 127; @@ -308,6 +391,7 @@ __device__ __forceinline__ e8m0_t float_to_e8m0(float val) { #endif //#ifndef __HIP_PLATFORM_AMD__ } +#ifndef __HIP_PLATFORM_AMD__ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // shared::cta -> global __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr, @@ -416,6 +500,7 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() { NVTE_DEVICE_ERROR("fence_proxy_async_shared_cta is only supported on SM 9.0+."); #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } +#endif //#ifndef __HIP_PLATFORM_AMD__ template struct alignas(2 * sizeof(T)) FPx2 { @@ -699,8 +784,184 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits); } } + +#ifndef __HIP_PLATFORM_AMD__ +template +__device__ __forceinline__ uint32_t mul_cvt_bf16_to_fp4_8x_round_to_nearest( + const uint64_t in03, const uint64_t in47, const SCALING_COEFFICIENT_TYPE scaling_coefficient) { + uint32_t out_8x = 0; + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + if constexpr (std::is_same::value) { + asm volatile( + "{\n" + ".reg.f32 zero; \n\t" + "mov.b32 zero, 0; \n\t" + ".reg.b16 scaling_coeff; \n\t" + "mov.b16 scaling_coeff, %3; \n\t" + ".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; \n\t" + "mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1; \n\t" + "mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2; \n\t" + + ".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t" + "fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero; \n\t" + + ".reg.b8 f0, f1, f2, f3; \n\t" + // Elements reordered to match e2m1x4 packing order (v1,v0) + "cvt.rn.satfinite.e2m1x2.f32 f0, v1, v0;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v3, v2;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f2, v5, v4;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f3, v7, v6;\n\t" + "mov.b32 %0, {f0, f1, f2, f3};\n" + "}" + : "=r"(out_8x) + : "l"(in03), "l"(in47), "h"(reinterpret_cast(scaling_coefficient))); + } else if constexpr (std::is_same::value) { + asm volatile( + "{\n" + ".reg.b64 scaling_coeff_2x; \n\t" + "mov.b64 scaling_coeff_2x, {%3, %3}; \n\t" + ".reg.b16 v0_bf16, v1_bf16, v2_bf16, v3_bf16, v4_bf16, v5_bf16, v6_bf16, v7_bf16; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16}, %1; \n\t" + "mov.b64 {v4_bf16, v5_bf16, v6_bf16, v7_bf16}, %2; \n\t" + + ".reg.b32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "cvt.f32.bf16 v4, v4_bf16; \n\t" + "cvt.f32.bf16 v5, v5_bf16; \n\t" + "cvt.f32.bf16 v6, v6_bf16; \n\t" + "cvt.f32.bf16 v7, v7_bf16; \n\t" + + ".reg.b64 v01, v23, v45, v67; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mov.b64 v45, {v4, v5}; \n\t" + "mov.b64 v67, {v6, v7}; \n\t" + "mul.f32x2 v01, v01, scaling_coeff_2x; \n\t" + "mul.f32x2 v23, v23, scaling_coeff_2x; \n\t" + "mul.f32x2 v45, v45, scaling_coeff_2x; \n\t" + "mul.f32x2 v67, v67, scaling_coeff_2x; \n\t" + // Elements reordered to match the packing order (v1,v0) + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "mov.b64 {v5, v4}, v45; \n\t" + "mov.b64 {v7, v6}, v67; \n\t" + + ".reg.b8 f0, f1, f2, f3; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f2, v4, v5;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f3, v6, v7;\n\t" + "mov.b32 %0, {f0, f1, f2, f3};\n\t" + "}" + : "=r"(out_8x) + : "l"(in03), "l"(in47), "f"(scaling_coefficient)); + } else { + NVTE_DEVICE_ERROR("Not supported scaling coefficient type."); + } + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return out_8x; +} + +template +__device__ __forceinline__ uint32_t mul_cvt_bf16_to_fp4_8x_stochastic_rounding( + const uint64_t in03, const uint64_t in47, const SCALING_COEFFICIENT_TYPE scaling_coefficient, + const uint32_t rbits03, const uint32_t rbits47) { + uint32_t out_8x = 0; + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + if constexpr (std::is_same::value) { + asm volatile( + "{\n" + ".reg.f32 zero; \n\t" + "mov.b32 zero, 0; \n\t" + ".reg.b16 scaling_coeff; \n\t" + "mov.b16 scaling_coeff, %3; \n\t" + ".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; \n\t" + "mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1; \n\t" + "mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2; \n\t" + + ".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t" + "fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero; \n\t" + "fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero; \n\t" + + ".reg.b16 b03, b47; \n\t" + // Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0) + "cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5; \n\t" + "mov.b32 %0, {b03, b47};\n" + "}" + : "=r"(out_8x) + : "l"(in03), "l"(in47), "h"(reinterpret_cast(scaling_coefficient)), + "r"(rbits03), "r"(rbits47)); + } else if constexpr (std::is_same::value) { + asm volatile( + "{\n" + ".reg.b16 v0_bf16, v1_bf16, v2_bf16, v3_bf16, v4_bf16, v5_bf16, v6_bf16, v7_bf16; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16}, %1; \n\t" + "mov.b64 {v4_bf16, v5_bf16, v6_bf16, v7_bf16}, %2; \n\t" + + ".reg.b32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "cvt.f32.bf16 v4, v4_bf16; \n\t" + "cvt.f32.bf16 v5, v5_bf16; \n\t" + "cvt.f32.bf16 v6, v6_bf16; \n\t" + "cvt.f32.bf16 v7, v7_bf16; \n\t" + + "mul.f32 v0, v0, %3; \n\t" + "mul.f32 v1, v1, %3; \n\t" + "mul.f32 v2, v2, %3; \n\t" + "mul.f32 v3, v3, %3; \n\t" + "mul.f32 v4, v4, %3; \n\t" + "mul.f32 v5, v5, %3; \n\t" + "mul.f32 v6, v6, %3; \n\t" + "mul.f32 v7, v7, %3; \n\t" + ".reg.b16 b03, b47; \n\t" + // Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0) + "cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5; \n\t" + "mov.b32 %0, {b03, b47};\n" + "}" + : "=r"(out_8x) + : "l"(in03), "l"(in47), "f"(scaling_coefficient), "r"(rbits03), "r"(rbits47)); + } else { + NVTE_DEVICE_ERROR("Not supported scaling coefficient type."); + } + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return out_8x; +} +#endif //#ifndef __HIP_PLATFORM_AMD__ + #endif // FP4_TYPE_SUPPORTED +#ifndef __HIP_PLATFORM_AMD__ // SIMD like "Fused" cast + multiplication (x2) __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, const floatx2 &scale) { @@ -868,7 +1129,6 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) } -#ifndef __HIP_PLATFORM_AMD__ __device__ __forceinline__ int32_t elect_one_sync(uint32_t mask = 0xFFFFFFFFu) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) int32_t pred = 0; @@ -1552,10 +1812,65 @@ __device__ __forceinline__ floatx4 up_cast(const bf16x4 &in) { } #endif //#ifndef __HIP_PLATFORM_AMD__ +#ifndef __HIP_PLATFORM_AMD__ +// Loads single BF16/FP16 element from shared memory state space +__device__ __forceinline__ bf16 ld_shared_b16(const bf16 *__restrict__ src_smem) { + const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem); + bf16 dst; + asm volatile("ld.shared.b16 %0, [%1];" + : "=h"(reinterpret_cast(dst)) + : "r"(src_smem_ptr)); + return dst; +} + +// Loads pair of BF16/FP16 values from shared memory state space +__device__ __forceinline__ bf16x2 ld_shared_b32(const bf16x2 *__restrict__ src_smem) { + const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem); + bf16x2 dst; + asm volatile("ld.shared.b32 %0, [%1];" + : "=r"(reinterpret_cast(dst)) + : "r"(src_smem_ptr)); + return dst; +} + +// Loads 8x BF16 values from shared memory state space +__device__ __forceinline__ __uint128_t ld_shared_b128(const bf16 *__restrict__ src_smem) { + uint64_t elts03, elts47; + const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem); + asm volatile( + "{\n\t" + ".reg.b128 xy; \n\t" + "ld.shared.b128 xy, [%2]; \n\t" + "mov.b128 {%0, %1}, xy; \n" + "}\n" + : "=l"(elts03), "=l"(elts47) + : "r"(src_smem_ptr)); + return (static_cast<__uint128_t>(elts47) << 64) | static_cast<__uint128_t>(elts03); +} + +#if FP4_TYPE_SUPPORTED +// Vectorized store of x8 FP4 elements into shared memory state space +__device__ __forceinline__ void st_shared_b32(fp4e2m1x2 *__restrict__ dst_smem, + uint32_t fp4_pack_x8) { + const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst_smem); + asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(fp4_pack_x8)); +} +#endif + +// Vectorized store of x16 FP4 elements into shared memory state space +#if FP4_TYPE_SUPPORTED +__device__ __forceinline__ void st_shared_b64(fp4e2m1x2 *__restrict__ dst_smem, + uint64_t fp4_pack_x16) { + const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst_smem); + asm volatile("st.shared.b64 [%0], %1;" : : "r"(dst_smem_ptr), "l"(fp4_pack_x16)); +} +#endif +#endif //#ifndef __HIP_PLATFORM_AMD__ } // namespace ptx namespace { +#ifndef __HIP_PLATFORM_AMD__ template __forceinline__ __device__ void initialize_barriers(uint64_t *mbar, const bool is_master_thread) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) @@ -1692,6 +2007,7 @@ __forceinline__ __device__ void copy_2d_to_sharedx3( NVTE_DEVICE_ERROR("copy_2d_to_sharedx3 is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +#endif //#ifndef __HIP_PLATFORM_AMD__ } // namespace } // namespace transformer_engine diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 4d669bc46..65857cbc4 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -75,6 +75,7 @@ "is_training", "max_segments_per_seq", "window_size", + "bottom_right_diagonal", "context_parallel_load_balanced", "cp_axis", "cp_striped_window_size", @@ -96,6 +97,7 @@ class _FusedAttnConfig: is_training: bool max_segments_per_seq: int window_size: Tuple[int, int] + bottom_right_diagonal: bool context_parallel_load_balanced: bool cp_axis: str cp_striped_window_size: Tuple[int, int] # Only for CP + Ring P2P + THD + SWA @@ -149,6 +151,7 @@ def get_fused_attn_backend(self): self.head_dim_v, self.window_size[0], self.window_size[1], + not self.is_non_deterministic_allowed(), ) @staticmethod @@ -394,6 +397,11 @@ def abstract( *bias_batch_shape, bias_heads, _, _ = bias_aval.shape bias_batch = reduce(operator.mul, bias_batch_shape) + bottom_right_diagonal = config.attn_mask_type in [ + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to # prepare for the active fused-attn backend input_batch = reduce(operator.mul, batch_shape) @@ -418,6 +426,7 @@ def abstract( config.max_segments_per_seq, config.window_size[0], config.window_size[1], + bottom_right_diagonal, ) wkspace_aval = q_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) @@ -526,6 +535,7 @@ def lowering( deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), window_size_left=window_size_left, window_size_right=window_size_right, + bottom_right_diagonal=config.bottom_right_diagonal, softmax_type=int(config.softmax_type.value), ) @@ -838,6 +848,7 @@ def abstract( config.max_segments_per_seq, config.window_size[0], config.window_size[1], + config.bottom_right_diagonal, ) dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) @@ -973,6 +984,7 @@ def lowering( deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), window_size_left=window_size_left, window_size_right=window_size_right, + bottom_right_diagonal=config.bottom_right_diagonal, softmax_type=int(config.softmax_type.value), ) @@ -1384,9 +1396,10 @@ def get_adjusted_max_segments_per_seq(self, max_seqlen, cp_size): def get_step_config(self) -> _FusedAttnConfig: """Returns a _FusedAttnConfig for single CP step call to fused attention.""" + adjusted_mask = self.get_adjusted_mask() return _FusedAttnConfig( attn_bias_type=self.config.attn_bias_type, - attn_mask_type=self.get_adjusted_mask(), + attn_mask_type=adjusted_mask, softmax_type=self.config.softmax_type, qkv_layout=self.config.qkv_layout, scaling_factor=self.config.scaling_factor, @@ -1394,6 +1407,7 @@ def get_step_config(self) -> _FusedAttnConfig: is_training=self.config.is_training, max_segments_per_seq=self.config.max_segments_per_seq, window_size=self.config.window_size, + bottom_right_diagonal=adjusted_mask.is_bottom_right(), context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, cp_striped_window_size=None, @@ -1402,9 +1416,10 @@ def get_step_config(self) -> _FusedAttnConfig: def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig: """Returns a _FusedAttnConfig for single CP step call (made via a striped AG primitive) to fused attention.""" + adjusted_mask = self.get_adjusted_mask() return _FusedAttnConfig( attn_bias_type=self.config.attn_bias_type, - attn_mask_type=self.get_adjusted_mask(), + attn_mask_type=adjusted_mask, softmax_type=self.config.softmax_type, qkv_layout=self.config.qkv_layout, scaling_factor=self.config.scaling_factor, @@ -1412,6 +1427,7 @@ def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig: is_training=self.config.is_training, max_segments_per_seq=self.get_adjusted_max_segments_per_seq(max_seqlen, cp_size), window_size=self.config.window_size, + bottom_right_diagonal=adjusted_mask.is_bottom_right(), context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, cp_striped_window_size=None, @@ -2457,6 +2473,7 @@ def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: is_training=self.config.is_training, max_segments_per_seq=self.config.max_segments_per_seq, window_size=self.config.window_size, + bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=self.config.context_parallel_load_balanced, cp_axis=self.config.cp_axis, cp_striped_window_size=None, @@ -3445,6 +3462,7 @@ def fused_attn_fwd( is_training=is_training, max_segments_per_seq=max_segments_per_seq, window_size=(-1, -1) if window_size is None else window_size, + bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_striped_window_size=None, @@ -3591,13 +3609,21 @@ def fused_attn_bwd( softmax_offset, (None, HEAD_AXES, None, None) ) - # TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on - # sm100+ - compute_capabilities = get_all_device_compute_capability() - if any(x >= 100 for x in compute_capabilities) and not is_hip_extension(): - assert not ( - attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0 - ), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" + compute_capabilities = get_all_device_compute_capability() if not is_hip_extension() else [] + if any(x >= 100 for x in compute_capabilities) and is_training: + assert ( + FusedAttnHelper.is_non_deterministic_allowed() + and get_cudnn_version() >= (9, 7, 0) + and (attn_bias_type == AttnBiasType.NO_BIAS or dropout_probability == 0.0) + ) or ( + not FusedAttnHelper.is_non_deterministic_allowed() + and get_cudnn_version() >= (9, 18, 1) + and attn_bias_type == AttnBiasType.NO_BIAS + and dropout_probability == 0.0 + ), ( + "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with dropout," + " and deterministic bprop (cuDNN 9.18.1+) does not support bias or dropout" + ) fused_config = _FusedAttnConfig( attn_bias_type=attn_bias_type, @@ -3609,6 +3635,7 @@ def fused_attn_bwd( is_training=is_training, max_segments_per_seq=max_segments_per_seq, window_size=(-1, -1) if window_size is None else window_size, + bottom_right_diagonal=attn_mask_type.is_bottom_right(), context_parallel_load_balanced=context_parallel_causal_load_balanced, cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_striped_window_size=None, diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 8c2798c68..43e20a845 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -119,7 +119,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend( NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, - int64_t window_size_right); + int64_t window_size_right, bool deterministic); pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, @@ -127,7 +127,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left, - int64_t window_size_right); + int64_t window_size_right, bool bottom_right_diagonal); pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, @@ -135,7 +135,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, - int64_t window_size_left, int64_t window_size_right); + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal); // GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 41347a85e..02efd1b38 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -18,12 +18,12 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend( NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, - int64_t window_size_right) { + int64_t window_size_right, bool deterministic) { auto backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false); + false, false, deterministic); return backend; } @@ -159,7 +159,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left, - int64_t window_size_right) { + int64_t window_size_right, bool bottom_right_diagonal) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; @@ -207,7 +207,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); + window_size_left, window_size_right, bottom_right_diagonal, query_workspace_tensor.data(), + nullptr); } nvte_tensor_pack_destroy(&aux_output_tensors); @@ -255,7 +256,7 @@ static void FusedAttnForwardImpl( size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, bool deterministic, - int64_t window_size_left, int64_t window_size_right) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) { FUSED_ATTN_IMPL_COMMON_BLOCK; /* Input tensors */ @@ -284,7 +285,7 @@ static void FusedAttnForwardImpl( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false); + false, false, deterministic); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -346,7 +347,7 @@ static void FusedAttnForwardImpl( k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, workspace_tensor.data(), stream); + window_size_left, window_size_right, bottom_right_diagonal, workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_output_tensors); } @@ -364,6 +365,7 @@ static void FusedAttnForwardImpl( size_t max_segments_per_seq = get_attr_value(attrs, "max_segments_per_seq"); \ auto window_size_left = get_attr_value(attrs, "window_size_left"); \ auto window_size_right = get_attr_value(attrs, "window_size_right"); \ + bool bottom_right_diagonal = get_attr_value(attrs, "bottom_right_diagonal"); \ float scaling_factor = get_attr_value(attrs, "scaling_factor"); \ float dropout_probability = get_attr_value(attrs, "dropout_probability"); \ NVTE_Bias_Type bias_type = \ @@ -402,7 +404,7 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, softmax_type, qkv_layout, dtype, wkspace_dtype, - is_training, deterministic, window_size_left, window_size_right); + is_training, deterministic, window_size_left, window_size_right, bottom_right_diagonal); return ffi_with_cuda_error_check(); } @@ -433,7 +435,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, - int64_t window_size_left, int64_t window_size_right) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); @@ -485,17 +487,18 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( auto dummy_ragged_offset_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); - nvte_fused_attn_bwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr); + nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), + dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, bottom_right_diagonal, deterministic, false, + query_workspace_tensor.data(), nullptr); } nvte_tensor_pack_destroy(&aux_input_tensors); @@ -514,7 +517,7 @@ static void FusedAttnBackwardImpl( size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, bool deterministic, - int64_t window_size_left, int64_t window_size_right) { + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) { FUSED_ATTN_IMPL_COMMON_BLOCK; /* Input tensors */ @@ -540,7 +543,7 @@ static void FusedAttnBackwardImpl( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false); + false, false, deterministic); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias, softmax_offset); @@ -611,16 +614,17 @@ static void FusedAttnBackwardImpl( } } - nvte_fused_attn_bwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(), - dsoftmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream); + nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), + dbias_tensor.data(), dsoftmax_offset_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, + kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, false, workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_input_tensors); } @@ -649,7 +653,7 @@ Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_T q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, softmax_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left, - window_size_right); + window_size_right, bottom_right_diagonal); return ffi_with_cuda_error_check(); } diff --git a/transformer_engine/jax/triton_extensions/permutation.py b/transformer_engine/jax/triton_extensions/permutation.py index bd8bd8ff1..0c80f9f18 100644 --- a/transformer_engine/jax/triton_extensions/permutation.py +++ b/transformer_engine/jax/triton_extensions/permutation.py @@ -65,8 +65,6 @@ class RowIdMapPass1Primitive(BasePrimitive): @staticmethod def abstract(routing_map_aval, *, num_tokens, num_experts, block_size): """Shape/dtype inference for pass 1.""" - del block_size # Only affects grid, not output shape - assert routing_map_aval.shape == ( num_tokens, num_experts, @@ -75,7 +73,7 @@ def abstract(routing_map_aval, *, num_tokens, num_experts, block_size): row_id_map_shape = (num_tokens, num_experts * 2 + 1) workspace_shape = ( num_experts, - triton.cdiv(num_tokens, DEFAULT_BLOCK_SIZE), + triton.cdiv(num_tokens, block_size), ) return ( @@ -134,9 +132,10 @@ def infer_sharding_from_operands( desc="RowIdMapPass1.row_id_map_sharding", ) # Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) + # Second dim depends on num_tokens, so it must be sharded on the same axis as tokens workspace_sharding = NamedSharding( mesh, - PartitionSpec(None, None), + PartitionSpec(None, routing_map_spec[0]), desc="RowIdMapPass1.workspace_sharding", ) return [row_id_map_sharding, workspace_sharding] @@ -156,9 +155,11 @@ def partition(num_tokens, num_experts, block_size, mesh, arg_infos, result_infos PartitionSpec(routing_map_spec[0], None), desc="RowIdMapPass1.row_id_map_sharding", ) + # Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) + # Second dim depends on num_tokens, so it must be sharded on the same axis as tokens workspace_sharding = NamedSharding( mesh, - PartitionSpec(None, None), + PartitionSpec(None, routing_map_spec[0]), desc="RowIdMapPass1.workspace_sharding", ) out_shardings = [row_id_map_sharding, workspace_sharding] @@ -186,7 +187,8 @@ def shardy_sharding_rule(num_tokens, num_experts, block_size, mesh, value_types, # Note: row_id_cols != experts since it's num_experts * 2 + 1 row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_row_id_cols") # workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) - workspace_spec = (f"{prefix}_experts", f"{prefix}_ws_blocks") + # Second dim depends on num_tokens, so use same factor to ensure same sharding + workspace_spec = (f"{prefix}_experts", f"{prefix}_tokens") return SdyShardingRule((input_spec,), (row_id_map_spec, workspace_spec)) @@ -208,10 +210,9 @@ class RowIdMapPass2Primitive(BasePrimitive): def abstract(row_id_map_aval, workspace_aval, *, num_tokens, num_experts, block_size): """Shape/dtype inference for pass 2 (in-place operation).""" del row_id_map_aval, workspace_aval - del block_size row_id_map_shape = (num_tokens, num_experts * 2 + 1) - workspace_shape = (num_experts, triton.cdiv(num_tokens, DEFAULT_BLOCK_SIZE)) + workspace_shape = (num_experts, triton.cdiv(num_tokens, block_size)) return ( jax.core.ShapedArray(row_id_map_shape, jnp.int32), @@ -270,9 +271,11 @@ def infer_sharding_from_operands( PartitionSpec(*row_id_map_spec), desc="RowIdMapPass2.row_id_map_sharding", ) + # Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) + # Second dim depends on num_tokens, so it must be sharded on the same axis as tokens workspace_sharding = NamedSharding( mesh, - PartitionSpec(None, None), + PartitionSpec(None, row_id_map_spec[0]), desc="RowIdMapPass2.workspace_sharding", ) return [row_id_map_sharding, workspace_sharding] @@ -292,9 +295,11 @@ def partition(num_tokens, num_experts, block_size, mesh, arg_infos, result_infos PartitionSpec(*row_id_map_spec), desc="RowIdMapPass2.row_id_map_sharding", ) + # Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) + # Second dim depends on num_tokens, so it must be sharded on the same axis as tokens workspace_sharding = NamedSharding( mesh, - PartitionSpec(None, None), + PartitionSpec(None, row_id_map_spec[0]), desc="RowIdMapPass2.workspace_sharding", ) out_shardings = [row_id_map_sharding, workspace_sharding] @@ -317,7 +322,9 @@ def shardy_sharding_rule(num_tokens, num_experts, block_size, mesh, value_types, del num_tokens, num_experts, block_size, mesh, value_types, result_types prefix = "RowIdMapPass2" row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_cols") - workspace_spec = (f"{prefix}_ws_experts", f"{prefix}_ws_blocks") + # workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE)) + # Second dim depends on num_tokens, so use same factor to ensure same sharding + workspace_spec = (f"{prefix}_ws_experts", f"{prefix}_tokens") return SdyShardingRule((row_id_map_spec, workspace_spec), (row_id_map_spec, workspace_spec)) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 6ea4092cb..2627a0892 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -36,6 +36,8 @@ from typing import Any, Callable, Mapping import zlib +from packaging import version + from jax import core import jax import jax.numpy as jnp @@ -274,13 +276,16 @@ def compile_triton( return _TRITON_KERNEL_CACHE[cache_key] # Compile kernel + cuda_option_kwargs = {} + if version.parse(_TRITON_VERSION) < version.parse("3.6.0"): + cuda_option_kwargs["cluster_dims"] = (1, 1, 1) options = cb.CUDAOptions( num_warps=num_warps, num_stages=num_stages, num_ctas=num_ctas, - cluster_dims=(1, 1, 1), debug=False, enable_fp_fusion=enable_fp_fusion, + **cuda_option_kwargs, ) # Mark constants as constexpr in signature @@ -303,8 +308,6 @@ def compile_triton( # Create kernel object for JAX # From jax/jaxlib/gpu/triton_kernels.cc: - from packaging import version - if version.parse(jax.__version__) >= version.parse("0.8.2"): kernel = gpu_triton.TritonKernel( compiled.name, # arg0: kernel_name (str) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 2d6f1da7e..560fc5f72 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -170,6 +170,11 @@ class FP8EmulationFunc(torch.autograd.Function): @staticmethod def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout): # pylint: disable=missing-function-docstring + if is_in_onnx_export_mode(): + return FP8EmulationFunc.onnx_forward( + tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout + ) + if quantizer_name == "QKV_quantizer": query_layer, key_layer, value_layer = [ x.contiguous() for x in [tensor1, tensor2, tensor3] @@ -208,6 +213,47 @@ def backward(ctx, grad1, grad2, grad3): tensors = grad1, grad2, grad3 return tensors[0], tensors[1], tensors[2], None, None, None + @staticmethod + def onnx_forward(tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout=None): + """ + ONNX-compatible forward for FP8 emulation using operations with defined ONNX translations. + """ + # pylint: disable=unused-argument + is_qkv_quantizer = quantizer_name == "QKV_quantizer" + assert isinstance( + quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ), "ONNX FP8 emulation path supports only Float8 quantizers." + + if is_qkv_quantizer: + # Flatten + concatenate + quantize + split. Equivalent to combine_and_quantize Case 3. + orig_dtype = tensor1.dtype + shapes = [tensor1.shape, tensor2.shape, tensor3.shape] + numels = [tensor1.numel(), tensor2.numel(), tensor3.numel()] + + # Flatten and concatenate + combined = torch.cat( + [tensor1.reshape(-1), tensor2.reshape(-1), tensor3.reshape(-1)], dim=0 + ) + + # Quantize + dequantize combined tensor using quantizer's ONNX methods + combined_fp8 = quantizer.onnx_quantize(combined) + out = quantizer.onnx_dequantize(combined_fp8).to(orig_dtype) + + # Split back + out1 = out[: numels[0]].reshape(shapes[0]) + out2 = out[numels[0] : numels[0] + numels[1]].reshape(shapes[1]) + out3 = out[numels[0] + numels[1] :].reshape(shapes[2]) + + return out1, out2, out3 + if quantizer_name in ["S_quantizer", "O_quantizer"]: + # Emulate FP8 on single tensor using quantizer's ONNX methods + orig_dtype = tensor1.dtype + t_fp8 = quantizer.onnx_quantize(tensor1) + out = quantizer.onnx_dequantize(t_fp8).to(orig_dtype) + return out, tensor2, tensor3 + # Pass-through + return tensor1, tensor2, tensor3 + class UnfusedDotProductAttention(torch.nn.Module): """Parallel attention w/o QKV and Proj Gemms @@ -269,6 +315,7 @@ def forward( attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, @@ -354,6 +401,11 @@ def forward( attention_mask=attention_mask, window_size=window_size, attention_type=self.attention_type, + bottom_right_alignment=( + attn_mask_type not in ["causal", "padding_causal"] + if bottom_right_diagonal is None + else bottom_right_diagonal + ), ) ) @@ -457,7 +509,11 @@ def forward( actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None, actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None, alibi_slopes=alibi_slopes, - bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], + bottom_right_alignment=( + attn_mask_type not in ["causal", "padding_causal"] + if bottom_right_diagonal is None + else bottom_right_diagonal + ), ) matmul_result = torch.baddbmm( matmul_result, @@ -1118,6 +1174,7 @@ def forward( attn_mask_type, softmax_type, window_size, + bottom_right_diagonal, rng_gen, fused_attention_backend, use_FAv2_bwd, @@ -1221,6 +1278,7 @@ def forward( attn_mask_type, softmax_type, window_size, + bottom_right_diagonal, rng_gen, softmax_offset, cuda_graph=is_graph_capturing(), @@ -1298,6 +1356,7 @@ def forward( attn_mask_type, softmax_type, window_size, + bottom_right_diagonal, rng_gen, softmax_offset, return_max_logit, @@ -1385,6 +1444,7 @@ def forward( ctx.attn_mask_type = attn_mask_type ctx.softmax_type = softmax_type ctx.window_size = window_size + ctx.bottom_right_diagonal = bottom_right_diagonal ctx.fused_attention_backend = ( fused_attention_backend if (IS_HIP_EXTENSION or ctx.fp8) else FusedAttnBackend["F16_arbitrary_seqlen"] ) @@ -1535,6 +1595,7 @@ def backward(ctx, d_out, *_args): ctx.attn_mask_type, ctx.softmax_type, ctx.window_size, + ctx.bottom_right_diagonal, ctx.deterministic, is_graph_capturing(), ) @@ -1600,6 +1661,7 @@ def backward(ctx, d_out, *_args): ctx.attn_mask_type, ctx.softmax_type, ctx.window_size, + ctx.bottom_right_diagonal, ctx.deterministic, is_graph_capturing(), ) @@ -1639,6 +1701,7 @@ def backward(ctx, d_out, *_args): None, None, None, + None, d_softmax_offset, None, None, @@ -1738,6 +1801,7 @@ def forward( attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, @@ -1945,6 +2009,7 @@ def forward( attn_mask_type, self.softmax_type, window_size, + bottom_right_diagonal, None, # rng_gen fused_attention_backend, use_FAv2_bwd, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index f8fad6993..087b27f34 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4029,28 +4029,30 @@ def attn_forward_func_with_cp( assert not sliding_window_attn or cp_comm_type in [ "a2a", "all_gather", - ], "Context parallelism does not support sliding window attention with {cp_comm_type=}!" + ], f"Context parallelism does not support sliding window attention with {cp_comm_type=}!" enable_mla = k.shape[-1] != v.shape[-1] assert not enable_mla or cp_comm_type in [ "p2p", "a2a+p2p", - ], "Context parallelism does not support MLA with {cp_comm_type=}!" + ], f"Context parallelism does not support MLA with {cp_comm_type=}!" if fp8 and fp8_meta is not None: if fp8_meta["recipe"].fp8_dpa: assert ( softmax_type == "vanilla" - ), "Context parallelism does not support {softmax_type=} with FP8 attention!" + ), f"Context parallelism does not support {softmax_type=} with FP8 attention!" assert ( softmax_type == "vanilla" or use_fused_attention - ), "Context parallelism only supports {softmax_type=} with FusedAttention backend!" + ), f"Context parallelism only supports {softmax_type=} with FusedAttention backend!" assert ( softmax_type == "vanilla" or cp_comm_type == "a2a" - ), "Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" - assert ( - softmax_type == "vanilla" or qkv_format != "thd" - ), "Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!" + ), f"Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" + if get_cudnn_version() < (9, 18, 0): + assert softmax_type == "vanilla" or qkv_format != "thd", ( + f"Before cuDNN 9.18.0, context parallelism does not support {softmax_type=} with" + " qkv_format = 'thd'!" + ) args = [ is_training, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index fa4fb9a48..327bb8e4b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -232,6 +232,11 @@ class DotProductAttention(TransformerEngineBaseModule): map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on ``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can be overridden by :attr:`window_size` in ``forward`` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. attention_type : str, default = "self" type of attention, either ``"self"`` and ``"cross"``. layer_number : int, default = None @@ -328,6 +333,7 @@ def __init__( qkv_format: str = "sbhd", attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, sequence_parallel: bool = False, tp_size: int = 1, get_rng_state_tracker: Optional[Callable] = None, @@ -356,6 +362,7 @@ def __init__( attn_mask_type = "padding_causal" self.attn_mask_type = attn_mask_type self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) + self.bottom_right_diagonal = bottom_right_diagonal if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -682,9 +689,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # assume attention uses the same fp8_group as GEMMs fp8_group = FP8GlobalStateManager.get_fp8_group() - self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() - self.fp8 = FP8GlobalStateManager.is_fp8_enabled() - self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + self.fast_setattr("fp8_parameters", FP8GlobalStateManager.with_fp8_parameters()) + self.fast_setattr("fp8", FP8GlobalStateManager.is_fp8_enabled()) + self.fast_setattr("fp8_calibration", FP8GlobalStateManager.is_fp8_calibration()) fp8_enabled = self.fp8 or self.fp8_calibration self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration if self.fp8_parameters or fp8_enabled: @@ -709,7 +716,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ) else: # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False + self.fast_setattr("fp8_initialized", False) return if self.fp8_parameters and not self.fp8_initialized: @@ -727,7 +734,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # Allocate scales and amaxes self.init_fp8_meta_tensors(fp8_recipes) - self.fp8_initialized = True + self.fast_setattr("fp8_initialized", True) self.fp8_meta["recipe"] = fp8_recipe_dpa if fp8_recipe != fp8_recipe_dpa: @@ -817,6 +824,7 @@ def forward( max_seqlen_kv: int = None, attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, checkpoint_core_attention: bool = False, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, @@ -969,6 +977,16 @@ def forward( causal masks are aligned to the bottom right corner. window_size: Optional[Tuple[int, int]], default = None Sliding window size for local attention. + bottom_right_diagonal: Optional[bool], default = None + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {'causal', 'padding_causal'} and `True` for other mask types. + Note: This parameter will be automatically overridden based on the + `attn_mask_type` - it will be forced to `False` for 'causal' and + 'padding_causal' mask types, and forced to `True` for mask types + containing 'bottom_right' (e.g., 'causal_bottom_right', + 'padding_causal_bottom_right'), regardless of the explicitly passed value. checkpoint_core_attention : bool, default = False If true, forward activations for attention are recomputed during the backward pass in order to save memory that would @@ -1006,7 +1024,7 @@ def forward( cases. It is ignored for other backends and when context parallelism is enabled. """ - with self.prepare_forward( + with self.prepare_forward_ctx( query_layer, num_gemms=3, allow_non_contiguous=True, @@ -1087,6 +1105,15 @@ def forward( if window_size is None: window_size = self.window_size window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True # checks for qkv_format if qkv_format is None: @@ -1150,11 +1177,14 @@ def forward( assert "padding" in attn_mask_type, "KV caching requires padding mask!" if attn_mask_type == "padding_causal": attn_mask_type = attn_mask_type + "_bottom_right" + # since attention mask is changed, set `bottom_right_diagonal` to True + bottom_right_diagonal = True - self.attention_type = "cross" - self.flash_attention.attention_type = self.attention_type - self.fused_attention.attention_type = self.attention_type - self.unfused_attention.attention_type = self.attention_type + if self.attention_type != "cross": + self.fast_setattr("attention_type", "cross") + self.flash_attention.attention_type = self.attention_type + self.fused_attention.attention_type = self.attention_type + self.unfused_attention.attention_type = self.attention_type query_layer, key_layer, value_layer = [ x.contiguous() if not x.is_contiguous() else x @@ -1262,7 +1292,6 @@ def forward( if self.layer_number == 1: _alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True - bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],) if core_attention_bias_type == "alibi": assert ( core_attention_bias is None @@ -1271,7 +1300,7 @@ def forward( _alibi_cache["_num_heads"] != query_layer.shape[-2] or _alibi_cache["_max_seqlen_q"] != max_seqlen_q or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv - or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment + or _alibi_cache["_bottom_right_alignment"] != bottom_right_diagonal or _alibi_cache["_alibi_slopes"] is None ): _alibi_cache["_alibi_slopes_require_update"] = True @@ -1328,6 +1357,7 @@ def forward( head_dim_v=head_dim_v, attn_mask_type=attn_mask_type, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None, core_attention_bias_type=core_attention_bias_type, core_attention_bias_shape=core_attention_bias_shape, @@ -1451,9 +1481,7 @@ def forward( if use_fused_attention: fu_core_attention_bias_type = core_attention_bias_type fu_core_attention_bias = core_attention_bias - if core_attention_bias_type == "alibi" and ( - alibi_slopes is not None or max_seqlen_q != max_seqlen_kv - ): + if core_attention_bias_type == "alibi" and (alibi_slopes is not None): fu_core_attention_bias_type = "post_scale_bias" _, fu_core_attention_bias = dpa_utils.get_alibi( _alibi_cache, @@ -1462,7 +1490,7 @@ def forward( max_seqlen_kv, alibi_slopes=alibi_slopes, bias_dtype=query_layer.dtype, - bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], + bottom_right_alignment=bottom_right_diagonal, ) if checkpoint_core_attention: return self._checkpointed_attention_forward( @@ -1480,6 +1508,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, fused_attention_backend=fused_attention_backend, core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias=fu_core_attention_bias, @@ -1510,6 +1539,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, fused_attention_backend=fused_attention_backend, core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias=fu_core_attention_bias, @@ -1528,7 +1558,9 @@ def forward( ) if use_unfused_attention: - allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" + allow_emulation = ( + os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode() + ) if checkpoint_core_attention: return self._checkpointed_attention_forward( self.unfused_attention, @@ -1544,6 +1576,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, @@ -1567,6 +1600,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 9b7147106..c272dd90f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -203,6 +203,9 @@ class AttentionParams: `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} window_size : Tuple[int, int], default = None Sliding window attention size. + bottom_right_diagonal: bool, default = `None` + Whether to align sliding window and ALiBi diagonal to the bottom right corner + of the softmax matrix. alibi_slopes_shape : Optional[Union[torch.Size, List]], default = None Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`. core_attention_bias_type : str, default = no_bias @@ -252,6 +255,7 @@ class AttentionParams: head_dim_v: int = 64 attn_mask_type: str = "no_mask" window_size: Union[Tuple[int, int], None] = None + bottom_right_diagonal: bool = True alibi_slopes_shape: Union[torch.Size, List, None] = None core_attention_bias_type: str = "no_bias" core_attention_bias_shape: str = "1hss" @@ -330,6 +334,7 @@ def get_attention_backend( head_dim_v = attention_params.head_dim_v attn_mask_type = attention_params.attn_mask_type window_size = attention_params.window_size + bottom_right_diagonal = attention_params.bottom_right_diagonal alibi_slopes_shape = attention_params.alibi_slopes_shape core_attention_bias_type = attention_params.core_attention_bias_type core_attention_bias_shape = attention_params.core_attention_bias_shape @@ -479,7 +484,9 @@ def get_attention_backend( logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False if use_unfused_attention: - allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" + allow_emulation = ( + os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode() + ) if not allow_emulation: logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") use_unfused_attention = False @@ -723,22 +730,14 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_unfused_attention = False if qkv_format == "thd": - logger.debug( - "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type - ) - use_fused_attention = False - logger.debug( - "Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd", - softmax_type, - ) - use_unfused_attention = False + if cudnn_version < (9, 18, 0): + logger.debug( + "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN" + " version < 9.18", + softmax_type, + ) + use_fused_attention = False if context_parallel: - logger.debug( - "Disabling UnfusedDotProductAttention for context parallelism with softmax_type" - " = %s", - softmax_type, - ) - use_unfused_attention = False if cp_comm_type != "a2a": logger.debug( "Disabling FusedAttention for context parallelism with softmax_type = %s and" @@ -874,39 +873,44 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt # backend | window_size | diagonal alignment # --------------------------------------------------------------------------------- # FlashAttention | (-1, -1) or (>=0, >=0) | bottom right - # FusedAttention | (-1, 0) or (>=0, 0) | top left - # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both; + # FusedAttention | (-1, 0) or (>=0, >=0) | top left, bottom right + # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | top left, bottom right # | | converts window_size to an 'arbitrary' mask if window_size is None: window_size = check_set_window_size(attn_mask_type, window_size) - else: - if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention" - " for FP8" - ) - use_fused_attention = False - elif (not IS_HIP_EXTENSION) and (window_size[1] != 0 or attention_dropout != 0.0): - logger.debug( - "Disabling FusedAttention as it only supports sliding window attention " - "with (left, 0) and no dropout" - ) - use_fused_attention = False - elif max_seqlen_q > max_seqlen_kv: - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention " - "with s_q > s_kv for cross-attention" - ) - use_fused_attention = False - if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if not FlashAttentionUtils.is_installed: - FlashAttentionUtils.version_required = PkgVersion("2.3") - elif not FlashAttentionUtils.v2_3_plus: - logger.debug( - "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" - ) - use_flash_attention_2 = False + if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): + if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): + logger.debug( + "Disabling FusedAttention as it does not support sliding window attention" + " for FP8" + ) + use_fused_attention = False + elif (not IS_HIP_EXTENSION) and attention_dropout != 0.0: + logger.debug( + "Disabling FusedAttention as it only supports sliding window attention " + "with (left, 0) and no dropout" + ) + use_fused_attention = False + elif max_seqlen_q > max_seqlen_kv: + logger.debug( + "Disabling FusedAttention as it does not support sliding window attention " + "with s_q > s_kv for cross-attention" + ) + use_fused_attention = False + if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]): + if not FlashAttentionUtils.is_installed: + FlashAttentionUtils.version_required = PkgVersion("2.3") + elif not FlashAttentionUtils.v2_3_plus: + logger.debug( + "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" + ) + use_flash_attention_2 = False + elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FlashAttention as it only supports sliding window with bottom right" + " diagonal alignment for cross-attention" + ) + use_flash_attention = False # Filter: Attention bias # backend | bias types | ALiBi diagonal alignment @@ -928,6 +932,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt elif not FlashAttentionUtils.v2_4_plus: logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") use_flash_attention_2 = False + elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FlashAttention as it only supports ALiBi with bottom right diagonal" + " alignment for cross-attention" + ) + use_flash_attention = False if ( core_attention_bias_type not in ["no_bias", "alibi"] @@ -945,13 +955,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if ( use_fused_attention and core_attention_bias_type == "alibi" - and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv) + and (alibi_slopes_shape is not None) ): fu_core_attention_bias_type = "post_scale_bias" fu_core_attention_bias_requires_grad = False - if alibi_slopes_shape is None: - fu_core_attention_bias_shape = "1hss" - elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads: + + if len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads: fu_core_attention_bias_shape = "1hss" elif ( len(alibi_slopes_shape) == 2 @@ -1003,6 +1012,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt window_size[1], return_max_logit, cuda_graph, + deterministic, ) if fused_attention_backend == FusedAttnBackend["No_Backend"]: logger.debug("Disabling FusedAttention as no backend supports the provided input") @@ -1057,6 +1067,15 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_flash_attention_2 = False if use_fused_attention and deterministic and (not IS_HIP_EXTENSION): + if softmax_type != "vanilla": + logger.debug( + "Disabling FusedAttention for determinism reasons with softmax_type = %s. " + "Sink attention (off-by-one and learnable softmax) requires " + "NVTE_ALLOW_NONDETERMINISTIC_ALGO=1", + softmax_type, + ) + use_fused_attention = False + fused_attention_backend = None if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: logger.debug("Disabling FusedAttention for determinism reasons with FP8") use_fused_attention = False @@ -1073,10 +1092,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias") use_fused_attention = False fused_attention_backend = None - if is_training and device_compute_capability >= (10, 0): - logger.debug("Disabling FusedAttention for determinism reasons on Blackwell") - use_fused_attention = False - fused_attention_backend = None + # TODO: remove the filtering after ck team tells us how to enable more deterministic bwd kernels if use_fused_attention and deterministic and IS_HIP_EXTENSION: if ( diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index f875fd1e0..01c4955d7 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -8,7 +8,6 @@ from typing import Callable, List, Optional, Tuple, Union import torch -from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule @@ -32,6 +31,7 @@ from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb +from transformer_engine.pytorch.attention.dot_product_attention import utils as dpa_utils from transformer_engine.pytorch.cpu_offload import start_offload, is_cpu_offload_enabled @@ -93,6 +93,11 @@ class MultiheadAttention(torch.nn.Module): map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on ``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can be overridden by :attr:`window_size` in :meth:`forward` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. num_gqa_groups : int, default = None number of GQA groups in the transformer layer. Grouped Query Attention is described in @@ -248,6 +253,7 @@ def __init__( layer_number: Optional[int] = None, attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, num_gqa_groups: Optional[int] = None, @@ -286,6 +292,7 @@ def __init__( self.qkv_format = qkv_format self.attn_mask_type = attn_mask_type self.window_size = window_size + self.bottom_right_diagonal = bottom_right_diagonal self.layer_number = 1 if layer_number is None else layer_number self.input_layernorm = input_layernorm self.attention_type = attention_type @@ -335,6 +342,7 @@ def __init__( self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups self.name = name + TransformerEngineBaseModule._validate_name(self) common_gemm_kwargs = { "fuse_wgrad_accumulation": fuse_wgrad_accumulation, @@ -621,6 +629,7 @@ def forward( encoder_output: Optional[torch.Tensor] = None, attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[InferenceParams] = None, @@ -667,6 +676,11 @@ def forward( aligned to the bottom right corner. window_size: Optional[Tuple[int, int]], default = None sliding window size for local attention. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. encoder_output : Optional[torch.Tensor], default = None Output of the encoder block to be fed into the decoder block if using ``layer_type="decoder"``. @@ -731,6 +745,17 @@ def forward( if window_size is None: window_size = self.window_size + window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True + if "padding" in attn_mask_type and attention_mask is not None: for mask in attention_mask: assert mask.dtype == torch.bool, "Attention mask must be in boolean type!" @@ -739,9 +764,6 @@ def forward( core_attention_bias_type in AttnBiasTypes ), f"core_attention_bias_type {core_attention_bias_type} is not supported!" - if TEDebugState.debug_enabled: - TransformerEngineBaseModule._validate_name(self) - # ================================================= # Pre-allocate memory for key-value cache for inference # ================================================= @@ -1004,6 +1026,7 @@ def forward( attention_mask=attention_mask, attn_mask_type=attn_mask_type, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, checkpoint_core_attention=checkpoint_core_attention, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 788a9d7ef..e9c701080 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -146,6 +146,7 @@ def fused_attn_fwd( attn_mask_type: str = "padding", softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = None, rng_gen: torch.Generator = None, softmax_offset: torch.Tensor = None, return_max_logit: bool = False, @@ -221,6 +222,9 @@ def fused_attn_fwd( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = None + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. rng_gen : torch.Generator, default = None random number generator; if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen @@ -267,6 +271,12 @@ def fused_attn_fwd( if IS_HIP_EXTENSION: assert not return_max_logit, "ROCm does not support return_max_logit yet." + if bottom_right_diagonal is None: + bottom_right_diagonal = attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + } + if attn_scale is None: d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) @@ -322,6 +332,7 @@ def fused_attn_fwd( AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], window_size, + bottom_right_diagonal, cu_seqlens_q, cu_seqlens_kv, q, @@ -386,6 +397,7 @@ def fused_attn_bwd( attn_mask_type: str = "padding", softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), + bottom_right_diagonal: bool = None, deterministic: bool = False, cuda_graph: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: @@ -458,6 +470,9 @@ def fused_attn_bwd( in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding window and causal mask specifically. + bottom_right_diagonal: bool, default = None + whether to align sliding window and ALiBi diagonal to the top left (False) or + bottom right (True) corner of the softmax matrix. deterministic : bool, default = False whether to execute the backward pass with deterministic behaviours. cuda_graph : bool, default = False @@ -478,6 +493,12 @@ def fused_attn_bwd( gradient tensor of softmax offset of shape [1, h_q, 1, 1]. See softmax_type in DotProductAttention for details. """ + if bottom_right_diagonal is None: + bottom_right_diagonal = attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + } + if attn_scale is None: d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) @@ -517,6 +538,7 @@ def fused_attn_bwd( AttnMaskType[attn_mask_type], SoftmaxType[softmax_type], window_size, + bottom_right_diagonal, deterministic, cu_seqlens_q, cu_seqlens_kv, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index df8b548cc..f5e2a2ab0 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -83,15 +83,16 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph); + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - const std::vector window_size, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, + const std::vector window_size, bool bottom_right_diagonal, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const at::ScalarType fake_dtype, + const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, @@ -101,10 +102,10 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, const std::vector window_size, bool deterministic, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const py::handle O, const py::handle dO, - const at::ScalarType fake_dtype, const DType dqkv_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, + bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index b455e0375..72087a521 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -8,6 +10,13 @@ #include "common.h" #include "pybind.h" +#include +#if USE_ROCM && TORCH_VERSION_MINOR < 11 +using TECUDAGuard = at::hip::HIPGuardMasqueradingAsCUDA; +#else +using TECUDAGuard = at::cuda::CUDAGuard; +#endif + namespace { constexpr int block_size = 512; @@ -45,12 +54,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph) { + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, - return_max_logit, cuda_graph); + return_max_logit, cuda_graph, deterministic); return fused_attention_backend; } @@ -100,9 +109,10 @@ std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - const std::vector window_size, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, + const std::vector window_size, bool bottom_right_diagonal, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const at::ScalarType fake_dtype, + const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, @@ -111,7 +121,7 @@ std::vector fused_attn_fwd( // Ensure that cuDNN handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(cu_seqlens_q.device()); + TECUDAGuard device_guard(cu_seqlens_q.device()); auto none = py::none(); @@ -235,7 +245,7 @@ std::vector fused_attn_fwd( te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], workspace.data(), + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -295,7 +305,7 @@ std::vector fused_attn_fwd( te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], workspace.data(), + softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -310,10 +320,10 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, const std::vector window_size, bool deterministic, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const py::handle O, const py::handle dO, - const at::ScalarType fake_dtype, const DType dqkv_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, + bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, @@ -532,14 +542,14 @@ std::vector fused_attn_bwd( // populate tensors with appropriate shapes and dtypes NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), - te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), - te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], deterministic, cuda_graph, - workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd( + te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), + &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), + te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -549,14 +559,14 @@ std::vector fused_attn_bwd( // execute kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), - te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), - te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size[0], window_size[1], deterministic, cuda_graph, - workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd( + te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), + &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), + te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 941b88e36..6898ce387 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -20,6 +20,13 @@ #include "transformer_engine/transformer_engine.h" #include "util.h" +#include +#if USE_ROCM && TORCH_VERSION_MINOR < 11 +using TECUDAGuard = at::hip::HIPGuardMasqueradingAsCUDA; +#else +using TECUDAGuard = at::cuda::CUDAGuard; +#endif + namespace { void* get_data_ptr(transformer_engine::pytorch::MaybeTensor tensor) { @@ -100,7 +107,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Ensure that cublasLt handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(workspace.device()); + TECUDAGuard device_guard(workspace.device()); // Input tensors NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); @@ -388,7 +395,7 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, // Ensure that cublasLt handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(workspace.device()); + TECUDAGuard device_guard(workspace.device()); // TODO: Handle scaling modes NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; @@ -442,7 +449,7 @@ std::optional> te_general_grouped_gemm( // Ensure that cublasLt handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(workspace[0].device()); + TECUDAGuard device_guard(workspace[0].device()); void* output_data_ptr = nullptr; if (single_output) { diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index b78982d4d..8f8eed2c3 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -10,6 +10,13 @@ #include "common/util/system.h" #include "pybind.h" +#include +#if USE_ROCM && TORCH_VERSION_MINOR < 11 +using TECUDAGuard = at::hip::HIPGuardMasqueradingAsCUDA; +#else +using TECUDAGuard = at::cuda::CUDAGuard; +#endif + namespace transformer_engine::pytorch { std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, @@ -69,7 +76,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Ensure that cuDNN handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(input.cast().device()); + TECUDAGuard device_guard(input.cast().device()); // Input and param tensors auto none = py::none(); @@ -319,7 +326,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Ensure that cuDNN handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(input.cast().device()); + TECUDAGuard device_guard(input.cast().device()); // Input and param tensors auto none = py::none(); diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index c881bbe08..341e407e9 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -731,8 +731,8 @@ def checkpoint( if isinstance(function, TransformerEngineBaseModule): # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need # to scatter/gather activations that we will recompute anyway. - setattr(function, "fsdp_wrapped", False) - setattr(function, "fsdp_group", None) + function.fast_setattr("fsdp_wrapped", False) + function.fast_setattr("fsdp_group", None) # Otherwise discard unused te.utils.checkpoint.checkpoint() arguments # and execute TE's own checkpointing @@ -2026,7 +2026,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: ) root_state = _get_module_fsdp_state(fsdp_root) assert root_state is not None, "Root module does not have a valid _FSDPState." - setattr(fsdp_root.module, "fsdp_group", root_state.process_group) + fsdp_root.module.fast_setattr("fsdp_group", root_state.process_group) # Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root) @@ -2037,7 +2037,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: "TE modules with primary weights in FP8 cannot be FSDP-wrapped. " "Please initialize your model without the te.quantized_model_init(...) context." ) - setattr(fsdp_module.module, "fsdp_group", state.process_group) + fsdp_module.module.fast_setattr("fsdp_group", state.process_group) class FullyShardedDataParallel(FSDP): diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index e8cef56bd..4d52d9b92 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -49,17 +49,35 @@ def wrapper(*args, **kwargs): # Decorator to disable Torch Dynamo # See: https://github.com/NVIDIA/TransformerEngine/issues/308 -no_torch_dynamo = lambda recursive=True: lambda func: func if torch.__version__ >= "2": import torch._dynamo - if torch.__version__ >= "2.1": - no_torch_dynamo = lambda recursive=True: lambda f: ( - f if is_in_onnx_export_mode() else torch._dynamo.disable(f, recursive=recursive) - ) - else: - # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True - no_torch_dynamo = lambda recursive=True: torch._dynamo.disable + def no_torch_dynamo(recursive=True): + """Decorator to disable Torch Dynamo, except during ONNX export.""" + + def decorator(f): + # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True + disabled_f = ( + torch._dynamo.disable(f, recursive=recursive) + if torch.__version__ >= "2.1" + else torch._dynamo.disable(f) + ) + + @wraps(f) + def wrapper(*args, **kwargs): + if is_in_onnx_export_mode(): + return f(*args, **kwargs) + return disabled_f(*args, **kwargs) + + return wrapper + + return decorator + +else: + # Fallback for PyTorch < 2.0: no-op decorator + def no_torch_dynamo(recursive=True): # pylint: disable=unused-argument + """No-op decorator for PyTorch < 2.0.""" + return lambda func: func def set_jit_fusion_options() -> None: diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 2d8563729..d4feb7354 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -12,9 +12,8 @@ import warnings from enum import Enum from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Union from contextlib import contextmanager -import logging from types import MethodType from itertools import chain @@ -52,7 +51,15 @@ from ..triton_kernels.cast import te_quantize_triton from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage -from ..utils import get_device_compute_capability, is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype +from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage +from ..utils import ( + get_device_compute_capability, + is_non_tn_fp8_gemm_supported, + torch_get_autocast_gpu_dtype, + get_nvtx_range_context, + nvtx_range_push, + nvtx_range_pop, +) from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState @@ -674,10 +681,10 @@ def fill_userbuffers_buffer_for_all_gather( class TransformerEngineBaseModule(torch.nn.Module, ABC): """Base TE module.""" - def __init__(self) -> None: + def __init__(self, name: Optional[str] = None) -> None: super().__init__() assert torch.cuda.is_available(), "TransformerEngine needs CUDA." - self.name = None + self.name = name self.next_iter_when_debug_should_be_run = 0 self.fp8_initialized = False self.fp8 = False @@ -704,26 +711,22 @@ def __init__(self) -> None: if not TEDebugState.debug_enabled: TEDebugState.initialize() + self._validate_name() - # Names of attributes that can be set quickly (see __setattr__ - # method) - _fast_setattr_names: Set[str] = { - "activation_dtype", - "fp8", - "fp8_initialized", - "fp8_calibration", - "fp8_parameters", - } - - def __setattr__(self, name: str, value: Any) -> None: - if name in TransformerEngineBaseModule._fast_setattr_names: - # torch.nn.Module has a custom __setattr__ that handles - # modules, parameters, and buffers. This is unnecessary - # overhead when setting plain attrs. - self.__dict__[name] = value - else: - # Default case - super().__setattr__(name, value) + def fast_setattr(self, name: str, value: Any) -> None: + """ + Fast version of the Module's set attribute function. + Should be used for regular attributes, but not properties nor parameters/buffers. + """ + self.__dict__[name] = value + + def module_setattr(self, name: str, value: Any) -> None: + """ + Regular version of the Module's set attribute function. + Should be used only when the fast version cannot be used - for the properties, + parameters and buffers. + """ + super().__setattr__(name, value) def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: """ @@ -844,7 +847,7 @@ def init_fp8_meta_tensors(self, recipe: Recipe) -> None: self.set_meta_tensor(True, recipe) self.set_meta_tensor(False, recipe) - self.fp8_meta_tensors_initialized = True + self.fast_setattr("fp8_meta_tensors_initialized", True) def get_fp8_meta_tensors(self) -> None: """Get scales and amaxes.""" @@ -1001,7 +1004,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: """Get activation data type for AMP.""" # Native AMP (`torch.autocast`) gets highest priority if torch.is_autocast_enabled(): - self.activation_dtype = torch_get_autocast_gpu_dtype() + self.fast_setattr("activation_dtype", torch_get_autocast_gpu_dtype()) return # All checks after this have already been performed once, thus skip @@ -1016,7 +1019,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: "Data types for parameters must match when outside of autocasted region. " f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" ) - self.activation_dtype = dtype + self.fast_setattr("activation_dtype", dtype) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """ @@ -1028,8 +1031,8 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N tp_group : ProcessGroup, default = `None` tensor parallel process group. """ - self.tp_group = tp_group - self.tp_group_initialized = True + self.fast_setattr("tp_group", tp_group) + self.fast_setattr("tp_group_initialized", True) def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: """returns the FP8 weights.""" @@ -1045,53 +1048,56 @@ def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: # assume FP8 execution. def init_fp8_metadata(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" - _original_recipe = self.fp8_meta.get("recipe", None) + meta = self.fp8_meta - self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() - self.fp8 = FP8GlobalStateManager.is_fp8_enabled() - self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() - fp8_enabled = self.fp8 or self.fp8_calibration - self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration + fp8 = FP8GlobalStateManager.is_fp8_enabled() + fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() + fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + self.fast_setattr("fp8_parameters", fp8_parameters) + self.fast_setattr("fp8", fp8) + self.fast_setattr("fp8_calibration", fp8_calibration) + fp8_enabled = fp8 or fp8_calibration + meta["fp8_checkpoint"] = fp8_enabled - if IS_HIP_EXTENSION and not FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 and hasattr(self, 'use_fsdp2') and self.use_fsdp2: - FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 = True + if IS_HIP_EXTENSION and not FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 and hasattr(self, 'use_fsdp2') and self.use_fsdp2: + FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 = True - if self.fp8_parameters or fp8_enabled: - if ( - self.fp8_initialized - and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"] - ): + _original_recipe = None + + if fp8_parameters or fp8_enabled: + _original_recipe = meta.get("recipe", None) + if self.fp8_initialized and FP8GlobalStateManager.get_fp8_recipe() == _original_recipe: # FP8 init has already been run and recipe is the same, don't do anything. return - self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() else: # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False + self.fast_setattr("fp8_initialized", False) return - if self.fp8_parameters and not self.fp8_initialized: - self.fp8_meta["num_gemms"] = num_gemms - self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) + if fp8_parameters and not self.fp8_initialized: + meta["num_gemms"] = num_gemms + self.init_fp8_meta_tensors(meta["recipe"]) if fp8_enabled: # Set FP8 and other FP8 metadata - self.fp8_meta["num_gemms"] = num_gemms - self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + meta["num_gemms"] = num_gemms + meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() # Set FP8_MAX per tensor according to recipe - if hasattr(self.fp8_meta["recipe"], "fp8_format"): - self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd - self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd + if hasattr(meta["recipe"], "fp8_format"): + meta["fp8_max_fwd"] = meta["recipe"].fp8_format.value.max_fwd + meta["fp8_max_bwd"] = meta["recipe"].fp8_format.value.max_bwd # Allocate scales and amaxes - self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) - self.fp8_initialized = True + self.init_fp8_meta_tensors(meta["recipe"]) + self.fast_setattr("fp8_initialized", True) - self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() - if self.fp8_meta["recipe"].mxfp8(): - self.keep_fp8_weight_transpose_cache = True + meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + if meta["recipe"].mxfp8(): + self.keep_fp8_weight_transpose_cache = True - _current_recipe = self.fp8_meta["recipe"] + _current_recipe = meta["recipe"] if _original_recipe is not None and not ( issubclass(_current_recipe.__class__, _original_recipe.__class__) or issubclass(_original_recipe.__class__, _current_recipe.__class__) @@ -1104,22 +1110,19 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # Clear cached workspaces as they were created with the old recipe/quantizer type self._fp8_workspaces.clear() - @contextmanager def prepare_forward( self, inp: torch.Tensor, num_gemms: int = 1, allow_non_contiguous: bool = False, allow_different_data_and_param_types: bool = False, - ) -> Generator[torch.Tensor, None, None]: - """Checks and prep for FWD. - The context manager is needed because there isn't a way for a module to know - if it's the last FP8 module in the forward autocast. It is useful - to setup the forward aggregated amax reduction for every module - just in case. The autocast exit will pick up the most recent one. - """ - self.allow_different_data_and_param_types = allow_different_data_and_param_types - self.forwarded_at_least_once = True + ) -> torch.Tensor: + """Checks and prepares for FWD execution.""" + self.fast_setattr( + "allow_different_data_and_param_types", allow_different_data_and_param_types + ) + self.fast_setattr("forwarded_at_least_once", True) + # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) @@ -1146,13 +1149,37 @@ def prepare_forward( if self.fp8 and self.training and is_fp8_activation_recompute_enabled(): FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) - with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): - if not allow_non_contiguous and not inp.is_contiguous(): - inp = inp.contiguous() - yield inp + nvtx_range_push(self.__class__.__name__ + " forward") + if not allow_non_contiguous and not inp.is_contiguous(): + inp = inp.contiguous() + return inp - if self.fp8 and in_fp8_activation_recompute_phase(): + def end_forward(self): + """ + Required to be called at the end of the forward function to properly handle + DelayedScaling metadata handling and the NVTX ranges. + """ + delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() + if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) + nvtx_range_pop() + + @contextmanager + def prepare_forward_ctx( + self, + inp: torch.Tensor, + num_gemms: int = 1, + allow_non_contiguous: bool = False, + allow_different_data_and_param_types: bool = False, + ) -> Generator[torch.Tensor, None, None]: + """Checks and prepares for FWD execution.""" + inp = self.prepare_forward( + inp, num_gemms, allow_non_contiguous, allow_different_data_and_param_types + ) + try: + yield inp + finally: + self.end_forward() def set_nccl_overlap_warning_if_tp(self) -> None: """When using TP, the NCCL communication needs to be scheduled @@ -1409,9 +1436,9 @@ def clear(self): # Update the parameter based on its type if not is_dtensor: - setattr(self, name, param) + self.module_setattr(name, param) else: - setattr(self, name, dtensor_param) + self.module_setattr(name, dtensor_param) @abstractmethod def forward(self): @@ -1611,7 +1638,6 @@ def is_debug_iter(self) -> bool: debug = TEDebugState.debug_enabled if not debug: return False - self._validate_name() # If layer is run first time in new iteration, # we need to check if the debug should be enabled for this layer - @@ -1625,13 +1651,19 @@ def is_debug_iter(self) -> bool: debug = False else: debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run - self.debug_last_iteration = TEDebugState.get_iteration() - self.debug_enabled_in_this_iteration = debug + self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration()) + self.fast_setattr("debug_enabled_in_this_iteration", debug) else: # If this is the same iteration as previous invocation of the module, # we use the debug value from the first invocation in the iteration. debug = self.debug_enabled_in_this_iteration + self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration()) + + if self.wgrad_store is not None: + if debug and self.wgrad_store.delay_wgrad_compute(): + raise RuntimeError("Delayed wgrad compute is not supported in debug mode.") + return debug def no_debug_features_active(self, quantizers): @@ -1642,7 +1674,9 @@ def no_debug_features_active(self, quantizers): # Sometimes features inform that they will not be enabled for particular layer # for multiple next iterations. - self.next_iter_when_debug_should_be_run = next_iter_when_debug_should_be_run(quantizers) + self.fast_setattr( + "next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers) + ) if not run_current: return True @@ -1654,22 +1688,13 @@ def no_debug_features_active(self, quantizers): def _validate_name(self): """ Validate name passed to the module. - This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM. - If no name is assigned, it creates a default name with layer count as the variable. + It creates a default name with layer count as the variable + which may be changed by the user of the module. """ if self.name is not None: return - assert TEDebugState.debug_enabled - import nvdlfw_inspect.api as debug_api - - if self.name is None: - debug_api.log_message( - "Names are not provided to debug modules. ", - "Creating and using generic names. Pass names to debug modules for better" - " insight. ", - level=logging.WARNING, - ) - self.name = f"Layer_{TEDebugState.get_layer_count()}" + + self.name = f"Layer_{TEDebugState.get_layer_count()}" def _check_weight_tensor_recipe_correspondence(self) -> None: """ diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 8c6aa8bde..a268f6352 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -668,7 +668,7 @@ def __init__( save_original_input: bool = False, name: Optional[str] = None, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.num_gemms = num_gemms @@ -687,7 +687,6 @@ def __init__( ), "GroupedLinear doesn't support Userbuffer overlap." self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute) @@ -844,7 +843,8 @@ def forward( is_grad_enabled = torch.is_grad_enabled() - with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: + inp = self.prepare_forward(inp, num_gemms=self.num_gemms) + try: weight_tensors = self._get_weight_tensors() bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] @@ -900,6 +900,9 @@ def forward( ) out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) + finally: + self.end_forward() + if self.return_bias: return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7347fc138..30dffccd7 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1207,11 +1207,11 @@ def __init__( ub_name: Optional[str] = None, delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, - name: str = None, + name: Optional[str] = None, keep_fp8_weight_transpose_cache: bool = True, use_fsdp2: bool = False ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.in_features = in_features @@ -1231,7 +1231,7 @@ def __init__( self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.name = name - self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache if IS_HIP_EXTENSION else True + self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache if IS_HIP_EXTENSION else True self.use_fsdp2 = use_fsdp2 if IS_HIP_EXTENSION else False if tp_group is None: @@ -1561,10 +1561,11 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( + inp = self.prepare_forward( inp, allow_non_contiguous=False # removed .contiguous from inside the layer - ) as inp: + ) + try: # Get concatenated weight and bias tensors weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() @@ -1645,6 +1646,9 @@ def forward( non_tensor_args, ) + finally: + self.end_forward() + if self.return_layernorm_output: out, ln_out = out diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 49667b633..aa95b9521 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1853,7 +1853,7 @@ def __init__( zero_centered_gamma: bool = False, device: Union[torch.device, str] = "cuda", ub_overlap_ag: bool = False, - name: str = None, + name: Optional[str] = None, ub_overlap_rs: bool = False, ub_overlap_rs_dgrad: bool = False, ub_bulk_dgrad: bool = False, @@ -1864,7 +1864,7 @@ def __init__( keep_fp8_weight_transpose_cache: bool = True, use_fsdp2: bool = False, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.fuse_wgrad_accumulation = fuse_wgrad_accumulation @@ -1897,7 +1897,6 @@ def __init__( for use_fp8 in [False, True] ) ) - self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) @@ -2117,8 +2116,9 @@ def forward( if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): fp8_output = True - with self.prepare_forward(inp, num_gemms=2) as inp: + inp = self.prepare_forward(inp, num_gemms=2) + try: quantizers = ( self._get_quantizers(fp8_output, is_grad_enabled) if not debug @@ -2156,9 +2156,9 @@ def forward( fc2_weight = fc2_weight.dequantize() # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode - if ( not IS_HIP_EXTENSION - and self.bias_gelu_nvfusion and not use_reentrant_activation_recompute() ): - self.bias_gelu_nvfusion = False + if (not IS_HIP_EXTENSION + and self.bias_gelu_nvfusion and not use_reentrant_activation_recompute()): + self.fast_setattr("bias_gelu_nvfusion", False) if is_grad_enabled: fwd_fn = _LayerNormMLP.apply @@ -2230,6 +2230,9 @@ def forward( non_tensor_args, ) + finally: + self.end_forward() + if self.return_layernorm_output: out, ln_out = out diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 01d07d91a..5d0ab12aa 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -441,8 +441,8 @@ def forward( # weights if weights are externally touched outside this module ctx.weight_object = weight - if cpu_offloading: mark_not_offload(weight, weightmat, bias) + # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, @@ -1136,7 +1136,7 @@ def __init__( keep_fp8_weight_transpose_cache: bool = True, use_fsdp2: bool = False ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.in_features = in_features @@ -1149,7 +1149,6 @@ def __init__( self.rng_tracker_name = rng_tracker_name self.symmetric_ar_type = symmetric_ar_type self.save_original_input = save_original_input - self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache if IS_HIP_EXTENSION else True @@ -1435,11 +1434,8 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( - inp, - allow_non_contiguous=isinstance(inp, QuantizedTensor), - ) as inp: - + inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) + try: weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() quantizers = ( @@ -1512,6 +1508,8 @@ def forward( bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, non_tensor_args, ) + finally: + self.end_forward() if self.gemm_bias_unfused_add: out = out + cast_if_needed(bias_tensor, self.activation_dtype) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 4f131c3c0..6f5f6a507 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -14,7 +14,6 @@ from transformer_engine.pytorch.torch_version import torch_version from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm -from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.jit import ( @@ -40,6 +39,8 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION +import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils + warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") @@ -153,11 +154,21 @@ class TransformerLayer(torch.nn.Module): distinguishes them based on :attr:`self_attn_mask_type` or :attr:`enc_dec_attn_mask_type`. Similar to :attr:`self_attn_mask_type`, :attr:`window_size` can be overridden by :attr:`window_size` in :meth:`forward` as well. + bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `self_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. enc_dec_attn_mask_type : {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, default = "no_mask" type of attention mask passed into softmax operation for decoder. enc_dec_window_size : Optional[Tuple[int, int]], default = None sliding window size for local attention in decoder. + enc_dec_bottom_right_diagonal: Optional[bool], default = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the decoder. + If `None`, it will be set to `False` for `enc_dec_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. zero_centered_gamma : bool, default = False if set to ``True``, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to @@ -306,7 +317,9 @@ def __init__( kv_channels: Optional[int] = None, self_attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, enc_dec_attn_mask_type: str = "no_mask", + enc_dec_bottom_right_diagonal: Optional[bool] = None, enc_dec_window_size: Optional[Tuple[int, int]] = None, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, @@ -348,8 +361,10 @@ def __init__( self.self_attn_mask_type = self_attn_mask_type self.window_size = window_size + self.bottom_right_diagonal = bottom_right_diagonal self.enc_dec_attn_mask_type = enc_dec_attn_mask_type self.enc_dec_window_size = enc_dec_window_size + self.enc_dec_bottom_right_diagonal = enc_dec_bottom_right_diagonal params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad @@ -402,6 +417,7 @@ def __init__( self.softmax_type = softmax_type self.name = name + TransformerEngineBaseModule._validate_name(self) attention_args = ( hidden_size, @@ -450,7 +466,7 @@ def __init__( qk_norm_type=qk_norm_type, qk_norm_eps=qk_norm_eps, qk_norm_before_rope=qk_norm_before_rope, - name=name + ".self_attention" if name is not None else None, + name=self.name + ".self_attention" if self.name is not None else None, ) if layer_type == "decoder": @@ -467,7 +483,7 @@ def __init__( qk_norm_type=qk_norm_type, qk_norm_eps=qk_norm_eps, qk_norm_before_rope=qk_norm_before_rope, - name=name + ".inter_attention" if name is not None else None, + name=self.name + ".inter_attention" if self.name is not None else None, ) # LayerNorm -> activation(Linear + Bias) -> Linear @@ -503,7 +519,7 @@ def __init__( activation_params=activation_params, normalization=normalization, device=device, - name=name + ".layernorm_mlp" if name is not None else None, + name=self.name + ".layernorm_mlp" if self.name is not None else None, ) self.hidden_dropout = hidden_dropout @@ -610,10 +626,12 @@ def forward( attention_mask: Optional[torch.Tensor] = None, self_attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, + bottom_right_diagonal: Optional[bool] = None, encoder_output: Optional[torch.Tensor] = None, enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, enc_dec_attn_mask_type: Optional[str] = None, enc_dec_window_size: Optional[Tuple[int, int]] = None, + enc_dec_bottom_right_diagonal: Optional[bool] = None, is_first_microbatch: Optional[bool] = None, checkpoint_core_attention: bool = False, inference_params: Optional[InferenceParams] = None, @@ -658,6 +676,11 @@ def forward( causal masks are aligned to the bottom right corner. window_size: Optional[Tuple[int, int]], default = None Sliding window size for local attention in encoder. + bottom_right_diagonal: Optional[bool] = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the encoder. + If `None`, it will be set to `False` for `self_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. encoder_output : Optional[torch.Tensor], default = None Output of the encoder block to be fed into the decoder block if using :attr:`layer_type` = ``"decoder"``. @@ -674,6 +697,11 @@ def forward( Type of attention mask passed into softmax operation for decoder. enc_dec_window_size: Optional[Tuple[int, int]], default = None Sliding window size for local attention in decoder. + enc_dec_bottom_right_diagonal: Optional[bool] = `None` + Align sliding window and ALiBi diagonal to the top left (`False`) + or bottom right (`True`) corner of the softmax matrix in the decoder. + If `None`, it will be set to `False` for `enc_dec_attn_mask_type` = + {`causal`, `padding_causal`} and `True` for other mask types. is_first_microbatch : {True, False, None}, default = None During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split @@ -740,10 +768,35 @@ def forward( self_attn_mask_type = self.self_attn_mask_type if window_size is None: window_size = self.window_size + window_size = dpa_utils.check_set_window_size(self_attn_mask_type, window_size) + if enc_dec_attn_mask_type is None: enc_dec_attn_mask_type = self.enc_dec_attn_mask_type if enc_dec_window_size is None: enc_dec_window_size = self.enc_dec_window_size + enc_dec_window_size = dpa_utils.check_set_window_size( + enc_dec_attn_mask_type, enc_dec_window_size + ) + + if bottom_right_diagonal is None: + bottom_right_diagonal = self.bottom_right_diagonal + if self_attn_mask_type in {"causal", "padding_causal"}: + bottom_right_diagonal = False + if bottom_right_diagonal is None or self_attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + bottom_right_diagonal = True + + if enc_dec_bottom_right_diagonal is None: + enc_dec_bottom_right_diagonal = self.enc_dec_bottom_right_diagonal + if enc_dec_attn_mask_type in {"causal", "padding_causal"}: + enc_dec_bottom_right_diagonal = False + if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in { + "causal_bottom_right", + "padding_causal_bottom_right", + }: + enc_dec_bottom_right_diagonal = True assert ( self_attn_mask_type in AttnMaskTypes @@ -772,9 +825,6 @@ def forward( enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask)) ), "Encoder-decoder attention mask must be boolean tensor(s)" - if TEDebugState.debug_enabled: - TransformerEngineBaseModule._validate_name(self) - # For AMP if torch.is_autocast_enabled(): hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype()) @@ -785,6 +835,7 @@ def forward( attention_mask=attention_mask, attn_mask_type=self_attn_mask_type, window_size=window_size, + bottom_right_diagonal=bottom_right_diagonal, inference_params=inference_params, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, @@ -820,6 +871,7 @@ def forward( attention_mask=enc_dec_attn_mask, attn_mask_type=enc_dec_attn_mask_type, window_size=enc_dec_window_size, + bottom_right_diagonal=enc_dec_bottom_right_diagonal, encoder_output=encoder_output, inference_params=inference_params, is_first_microbatch=is_first_microbatch,