From cfd75e8610a645d5a136659adcd78c29c0ad4018 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Fri, 17 Apr 2026 05:36:01 +0000 Subject: [PATCH 1/7] Fix cpplint violations in common and PyTorch extension code transformer_engine/common/amd_detail/hip_float8.h -Host constructor: multi-statement if/else now uses braces (readability/braces). transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh -Include ; typedef for gfx950 vector type uses int16_t instead of short (runtime/int). transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp -dladdr: avoid ill-formed function-pointer-to-void* cast via a small union (readability/casting / portable POSIX). -get_ck_log_stream: else branch restructured with nested if so else/brace pairing satisfies cpplint (readability/braces). transformer_engine/common/fused_attn_rocm/fused_attn.cpp -check_set_window_size: replace std::make_pair(...) with std::pair(...) (build/explicit_make_pair). -Replace alternative tokens `or` with || (readability/alt_tokens). -log_fused_attn_config: same for sliding-window condition. transformer_engine/common/gemm/rocm_gemm.cu -ObjCache / NameMapper: mark single-argument constructors explicit (runtime/explicit). -HIPBLASLT scaling_mode check: split #if/#else branches so each if has its own braced body; use static_cast instead of C-style cast (readability/braces, readability/casting). -Debug logging: (int) casts -> static_cast for hipDataType fields (readability/casting). -ServiceStreamKey: use std::uint64_t alias instead of unsigned long long (runtime/int). transformer_engine/common/normalization/common.cpp -getNormalizationPlan: after optional CUDNN plan, use if (!plan) { ... } for TE plans instead of } else #endif if (readability/braces across preprocessor). transformer_engine/common/normalization/layernorm/ln_api.cpp -Forward/backward: default norm_backend to Te; optional CUDNN path only under #ifndef __HIP_PLATFORM_AMD__; set is_aligned only when backend is Te, so preprocessor does not split if/else from its braces (readability/braces). transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp -Same pattern as ln_api for forward (including HIP constexpr gamma_in_weight_dtype) and backward cudnn vs Te (readability/braces). transformer_engine/common/permutation/permutation.cu -MoE unpermute kernel: functional-style float(...) casts replaced with static_cast(...) (readability/casting). transformer_engine/common/util/logging.h -NVTE_CHECK_HIPBLASLT macro: std::to_string((int)status) -> std::to_string(static_cast(status)) (readability/casting). transformer_engine/pytorch/csrc/extensions/gemm.cpp -Comm overlap RS path: HIP p2p vs split_overlap_rs restructured with proper #else for non-HIP so } else #endif { does not confuse brace rules (readability/braces). --- commit.txt | 57 +++++++++++++++++++ .../common/amd_detail/hip_float8.h | 6 +- .../common/cast/mxfp8/rocm_quantize_mxfp8.cuh | 4 +- .../ck_fused_attn/src/ck_fused_attn_utils.cpp | 15 +++-- .../common/fused_attn_rocm/fused_attn.cpp | 10 ++-- transformer_engine/common/gemm/rocm_gemm.cu | 25 ++++---- .../common/normalization/common.cpp | 28 ++++----- .../common/normalization/layernorm/ln_api.cpp | 14 ++--- .../normalization/rmsnorm/rmsnorm_api.cpp | 14 ++--- .../common/permutation/permutation.cu | 8 +-- transformer_engine/common/util/logging.h | 2 +- .../pytorch/csrc/extensions/gemm.cpp | 14 +++-- 12 files changed, 137 insertions(+), 60 deletions(-) create mode 100644 commit.txt diff --git a/commit.txt b/commit.txt new file mode 100644 index 000000000..8c8064588 --- /dev/null +++ b/commit.txt @@ -0,0 +1,57 @@ +Fix cpplint violations in common and PyTorch extension code + +transformer_engine/common/amd_detail/hip_float8.h + -Host constructor: multi-statement if/else now uses braces (readability/braces). + +transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh + -Include ; typedef for gfx950 vector type uses int16_t instead of + short (runtime/int). + +transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp + -dladdr: avoid ill-formed function-pointer-to-void* cast via a small union + (readability/casting / portable POSIX). + -get_ck_log_stream: else branch restructured with nested if so else/brace + pairing satisfies cpplint (readability/braces). + +transformer_engine/common/fused_attn_rocm/fused_attn.cpp + -check_set_window_size: replace std::make_pair(...) with + std::pair(...) (build/explicit_make_pair). + -Replace alternative tokens `or` with || (readability/alt_tokens). + -log_fused_attn_config: same for sliding-window condition. + +transformer_engine/common/gemm/rocm_gemm.cu + -ObjCache / NameMapper: mark single-argument constructors explicit + (runtime/explicit). + -HIPBLASLT scaling_mode check: split #if/#else branches so each if has its + own braced body; use static_cast instead of C-style cast + (readability/braces, readability/casting). + -Debug logging: (int) casts -> static_cast for hipDataType fields + (readability/casting). + -ServiceStreamKey: use std::uint64_t alias instead of unsigned long long + (runtime/int). + +transformer_engine/common/normalization/common.cpp + -getNormalizationPlan: after optional CUDNN plan, use if (!plan) { ... } for + TE plans instead of } else #endif if (readability/braces across preprocessor). + +transformer_engine/common/normalization/layernorm/ln_api.cpp + -Forward/backward: default norm_backend to Te; optional CUDNN path only under + #ifndef __HIP_PLATFORM_AMD__; set is_aligned only when backend is Te, so + preprocessor does not split if/else from its braces (readability/braces). + +transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp + -Same pattern as ln_api for forward (including HIP constexpr + gamma_in_weight_dtype) and backward cudnn vs Te (readability/braces). + +transformer_engine/common/permutation/permutation.cu + -MoE unpermute kernel: functional-style float(...) casts replaced with + static_cast(...) (readability/casting). + +transformer_engine/common/util/logging.h + -NVTE_CHECK_HIPBLASLT macro: std::to_string((int)status) -> + std::to_string(static_cast(status)) (readability/casting). + +transformer_engine/pytorch/csrc/extensions/gemm.cpp + -Comm overlap RS path: HIP p2p vs split_overlap_rs restructured with proper + #else for non-HIP so } else #endif { does not confuse brace rules + (readability/braces). \ No newline at end of file diff --git a/transformer_engine/common/amd_detail/hip_float8.h b/transformer_engine/common/amd_detail/hip_float8.h index 0e4de3294..d1b687130 100644 --- a/transformer_engine/common/amd_detail/hip_float8.h +++ b/transformer_engine/common/amd_detail/hip_float8.h @@ -61,7 +61,11 @@ union _te_hip_fp8 { __device__ operator float() const; __host__ _te_hip_fp8(const float& v) { - if (te_fp8_fnuz()) fnuz=v; else ocp=v; + if (te_fp8_fnuz()) { + fnuz = v; + } else { + ocp = v; + } } __device__ _te_hip_fp8(const float& v); }; diff --git a/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh index 40913cc03..cfa9bf312 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh @@ -7,6 +7,8 @@ // drop-in replacement for rocm quantize_mxfp8 kernels //#include "hip/hip_runtime.h" //dummy include to prevent hipification adding this header +#include + constexpr size_t MXFP8_CHUNK_DIM_Y = 64; constexpr size_t MXFP8_CHUNK_DIM_X = 64; constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; @@ -15,7 +17,7 @@ constexpr size_t ELEMS_PER_THREAD = 16; constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported #if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ -typedef short mxfp8_v2i16_t __attribute__((ext_vector_type(2))); +typedef int16_t mxfp8_v2i16_t __attribute__((ext_vector_type(2))); #endif template (fn) (not ISO C++). + union { + void (*fn)(); + void *addr; + } sym{}; + sym.fn = set_aiter_asm_dir; + dladdr(sym.addr, &info); const char* log_ck_config_env = std::getenv("NVTE_LOG_CK_CONFIG"); bool log_ck_config = log_ck_config_env && std::string(log_ck_config_env) == "1"; // Check if user has set AITER_ASM_DIR, if yes, skip auto setting and log @@ -130,9 +136,10 @@ std::ostream* get_ck_log_stream() { if (!log_dir_str.empty() && log_dir_str != "0") { if (log_dir_str == "1") { log_stream = &std::cout; - } - else if (open_ck_fused_attn_log_file(log_file, "ck_fused_attn", log_dir_str)) { - log_stream = &log_file; + } else { + if (open_ck_fused_attn_log_file(log_file, "ck_fused_attn", log_dir_str)) { + log_stream = &log_file; + } } } } diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index e787b31c8..ebfc20a8a 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -146,26 +146,26 @@ std::pair check_set_window_size(NVTE_Mask_Type attn_mask_type, nvte_log_fused_attn_config = true; } if(attn_mask_type==NVTE_CAUSAL_MASK || attn_mask_type==NVTE_PADDING_CAUSAL_MASK || attn_mask_type==NVTE_CAUSAL_BOTTOM_RIGHT_MASK || attn_mask_type==NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK){ - if(window_size==std::make_pair(-1, -1) || (window_size.first >=0 && window_size.second!=0)){ + if(window_size==std::pair(-1, -1) || (window_size.first >=0 && window_size.second!=0)){ //TODO: better INFO logging if(nvte_log_fused_attn_config){ std::cout<<"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type="<(-1, 0) && (window_size.first < 0 || window_size.second != 0)){ + }else if( window_size!=std::pair(-1, 0) && (window_size.first < 0 || window_size.second != 0)){ NVTE_ERROR("window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + std::to_string(attn_mask_type)); } }else if(attn_mask_type==NVTE_NO_MASK || attn_mask_type==NVTE_PADDING_MASK){ //no_mask and padding mask - if(window_size==std::make_pair(-1, 0)){ + if(window_size==std::pair(-1, 0)){ //TODO: better INFO logging if(nvte_log_fused_attn_config){ std::cout<<"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type="<(-1, -1) && (window_size.first < 0 or window_size.second < 0)){ + }else if(window_size!=std::pair(-1, -1) && (window_size.first < 0 || window_size.second < 0)){ NVTE_ERROR("window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + std::to_string(attn_mask_type)); } }else{ @@ -267,7 +267,7 @@ void log_fused_attn_config( std::cout<<"d_qk: "<0 or window_size_right >0){ + if(window_size_left >0 || window_size_right >0){ std::cout<<", (sliding window)"; } std::cout< class NameMapper { public: - NameMapper(const std::unordered_map& name_map): map(name_map) {} + explicit NameMapper(const std::unordered_map& name_map): map(name_map) {} const std::string_view &getName(const T &val) { return map.at(val); } @@ -769,14 +769,17 @@ protected: } #if HIPBLASLT_VERSION_MAJOR > 0 || HIPBLASLT_VERSION_MINOR >= 15 - if (cfg.scaling_mode < 0 || cfg.scaling_mode >= (int)HIPBLASLT_MATMUL_MATRIX_SCALE_END) + if (cfg.scaling_mode < 0 || + cfg.scaling_mode >= static_cast(HIPBLASLT_MATMUL_MATRIX_SCALE_END)) { + std::cout << "[WARNING] Unsupported scaling mode at " << line << "\n"; + continue; + } #else - if (cfg.scaling_mode != 0) -#endif - { + if (cfg.scaling_mode != 0) { std::cout << "[WARNING] Unsupported scaling mode at " << line << "\n"; continue; } +#endif auto fp8_filter = te_fp8_fnuz() ? [](const hipDataType& val) @@ -966,10 +969,10 @@ void hipblaslt_gemm(const Tensor *inputA, std::cout << "m=" << m << " k=" << k << " n=" << n << " transa=" << (param.transA == HIPBLAS_OP_T ? "T" : "N") << " transb=" << (param.transB == HIPBLAS_OP_T ? "T" : "N") - << " A_type=" << (int)(param.Atype) - << " B_type=" << (int)(param.Btype) - << " D_type=" << (int)outputD->data.dtype - << " bias_type=" << (int)inputBias->data.dtype + << " A_type=" << static_cast(param.Atype) + << " B_type=" << static_cast(param.Btype) + << " D_type=" << static_cast(outputD->data.dtype) + << " bias_type=" << static_cast(inputBias->data.dtype) << " grad=" << grad << " bias=" << (inputBias->data.dptr != nullptr) << " gelu=" << (outputPreGelu->data.dptr != nullptr) @@ -1386,7 +1389,7 @@ void hipblaslt_gemm(const Tensor *inputA, } -typedef unsigned long long ServiceStreamKey; +using ServiceStreamKey = std::uint64_t; ServiceStreamKey make_service_stream_key(const int device_id, const int cu_count) { return (static_cast(device_id) << 32) | static_cast(cu_count); diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index d6aa55b37..ee5876b5a 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -544,24 +544,26 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan( plan = std::make_unique(NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, zero_centered_gamma, mode, training); - } else + } #endif - if (NormStage == NVTE_Norm_Stage::Forward) { - plan = std::make_unique>( - NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, - zero_centered_gamma, is_tuned + if (!plan) { + if (NormStage == NVTE_Norm_Stage::Forward) { + plan = std::make_unique>( + NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, + zero_centered_gamma, is_tuned #ifdef __HIP_PLATFORM_AMD__ - , mode, training + , mode, training #endif - ); - } else { - plan = std::make_unique>( - NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, - zero_centered_gamma, is_tuned + ); + } else { + plan = std::make_unique>( + NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, + zero_centered_gamma, is_tuned #ifdef __HIP_PLATFORM_AMD__ - , mode, training + , mode, training #endif - ); + ); + } } normalizationPlanMap.insert({key, std::move(plan)}); return normalizationPlanMap[key].get(); diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index b621727c4..855aa1838 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -68,7 +68,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size CheckOutputTensor(*rsigma, "rsigma"); } - NVTE_Norm_Backend norm_backend; + NVTE_Norm_Backend norm_backend = NVTE_Norm_Backend::Te; bool is_aligned = true; #ifndef __HIP_PLATFORM_AMD__ bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode); @@ -85,10 +85,9 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); - } else + } #endif //__HIP_PLATFORM_AMD__ - { - norm_backend = NVTE_Norm_Backend::Te; + if (norm_backend == NVTE_Norm_Backend::Te) { is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, beta.data.dptr, mu->data.dptr, rsigma->data.dptr); } @@ -169,7 +168,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te CheckOutputTensor(*dbeta, "dbeta"); } - NVTE_Norm_Backend norm_backend; + NVTE_Norm_Backend norm_backend = NVTE_Norm_Backend::Te; bool is_aligned = true; bool gamma_in_weight_dtype = false; #ifndef __HIP_PLATFORM_AMD__ @@ -177,10 +176,9 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); - } else + } #endif - { - norm_backend = NVTE_Norm_Backend::Te; + if (norm_backend == NVTE_Norm_Backend::Te) { is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr, dx->data.dptr, dz.data.dptr, dbeta->data.dptr, dgamma->data.dptr); } diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index bd085453b..4b133fd5c 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -54,7 +54,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens CheckOutputTensor(*rsigma, "rsigma"); } - NVTE_Norm_Backend norm_backend; + NVTE_Norm_Backend norm_backend = NVTE_Norm_Backend::Te; bool is_aligned = true; #ifndef __HIP_PLATFORM_AMD__ bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode); @@ -76,10 +76,9 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); - } else + } #endif - { - norm_backend = NVTE_Norm_Backend::Te; + if (norm_backend == NVTE_Norm_Backend::Te) { is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, rsigma->data.dptr); } @@ -148,7 +147,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const CheckOutputTensor(*dgamma, "dgamma"); } - NVTE_Norm_Backend norm_backend; + NVTE_Norm_Backend norm_backend = NVTE_Norm_Backend::Te; bool is_aligned = true; bool gamma_in_weight_dtype = false; #ifndef __HIP_PLATFORM_AMD__ @@ -156,10 +155,9 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); - } else + } #endif - { - norm_backend = NVTE_Norm_Backend::Te; + if (norm_backend == NVTE_Norm_Backend::Te) { is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr, dz.data.dptr, dgamma->data.dptr); } diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index 5906e4fff..55ba66aee 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -81,12 +81,12 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const #endif for (int e = 0; e < kElementsPerAccess; e++) { - frag_sum[e] = float(TCompute(frag_load_store_ptr[e])); + frag_sum[e] = static_cast(TCompute(frag_load_store_ptr[e])); } if (hasProb) { for (int e = 0; e < kElementsPerAccess; e++) { - frag_sum[e] = frag_sum[e] * float(s_prob[0]); + frag_sum[e] = frag_sum[e] * static_cast(s_prob[0]); } } } else { @@ -120,7 +120,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const } for (int e = 0; e < kElementsPerAccess; e++) { - frag_sum[e] += float(frag_elem[e]); + frag_sum[e] += static_cast(frag_elem[e]); } } @@ -129,7 +129,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const for (int e = 0; e < kElementsPerAccess; e++) { if constexpr ((std::is_same_v || std::is_same_v) && (!hasProb)) { - frag_sum[e] = frag_sum[e] / float(TCompute(topK)); + frag_sum[e] = frag_sum[e] / static_cast(TCompute(topK)); } frag_load_store_ptr[e] = T(TCompute(frag_sum[e])); } diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index a70ae4398..0cb22af59 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -68,7 +68,7 @@ const hipblasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \ if (status_NVTE_CHECK_CUBLAS != CUBLAS_STATUS_SUCCESS) { \ NVTE_ERROR("HIPBLASLT Error: ", \ - std::to_string((int)status_NVTE_CHECK_CUBLAS)); \ + std::to_string(static_cast(status_NVTE_CHECK_CUBLAS))); \ } \ } while (false) #else //cublas diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 941b88e36..b6b630523 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -325,16 +325,22 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); - } else -#endif - { + } else { + NVTE_SCOPED_GIL_RELEASE({ + comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, + bias_tensor, te_pre_gelu_out, te_workspace, grad, + accumulate, use_split_accumulator, extra_output_tensor, + main_stream); + }); + } +#else NVTE_SCOPED_GIL_RELEASE({ comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); - } +#endif } } } else { From 4ea0d55028758b5fd90299b1c7d5759f000b1c3a Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 21 Apr 2026 05:55:47 +0000 Subject: [PATCH 2/7] Addressed review --- commit.txt | 57 ------------------------------------------------------ 1 file changed, 57 deletions(-) delete mode 100644 commit.txt diff --git a/commit.txt b/commit.txt deleted file mode 100644 index 8c8064588..000000000 --- a/commit.txt +++ /dev/null @@ -1,57 +0,0 @@ -Fix cpplint violations in common and PyTorch extension code - -transformer_engine/common/amd_detail/hip_float8.h - -Host constructor: multi-statement if/else now uses braces (readability/braces). - -transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh - -Include ; typedef for gfx950 vector type uses int16_t instead of - short (runtime/int). - -transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp - -dladdr: avoid ill-formed function-pointer-to-void* cast via a small union - (readability/casting / portable POSIX). - -get_ck_log_stream: else branch restructured with nested if so else/brace - pairing satisfies cpplint (readability/braces). - -transformer_engine/common/fused_attn_rocm/fused_attn.cpp - -check_set_window_size: replace std::make_pair(...) with - std::pair(...) (build/explicit_make_pair). - -Replace alternative tokens `or` with || (readability/alt_tokens). - -log_fused_attn_config: same for sliding-window condition. - -transformer_engine/common/gemm/rocm_gemm.cu - -ObjCache / NameMapper: mark single-argument constructors explicit - (runtime/explicit). - -HIPBLASLT scaling_mode check: split #if/#else branches so each if has its - own braced body; use static_cast instead of C-style cast - (readability/braces, readability/casting). - -Debug logging: (int) casts -> static_cast for hipDataType fields - (readability/casting). - -ServiceStreamKey: use std::uint64_t alias instead of unsigned long long - (runtime/int). - -transformer_engine/common/normalization/common.cpp - -getNormalizationPlan: after optional CUDNN plan, use if (!plan) { ... } for - TE plans instead of } else #endif if (readability/braces across preprocessor). - -transformer_engine/common/normalization/layernorm/ln_api.cpp - -Forward/backward: default norm_backend to Te; optional CUDNN path only under - #ifndef __HIP_PLATFORM_AMD__; set is_aligned only when backend is Te, so - preprocessor does not split if/else from its braces (readability/braces). - -transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp - -Same pattern as ln_api for forward (including HIP constexpr - gamma_in_weight_dtype) and backward cudnn vs Te (readability/braces). - -transformer_engine/common/permutation/permutation.cu - -MoE unpermute kernel: functional-style float(...) casts replaced with - static_cast(...) (readability/casting). - -transformer_engine/common/util/logging.h - -NVTE_CHECK_HIPBLASLT macro: std::to_string((int)status) -> - std::to_string(static_cast(status)) (readability/casting). - -transformer_engine/pytorch/csrc/extensions/gemm.cpp - -Comm overlap RS path: HIP p2p vs split_overlap_rs restructured with proper - #else for non-HIP so } else #endif { does not confuse brace rules - (readability/braces). \ No newline at end of file From a87ccf4903bd2039ed4a180acd405693256ab759 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 21 Apr 2026 07:20:24 +0000 Subject: [PATCH 3/7] JAX/Python pylint - common: trailing space, encoding, docstrings, wheel file handle naming - recipe: Format helper docstrings, whitespace - jax/util: PEP8, module docstring, subprocess check, def over lambdas - jax/setup: group build_tools.hipify imports before pybind11 - jax/quantize/helper: tejax + CUDA/ROCm helpers, no-else-return - cpp_extensions: attention/normalization lazy SdyShardingRule; base is_hip_extension(); gemm conditional cGEMM imports + stubs - pylintrc: align disables (e.g. wrong-import-position) with CI. --- pylintrc | 5 +- transformer_engine/common/__init__.py | 21 +++--- transformer_engine/common/recipe/__init__.py | 4 +- .../jax/cpp_extensions/attention.py | 3 +- transformer_engine/jax/cpp_extensions/base.py | 5 +- transformer_engine/jax/cpp_extensions/gemm.py | 17 +++-- .../jax/cpp_extensions/normalization.py | 8 ++- transformer_engine/jax/quantize/helper.py | 31 ++++---- transformer_engine/jax/setup.py | 4 +- transformer_engine/jax/util.py | 71 ++++++++++++------- 10 files changed, 105 insertions(+), 64 deletions(-) diff --git a/pylintrc b/pylintrc index 50f85fad9..1775bb7f1 100644 --- a/pylintrc +++ b/pylintrc @@ -31,7 +31,10 @@ disable=too-many-locals, redefined-argument-from-local, line-too-long, too-many-return-statements, - too-many-nested-blocks + too-many-nested-blocks, + import-outside-toplevel, + possibly-used-before-assignment, + wrong-import-position [TYPECHECK] ignored-modules=torch diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 662bc504c..531ad67aa 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -51,8 +51,8 @@ def _is_package_installed_from_wheel(package) -> bool: if not te_wheel_file: return False - with te_wheel_file.open("r") as f: - for line in f: + with te_wheel_file.open("r") as wheel_f: + for line in wheel_f: if line.startswith("Root-Is-Purelib:"): return line.strip().split(":")[1].strip().lower() == "true" return False @@ -138,7 +138,7 @@ def get_te_core_package_info() -> Tuple[bool, str, str]: Check if Tranformer Engine core package is installed. Returns the module name and version if found. """ - + te_core_packages = ("transformer-engine-cu12", "transformer-engine-cu13") if te_rocm_build: te_core_packages = ("transformer-engine-rocm7",) @@ -366,6 +366,7 @@ def _load_cuda_library(lib_name: str): @functools.cache def is_fp8_fnuz(): + """Return True when TE was built with FP8 FNUZ mode (ROCm path).""" if te_rocm_build: _TE_LIB_CTYPES.nvte_uses_fp8_fnuz.restype = ctypes.c_bool return _TE_LIB_CTYPES.nvte_uses_fp8_fnuz() @@ -411,13 +412,17 @@ def _load_core_library(): for rocm_path in (os.getenv("ROCM_PATH"), "/opt/rocm/core", "/opt/rocm"): if rocm_path and os.path.exists(os.path.join(rocm_path, ".info/version")): break - with open(os.path.join(rocm_path, ".info/version"), "r") as f: - rocm_version= f.read().strip().split('.')[:2] + with open(os.path.join(rocm_path, ".info/version"), "r", encoding="utf-8") as ver_file: + rocm_version = ver_file.read().strip().split(".")[:2] # Get ROCm version from the build info file - with open(Path(transformer_engine.__path__[0]).parent / "transformer_engine" / "build_info.txt", 'r') as f: - build_info = f.read().split('\n') - build_rocm_version = list(filter(lambda f: f.startswith("ROCM_VERSION:"), build_info)) + with open( + Path(transformer_engine.__path__[0]).parent / "transformer_engine" / "build_info.txt", + "r", + encoding="utf-8", + ) as build_file: + build_info = build_file.read().split("\n") + build_rocm_version = list(filter(lambda line: line.startswith("ROCM_VERSION:"), build_info)) if build_rocm_version: build_rocm_version = build_rocm_version[0].split(":")[1].strip().split('.')[:2] assert (rocm_version[0] == build_rocm_version[0]), f"ROCm {'.'.join(rocm_version)} is detected but the library is built for {'.'.join(build_rocm_version)}" diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 366f43d1f..b62b0b11e 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -29,11 +29,13 @@ class _FormatHelperFP8(NamedTuple): bwd: tuple @property - def max_fwd(self) -> float: + def max_fwd(self) -> float: + """Max FP8 forward value for the active FP8 variant (OCP vs FNUZ).""" return self.fwd[is_fp8_fnuz()] @property def max_bwd(self) -> float: + """Max FP8 backward value for the active FP8 variant (OCP vs FNUZ).""" return self.bwd[is_fp8_fnuz()] class _FormatMaxVals(Enum): diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 4d669bc46..89510e88e 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -9,9 +9,10 @@ import warnings from dataclasses import dataclass, replace from functools import partial, reduce -from packaging import version from typing import Optional, Tuple +from packaging import version + import jax import jax.numpy as jnp from jax import dtypes, lax, ffi diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index e65215bec..27455c716 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -17,9 +17,10 @@ from jax._src import dispatch from jax import ffi -from .misc import is_hip_extension import transformer_engine_jax +from .misc import is_hip_extension + class BasePrimitive(metaclass=ABCMeta): """ @@ -223,7 +224,7 @@ def name_of_wrapper_p(): for _name, _value in transformer_engine_jax.registrations().items(): - ffi.register_ffi_target(_name, _value, platform="ROCM" if is_hip_extension else "CUDA") + ffi.register_ffi_target(_name, _value, platform="ROCM" if is_hip_extension() else "CUDA") def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False): diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 4e369ebb3..94d78a0c1 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -20,24 +20,17 @@ from jax.sharding import NamedSharding, PartitionSpec from jax.experimental.custom_partitioning import SdyShardingRule -from ..util import is_hip_extension - from transformer_engine_jax import ( get_num_compute_streams, JAXX_Collective_Op, get_device_compute_capability, ) -if not is_hip_extension(): - from transformer_engine_jax import ( - initialize_cgemm_communicator, - get_cgemm_num_max_streams, - ) + +from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize -from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type - from ..quantize import ( AbstractBaseTensor, NoScaleTensor, @@ -63,6 +56,12 @@ dp_or_fsdp_axis_size, ) +if not is_hip_extension(): + from transformer_engine_jax import ( + initialize_cgemm_communicator, + get_cgemm_num_max_streams, + ) + __all__ = [ "CollectiveOp", diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 5ec0c4b4e..6ef35a134 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -8,16 +8,18 @@ import warnings import operator from functools import partial, cache, reduce -from packaging import version from typing import Optional, Union +from packaging import version + import jax import jax.numpy as jnp from jax import dtypes, ffi -from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING +if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax.experimental.custom_partitioning import SdyShardingRule +from jax.experimental.custom_partitioning import BATCHING from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec -from .misc import is_hip_extension import transformer_engine_jax from transformer_engine_jax import NVTE_Norm_Type diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 366c31726..ebb99d5ee 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -26,14 +26,8 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict -from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type +import transformer_engine_jax as tejax -from transformer_engine_jax import DType -if not is_hip_extension(): - from transformer_engine_jax import ( - get_cublasLt_version, - get_cuda_version, - ) from transformer_engine.common.recipe import ( Recipe, DelayedScaling, @@ -50,9 +44,23 @@ with_sharding_constraint, ) +from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type +from .device_utils import get_device_compute_capability from .metadata import QuantizeMeta from .scaling_modes import ScalingMode -from .device_utils import get_device_compute_capability + +if not is_hip_extension(): + get_cublasLt_version = tejax.get_cublasLt_version + get_cuda_version = tejax.get_cuda_version +else: + + def get_cublasLt_version(): + """CUDA-only; not used on ROCm code paths.""" + raise RuntimeError("get_cublasLt_version is not available on ROCm") + + def get_cuda_version(): + """CUDA-only; not used on ROCm code paths.""" + raise RuntimeError("get_cuda_version is not available on ROCm") __all__ = [ "get_global_quantize_recipe", @@ -99,8 +107,7 @@ def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: if is_hip_extension(): if gpu_arch in [94, 95]: return True, "" - else: - return False, "Device arch gfx94x or gfx95x required for FP8 execution." + return False, "Device arch gfx94x or gfx95x required for FP8 execution." if gpu_arch < 89: # pre-ada return False, "Device compute capability 8.9 or higher required for FP8 execution." if get_cublasLt_version() < 120103: @@ -305,8 +312,8 @@ class BaseQuantizeConfig(ABC): INITIALIZED = False MARGIN: float = 0.0 COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME - FWD_DTYPE: DType = None - BWD_DTYPE: DType = None + FWD_DTYPE: tejax.DType = None + BWD_DTYPE: tejax.DType = None FP8_2X_ACC_FPROP: bool = False FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index 47710a1f6..574398a7b 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -51,11 +51,11 @@ from build_tools.te_version import te_version from build_tools.jax import setup_jax_extension, install_requirements, test_requirements -from pybind11.setup_helpers import build_ext as BuildExtension - if rocm_build(): from build_tools.hipify.hipify import copy_hipify_tools, clear_hipify_tools_copy +from pybind11.setup_helpers import build_ext as BuildExtension + os.environ["NVTE_PROJECT_BUILDING"] = "1" CMakeBuildExtension = get_build_ext(BuildExtension, True) diff --git a/transformer_engine/jax/util.py b/transformer_engine/jax/util.py index c1b591768..7f9d33182 100644 --- a/transformer_engine/jax/util.py +++ b/transformer_engine/jax/util.py @@ -1,37 +1,58 @@ # Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information -from functools import cache +"""Small JAX-side helpers shared across TE JAX code (ROCm detection, FP8 dtypes).""" + import importlib.metadata import re -import subprocess, sys +import subprocess +import sys +from functools import cache + import jax.numpy as jnp -# check whether ROCm is supported by JAX + @cache def is_hip_extension() -> bool: - if any(re.match(r'jax-rocm\d+-plugin', d.metadata['Name']) - for d in importlib.metadata.distributions()): - return True - return False + """Return True when the JAX ROCm plugin distribution is installed.""" + return any( + re.match(r"jax-rocm\d+-plugin", d.metadata["Name"]) + for d in importlib.metadata.distributions() + ) + if is_hip_extension(): - @cache - def is_mi200(): - import jax - """check whether this machine is mi200/210/250""" - return (re.search('AMD Instinct MI2.0', jax.devices()[0].device_kind) is not None) - + + @cache + def is_mi200(): + """Return True when running on AMD Instinct MI200-class hardware.""" + import jax + + return re.search(r"AMD Instinct MI2\.0", jax.devices()[0].device_kind) is not None + + @cache def is_fp8_fnuz(): - if not is_hip_extension(): - return False - ret = subprocess.run( - [sys.executable, "-c", - "import sys; sys.path[:] = [p for p in sys.path if p not in ['', '.']]; "+ - "import os; os.environ['NVTE_FRAMEWORK']='none'; "+ - "import transformer_engine as te; exit(not te.common.is_fp8_fnuz())"] - ).returncode - return ret == 0 - -get_jnp_float8_e4m3_type = lambda: jnp.float8_e4m3fnuz if is_fp8_fnuz() else jnp.float8_e4m3fn -get_jnp_float8_e5m2_type = lambda: jnp.float8_e5m2fnuz if is_fp8_fnuz() else jnp.float8_e5m2 + """Return True when TE core reports FP8 FNUZ usage (matches subprocess TE check).""" + if not is_hip_extension(): + return False + proc = subprocess.run( + [ + sys.executable, + "-c", + "import sys; sys.path[:] = [p for p in sys.path if p not in ['', '.']]; " + "import os; os.environ['NVTE_FRAMEWORK']='none'; " + "import transformer_engine as te; exit(not te.common.is_fp8_fnuz())", + ], + check=False, + ) + return proc.returncode == 0 + + +def get_jnp_float8_e4m3_type(): + """JAX FP8 e4m3 dtype for this platform (FNUZ on ROCm when applicable).""" + return jnp.float8_e4m3fnuz if is_fp8_fnuz() else jnp.float8_e4m3fn + + +def get_jnp_float8_e5m2_type(): + """JAX FP8 e5m2 dtype for this platform (FNUZ on ROCm when applicable).""" + return jnp.float8_e5m2fnuz if is_fp8_fnuz() else jnp.float8_e5m2 From 6748cd8c2c0552ee39018a3bfdcb08415286bbe3 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 21 Apr 2026 22:35:44 +0000 Subject: [PATCH 4/7] CI: Fixed python pylint issues --- pylintrc | 10 +- .../attention/dot_product_attention/utils.py | 16 +- transformer_engine/pytorch/constants.py | 21 ++- transformer_engine/pytorch/jit.py | 38 ++--- transformer_engine/pytorch/module/__init__.py | 28 ++-- transformer_engine/pytorch/module/base.py | 17 ++- .../pytorch/module/layernorm_linear.py | 12 +- .../pytorch/module/layernorm_mlp.py | 6 +- transformer_engine/pytorch/module/linear.py | 4 +- transformer_engine/pytorch/quantization.py | 11 +- .../pytorch/quantized_tensor.py | 7 +- transformer_engine/pytorch/setup.py | 8 +- .../pytorch/tensor/float8_tensor.py | 16 +- .../pytorch/tensor/fsdp2_allgather_tensor.py | 140 ++++++++++++------ .../pytorch/tensor/mxfp4_tensor.py | 4 +- .../pytorch/tensor/mxfp8_tensor.py | 10 +- .../tensor/storage/mxfp4_tensor_storage.py | 6 +- .../pytorch/triton_kernels/cast.py | 64 ++++---- .../pytorch/triton_kernels/cast_transpose.py | 140 +++++++++--------- .../pytorch/triton_kernels/common.py | 30 ++-- .../pytorch/triton_kernels/gmm/gmm_common.py | 16 +- .../pytorch/triton_kernels/gmm/gmm_kernels.py | 20 +-- .../pytorch/triton_kernels/gmm/gmm_wrapper.py | 14 +- .../pytorch/triton_kernels/grouped_gemm.py | 97 ++++++------ .../pytorch/triton_kernels/rmsnorm.py | 29 ++-- .../pytorch/triton_kernels/utils.py | 8 +- transformer_engine/pytorch/utils.py | 58 +++++--- 27 files changed, 451 insertions(+), 379 deletions(-) diff --git a/pylintrc b/pylintrc index 1775bb7f1..faba34d5c 100644 --- a/pylintrc +++ b/pylintrc @@ -5,6 +5,11 @@ extension-pkg-whitelist=flash_attn_2_cuda, transformer_engine_jax disable=too-many-locals, + missing-module-docstring, + missing-function-docstring, + wrong-import-order, + ungrouped-imports, + fixme, too-few-public-methods, too-many-public-methods, too-many-positional-arguments, @@ -34,7 +39,10 @@ disable=too-many-locals, too-many-nested-blocks, import-outside-toplevel, possibly-used-before-assignment, - wrong-import-position + wrong-import-position, + unnecessary-lambda-assignment, + use-dict-literal, + redefined-builtin [TYPECHECK] ignored-modules=torch diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 9b7147106..ef678e39e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -623,15 +623,13 @@ def get_attention_backend( ) use_fused_attention = False - if use_flash_attention_2 and ( - head_dim_qk > 256 - or head_dim_qk % 8 != 0 - or ( - not IS_HIP_EXTENSION - and head_dim_qk > 192 - and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0)) - ) - ): + fa2_hd_check = head_dim_qk > 256 or head_dim_qk % 8 != 0 + fa2_hd_dcc_check = ( + not IS_HIP_EXTENSION + and head_dim_qk > 192 + and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0)) + ) + if use_flash_attention_2 and (fa2_hd_check or fa2_hd_dcc_check): if FlashAttentionUtils.is_installed: logger.debug( "Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. " diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index 120e63c04..c6cea77b6 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -5,9 +5,13 @@ # See LICENSE for license information. """Enums for e2e transformer""" + import torch import torch.distributed +from torch.utils.cpp_extension import IS_HIP_EXTENSION + import transformer_engine_torch as tex + from .utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type """ @@ -25,24 +29,31 @@ torch.half: tex.DType.kFloat16, torch.bfloat16: tex.DType.kBFloat16, } -from torch.utils.cpp_extension import IS_HIP_EXTENSION + if IS_HIP_EXTENSION: - TE_DType.update({torch.float8_e4m3fnuz: tex.DType.kFloat8E4M3, - torch.float8_e5m2fnuz: tex.DType.kFloat8E5M2}) + TE_DType.update( + { + torch.float8_e4m3fnuz: tex.DType.kFloat8E4M3, + torch.float8_e5m2fnuz: tex.DType.kFloat8E5M2, + } + ) _FP8_KEYS = (tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2) class Custom_DType_Dict(dict): + """Torch dtype lookup with lazy population for FP8 aliases.""" + def __missing__(self, key): if key in _FP8_KEYS: value = ( get_torch_float8_e4m3_type() if key is tex.DType.kFloat8E4M3 else get_torch_float8_e5m2_type() ) - self[key] = value + self[key] = value return value raise KeyError(key) - + + TE_DType_To_Torch = Custom_DType_Dict({ tex.DType.kByte: torch.uint8, tex.DType.kInt32: torch.int32, diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index e8cef56bd..2f201236c 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -63,26 +63,26 @@ def wrapper(*args, **kwargs): def set_jit_fusion_options() -> None: - if not IS_HIP_EXTENSION: """Set PyTorch JIT layer fusion options.""" - # flags required to enable jit fusion kernels - if torch_version() >= (2, 2, 0): - pass - elif torch_version() >= (1, 10, 0): - # nvfuser - torch._C._jit_set_profiling_executor(True) - torch._C._jit_set_profiling_mode(True) - torch._C._jit_override_can_fuse_on_cpu(False) - torch._C._jit_override_can_fuse_on_gpu(False) - torch._C._jit_set_texpr_fuser_enabled(False) - torch._C._jit_set_nvfuser_enabled(True) - torch._C._debug_set_autodiff_subgraph_inlining(False) - else: - # legacy pytorch fuser - torch._C._jit_set_profiling_mode(False) - torch._C._jit_set_profiling_executor(False) - torch._C._jit_override_can_fuse_on_cpu(True) - torch._C._jit_override_can_fuse_on_gpu(True) + if not IS_HIP_EXTENSION: + # flags required to enable jit fusion kernels + if torch_version() >= (2, 2, 0): + pass + elif torch_version() >= (1, 10, 0): + # nvfuser + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(True) + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(True) + torch._C._debug_set_autodiff_subgraph_inlining(False) + else: + # legacy pytorch fuser + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) @jit_fuser diff --git a/transformer_engine/pytorch/module/__init__.py b/transformer_engine/pytorch/module/__init__.py index 3cf15efc1..7c76cacb8 100644 --- a/transformer_engine/pytorch/module/__init__.py +++ b/transformer_engine/pytorch/module/__init__.py @@ -1,14 +1,14 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Module level PyTorch APIs""" -from .layernorm_linear import LayerNormLinear -from .linear import Linear -from .grouped_linear import GroupedLinear -from .layernorm_mlp import LayerNormMLP -from .layernorm import LayerNorm -from .rmsnorm import RMSNorm -from .fp8_padding import Fp8Padding -from .fp8_unpadding import Fp8Unpadding -from .base import initialize_ub, destroy_ub, UserBufferQuantizationMode +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Module level PyTorch APIs""" +from .layernorm_linear import LayerNormLinear +from .linear import Linear +from .grouped_linear import GroupedLinear +from .layernorm_mlp import LayerNormMLP +from .layernorm import LayerNorm +from .rmsnorm import RMSNorm +from .fp8_padding import Fp8Padding +from .fp8_unpadding import Fp8Unpadding +from .base import initialize_ub, destroy_ub, UserBufferQuantizationMode diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 2d8563729..b859df4c5 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -88,7 +88,7 @@ def get_cublas_workspace_size_bytes() -> None: """Return 64 MiB for gfx50x, 32 MiB for all other architectures.""" if get_device_compute_capability() == (9, 5): return 67_108_864 - return 33_554_432 + return 33_554_432 """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: # 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales @@ -664,6 +664,7 @@ def fill_userbuffers_buffer_for_all_gather( columnwise_scale_inv=columnwise_scale_inv, fp8_dtype=local_tensor._fp8_dtype, quantizer=quantizer, + with_gemm_swizzled_scales=local_tensor._with_gemm_swizzled_scales, ) return global_tensor, local_tensor @@ -1053,8 +1054,8 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: fp8_enabled = self.fp8 or self.fp8_calibration self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration - if IS_HIP_EXTENSION and not FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 and hasattr(self, 'use_fsdp2') and self.use_fsdp2: - FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 = True + if IS_HIP_EXTENSION and not FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 and hasattr(self, 'use_fsdp2') and self.use_fsdp2: + FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 = True if self.fp8_parameters or fp8_enabled: if ( @@ -1088,8 +1089,8 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: self.fp8_initialized = True self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() - if self.fp8_meta["recipe"].mxfp8(): - self.keep_fp8_weight_transpose_cache = True + if self.fp8_meta["recipe"].mxfp8(): + self.keep_fp8_weight_transpose_cache = True _current_recipe = self.fp8_meta["recipe"] if _original_recipe is not None and not ( @@ -1357,9 +1358,9 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: if IS_HIP_EXTENSION and self.use_fsdp2 and not self.primary_weights_in_fp8 and fp8_meta_index is not None: self.keep_fp8_weight_transpose_cache = False param = FSDPAGTensor( - param, - module=self, - fp8_meta_index=fp8_meta_index, + param, + module=self, + fp8_meta_index=fp8_meta_index, keep_fp8_weight_transpose_cache=self.keep_fp8_weight_transpose_cache ) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7347fc138..de7af2d45 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1165,14 +1165,14 @@ class LayerNormLinear(TransformerEngineBaseModule): keep_fp8_weight_transpose_cache: bool, default = `True` Controls whether to cache the FP8 weight transpose buffer during training. - - If set to `True` (default), the FP8 weight transpose buffer is cached to avoid recomputation, + - If set to `True` (default), the FP8 weight transpose buffer is cached to avoid recomputation, which can improve performance but significantly increases memory usage. - - If set to `False`, the buffer is not cached and the FP8 weight transpose is recomputed as needed. + - If set to `False`, the buffer is not cached and the FP8 weight transpose is recomputed as needed. This reduces memory consumption, especially during checkpoint loading and runtime. - **Recommendation**: Set this to `False` when using Fully Sharded Data Parallel (FSDP) training. - Caching FP8 weight transposes can double memory usage for modules such as `Linear`, - `LayerNormLinear`, and `LayerNormMLP`, which may lead to excessive memory pressure and + **Recommendation**: Set this to `False` when using Fully Sharded Data Parallel (FSDP) training. + Caching FP8 weight transposes can double memory usage for modules such as `Linear`, + `LayerNormLinear`, and `LayerNormMLP`, which may lead to excessive memory pressure and reduced efficiency of PyTorch's caching allocator. Use this setting to balance memory usage and performance based on your training configuration. @@ -1231,7 +1231,7 @@ def __init__( self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.name = name - self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache if IS_HIP_EXTENSION else True + self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache if IS_HIP_EXTENSION else True self.use_fsdp2 = use_fsdp2 if IS_HIP_EXTENSION else False if tp_group is None: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 49667b633..3243eace9 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -555,7 +555,7 @@ def _forward( gemm_gelu_fusion = False if debug: gemm_gelu_fusion = False - + if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache: assert fc1_weight_final._transpose is None or fc1_weight_final._transpose.numel() == 0, "Expected _transpose to be None or an empty tensor when transpose cache is disabled." @@ -1400,7 +1400,7 @@ def fc2_wgrad_gemm( # Overlap FC1 DGRAD reduce-scatter with WGRAD compute ub_obj_fc1_wgrad = get_ub("fc1_wgrad", ctx.fp8) ub_type_fc1_wgrad = tex.CommOverlapType.RS - + # -------------------------------------------------- # FC1 DGRAD @@ -1676,7 +1676,7 @@ def fc1_wgrad_gemm( if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) if ctx.autocast_fp8_reduction_skipped: - FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True) + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True) # FIX THIS # Scatter Fp8 tranposed-weight buffers diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 01d07d91a..7a6ad7ad8 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -326,7 +326,9 @@ def forward( # Note: y = x * w^T # ------------------------------------------------------ if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache: - assert weightmat._transpose is None or weightmat._transpose.numel() == 0, "Expected _transpose to be None or an empty tensor when transpose cache is disabled." + assert ( + weightmat._transpose is None or weightmat._transpose.numel() == 0 + ), "Expected _transpose to be None or an empty tensor when transpose cache is disabled." nvtx_range_push(f"{nvtx_label}.gemm") gemm_out, *_, reduce_scatter_out = general_gemm( diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 25c2ff7f3..ceb1cf24b 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -17,6 +17,7 @@ from typing import Callable, List, Optional, Dict, Any, Tuple, Union import torch +from torch.utils.cpp_extension import IS_HIP_EXTENSION import transformer_engine_torch as tex from transformer_engine.common.recipe import ( Recipe, @@ -30,12 +31,9 @@ ) from .constants import dist_group_type -from .utils import get_device_compute_capability +from .utils import get_device_compute_capability, get_torch_float8_e4m3_type, get_torch_float8_e5m2_type from .jit import jit_fuser -from torch.utils.cpp_extension import IS_HIP_EXTENSION -from .utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type - __all__ = [ "autocast", "quantized_model_init", @@ -55,8 +53,7 @@ def check_fp8_support() -> Tuple[bool, str]: gpu_arch = get_device_compute_capability() if gpu_arch in ((9, 4), (9, 5)): return True, "" - else: - return False, "Device arch gfx94x or gfx95x required for FP8 execution." + return False, "Device arch gfx94x or gfx95x required for FP8 execution." if get_device_compute_capability() >= (9, 0): # hopper and above return True, "" if get_device_compute_capability() < (8, 9): # pre-ada @@ -87,9 +84,9 @@ def check_mxfp8_support() -> Tuple[bool, str]: @functools.lru_cache(maxsize=None) def check_nvfp4_support() -> Tuple[bool, str]: + """Return whether NVFP4 support is available.""" if IS_HIP_EXTENSION: return False, "ROCm TE currently not supporting NVFP4" - """Return if nvfp4 support is available""" if get_device_compute_capability() >= (10, 0): # blackwell and above return True, "" return False, "Device compute capability 10.0 or higher required for NVFP4 execution." diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 5c1d7290e..cf9bdc6a5 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -7,12 +7,11 @@ """Pure Python base classes for quantization.""" from __future__ import annotations -from torch.utils.cpp_extension import IS_HIP_EXTENSION -import os -from typing import Optional, Tuple, Iterable, Any, Dict, Union + import abc -import warnings import math +import warnings +from typing import Any, Dict, Iterable, Optional, Tuple, Union import torch from torch.utils._pytree import tree_map diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index a80dcacb8..c76f8d32c 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -130,8 +130,7 @@ def run(self): return super().run() if FORCE_BUILD: - super().run() - return + return super().run() wheel_url, wheel_filename = get_wheel_url() print("Guessing wheel URL: ", wheel_url) @@ -150,12 +149,11 @@ def run(self): wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") print("Raw wheel path", wheel_path) os.rename(wheel_filename, wheel_path) - return + return None except (urllib.error.HTTPError, urllib.error.URLError): print("Precompiled wheel not found. Building from source...") # If the wheel could not be downloaded, build from source - super().run() - return + return super().run() if __name__ == "__main__": diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 8ea81d912..dfbce4aca 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -120,12 +120,12 @@ def update_quantized( def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" if IS_HIP_EXTENSION: - from ..triton_kernels.cast import te_quantize_triton - use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ) + use_cast_transpose_triton = bool( + int(os.environ.get("NVTE_USE_CAST_TRANSPOSE_TRITON", "0")) + ) quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize return quantize_func(tensor, self) - else: - return tex.quantize(tensor, self) + return tex.quantize(tensor, self) def make_empty( self, @@ -348,12 +348,12 @@ def update_quantized( def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" if IS_HIP_EXTENSION: - from ..triton_kernels.cast import te_quantize_triton - use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ) + use_cast_transpose_triton = bool( + int(os.environ.get("NVTE_USE_CAST_TRANSPOSE_TRITON", "0")) + ) quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize return quantize_func(tensor, self) - else: - return tex.quantize(tensor, self) + return tex.quantize(tensor, self) def make_empty( self, diff --git a/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py b/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py index 763fd4419..b17ce4ce6 100644 --- a/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py +++ b/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py @@ -2,13 +2,17 @@ # Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # See LICENSE for license information. -from typing import Any, Optional, Tuple +"""FSDP2 all-gather wrapper tensor for FP8/MXFP8 parameter transport.""" + +from typing import Any, Optional, Tuple, cast + import torch -import torch.nn as nn -from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from torch import nn import torch.utils._pytree as pytree +from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + _ops_to_preserve_subclass = { torch.ops.aten.empty_like.default, torch.ops.aten.new_zeros.default, @@ -23,8 +27,8 @@ } -# A wrapper subclass for stateful FSDP transport class FSDPAGTensor(torch.Tensor): + """Tensor subclass carrying FSDP metadata for quantized all-gather.""" @staticmethod def __new__(cls, elem: torch.Tensor, **kwargs): @@ -48,7 +52,7 @@ def __init__( fp8_meta_index: str, keep_fp8_weight_transpose_cache: bool, ): - #The underlying tensor + # The underlying tensor self._data = tensor # Where quantizers are present self._module = module @@ -60,37 +64,37 @@ def __init__( @property def data(self) -> torch.Tensor: return self._data.detach() - + def __repr__(self): - return ( - f"FSDPAGTensor(" - f"elem={self._data}, " - f"module={self._module.__class__.__name__}, " - f"fp8_meta_index={self._fp8_meta_index})" - ) - + """String form for debugging.""" + return ( + f"FSDPAGTensor(" + f"elem={self._data}, " + f"module={self._module.__class__.__name__}, " + f"fp8_meta_index={self._fp8_meta_index})" + ) + def __tensor_flatten__(self): - """ - Makes some ops (view/as_strided, etc.) and serialization friendlier for wrapper subclasses. - Return (names_of_inner_tensors, flatten_spec_metadata). - """ - # We only carry the one inner tensor. - # We store (module, fp8_meta_index, keep_fp8_weight_transpose_cache) as metadata to reconstruct. - return ["_data"], (self._module, self._fp8_meta_index, self._keep_fp8_weight_transpose_cache) - - + """ + Makes some ops (view/as_strided, etc.) and serialization friendlier for wrapper subclasses. + Return (names_of_inner_tensors, flatten_spec_metadata). + """ + # We only carry the one inner tensor. + # We store (module, fp8_meta_index, keep_fp8_weight_transpose_cache) as metadata to reconstruct. + return ["_data"], (self._module, self._fp8_meta_index, self._keep_fp8_weight_transpose_cache) + @staticmethod - def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): + def __tensor_unflatten__(inner_tensors, flatten_spec, _outer_size, _outer_stride): module, fp8_meta_index, keep_fp8_weight_transpose_cache = flatten_spec return FSDPAGTensor( inner_tensors["_data"], module=module, fp8_meta_index=fp8_meta_index, - keep_fp8_weight_transpose_cache=keep_fp8_weight_transpose_cache + keep_fp8_weight_transpose_cache=keep_fp8_weight_transpose_cache, ) @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): + def __torch_dispatch__(cls, func, _types, args, kwargs=None): if kwargs is None: kwargs = {} @@ -99,7 +103,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): t = args[0] assert isinstance(t, cls), f"Unexpected detach input type: {type(t)}" detached = t._data.detach() - return cls(detached, module=t._module, fp8_meta_index=t._fp8_meta_index, keep_fp8_weight_transpose_cache=t._keep_fp8_weight_transpose_cache) + return cls( + detached, + module=t._module, + fp8_meta_index=t._fp8_meta_index, + keep_fp8_weight_transpose_cache=t._keep_fp8_weight_transpose_cache, + ) # Unwrap only our subclass; capture shared metadata for rewrapping meta: Optional[tuple[nn.Module, str, bool]] = None @@ -121,24 +130,31 @@ def unwrap(x): if func not in _ops_to_preserve_subclass or meta is None: return out + narrowed_meta = cast(tuple[nn.Module, str, bool], meta) + def rewrap(x): if isinstance(x, torch.Tensor): - mod, idx, keep_transpose = meta - return cls(x, module=mod, fp8_meta_index=idx, keep_fp8_weight_transpose_cache=keep_transpose) + return cls( + x, + module=narrowed_meta[0], + fp8_meta_index=narrowed_meta[1], + keep_fp8_weight_transpose_cache=narrowed_meta[2], + ) return x out = pytree.tree_map_only(torch.Tensor, rewrap, out) return out # Must return (list_of_tensors_to_all_gather, user_metadata) - def fsdp_pre_all_gather(self, mesh): + def fsdp_pre_all_gather(self, _mesh): + """Return sharded FP8/MXFP8 pieces and metadata for FSDP all-gather.""" # If metadata isn't initialized yet, we can't access the quantizers if not self._module.fp8: - module_class_name = self._module.__class__.__name__ - if "LayerNormMLP" in module_class_name: - num_gemms = 2 - else: # Linear, LayerNormLinear, etc. - num_gemms = 1 + module_class_name = self._module.__class__.__name__ + if "LayerNormMLP" in module_class_name: + num_gemms = 2 + else: # Linear, LayerNormLinear, etc. + num_gemms = 1 self._module.init_fp8_metadata(num_gemms=num_gemms) if not self._module.fp8: @@ -153,13 +169,34 @@ def fsdp_pre_all_gather(self, mesh): quantizer.with_amax_reduction = True sharded_fp8_tensor = quantizer(base) if isinstance(quantizer, MXFP8Quantizer): - rowwise_data = sharded_fp8_tensor._rowwise_data if quantizer.rowwise_usage else torch.empty(0, dtype=torch.uint8, device=base.device) - rowwise_scale_inv = sharded_fp8_tensor._rowwise_scale_inv if quantizer.rowwise_usage else torch.empty(0, dtype=torch.uint8, device=base.device) - columnwise_data = sharded_fp8_tensor._columnwise_data if quantizer.columnwise_usage else torch.empty(0, dtype=torch.uint8, device=base.device) - columnwise_scale_inv = sharded_fp8_tensor._columnwise_scale_inv if quantizer.columnwise_usage else torch.empty(0, dtype=torch.uint8, device=base.device) - return (rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, ), (base.requires_grad,) + rowwise_data = ( + sharded_fp8_tensor._rowwise_data + if quantizer.rowwise_usage + else torch.empty(0, dtype=torch.uint8, device=base.device) + ) + rowwise_scale_inv = ( + sharded_fp8_tensor._rowwise_scale_inv + if quantizer.rowwise_usage + else torch.empty(0, dtype=torch.uint8, device=base.device) + ) + columnwise_data = ( + sharded_fp8_tensor._columnwise_data + if quantizer.columnwise_usage + else torch.empty(0, dtype=torch.uint8, device=base.device) + ) + columnwise_scale_inv = ( + sharded_fp8_tensor._columnwise_scale_inv + if quantizer.columnwise_usage + else torch.empty(0, dtype=torch.uint8, device=base.device) + ) + return ( + rowwise_data, + rowwise_scale_inv, + columnwise_data, + columnwise_scale_inv, + ), (base.requires_grad,) return (sharded_fp8_tensor._data,), (base.requires_grad,) - + def fsdp_post_all_gather( self, all_gather_outputs: Tuple[torch.Tensor, ...], @@ -168,17 +205,26 @@ def fsdp_post_all_gather( *, out: Optional[torch.Tensor] = None, ): - (requires_grad, ) = metadata + """Reconstruct the full-quantized parameter after all-gather.""" + requires_grad = metadata[0] if not self._module.fp8: (data,) = all_gather_outputs return data, all_gather_outputs # Retrieve the same quantizer you used in pre_all_gather quantizer = self._module.quantizers["scaling_fwd"][self._fp8_meta_index] shape = None - if not isinstance(quantizer, MXFP8Quantizer) and not self._keep_fp8_weight_transpose_cache: + if ( + not isinstance(quantizer, MXFP8Quantizer) + and not self._keep_fp8_weight_transpose_cache + ): quantizer.set_usage(columnwise=False) if isinstance(quantizer, MXFP8Quantizer): - (rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv,) = all_gather_outputs + ( + rowwise_data, + rowwise_scale_inv, + columnwise_data, + columnwise_scale_inv, + ) = all_gather_outputs shape = rowwise_data.shape else: (data,) = all_gather_outputs @@ -186,13 +232,13 @@ def fsdp_post_all_gather( # Construct a new low precision tensor subclass that will wrap the gathered data if out is None: - out = quantizer.make_empty(shape = shape, dtype=param_dtype, requires_grad=requires_grad) + out = quantizer.make_empty(shape=shape, dtype=param_dtype, requires_grad=requires_grad) if isinstance(quantizer, MXFP8Quantizer): out._rowwise_data = rowwise_data - out._rowwise_scale_inv = rowwise_scale_inv + out._rowwise_scale_inv = rowwise_scale_inv out._columnwise_data = None if columnwise_data.numel() == 0 else columnwise_data - out._columnwise_scale_inv = None if columnwise_scale_inv.numel() == 0 else columnwise_scale_inv + out._columnwise_scale_inv = None if columnwise_scale_inv.numel() == 0 else columnwise_scale_inv else: out._scale_inv = 1 / quantizer.scale out._data = data diff --git a/transformer_engine/pytorch/tensor/mxfp4_tensor.py b/transformer_engine/pytorch/tensor/mxfp4_tensor.py index adbe9802b..8e85746a9 100644 --- a/transformer_engine/pytorch/tensor/mxfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp4_tensor.py @@ -124,7 +124,7 @@ def make_empty( # Allocate FP4 data: [M, K/2] rowwise_data = torch.empty(M, K // 2, dtype=torch.uint8, device=device) - + # Allocate PADDED scale tensors for shuffle compatibility rowwise_scale_K = math.ceil(K / MXFP4_BLOCK_SCALING_SIZE) rowwise_scale_inv = torch.zeros( @@ -252,7 +252,7 @@ def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: Construct plain PyTorch tensor from MXFP4Tensor By default the resulting tensor's dtype is the MXFP4Tensor's nominal dtype. - + Note: For MXFP4 forward-only training, this is typically not needed as backward pass uses high-precision activations. """ diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index bd3d93e9f..c1e86d28a 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -98,12 +98,12 @@ def update_quantized( def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" if IS_HIP_EXTENSION: - from ..triton_kernels.cast import te_quantize_triton - use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ) + use_cast_transpose_triton = bool( + int(os.environ.get("NVTE_USE_CAST_TRANSPOSE_TRITON", "0")) + ) quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize return quantize_func(tensor, self) - else: - return tex.quantize(tensor, self) + return tex.quantize(tensor, self) def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" @@ -145,7 +145,7 @@ def make_empty( # ROCm TE does not implement fuse padding zeros so use zero tensor here if IS_HIP_EXTENSION: scale_inv = torch.zeros( - math.prod(shape[:-1]), + math.prod(shape[:-1]), math.ceil(shape[-1] / MXFP8_BLOCK_SCALING_SIZE), dtype=torch.uint8, device=device, diff --git a/transformer_engine/pytorch/tensor/storage/mxfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp4_tensor_storage.py index 9e3aa3fdd..e57b69e1f 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp4_tensor_storage.py @@ -29,7 +29,7 @@ def forward( if hasattr(tensor, '_data') and tensor._data is not None: # Return cached high-precision data (used during model initialization/teardown) return tensor._data.to(dtype) if tensor._data.dtype != dtype else tensor._data - + # If no cached data, we would need to dequantize from rowwise FP4 data # This path should not be hit in forward-only MXFP4 training @@ -77,7 +77,7 @@ def __new__( columnwise_data: Optional[torch.Tensor], columnwise_scale_inv: torch.Tensor, fp4_dtype: TE_DType, - quantizer: Optional[Quantizer] = None, + quantizer: Optional[Quantizer], *args, **kwargs, ): @@ -175,7 +175,7 @@ def update_usage( ): """ Update the usage of the MXFP4TensorStorage. - + """ # Default usage is based on available data diff --git a/transformer_engine/pytorch/triton_kernels/cast.py b/transformer_engine/pytorch/triton_kernels/cast.py index 4404ac00b..0320ab8d8 100644 --- a/transformer_engine/pytorch/triton_kernels/cast.py +++ b/transformer_engine/pytorch/triton_kernels/cast.py @@ -2,11 +2,10 @@ # License for AMD contributions = MIT. See LICENSE for more information """Python interface for cast extensions""" -import os -from typing import Iterable, List, Optional, Tuple, Union +from typing import Optional + import functools import torch -import warnings from ..utils import is_non_tn_fp8_gemm_supported @@ -22,36 +21,34 @@ def _empty_tensor() -> torch.Tensor: """Get tensor with no entries and no data""" return torch.Tensor().cuda() -def _setup_conditional_transpose_storage( - tensor: QuantizedTensor, - ) -> QuantizedTensor: - shape = tensor.shape - quantizer = tensor._get_quantizer() - - # Allocate FP8 data transpose if needed - data_transpose = None - create_transpose = quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported(); - if quantizer.columnwise_usage and create_transpose: - if tensor.ndim == 0: - # If the original tensor is a scalar, its transpose is also a scalar. - data_transpose = torch.empty((), dtype=torch.uint8, device=tensor.device) - else: - transposed_shape = (shape[-1],) + shape[:-1] - data_transpose = torch.empty( - transposed_shape, - dtype=torch.uint8, - device=tensor.device, - ) +def _setup_conditional_transpose_storage(tensor: QuantizedTensor) -> None: + shape = tensor.shape + quantizer = tensor._get_quantizer() + + # Allocate FP8 data transpose if needed + data_transpose = None + create_transpose = quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported() + if quantizer.columnwise_usage and create_transpose: + if tensor.ndim == 0: + # If the original tensor is a scalar, its transpose is also a scalar. + data_transpose = torch.empty((), dtype=torch.uint8, device=tensor.device) + else: + transposed_shape = (shape[-1],) + shape[:-1] + data_transpose = torch.empty( + transposed_shape, + dtype=torch.uint8, + device=tensor.device, + ) - # Construct FP8 tensor - tensor._transpose = data_transpose - tensor._transpose_invalid = tensor._transpose is None + # Construct FP8 tensor + tensor._transpose = data_transpose + tensor._transpose_invalid = tensor._transpose is None def te_quantize_triton( tensor: torch.Tensor, quantizer: Quantizer, output: Optional[torch.Tensor] = None, - noop_flag: torch.Tensor = None + noop_flag: torch.Tensor = None ) -> torch.Tensor: """ Quantizes the input tensor using a specified quantizer, @@ -61,7 +58,7 @@ def te_quantize_triton( fake_tensor_type = input_tensor.dtype if not fake_tensor_type.is_floating_point: fake_tensor_type = torch.float32 - + out: QuantizedTensor = None if output is None: assert quantizer is not None, "Quantizer object cannot be None. Please provide a valid quantizer." @@ -78,7 +75,7 @@ def te_quantize_triton( else: # Create a QuantizedTensor from the provided output tensor out = output - + # Construct no-op flag if needed if noop_flag is None: noop_flag = _empty_tensor() @@ -86,7 +83,7 @@ def te_quantize_triton( if (isinstance(out, MXFP8TensorStorage) and out._rowwise_data is None and out._columnwise_data is None) or (not isinstance(out, MXFP8TensorStorage) and out.size().numel() == 0): # Return empty output if the quantized tensor has no elements return out - + if isinstance(out, Float8TensorStorage): if input_tensor.nelement() > 0: if not out._transpose_invalid: @@ -114,7 +111,7 @@ def te_quantize_triton( eps = getattr(quantizer, "amax_epsilon", 0.0), force_pow_2_scales = getattr(quantizer, "force_pow_2_scales", False), ) - + else: out.remove_caches() #Make sure to remove transpose if it is marked as invalid out = tex.quantize(input_tensor, quantizer, out, noop_flag) @@ -130,7 +127,6 @@ def te_quantize_triton( def te_dequantize_triton(input, dtype: tex.DType): if isinstance(input, MXFP8TensorStorage): return te_dequantize_mxfp8_triton(input, dtype) - elif isinstance(input, Float8TensorStorage): + if isinstance(input, Float8TensorStorage): return tex.dequantize(input, dtype) - else: - raise NotImplementedError(f"Not implemented for tensor type: '{type(input).__name__}'") + raise NotImplementedError(f"Not implemented for tensor type: '{type(input).__name__}'") diff --git a/transformer_engine/pytorch/triton_kernels/cast_transpose.py b/transformer_engine/pytorch/triton_kernels/cast_transpose.py index ace2a49ae..eec01a62c 100644 --- a/transformer_engine/pytorch/triton_kernels/cast_transpose.py +++ b/transformer_engine/pytorch/triton_kernels/cast_transpose.py @@ -4,7 +4,6 @@ import torch from ..constants import MXFP8_BLOCK_SCALING_SIZE -import transformer_engine_torch as tex import triton import triton.language as tl from .common import ( @@ -77,6 +76,7 @@ def _compute_scale_from_amax_triton( a = tl.where(a < epsilon, epsilon, a) # bad amax (NaN, inf, 0.0) -> scale = 1.0 + # pylint: disable-next=comparison-with-itself bad = (a != a) | (tl.abs(a) == float('inf')) | (a == 0.0) if bad: @@ -112,13 +112,13 @@ def _cast_transpose_triton(A, noop_ptr, C, T, stride_am, stride_an, stride_bn, s grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_n = (N + BLOCK_N - 1) // BLOCK_N - + width = GROUP_M * grid_n group_id = pid // width group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // group_size - + rm = pid_m.to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n.to(tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N) A = A + rm[:, None] * stride_am + rn[None, :] * stride_an @@ -131,7 +131,7 @@ def _cast_transpose_triton(A, noop_ptr, C, T, stride_am, stride_an, stride_bn, s fp8_a = scaled_a.to(C.type.element_ty) C = C + rm[:, None] * stride_am + rn[None, :] * stride_an tl.store(C, fp8_a, mask=mask) - + # rematerialize to save registers rm = pid_m.to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n.to(tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N) @@ -271,6 +271,7 @@ def _amax_reduce_and_compute_scale_triton( a = tl.where(a < epsilon, epsilon, a) # bad amax (NaN, inf, 0.0) -> scale = 1.0 + # pylint: disable-next=comparison-with-itself bad = (a != a) | (tl.abs(a) == float('inf')) | (a == 0.0) if bad: @@ -298,7 +299,7 @@ def exp2f_rcp_triton(biased_exp: tl.uint8) -> tl.float32: @triton.jit def float_to_e8m0_triton(val: tl.float32) -> tl.uint8: - is_nan = (val != val) + is_nan = val != val # pylint: disable=comparison-with-itself is_inf = (tl.abs(val) == float('inf')) is_zero = val == 0.0 @@ -306,22 +307,22 @@ def float_to_e8m0_triton(val: tl.float32) -> tl.uint8: val_u32 = tl.cast(val, tl.uint32, bitcast=True) # Extract exponent and mantissa - exponent_raw = (val_u32 >> FP32_MANTISSA_BITS) & 0xFF + exponent_raw = (val_u32 >> FP32_MANTISSA_BITS) & 0xFF mantissa = val_u32 & 0x7FFFFF # Round up exponent and deal with satfinite. # (mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000) cond1 = mantissa > 0 - cond2 = exponent_raw != 0xFE + cond2 = exponent_raw != 0xFE cond3_part1 = exponent_raw == 0 cond3_part2 = mantissa <= 0x400000 - cond3 = cond3_part1 & cond3_part2 - + cond3 = cond3_part1 & cond3_part2 + round_up_condition = (cond1 & cond2) & ~cond3 # Increment exponent if the condition is true calculated_exponent = tl.where(round_up_condition, exponent_raw + 1, exponent_raw) - + # Priority: NaN -> Inf -> Zero -> Calculated Exponent result_e8m0 = tl.where(is_nan, tl.full(val.shape, 0xFF, dtype=tl.uint8), result_e8m0) result_e8m0 = tl.where(~is_nan & is_inf, tl.full(val.shape, 0xFE, dtype=tl.uint8), result_e8m0) @@ -332,15 +333,15 @@ def float_to_e8m0_triton(val: tl.float32) -> tl.uint8: @triton.jit def _cast_transpose_triton_mxfp8( - x_ptr, rowwise_y_ptr, colwise_y_ptr, - stride_rowwise_row, stride_rowwise_col, - n_rows, n_cols, + x_ptr, rowwise_y_ptr, colwise_y_ptr, + stride_rowwise_row, stride_rowwise_col, + n_rows, n_cols, rowwise_scale_inv_ptr, stride_rowwise_scale_inv_row, stride_rowwise_scale_inv_col, rowwise_scale_M, rowwise_scale_N, colwise_scale_inv_ptr, stride_colwise_scale_inv_row, stride_colwise_scale_inv_col, colwise_scale_M, colwise_scale_N, - max_fp8: tl.constexpr, BLOCK_X: tl.constexpr, BLOCK_Y: tl.constexpr, GROUP_Y: tl.constexpr, MXFP8_BLOCK_SCALING_SIZE: tl.constexpr, USE_ROWWISE_SCALING: tl.constexpr, USE_COLWISE_SCALING: tl.constexpr): - + max_fp8: tl.constexpr, BLOCK_X: tl.constexpr, BLOCK_Y: tl.constexpr, GROUP_Y: tl.constexpr, _MXFP8_BLOCK_SCALING_SIZE: tl.constexpr, USE_ROWWISE_SCALING: tl.constexpr, USE_COLWISE_SCALING: tl.constexpr): + pid = tl.program_id(0) num_pid_along_Y = tl.cdiv(n_rows, BLOCK_Y) @@ -354,30 +355,30 @@ def _cast_transpose_triton_mxfp8( global_offset_Y_base = pid_m.to(tl.int64) * BLOCK_Y global_offset_X_base = pid_n.to(tl.int64) * BLOCK_X - - num_chunks_in_block_Y = BLOCK_Y // MXFP8_BLOCK_SCALING_SIZE - num_chunks_in_block_X = BLOCK_X // MXFP8_BLOCK_SCALING_SIZE + + num_chunks_in_block_Y = BLOCK_Y // _MXFP8_BLOCK_SCALING_SIZE + num_chunks_in_block_X = BLOCK_X // _MXFP8_BLOCK_SCALING_SIZE max_norm_rcp = 1.0 / max_fp8 for chunk_id_y in range(0, num_chunks_in_block_Y): - offsets_Y = global_offset_Y_base + chunk_id_y * MXFP8_BLOCK_SCALING_SIZE + tl.arange(0, MXFP8_BLOCK_SCALING_SIZE) + offsets_Y = global_offset_Y_base + chunk_id_y * _MXFP8_BLOCK_SCALING_SIZE + tl.arange(0, _MXFP8_BLOCK_SCALING_SIZE) for chunk_id_x in range(0, num_chunks_in_block_X): - offsets_X = global_offset_X_base + chunk_id_x * MXFP8_BLOCK_SCALING_SIZE + tl.arange(0, MXFP8_BLOCK_SCALING_SIZE) + offsets_X = global_offset_X_base + chunk_id_x * _MXFP8_BLOCK_SCALING_SIZE + tl.arange(0, _MXFP8_BLOCK_SCALING_SIZE) x_ptr_current_chunk = x_ptr + offsets_Y[:, None] * stride_rowwise_row + offsets_X[None, :] * stride_rowwise_col mask = (offsets_Y < n_rows)[:, None] & (offsets_X < n_cols)[None, :] - # (MXFP8_BLOCK_SCALING_SIZE, MXFP8_BLOCK_SCALING_SIZE) + # (_MXFP8_BLOCK_SCALING_SIZE, _MXFP8_BLOCK_SCALING_SIZE) x_chunk = tl.load(x_ptr_current_chunk, mask=mask, other=0.0).to(tl.float32) # Rowwise if USE_ROWWISE_SCALING: subwarp_amax_rowwise = tl.max(tl.abs(x_chunk), axis=-1, keep_dims=True) biased_exponent_rowwise = float_to_e8m0_triton(subwarp_amax_rowwise * max_norm_rcp) - + scale_offset_X = (pid_n * num_chunks_in_block_X) + chunk_id_x - rowwise_scale_inv_store_offsets = (offsets_Y[:, None] * stride_rowwise_scale_inv_row) + scale_offset_X * stride_rowwise_scale_inv_col + rowwise_scale_inv_store_offsets = (offsets_Y[:, None] * stride_rowwise_scale_inv_row) + scale_offset_X * stride_rowwise_scale_inv_col rowwise_scale_inv_store_mask = (offsets_Y < rowwise_scale_M)[:, None] & (scale_offset_X < rowwise_scale_N) tl.store(rowwise_scale_inv_ptr + rowwise_scale_inv_store_offsets, biased_exponent_rowwise, mask = rowwise_scale_inv_store_mask) - + block_inverse_scale_rowwise = exp2f_rcp_triton(biased_exponent_rowwise) y_chunk_rowwise_scaled = x_chunk * block_inverse_scale_rowwise rowwise_y_ptr_current_chunk = rowwise_y_ptr + offsets_Y[:, None] * stride_rowwise_row + offsets_X[None, :] * stride_rowwise_col @@ -389,10 +390,10 @@ def _cast_transpose_triton_mxfp8( biased_exponent_colwise = float_to_e8m0_triton(subwarp_amax_colwise * max_norm_rcp) scale_offset_Y = (pid_m * num_chunks_in_block_Y) + chunk_id_y - colwise_scale_inv_store_offsets = scale_offset_Y * stride_colwise_scale_inv_row + (offsets_X[None, :] * stride_colwise_scale_inv_col) + colwise_scale_inv_store_offsets = scale_offset_Y * stride_colwise_scale_inv_row + (offsets_X[None, :] * stride_colwise_scale_inv_col) colwise_scale_inv_store_mask = (scale_offset_Y < colwise_scale_M) & (offsets_X < colwise_scale_N)[None, :] tl.store(colwise_scale_inv_ptr + colwise_scale_inv_store_offsets, biased_exponent_colwise, mask = colwise_scale_inv_store_mask) - + block_inverse_scale_colwise = exp2f_rcp_triton(biased_exponent_colwise) y_chunk_colwise_scaled = x_chunk * block_inverse_scale_colwise colwise_y_ptr_current_chunk = colwise_y_ptr + offsets_Y[:, None] * stride_rowwise_row + offsets_X[None, :] * stride_rowwise_col @@ -401,12 +402,12 @@ def _cast_transpose_triton_mxfp8( @triton.jit def _dequantize_mxfp8_triton( x_ptr, y_ptr, - stride_row, stride_col, - n_rows, n_cols, + stride_row, stride_col, + n_rows, n_cols, scale_inv_ptr, stride_scale_inv_row, stride_scale_inv_col, scale_n_rows, scale_n_cols, - BLOCK_X: tl.constexpr, BLOCK_Y: tl.constexpr, GROUP_Y: tl.constexpr, USE_ROWWISE_SCALING: tl.constexpr, MXFP8_BLOCK_SCALING_SIZE: tl.constexpr): - + BLOCK_X: tl.constexpr, BLOCK_Y: tl.constexpr, GROUP_Y: tl.constexpr, USE_ROWWISE_SCALING: tl.constexpr, _MXFP8_BLOCK_SCALING_SIZE: tl.constexpr): + pid = tl.program_id(0) num_pid_along_Y = tl.cdiv(n_rows, BLOCK_Y) @@ -420,27 +421,27 @@ def _dequantize_mxfp8_triton( global_offset_Y_base = pid_m.to(tl.int64) * BLOCK_Y global_offset_X_base = pid_n.to(tl.int64) * BLOCK_X - - num_chunks_in_block_Y = BLOCK_Y // MXFP8_BLOCK_SCALING_SIZE - num_chunks_in_block_X = BLOCK_X // MXFP8_BLOCK_SCALING_SIZE + + num_chunks_in_block_Y = BLOCK_Y // _MXFP8_BLOCK_SCALING_SIZE + num_chunks_in_block_X = BLOCK_X // _MXFP8_BLOCK_SCALING_SIZE for chunk_id_y in range(0, num_chunks_in_block_Y): - offsets_Y = global_offset_Y_base + chunk_id_y * MXFP8_BLOCK_SCALING_SIZE + tl.arange(0, MXFP8_BLOCK_SCALING_SIZE) + offsets_Y = global_offset_Y_base + chunk_id_y * _MXFP8_BLOCK_SCALING_SIZE + tl.arange(0, _MXFP8_BLOCK_SCALING_SIZE) for chunk_id_x in range(0, num_chunks_in_block_X): - offsets_X = global_offset_X_base + chunk_id_x * MXFP8_BLOCK_SCALING_SIZE + tl.arange(0, MXFP8_BLOCK_SCALING_SIZE) + offsets_X = global_offset_X_base + chunk_id_x * _MXFP8_BLOCK_SCALING_SIZE + tl.arange(0, _MXFP8_BLOCK_SCALING_SIZE) x_ptr_current_chunk = x_ptr + offsets_Y[:, None] * stride_row + offsets_X[None, :] * stride_col mask = (offsets_Y < n_rows)[:, None] & (offsets_X < n_cols)[None, :] x_chunk = tl.load(x_ptr_current_chunk, mask=mask) if USE_ROWWISE_SCALING: scale_offset_X = (pid_n * num_chunks_in_block_X) + chunk_id_x - scale_inv_store_offsets = (offsets_Y[:, None] * stride_scale_inv_row) + scale_offset_X * stride_scale_inv_col + scale_inv_store_offsets = (offsets_Y[:, None] * stride_scale_inv_row) + scale_offset_X * stride_scale_inv_col scale_inv_store_mask = (offsets_Y < scale_n_rows)[:, None] & (scale_offset_X < scale_n_cols) else: scale_offset_Y = (pid_m * num_chunks_in_block_Y) + chunk_id_y - scale_inv_store_offsets = scale_offset_Y * stride_scale_inv_row + (offsets_X[None, :] * stride_scale_inv_col) + scale_inv_store_offsets = scale_offset_Y * stride_scale_inv_row + (offsets_X[None, :] * stride_scale_inv_col) scale_inv_store_mask = (scale_offset_Y < scale_n_rows) & (offsets_X < scale_n_cols)[None, :] - + biased_exponent = tl.load(scale_inv_ptr + scale_inv_store_offsets, mask=scale_inv_store_mask, other=127) block_scale = tl.exp2(biased_exponent.to(tl.float32) - 127) y_chunk_scaled = x_chunk.to(tl.float32) * block_scale @@ -625,7 +626,7 @@ def _cast_transpose_triton_mxfp4( i3 = scale_offset_x // 8 i4 = (scale_offset_x % 8) // 4 i5 = scale_offset_x % 4 - + # rowwise_scale_N_pad is already (N/32) rounded up to multiple of 8 bs_offs = ( i0 * (rowwise_scale_N_pad // 8 * 256) + @@ -704,7 +705,7 @@ def _cast_transpose_triton_mxfp4( i3 = scale_chunk // 8 i4 = (scale_chunk % 8) // 4 i5 = scale_chunk % 4 - + # colwise_scale_N_pad is already (M/32) rounded up to multiple of 8 bs_offs = ( i0 * (colwise_scale_N_pad // 8 * 256) + @@ -735,7 +736,7 @@ def _cast_transpose_triton_mxfp4( mask=scale_mask, ) -# Reshapes input of any given shape to 2D for processing, +# Reshapes input of any given shape to 2D for processing, # then uses the Triton kernel to perform casting and transposition efficiently. def te_cast_transpose_noop_triton(input, noop_flag, input_scale, cast_out, trans_out, amax_out, scale_inv_out, otype, current_scaling, eps, force_pow_2_scales): @@ -750,14 +751,11 @@ def te_cast_transpose_noop_triton(input, noop_flag, input_scale, cast_out, trans trans_out_stride_M = trans_out_2d_view.stride(0) trans_out_stride_N = trans_out_2d_view.stride(1) - + tl_dtype = te_dtype_to_triton_dtype(otype) - - if noop_flag.nelement() > 0: - use_noop = True - else: - use_noop = False - + + use_noop = noop_flag.nelement() > 0 + grid = lambda META: (triton.cdiv(num_rows, META['BLOCK_M']) * triton.cdiv(row_length, META['BLOCK_N']),) if current_scaling: @@ -820,7 +818,7 @@ def te_cast_transpose_noop_triton(input, noop_flag, input_scale, cast_out, trans # Delayed scaling _cast_transpose_triton[grid](input_2d_view, noop_flag, triton.reinterpret(cast_out_2d_view, tl_dtype), triton.reinterpret(trans_out_2d_view, tl_dtype), input_stride_M, input_stride_N, trans_out_stride_M, trans_out_stride_N, num_rows, row_length, input_scale, amax_out, scale_inv_out, get_fp8_max(otype), use_noop) -def te_cast_transpose_mxfp8_triton(input, out, noop_flag=None): +def te_cast_transpose_mxfp8_triton(input, out, noop_flag=None): # pylint: disable=unused-argument row_length = input.shape[-1] if len(input.shape) > 0 else 1 num_rows = input.numel() // row_length input_2d_view = input.reshape(num_rows, row_length) @@ -828,10 +826,10 @@ def te_cast_transpose_mxfp8_triton(input, out, noop_flag=None): USE_ROWWISE_SCALING = out_metadata["rowwise_data"] is not None USE_COLWISE_SCALING = out_metadata["columnwise_data"] is not None - + fp8_dtype = out_metadata["fp8_dtype"] tl_dtype = te_dtype_to_triton_dtype(fp8_dtype) - + rowwise_y_ptr, rowwise_scale_inv_ptr = None, None rowwise_scale_M, rowwise_scale_N = 1, 1 rowwise_scale_stride_M, rowwise_scale_stride_N = 1, 1 @@ -841,7 +839,7 @@ def te_cast_transpose_mxfp8_triton(input, out, noop_flag=None): rowwise_scale_inv_ptr = out_metadata["rowwise_scale_inv"] rowwise_scale_M, rowwise_scale_N = rowwise_scale_inv_ptr.shape rowwise_scale_stride_M, rowwise_scale_stride_N = rowwise_scale_inv_ptr.stride(0), rowwise_scale_inv_ptr.stride(1) - + colwise_y_ptr, colwise_scale_inv_ptr = None, None colwise_scale_M, colwise_scale_N = 1, 1 colwise_scale_stride_M, colwise_scale_stride_N = 1, 1 @@ -851,17 +849,17 @@ def te_cast_transpose_mxfp8_triton(input, out, noop_flag=None): colwise_scale_inv_ptr = out_metadata["columnwise_scale_inv"] colwise_scale_M, colwise_scale_N = colwise_scale_inv_ptr.shape colwise_scale_stride_M, colwise_scale_stride_N = colwise_scale_inv_ptr.stride(0), colwise_scale_inv_ptr.stride(1) - - + + BLOCK_X = 64 BLOCK_Y = 64 GROUP_Y = MXFP8_BLOCK_SCALING_SIZE max_fp8 = get_fp8_max(fp8_dtype) grid = lambda META: (triton.cdiv(num_rows, META['BLOCK_Y']) * triton.cdiv(row_length, META['BLOCK_X']),) _cast_transpose_triton_mxfp8[grid]( - input_2d_view, rowwise_y_ptr, colwise_y_ptr, - input_2d_view.stride(0), input_2d_view.stride(1), - num_rows, row_length, + input_2d_view, rowwise_y_ptr, colwise_y_ptr, + input_2d_view.stride(0), input_2d_view.stride(1), + num_rows, row_length, rowwise_scale_inv_ptr, rowwise_scale_stride_M, rowwise_scale_stride_N, rowwise_scale_M, rowwise_scale_N, colwise_scale_inv_ptr, colwise_scale_stride_M, colwise_scale_stride_N, @@ -873,7 +871,7 @@ def te_dequantize_mxfp8_triton(input, dtype): use_rowwise_scaling = input_metadata["rowwise_data"] is not None x_ptr = None scale_inv_ptr = None - + if use_rowwise_scaling: x_ptr = input_metadata["rowwise_data"] row_length = x_ptr.shape[-1] if len(x_ptr.shape) > 0 else 1 @@ -886,7 +884,7 @@ def te_dequantize_mxfp8_triton(input, dtype): num_rows = x_ptr.numel() // row_length x_ptr = x_ptr.reshape(num_rows, row_length) scale_inv_ptr = input_metadata["columnwise_scale_inv"] - + fp8_dtype = input_metadata["fp8_dtype"] scale_M, scale_N = scale_inv_ptr.shape dtype = te_dtype_to_torch_dtype(dtype) @@ -900,15 +898,15 @@ def te_dequantize_mxfp8_triton(input, dtype): grid = lambda META: (triton.cdiv(num_rows, META['BLOCK_Y']) * triton.cdiv(row_length, META['BLOCK_X']),) _dequantize_mxfp8_triton[grid]( triton.reinterpret(x_ptr, tl_dtype), out, - x_ptr.stride(0), x_ptr.stride(1), - num_rows, row_length, + x_ptr.stride(0), x_ptr.stride(1), + num_rows, row_length, scale_inv_ptr, scale_inv_ptr.stride(0), scale_inv_ptr.stride(1), scale_M, scale_N, BLOCK_X, BLOCK_Y, GROUP_Y, use_rowwise_scaling, MXFP8_BLOCK_SCALING_SIZE) return out -def te_cast_transpose_mxfp4_triton(input, out, noop_flag=None): +def te_cast_transpose_mxfp4_triton(input, out, noop_flag=None): # pylint: disable=unused-argument # Reshape input to 2D: (M, N) logical N = input.shape[-1] if len(input.shape) > 0 else 1 M = input.numel() // N @@ -920,7 +918,7 @@ def te_cast_transpose_mxfp4_triton(input, out, noop_flag=None): USE_COLWISE_SCALING = out_metadata["columnwise_data"] is not None SHUFFLE_ROWWISE_SCALING = shuffle_B_matrix_for_aiter and USE_ROWWISE_SCALING - SHUFFLE_COLWISE_SCALING = shuffle_B_matrix_for_aiter and USE_COLWISE_SCALING + SHUFFLE_COLWISE_SCALING = shuffle_B_matrix_for_aiter and USE_COLWISE_SCALING # pylint: disable=unused-variable MXFP4_BLOCK_SIZE = 32 BLOCK_M = 128 @@ -1018,13 +1016,13 @@ def _transpose_triton_dbias(A, C, T, stride_am, stride_an, stride_bn, stride_bm, grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_n = (N + BLOCK_N - 1) // BLOCK_N - + width = GROUP_M * grid_n group_id = pid // width group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // group_size - + rm = pid_m.to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n.to(tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N) A = A + rm[:, None] * stride_am + rn[None, :] * stride_an @@ -1034,14 +1032,14 @@ def _transpose_triton_dbias(A, C, T, stride_am, stride_an, stride_bn, stride_bm, partial_sum_a = tl.sum(a, axis=0) partial_dbias = partial_dbias + pid_m.to(tl.int64) * stride_am + rn * stride_an - tl.store(partial_dbias, partial_sum_a, mask=(rn 0, f"Number of bias_grad rows K must be positive (K = {K})." assert G > 0, f"Number of groups G must be positive (G = {G})." return torch.empty((G, K), device=device, dtype=torch.float32) - else: - # Return dummy pointer when bias_grad is not needed. - # Must be float32 because atomic_add does not support bf16/fp16, - # and Triton validates the pointer dtype even in dead branches. - return torch.tensor([], device=device, dtype=torch.float32) + # Return dummy pointer when bias_grad is not needed. + # Must be float32 because atomic_add does not support bf16/fp16, + # and Triton validates the pointer dtype even in dead branches. + return torch.tensor([], device=device, dtype=torch.float32) def gen_tgmm_tensors( @@ -520,7 +519,7 @@ def gen_tgmm_tensors( input_type: torch.dtype = DTYPE, output_type: torch.dtype = DTYPE, trans_lhs: bool = TRANS_LHS, - trans_rhs: bool = False, + trans_rhs: bool = False, # pylint: disable=unused-argument rng_seed: int | None = RNG_SEED, unif_group_sizes: bool = False, use_bias: bool = False, @@ -666,8 +665,7 @@ def get_tgmm_bias_grad( return existing_bias_grad - else: - return gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=False) + return gen_tgmm_bias_grad(K, G, device=device, with_bias_grad=False) def get_tgmm_transposition(lhs: Tensor, rhs: Tensor, out: Tensor) -> tuple[bool, int]: diff --git a/transformer_engine/pytorch/triton_kernels/gmm/gmm_kernels.py b/transformer_engine/pytorch/triton_kernels/gmm/gmm_kernels.py index e26d55140..4c91d4991 100644 --- a/transformer_engine/pytorch/triton_kernels/gmm/gmm_kernels.py +++ b/transformer_engine/pytorch/triton_kernels/gmm/gmm_kernels.py @@ -49,11 +49,11 @@ def get_config( assert os.path.exists(config_filename) and os.path.isfile( config_filename ), f"'{config_filename}' isn't an existent file." - with open(config_filename, "r") as config_file: + with open(config_filename, "r", encoding="utf-8") as config_file: get_config._config_dict = json.load(config_file) assert all( gmm_type in get_config._config_dict - for gmm_type in {"gmm", "ptgmm", "nptgmm"} + for gmm_type in ("gmm", "ptgmm", "nptgmm") ), "Not all GMM variants are present in the configuration file." # Heuristic-based config selection for gmm @@ -61,7 +61,7 @@ def get_config( if fwd: k_n_ratio = K / N if N > 0 else 1.0 n_k_ratio = N / K if K > 0 else 1.0 - + # Prioritize small shapes first (before ratio checks) if M < 10000 and (N <= 2048 or K <= 2048): key = "tiny_shapes" @@ -85,7 +85,7 @@ def get_config( key = "balanced_large_n" else: key = "default" - + bwd = gmm_type == "gmm" and not trans_rhs if bwd: k_n_ratio = K / N if N > 0 else 1.0 @@ -109,7 +109,7 @@ def get_config( elif K < 5000 and N > 10000: key = "small_k_large_n_bwd" else: - key = "default" + key = "default" # Heuristic-based config selection for ptgmm elif gmm_type == "ptgmm": @@ -153,15 +153,15 @@ def get_config( assert ( key in get_config._config_dict[gmm_type] ), f"Configuration key '{key}' is absent for {gmm_type}." - + config = get_config._config_dict[gmm_type][key].copy() - + # Adapt block sizes to fit within hardware shared memory limits if dtype == torch.float32: config["BLOCK_SIZE_M"] = max(1, config["BLOCK_SIZE_M"] // 2) config["BLOCK_SIZE_K"] = max(1, config["BLOCK_SIZE_K"] // 2) config["BLOCK_SIZE_N"] = max(1, config["BLOCK_SIZE_N"] // 2) - + return config # Common code shared by GMM and TGMM kernels. @@ -255,7 +255,7 @@ def gmm_kernel( tl.device_assert(num_tiles >= 0, "num_tiles < 0") # Loop through tiles of current MM problem. - while tile >= last_mm_tile and tile < last_mm_tile + num_tiles: + while last_mm_tile <= tile < last_mm_tile + num_tiles: # Figure out tile coordinates in current MM problem. tile_in_mm = tile - last_mm_tile tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") @@ -424,7 +424,7 @@ def tgmm_persistent_kernel( tl.device_assert(m >= 0, "m < 0") # Loop through tiles of current MM problem. - while tile >= last_mm_tile and tile < last_mm_tile + num_tiles: + while last_mm_tile <= tile < last_mm_tile + num_tiles: # Figure out tile coordinates in current MM problem. tile_in_mm = tile - last_mm_tile tl.device_assert(tile_in_mm >= 0, "tile_in_mm < 0") diff --git a/transformer_engine/pytorch/triton_kernels/gmm/gmm_wrapper.py b/transformer_engine/pytorch/triton_kernels/gmm/gmm_wrapper.py index 5b8a5acc9..e613e6b75 100644 --- a/transformer_engine/pytorch/triton_kernels/gmm/gmm_wrapper.py +++ b/transformer_engine/pytorch/triton_kernels/gmm/gmm_wrapper.py @@ -56,7 +56,7 @@ def _gmm_grid( math.ceil(gs / block_size_m) for gs in group_sizes_list ) * num_n_tiles num_programs = min(grid_dim, num_tiles) - + return (num_programs,) @@ -163,13 +163,13 @@ def gmm( if key.startswith("BLOCK_SIZE_") else config[key] > 0 ) - for key in { + for key in ( "BLOCK_SIZE_M", "BLOCK_SIZE_K", "BLOCK_SIZE_N", "GROUP_SIZE", "GRID_DIM", - } + ) ), "Invalid GMM kernel config." group_sizes_list = group_sizes_list if group_sizes_list is not None else group_sizes.tolist() @@ -337,13 +337,13 @@ def ptgmm( if key.startswith("BLOCK_SIZE_") else config[key] > 0 ) - for key in { + for key in ( "BLOCK_SIZE_M", "BLOCK_SIZE_K", "BLOCK_SIZE_N", "GROUP_SIZE", "GRID_DIM", - } + ) ), "Invalid PTGMM kernel config." # Bias gradient handling. @@ -531,12 +531,12 @@ def nptgmm( if key.startswith("BLOCK_SIZE_") else config[key] > 0 ) - for key in { + for key in ( "BLOCK_SIZE_M", "BLOCK_SIZE_K", "BLOCK_SIZE_N", "GROUP_SIZE", - } + ) ), "Invalid NPTGMM kernel config." grid = _nptgmm_grid( diff --git a/transformer_engine/pytorch/triton_kernels/grouped_gemm.py b/transformer_engine/pytorch/triton_kernels/grouped_gemm.py index 84cac374c..286711856 100644 --- a/transformer_engine/pytorch/triton_kernels/grouped_gemm.py +++ b/transformer_engine/pytorch/triton_kernels/grouped_gemm.py @@ -3,20 +3,14 @@ """Triton kernels for grouped GEMM""" -import triton -import triton.language as tl import torch -from typing import Iterable, Optional, Tuple, Union, List -import functools -import json -import os.path -import sys -from pathlib import Path - -from .gmm.gmm_wrapper import gmm, ptgmm, nptgmm +from typing import List, Optional + +from .gmm.gmm_wrapper import gmm, ptgmm import transformer_engine_torch as tex -def general_grouped_gemm_triton( + +def general_grouped_gemm_triton( # pylint: disable=unused-argument A: List[torch.Tensor], B: List[torch.Tensor], out: List[torch.Tensor], @@ -36,12 +30,12 @@ def general_grouped_gemm_triton( ) -> list: """ Drop-in replacement for general_grouped_gemm using AITER's Triton kernels. - + Supports: - Forward pass (layout="TN"): C = B @ A^T (where A=weights, B=inputs, C=outputs) - Backward pass dgrad (layout="NN", grad=True): C = B @ A (where A=weights, B=grad_output, C=dgrad) - Backward pass wgrad (layout="NT", grad=True): C = B^T @ A (where A=inputs, B=grad_output, C=wgrad) - + Args: A: Left-hand side matrices (weights for forward/dgrad, inputs for wgrad) B: Right-hand side matrices (inputs for forward, grad_outputs for backward) @@ -56,7 +50,7 @@ def general_grouped_gemm_triton( layout: "TN" for forward pass, "NN" for dgrad backward pass, "NT" for wgrad backward pass grad: True for backward pass accumulate: Whether to accumulate into C (for wgrad only) - + Returns: Tuple of (outputs, bias_or_grad_bias, gelu_input) to match C++ backend signature - bias_or_grad_bias: List of bias/grad_bias tensors (or list of bias if passed in) @@ -70,21 +64,21 @@ def general_grouped_gemm_triton( # Triton kernel needs GPU tensor if m_splits_tensor is None: m_splits_tensor = torch.tensor(m_splits, dtype=torch.int32, device=out[0].device) - + if is_wgrad: # WGRAD: ptgmm expects lhs=(K,M), rhs=(M,N), out=(G,K,N) # A=inputs (list of (m_i, in_features)), B=grad_outputs (list of (m_i, out_features)) A_tensor = A[0] if len(A) == 1 else torch.cat(A, dim=0) # (M, in_features) B_tensor = B[0] if len(B) == 1 else torch.cat(B, dim=0) # (M, out_features) out_tensor_3d = out # (G, out_features, in_features) - + # Allocate bias_grad OUTPUT buffer if needed (kernel writes to this) bias_grad_tensor = None if use_bias: G = len(m_splits) K = B_tensor.shape[1] # out_features bias_grad_tensor = torch.zeros(G, K, dtype=torch.float32, device=B_tensor.device) - + # Backward pass: C = B^T @ A (wgrad = grad_output^T @ input) # ptgmm expects lhs shape (K, M), so we need to transpose ptgmm( @@ -97,29 +91,29 @@ def general_grouped_gemm_triton( bias_grad=bias_grad_tensor, # OUTPUT: (G, out_features) or None accumulate=accumulate, ) - + # Convert bias_grad to list to match C++ backend signature if use_bias and bias_grad_tensor is not None: grad_biases = list(torch.unbind(bias_grad_tensor, dim=0)) else: grad_biases = [None] * len(out) if bias is None else bias - + # Return appropriate output format return_out = out_tensor_3d.view(-1, out_tensor_3d.shape[-1]) if single_output else out return return_out, grad_biases, None - elif is_dgrad: + if is_dgrad: # DGRAD: gmm expects lhs=(M,K), rhs=(G,K,N), out=(M,N) # A=weights (list of (out_features, in_features)), B=grad_outputs (list of (m_i, out_features)) A_tensor_3d = torch.stack(A, dim=0) # (G, out_features, in_features) B_tensor = B[0] if len(B) == 1 else torch.cat(B, dim=0) # (M, out_features) out_tensor = out[0] if len(out) == 1 else torch.cat(out, dim=0) # (M, in_features) - + # Stack bias into 3D if provided bias_tensor = None if bias is not None and len(bias) > 0 and bias[0].numel() > 0: bias_tensor = torch.stack(bias, dim=0) # (G, in_features) - + # Backward pass: C = B @ A (dgrad = grad_output @ weight) gmm( lhs=B_tensor, # (M, out_features) @@ -131,36 +125,35 @@ def general_grouped_gemm_triton( bias=bias_tensor, group_sizes_list=m_splits, ) - - grad_biases = [None] * len(m_splits) if bias is None else bias - return_out = out_tensor if single_output else out - return return_out, grad_biases, None - - else: - # FORWARD: gmm expects lhs=(M,K), rhs=(G,K,N), out=(M,N) - # Forward pass: C = B @ A^T (output = input @ weight^T + bias) - # A=weights (list of (out_features, in_features)), B=inputs (list of (m_i, in_features)) - A_tensor_3d = torch.stack(A, dim=0) # (G, out_features, in_features) - A_tensor_3d = A_tensor_3d.transpose(1, 2) # (G, in_features, out_features) for TN layout - B_tensor = B[0] if len(B) == 1 else torch.cat(B, dim=0) # (M, in_features) - out_tensor = out[0] if len(out) == 1 else torch.cat(out, dim=0) # (M, out_features) - - # Stack bias into 3D if provided - bias_tensor = None - if bias is not None and len(bias) > 0 and bias[0].numel() > 0: - bias_tensor = torch.stack(bias, dim=0) # (G, out_features) - - gmm( - lhs=B_tensor, # (M, in_features) - rhs=A_tensor_3d, # (G, in_features, out_features) - group_sizes=m_splits_tensor, - preferred_element_type=out_dtype, - existing_out=out_tensor, # (M, out_features) - config=None, - bias=bias_tensor, - group_sizes_list=m_splits, - ) - + grad_biases = [None] * len(m_splits) if bias is None else bias return_out = out_tensor if single_output else out return return_out, grad_biases, None + + # FORWARD: gmm expects lhs=(M,K), rhs=(G,K,N), out=(M,N) + # Forward pass: C = B @ A^T (output = input @ weight^T + bias) + # A=weights (list of (out_features, in_features)), B=inputs (list of (m_i, in_features)) + A_tensor_3d = torch.stack(A, dim=0) # (G, out_features, in_features) + A_tensor_3d = A_tensor_3d.transpose(1, 2) # (G, in_features, out_features) for TN layout + B_tensor = B[0] if len(B) == 1 else torch.cat(B, dim=0) # (M, in_features) + out_tensor = out[0] if len(out) == 1 else torch.cat(out, dim=0) # (M, out_features) + + # Stack bias into 3D if provided + bias_tensor = None + if bias is not None and len(bias) > 0 and bias[0].numel() > 0: + bias_tensor = torch.stack(bias, dim=0) # (G, out_features) + + gmm( + lhs=B_tensor, # (M, in_features) + rhs=A_tensor_3d, # (G, in_features, out_features) + group_sizes=m_splits_tensor, + preferred_element_type=out_dtype, + existing_out=out_tensor, # (M, out_features) + config=None, + bias=bias_tensor, + group_sizes_list=m_splits, + ) + + grad_biases = [None] * len(m_splits) if bias is None else bias + return_out = out_tensor if single_output else out + return return_out, grad_biases, None diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index 5ecb48eb7..438022b47 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -1,7 +1,6 @@ # Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information -import torch import triton import triton.language as tl from itertools import product @@ -97,7 +96,7 @@ def _rmsnorm_fwd_triton_impl( x = tl.load(input_ptrs).to(tl.float32) g_ptrs = g_ptr + cols g = tl.load(g_ptrs).to(tl.float32) - if (ZERO_CENTERED_GAMMA): + if ZERO_CENTERED_GAMMA: g += 1 rms_norm = x * norm_factor * g output_ptrs = row_output_ptr + cols @@ -122,7 +121,7 @@ def _rmsnorm_fwd_triton_impl( x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) g_ptrs = g_ptr + cols g = tl.load(g_ptrs, mask=mask, other=0.0).to(tl.float32) - if (ZERO_CENTERED_GAMMA): + if ZERO_CENTERED_GAMMA: g += 1 rms_norm = x * norm_factor * g output_ptrs = row_output_ptr + cols @@ -154,7 +153,7 @@ def _rmsnorm_fwd_triton_impl( rsigma_output_ptr = rsigma_ptr + row_idx tl.store(rsigma_output_ptr, norm_factor) - if (ZERO_CENTERED_GAMMA): + if ZERO_CENTERED_GAMMA: g += 1 rms_norm = row * norm_factor * g @@ -219,7 +218,7 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d grad_output = tl.load(grad_output_ptrs).to(tl.float32) g_ptrs = g_ptr + cols g = tl.load(g_ptrs).to(tl.float32) - if (ZERO_CENTERED_GAMMA): + if ZERO_CENTERED_GAMMA: g += 1. grad_sum += tl.sum(grad_output * x * g, axis=0) @@ -232,7 +231,7 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d grad_output = tl.load(grad_output_ptrs, mask=mask, other=0.0).to(tl.float32) g_ptrs = g_ptr + cols g = tl.load(g_ptrs, mask=mask, other=0.0).to(tl.float32) - if (ZERO_CENTERED_GAMMA): + if ZERO_CENTERED_GAMMA: g += 1. grad_sum += tl.sum(grad_output * x * g, axis=0) @@ -254,7 +253,7 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d g_ptrs = g_ptr + cols g = tl.load(g_ptrs).to(tl.float32) - if (ZERO_CENTERED_GAMMA): + if ZERO_CENTERED_GAMMA: g += 1. grad_input = grad_output * norm_factor * g - (norm_factor * norm_factor * norm_factor) * x * (grad_sum / n_cols) @@ -280,7 +279,7 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d grad_output = tl.load(grad_output_ptrs, mask=mask, other=0.0).to(tl.float32) g_ptrs = g_ptr + cols g = tl.load(g_ptrs, mask=mask, other=0.0).to(tl.float32) - if (ZERO_CENTERED_GAMMA): + if ZERO_CENTERED_GAMMA: g += 1. grad_input = grad_output * norm_factor * g - (norm_factor * norm_factor * norm_factor) * x * (grad_sum / n_cols) @@ -315,7 +314,7 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) grad_output = tl.load(grad_output_ptrs, mask=mask, other=0.0).to(tl.float32) g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) - if (ZERO_CENTERED_GAMMA): + if ZERO_CENTERED_GAMMA: g += 1. norm_factor = tl.load(rsigma_ptr + row_idx).to(tl.float32) @@ -332,8 +331,15 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d @triton.jit -def _rmsnorm_bwd_dg_reduce_triton(dg_in_ptr, dg_out_ptr, dg_in_stride, n_rows, n_cols, BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr): +def _rmsnorm_bwd_dg_reduce_triton( + dg_in_ptr, + dg_out_ptr, + dg_in_stride, # pylint: disable=unused-argument + n_rows, + n_cols, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): # we want parallelism in N direction # if N is small, we will just use one CU, # otherwise, it can be split by N/BLOCK_SIZE @@ -348,4 +354,3 @@ def _rmsnorm_bwd_dg_reduce_triton(dg_in_ptr, dg_out_ptr, dg_in_stride, n_rows, n sum_dg = tl.sum(acc, axis=0) tl.store(dg_out_ptr + cols, sum_dg.to(dg_out_ptr.type.element_ty), mask=cols < n_cols) - diff --git a/transformer_engine/pytorch/triton_kernels/utils.py b/transformer_engine/pytorch/triton_kernels/utils.py index 15a733ce9..50cd63eb2 100644 --- a/transformer_engine/pytorch/triton_kernels/utils.py +++ b/transformer_engine/pytorch/triton_kernels/utils.py @@ -47,8 +47,8 @@ def num_programs(x, sm_margin=None): def block_size(x, norm="layer"): max_fused_size = (65536 if norm=="rms" else 16384) // x.element_size() - block_size = min(max_fused_size, triton.next_power_of_2(x.shape[1])) - return block_size + _block_size = min(max_fused_size, triton.next_power_of_2(x.shape[1])) + return _block_size def use_blocked(x): @@ -59,7 +59,9 @@ def make_ln_out(ln_out, quantizer=None, input_shape=None, out_dtype=torch.float3 if ln_out is None: # TODO(micky774): Remove corresponding FP8Quantizer check when kernels properly support MXFP8/float8_current_scaling as a fused operation - if quantizer is None or isinstance(quantizer, MXFP8Quantizer) or isinstance(quantizer, Float8CurrentScalingQuantizer): + if quantizer is None or isinstance( + quantizer, (MXFP8Quantizer, Float8CurrentScalingQuantizer) + ): return torch.empty(input_shape, dtype=out_dtype, device='cuda') return quantizer.make_empty(input_shape, dtype=out_dtype) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 7131d45e6..b58d8b294 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -454,22 +454,45 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: if IS_HIP_EXTENSION: @functools.lru_cache(maxsize=None) def is_mi200(): - """check whether this machine is mi200/210/250""" - import re - return (re.search('AMD Instinct MI2.0', torch.cuda.get_device_name(torch.cuda.current_device())) is not None) - + """Return True if the current GPU is MI200-class (MI2xx).""" + import re + + return ( + re.search( + "AMD Instinct MI2.0", + torch.cuda.get_device_name(torch.cuda.current_device()), + ) + is not None + ) + @functools.lru_cache(maxsize=None) def is_mi308(): - """check whether this machine is mi308""" - import re - return (re.search('AMD Instinct MI308', torch.cuda.get_device_name(torch.cuda.current_device())) is not None) + """Return True if the current GPU is MI308.""" + import re + + return ( + re.search( + "AMD Instinct MI308", + torch.cuda.get_device_name(torch.cuda.current_device()), + ) + is not None + ) + @functools.lru_cache(maxsize=None) -def is_fp8_fnuz(): +def is_fp8_fnuz() -> bool: + """True when using FP8 FNUZ dtypes (ROCm FP8 path).""" return IS_HIP_EXTENSION and get_device_compute_capability() == (9, 4) -get_torch_float8_e4m3_type = lambda: torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn -get_torch_float8_e5m2_type = lambda: torch.float8_e5m2fnuz if is_fp8_fnuz() else torch.float8_e5m2 + +def get_torch_float8_e4m3_type(): + """E4M3 dtype for current platform (FNUZ on ROCm gfx94x when applicable).""" + return torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn + + +def get_torch_float8_e5m2_type(): + """E5M2 dtype for current platform (FNUZ on ROCm gfx94x when applicable).""" + return torch.float8_e5m2fnuz if is_fp8_fnuz() else torch.float8_e5m2 def assert_dim_for_all_gather( tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer @@ -480,18 +503,13 @@ def assert_dim_for_all_gather( "All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__ ) -def is_bf16_compatible() -> None: +def is_bf16_compatible() -> bool: + """Whether BF16 tensor cores / ops are usable on the current device.""" if IS_HIP_EXTENSION: # only MI200 and newer machines support bf16 - if get_device_compute_capability() in [(9, 4), (9, 5)] or is_mi200(): - return True - else: - return False - else: - """Replaces torch.cuda.is_bf16_compatible() with an explicit - check on device compute capability to enforce sm_80 or higher. - """ - return torch.cuda.get_device_capability()[0] >= 8 + return get_device_compute_capability() in [(9, 4), (9, 5)] or is_mi200() + # CUDA: require sm_80 or higher (replaces torch.cuda.is_bf16_compatible heuristic). + return torch.cuda.get_device_capability()[0] >= 8 def is_bf16_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: """ From b60c7f97afe3716a1bcc249f31d074c1aea33d76 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Thu, 23 Apr 2026 03:23:48 +0000 Subject: [PATCH 5/7] Addressed reviews --- pylintrc | 7 +--- transformer_engine/common/__init__.py | 12 +++--- transformer_engine/common/gemm/rocm_gemm.cu | 13 +++---- .../common/normalization/common.cpp | 37 +++++++++---------- .../jax/cpp_extensions/attention.py | 9 +---- transformer_engine/jax/cpp_extensions/gemm.py | 4 +- .../jax/cpp_extensions/normalization.py | 8 +--- transformer_engine/jax/quantize/helper.py | 27 +++++--------- transformer_engine/jax/setup.py | 6 +-- transformer_engine/jax/util.py | 15 ++------ .../dot_product_attention.py | 3 +- .../pytorch/cpp_extensions/gemm.py | 4 +- transformer_engine/pytorch/module/__init__.py | 28 +++++++------- .../pytorch/module/grouped_linear.py | 4 +- transformer_engine/pytorch/module/linear.py | 2 +- .../pytorch/ops/basic/layer_norm.py | 8 ++-- .../pytorch/ops/basic/rmsnorm.py | 14 ++++--- transformer_engine/pytorch/ops/fuser.py | 11 +++--- .../pytorch/optimizers/fused_adam.py | 5 ++- .../pytorch/tensor/float8_blockwise_tensor.py | 3 +- .../pytorch/tensor/float8_tensor.py | 6 ++- .../pytorch/tensor/mxfp4_tensor.py | 2 +- .../pytorch/tensor/mxfp8_tensor.py | 5 ++- .../tensor/storage/mxfp8_tensor_storage.py | 5 ++- transformer_engine/pytorch/transformer.py | 3 +- .../pytorch/triton_kernels/cast.py | 11 ++++-- .../pytorch/triton_kernels/cast_transpose.py | 7 ++-- .../pytorch/triton_kernels/common.py | 3 +- .../pytorch/triton_kernels/gmm/gmm_kernels.py | 3 +- .../pytorch/triton_kernels/grouped_gemm.py | 6 ++- .../pytorch/triton_kernels/norms_common.py | 2 +- .../pytorch/triton_kernels/rmsnorm.py | 3 +- 32 files changed, 131 insertions(+), 145 deletions(-) diff --git a/pylintrc b/pylintrc index faba34d5c..11e0fbc4b 100644 --- a/pylintrc +++ b/pylintrc @@ -7,9 +7,6 @@ extension-pkg-whitelist=flash_attn_2_cuda, disable=too-many-locals, missing-module-docstring, missing-function-docstring, - wrong-import-order, - ungrouped-imports, - fixme, too-few-public-methods, too-many-public-methods, too-many-positional-arguments, @@ -37,11 +34,11 @@ disable=too-many-locals, line-too-long, too-many-return-statements, too-many-nested-blocks, - import-outside-toplevel, possibly-used-before-assignment, - wrong-import-position, + fixme, unnecessary-lambda-assignment, use-dict-literal, + redefined-outer-name, redefined-builtin [TYPECHECK] diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 531ad67aa..b60ccd14c 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -51,8 +51,8 @@ def _is_package_installed_from_wheel(package) -> bool: if not te_wheel_file: return False - with te_wheel_file.open("r") as wheel_f: - for line in wheel_f: + with te_wheel_file.open("r") as f: + for line in f: if line.startswith("Root-Is-Purelib:"): return line.strip().split(":")[1].strip().lower() == "true" return False @@ -412,16 +412,16 @@ def _load_core_library(): for rocm_path in (os.getenv("ROCM_PATH"), "/opt/rocm/core", "/opt/rocm"): if rocm_path and os.path.exists(os.path.join(rocm_path, ".info/version")): break - with open(os.path.join(rocm_path, ".info/version"), "r", encoding="utf-8") as ver_file: - rocm_version = ver_file.read().strip().split(".")[:2] + with open(os.path.join(rocm_path, ".info/version"), "r", encoding="utf-8") as f: + rocm_version = f.read().strip().split(".")[:2] # Get ROCm version from the build info file with open( Path(transformer_engine.__path__[0]).parent / "transformer_engine" / "build_info.txt", "r", encoding="utf-8", - ) as build_file: - build_info = build_file.read().split("\n") + ) as f: + build_info = f.read().split("\n") build_rocm_version = list(filter(lambda line: line.startswith("ROCM_VERSION:"), build_info)) if build_rocm_version: build_rocm_version = build_rocm_version[0].split(":")[1].strip().split('.')[:2] diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 5d0252c79..ba691fe83 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -769,17 +769,16 @@ protected: } #if HIPBLASLT_VERSION_MAJOR > 0 || HIPBLASLT_VERSION_MINOR >= 15 - if (cfg.scaling_mode < 0 || - cfg.scaling_mode >= static_cast(HIPBLASLT_MATMUL_MATRIX_SCALE_END)) { - std::cout << "[WARNING] Unsupported scaling mode at " << line << "\n"; - continue; - } + const bool scaling_mode_unsupported = + cfg.scaling_mode < 0 || + cfg.scaling_mode >= static_cast(HIPBLASLT_MATMUL_MATRIX_SCALE_END); #else - if (cfg.scaling_mode != 0) { + const bool scaling_mode_unsupported = (cfg.scaling_mode != 0); +#endif + if (scaling_mode_unsupported) { std::cout << "[WARNING] Unsupported scaling mode at " << line << "\n"; continue; } -#endif auto fp8_filter = te_fp8_fnuz() ? [](const hipDataType& val) diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index ee5876b5a..4517e5f06 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -544,27 +544,26 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan( plan = std::make_unique(NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, zero_centered_gamma, mode, training); + } else if (NormStage == NVTE_Norm_Stage::Forward) { + plan = std::make_unique>( + NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, + zero_centered_gamma, is_tuned); + } else { + plan = std::make_unique>( + NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, + zero_centered_gamma, is_tuned); } -#endif - if (!plan) { - if (NormStage == NVTE_Norm_Stage::Forward) { - plan = std::make_unique>( - NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, - zero_centered_gamma, is_tuned -#ifdef __HIP_PLATFORM_AMD__ - , mode, training -#endif - ); - } else { - plan = std::make_unique>( - NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, - zero_centered_gamma, is_tuned -#ifdef __HIP_PLATFORM_AMD__ - , mode, training -#endif - ); - } +#else + if (NormStage == NVTE_Norm_Stage::Forward) { + plan = std::make_unique>( + NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, + zero_centered_gamma, is_tuned, mode, training); + } else { + plan = std::make_unique>( + NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, + zero_centered_gamma, is_tuned, mode, training); } +#endif normalizationPlanMap.insert({key, std::move(plan)}); return normalizationPlanMap[key].get(); } diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 89510e88e..59fb1c4db 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -11,14 +11,11 @@ from functools import partial, reduce from typing import Optional, Tuple -from packaging import version - import jax import jax.numpy as jnp from jax import dtypes, lax, ffi from jax.sharding import PartitionSpec, NamedSharding -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax.experimental.custom_partitioning import SdyShardingRule +from jax.experimental.custom_partitioning import SdyShardingRule import transformer_engine_jax from transformer_engine_jax import NVTE_Fused_Attn_Backend @@ -719,8 +716,6 @@ def partition(config, mesh, arg_infos, result_infos): @staticmethod def shardy_sharding_rule(config, mesh, value_types, result_types): - if version.parse(jax.__version__) < version.parse("0.5.0"): - raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") del mesh, result_types # Keep in sync with `infer_sharding_from_operands`. @@ -1201,8 +1196,6 @@ def sharded_impl( @staticmethod def shardy_sharding_rule(config, mesh, value_types, result_types): - if version.parse(jax.__version__) < version.parse("0.5.0"): - raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") del config, mesh # Keep in sync with `infer_sharding_from_operands`. input_spec = tuple((f"…{x}",) for x in range(len(value_types))) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 94d78a0c1..8fad59136 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -26,8 +26,6 @@ get_device_compute_capability, ) -from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type - from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize @@ -56,6 +54,8 @@ dp_or_fsdp_axis_size, ) +from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type + if not is_hip_extension(): from transformer_engine_jax import ( initialize_cgemm_communicator, diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 6ef35a134..2813b4b03 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -10,14 +10,10 @@ from functools import partial, cache, reduce from typing import Optional, Union -from packaging import version - import jax import jax.numpy as jnp from jax import dtypes, ffi -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax.experimental.custom_partitioning import SdyShardingRule -from jax.experimental.custom_partitioning import BATCHING +from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec @@ -650,8 +646,6 @@ def shardy_sharding_rule( value_types, result_types, ): - if version.parse(jax.__version__) < version.parse("0.5.0"): - raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") del ( zero_centered_gamma, epsilon, diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index ebb99d5ee..06a15bdaf 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -26,8 +26,7 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict -import transformer_engine_jax as tejax - +from transformer_engine_jax import DType from transformer_engine.common.recipe import ( Recipe, DelayedScaling, @@ -44,23 +43,17 @@ with_sharding_constraint, ) -from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type -from .device_utils import get_device_compute_capability from .metadata import QuantizeMeta from .scaling_modes import ScalingMode +from .device_utils import get_device_compute_capability -if not is_hip_extension(): - get_cublasLt_version = tejax.get_cublasLt_version - get_cuda_version = tejax.get_cuda_version -else: - - def get_cublasLt_version(): - """CUDA-only; not used on ROCm code paths.""" - raise RuntimeError("get_cublasLt_version is not available on ROCm") +from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type - def get_cuda_version(): - """CUDA-only; not used on ROCm code paths.""" - raise RuntimeError("get_cuda_version is not available on ROCm") +if not is_hip_extension(): + from transformer_engine_jax import ( # pylint: disable=ungrouped-imports + get_cublasLt_version, + get_cuda_version, + ) __all__ = [ "get_global_quantize_recipe", @@ -312,8 +305,8 @@ class BaseQuantizeConfig(ABC): INITIALIZED = False MARGIN: float = 0.0 COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME - FWD_DTYPE: tejax.DType = None - BWD_DTYPE: tejax.DType = None + FWD_DTYPE: DType = None + BWD_DTYPE: DType = None FP8_2X_ACC_FPROP: bool = False FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index 574398a7b..7aa07d378 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -51,11 +51,11 @@ from build_tools.te_version import te_version from build_tools.jax import setup_jax_extension, install_requirements, test_requirements -if rocm_build(): - from build_tools.hipify.hipify import copy_hipify_tools, clear_hipify_tools_copy - from pybind11.setup_helpers import build_ext as BuildExtension +if rocm_build(): + from build_tools.hipify.hipify import copy_hipify_tools, clear_hipify_tools_copy # pylint: disable=ungrouped-imports + os.environ["NVTE_PROJECT_BUILDING"] = "1" CMakeBuildExtension = get_build_ext(BuildExtension, True) diff --git a/transformer_engine/jax/util.py b/transformer_engine/jax/util.py index 7f9d33182..6aa2390aa 100644 --- a/transformer_engine/jax/util.py +++ b/transformer_engine/jax/util.py @@ -1,6 +1,6 @@ # Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information -"""Small JAX-side helpers shared across TE JAX code (ROCm detection, FP8 dtypes).""" +"""JAX-side helpers shared across TE JAX code (ROCm detection, FP8 dtypes).""" import importlib.metadata import re @@ -32,7 +32,7 @@ def is_mi200(): @cache def is_fp8_fnuz(): - """Return True when TE core reports FP8 FNUZ usage (matches subprocess TE check).""" + """Return True when TE core reports FP8 FNUZ usage.""" if not is_hip_extension(): return False proc = subprocess.run( @@ -47,12 +47,5 @@ def is_fp8_fnuz(): ) return proc.returncode == 0 - -def get_jnp_float8_e4m3_type(): - """JAX FP8 e4m3 dtype for this platform (FNUZ on ROCm when applicable).""" - return jnp.float8_e4m3fnuz if is_fp8_fnuz() else jnp.float8_e4m3fn - - -def get_jnp_float8_e5m2_type(): - """JAX FP8 e5m2 dtype for this platform (FNUZ on ROCm when applicable).""" - return jnp.float8_e5m2fnuz if is_fp8_fnuz() else jnp.float8_e5m2 +get_jnp_float8_e4m3_type = lambda: jnp.float8_e4m3fnuz if is_fp8_fnuz() else jnp.float8_e4m3fn +get_jnp_float8_e5m2_type = lambda: jnp.float8_e5m2fnuz if is_fp8_fnuz() else jnp.float8_e5m2 diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index fa4fb9a48..b7c234a84 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -14,6 +14,7 @@ import torch from torch.nn.parameter import Parameter +from torch.utils.cpp_extension import IS_HIP_EXTENSION import transformer_engine_torch as tex from transformer_engine.common.recipe import ( @@ -63,8 +64,6 @@ FlashAttention, ) -from torch.utils.cpp_extension import IS_HIP_EXTENSION - # Setup Attention Logging attn_log.setup_logging() diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 35fae5ac1..412e90ada 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -14,8 +14,6 @@ import transformer_engine_torch as tex from ..constants import TE_DType from ..utils import get_sm_count, _empty_tensor -if IS_HIP_EXTENSION: - from ..utils import get_device_compute_capability from ..quantized_tensor import Quantizer from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage @@ -23,6 +21,8 @@ from ..custom_recipes.gemm import custom_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer +if IS_HIP_EXTENSION: + from ..utils import get_device_compute_capability __all__ = [ "general_gemm", diff --git a/transformer_engine/pytorch/module/__init__.py b/transformer_engine/pytorch/module/__init__.py index 7c76cacb8..3cf15efc1 100644 --- a/transformer_engine/pytorch/module/__init__.py +++ b/transformer_engine/pytorch/module/__init__.py @@ -1,14 +1,14 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Module level PyTorch APIs""" -from .layernorm_linear import LayerNormLinear -from .linear import Linear -from .grouped_linear import GroupedLinear -from .layernorm_mlp import LayerNormMLP -from .layernorm import LayerNorm -from .rmsnorm import RMSNorm -from .fp8_padding import Fp8Padding -from .fp8_unpadding import Fp8Unpadding -from .base import initialize_ub, destroy_ub, UserBufferQuantizationMode +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Module level PyTorch APIs""" +from .layernorm_linear import LayerNormLinear +from .linear import Linear +from .grouped_linear import GroupedLinear +from .layernorm_mlp import LayerNormMLP +from .layernorm import LayerNorm +from .rmsnorm import RMSNorm +from .fp8_padding import Fp8Padding +from .fp8_unpadding import Fp8Unpadding +from .base import initialize_ub, destroy_ub, UserBufferQuantizationMode diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 8c6aa8bde..723dc065b 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -10,7 +10,9 @@ import warnings import functools +import os import torch +from torch.utils.cpp_extension import IS_HIP_EXTENSION import transformer_engine_torch as tex @@ -54,11 +56,9 @@ ) from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_state import TEDebugState -from torch.utils.cpp_extension import IS_HIP_EXTENSION if IS_HIP_EXTENSION: from transformer_engine.pytorch.triton_kernels.grouped_gemm import general_grouped_gemm_triton - import os __all__ = ["GroupedLinear"] diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 7a6ad7ad8..6cbdec0b2 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -11,6 +11,7 @@ import warnings import torch +from torch.utils.cpp_extension import IS_HIP_EXTENSION import transformer_engine_torch as tex @@ -77,7 +78,6 @@ mark_activation_offload, ) from ...debug.pytorch.debug_state import TEDebugState -from torch.utils.cpp_extension import IS_HIP_EXTENSION __all__ = ["Linear"] diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 340c2b895..4df3a7f21 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -13,11 +13,10 @@ from typing import Optional import torch +from torch.utils.cpp_extension import IS_HIP_EXTENSION from transformer_engine_torch import layernorm_bwd, layernorm_fwd -from torch.utils.cpp_extension import IS_HIP_EXTENSION -if IS_HIP_EXTENSION: - from ...triton_kernels.norms_common import te_layernorm_fwd_triton, te_layernorm_bwd_triton + from ...constants import TE_DType from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...export import is_in_onnx_export_mode @@ -31,6 +30,9 @@ from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, maybe_dequantize +if IS_HIP_EXTENSION: + from ...triton_kernels.norms_common import te_layernorm_fwd_triton, te_layernorm_bwd_triton + class LayerNorm(BasicOperation): r"""Layer Normalization diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index a9bbeab8c..77daa99fa 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -13,14 +13,10 @@ from typing import Optional import torch +from torch.utils.cpp_extension import IS_HIP_EXTENSION from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd -from torch.utils.cpp_extension import IS_HIP_EXTENSION -if IS_HIP_EXTENSION: - from ...triton_kernels.norms_common import ( - te_rmsnorm_bwd_triton, - te_rmsnorm_fwd_triton - ) + from ...constants import TE_DType from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...export import is_in_onnx_export_mode @@ -34,6 +30,12 @@ from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, maybe_dequantize +if IS_HIP_EXTENSION: + from ...triton_kernels.norms_common import ( + te_rmsnorm_bwd_triton, + te_rmsnorm_fwd_triton, + ) + class RMSNorm(BasicOperation): r"""Root Mean Square Layer Normalization diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 86b279759..e52beb0a2 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -12,6 +12,7 @@ import itertools import torch +from torch.utils.cpp_extension import IS_HIP_EXTENSION from transformer_engine.pytorch.quantization import FP8GlobalStateManager, Recipe, DelayedScaling from transformer_engine.pytorch.ops.op import ( @@ -19,7 +20,6 @@ FusibleOperation, OperationContext, ) -from torch.utils.cpp_extension import IS_HIP_EXTENSION from transformer_engine.pytorch.ops.fused import ( fuse_backward_activation_bias, fuse_backward_add_rmsnorm, @@ -29,15 +29,16 @@ fuse_forward_linear_bias_add, fuse_forward_linear_scale_add, ) +from transformer_engine.pytorch.quantized_tensor import ( + prepare_for_saving, + restore_from_saved, +) + if not IS_HIP_EXTENSION: from transformer_engine.pytorch.ops.fused import ( fuse_userbuffers_backward_linear, fuse_userbuffers_forward_linear, ) -from transformer_engine.pytorch.quantized_tensor import ( - prepare_for_saving, - restore_from_saved, -) def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]: diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index b2b78f3eb..01909be5c 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -14,12 +14,13 @@ import torch from torch.distributed._tensor import DTensor +from torch.utils.cpp_extension import IS_HIP_EXTENSION + import transformer_engine_torch as tex from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch.quantized_tensor import QuantizedTensor -from .multi_tensor_apply import multi_tensor_applier -from torch.utils.cpp_extension import IS_HIP_EXTENSION from transformer_engine.pytorch.utils import is_fp8_fnuz +from .multi_tensor_apply import multi_tensor_applier def get_fp8_meta(fp8_tensor): diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 180cc6f25..fc8792b1a 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -11,6 +11,7 @@ from typing import Any, Optional, Tuple, Union import torch +from torch.utils.cpp_extension import IS_HIP_EXTENSION import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType @@ -20,8 +21,6 @@ from ._quantization_helpers import _IdentityFunc from ..utils import devices_match, round_up_to_nearest_multiple -from torch.utils.cpp_extension import IS_HIP_EXTENSION - aten = torch.ops.aten diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index dfbce4aca..a1a0b1f27 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -8,9 +8,13 @@ from __future__ import annotations from typing import Any, Optional, Tuple, Iterable, Union +import os import warnings + import torch from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState +from torch.utils.cpp_extension import IS_HIP_EXTENSION + import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType @@ -21,10 +25,8 @@ from ._quantization_helpers import _IdentityFunc from ..constants import dist_group_type -from torch.utils.cpp_extension import IS_HIP_EXTENSION if IS_HIP_EXTENSION: from ..triton_kernels.cast import te_quantize_triton - import os aten = torch.ops.aten diff --git a/transformer_engine/pytorch/tensor/mxfp4_tensor.py b/transformer_engine/pytorch/tensor/mxfp4_tensor.py index 8e85746a9..394d0b1bf 100644 --- a/transformer_engine/pytorch/tensor/mxfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp4_tensor.py @@ -9,7 +9,6 @@ from typing import Optional, Tuple, Union import torch -from ..triton_kernels.cast import te_quantize_triton import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType @@ -17,6 +16,7 @@ from transformer_engine.common.recipe import MXFP4BlockScaling, Recipe from ..constants import MXFP8_BLOCK_SCALING_SIZE # MXFP4 uses same block size from ..utils import devices_match, round_up_to_nearest_multiple +from ..triton_kernels.cast import te_quantize_triton from .storage.mxfp4_tensor_storage import MXFP4TensorStorage, _FromMXFP4Func from ..quantized_tensor import QuantizedTensor, Quantizer diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index c1e86d28a..72a8709dd 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -8,12 +8,15 @@ from __future__ import annotations from collections.abc import Iterable import math +import os from typing import Optional, Tuple, Union, Any import warnings import torch from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState +from torch.utils.cpp_extension import IS_HIP_EXTENSION + import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType @@ -24,9 +27,7 @@ from ..quantized_tensor import QuantizedTensor, Quantizer from ._quantization_helpers import _IdentityFunc -from torch.utils.cpp_extension import IS_HIP_EXTENSION if IS_HIP_EXTENSION: - import os from ..triton_kernels.cast import te_quantize_triton aten = torch.ops.aten diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 3f9aaa4d0..aa5c71f33 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -10,12 +10,13 @@ from typing import Optional, Dict, Any, Tuple from collections.abc import Iterable import math -import torch import os +import torch +from torch.utils.cpp_extension import IS_HIP_EXTENSION + import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from torch.utils.cpp_extension import IS_HIP_EXTENSION from ...quantized_tensor import QuantizedTensorStorage, Quantizer diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 4f131c3c0..ddfba1aed 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -11,6 +11,7 @@ from typing import Callable, List, Optional, Tuple, Union import torch +from torch.utils.cpp_extension import IS_HIP_EXTENSION from transformer_engine.pytorch.torch_version import torch_version from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm @@ -38,8 +39,6 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.module.base import TransformerEngineBaseModule -from torch.utils.cpp_extension import IS_HIP_EXTENSION - warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") diff --git a/transformer_engine/pytorch/triton_kernels/cast.py b/transformer_engine/pytorch/triton_kernels/cast.py index 0320ab8d8..f2df55edd 100644 --- a/transformer_engine/pytorch/triton_kernels/cast.py +++ b/transformer_engine/pytorch/triton_kernels/cast.py @@ -7,11 +7,16 @@ import functools import torch -from ..utils import is_non_tn_fp8_gemm_supported +import transformer_engine_torch as tex +from ..utils import is_non_tn_fp8_gemm_supported from ..tensor.storage.float8_tensor_storage import Float8TensorStorage -from .cast_transpose import te_cast_transpose_mxfp4_triton, te_cast_transpose_mxfp8_triton, te_cast_transpose_noop_triton, te_dequantize_mxfp8_triton -import transformer_engine_torch as tex +from .cast_transpose import ( + te_cast_transpose_mxfp4_triton, + te_cast_transpose_mxfp8_triton, + te_cast_transpose_noop_triton, + te_dequantize_mxfp8_triton, +) from ..quantized_tensor import QuantizedTensor, Quantizer from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..tensor.storage.mxfp4_tensor_storage import MXFP4TensorStorage diff --git a/transformer_engine/pytorch/triton_kernels/cast_transpose.py b/transformer_engine/pytorch/triton_kernels/cast_transpose.py index eec01a62c..73d798224 100644 --- a/transformer_engine/pytorch/triton_kernels/cast_transpose.py +++ b/transformer_engine/pytorch/triton_kernels/cast_transpose.py @@ -1,17 +1,18 @@ # Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information -import torch +import os -from ..constants import MXFP8_BLOCK_SCALING_SIZE +import torch import triton import triton.language as tl + +from ..constants import MXFP8_BLOCK_SCALING_SIZE from .common import ( te_dtype_to_triton_dtype, te_dtype_to_torch_dtype, get_fp8_max, ) -import os ########################################## #### cast_transpose diff --git a/transformer_engine/pytorch/triton_kernels/common.py b/transformer_engine/pytorch/triton_kernels/common.py index e99a06063..cd8fc376c 100644 --- a/transformer_engine/pytorch/triton_kernels/common.py +++ b/transformer_engine/pytorch/triton_kernels/common.py @@ -1,11 +1,12 @@ # Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information +from functools import cache + import torch import triton import triton.language as tl import transformer_engine_torch as tex -from functools import cache @cache def get_arch(): diff --git a/transformer_engine/pytorch/triton_kernels/gmm/gmm_kernels.py b/transformer_engine/pytorch/triton_kernels/gmm/gmm_kernels.py index 4c91d4991..4cd224c48 100644 --- a/transformer_engine/pytorch/triton_kernels/gmm/gmm_kernels.py +++ b/transformer_engine/pytorch/triton_kernels/gmm/gmm_kernels.py @@ -12,13 +12,14 @@ import os.path # Triton -from ..common import get_arch import triton import triton.language as tl # PyTorch import torch +from ..common import get_arch + # AITER from .pid_preprocessing import pid_grid, remap_xcd diff --git a/transformer_engine/pytorch/triton_kernels/grouped_gemm.py b/transformer_engine/pytorch/triton_kernels/grouped_gemm.py index 286711856..cef3ec9c3 100644 --- a/transformer_engine/pytorch/triton_kernels/grouped_gemm.py +++ b/transformer_engine/pytorch/triton_kernels/grouped_gemm.py @@ -3,12 +3,14 @@ """Triton kernels for grouped GEMM""" -import torch from typing import List, Optional -from .gmm.gmm_wrapper import gmm, ptgmm +import torch + import transformer_engine_torch as tex +from .gmm.gmm_wrapper import gmm, ptgmm + def general_grouped_gemm_triton( # pylint: disable=unused-argument A: List[torch.Tensor], diff --git a/transformer_engine/pytorch/triton_kernels/norms_common.py b/transformer_engine/pytorch/triton_kernels/norms_common.py index 87cfa722e..dc8cda79c 100644 --- a/transformer_engine/pytorch/triton_kernels/norms_common.py +++ b/transformer_engine/pytorch/triton_kernels/norms_common.py @@ -1,9 +1,9 @@ # Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information +import warnings import torch import triton -import warnings import transformer_engine_torch as tex from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index 438022b47..c93d7488c 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -1,9 +1,10 @@ # Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information +from itertools import product + import triton import triton.language as tl -from itertools import product def get_autotune_config(): return [triton.Config({'waves_per_eu': we}, num_warps=nw) for (we, nw) in product([0, 1, 2, 4], [4, 8, 16])] From caceac3ee48de58b20d58a0d3d662237c7277875 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Sat, 25 Apr 2026 03:08:38 +0000 Subject: [PATCH 6/7] Addressed reviews --- transformer_engine/pytorch/module/base.py | 5 ++++- transformer_engine/pytorch/utils.py | 10 ++-------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index b859df4c5..c31762b18 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -50,6 +50,8 @@ if IS_HIP_EXTENSION: from ..tensor.fsdp2_allgather_tensor import FSDPAGTensor from ..triton_kernels.cast import te_quantize_triton + +# pylint: disable=wrong-import-position from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..utils import get_device_compute_capability, is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype @@ -58,6 +60,7 @@ from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled +# pylint: enable=wrong-import-position __all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"] @@ -664,7 +667,7 @@ def fill_userbuffers_buffer_for_all_gather( columnwise_scale_inv=columnwise_scale_inv, fp8_dtype=local_tensor._fp8_dtype, quantizer=quantizer, - with_gemm_swizzled_scales=local_tensor._with_gemm_swizzled_scales, + with_gemm_swizzled_scales=False, ) return global_tensor, local_tensor diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index b58d8b294..1d2463011 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -485,14 +485,8 @@ def is_fp8_fnuz() -> bool: return IS_HIP_EXTENSION and get_device_compute_capability() == (9, 4) -def get_torch_float8_e4m3_type(): - """E4M3 dtype for current platform (FNUZ on ROCm gfx94x when applicable).""" - return torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn - - -def get_torch_float8_e5m2_type(): - """E5M2 dtype for current platform (FNUZ on ROCm gfx94x when applicable).""" - return torch.float8_e5m2fnuz if is_fp8_fnuz() else torch.float8_e5m2 +get_torch_float8_e4m3_type = lambda: torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn +get_torch_float8_e5m2_type = lambda: torch.float8_e5m2fnuz if is_fp8_fnuz() else torch.float8_e5m2 def assert_dim_for_all_gather( tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer From 6bc015cced63ac315dc54bc0f2bea66f61c6dec6 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Mon, 27 Apr 2026 15:09:19 +0000 Subject: [PATCH 7/7] Addressed reviews for python files --- transformer_engine/pytorch/constants.py | 2 -- transformer_engine/pytorch/optimizers/fused_adam.py | 7 ++++--- transformer_engine/pytorch/quantized_tensor.py | 6 +++--- transformer_engine/pytorch/setup.py | 8 +++++--- .../pytorch/tensor/fsdp2_allgather_tensor.py | 2 +- .../pytorch/tensor/storage/mxfp8_tensor_storage.py | 3 +-- transformer_engine/pytorch/triton_kernels/grouped_gemm.py | 1 - transformer_engine/pytorch/utils.py | 1 - 8 files changed, 14 insertions(+), 16 deletions(-) diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index c6cea77b6..737633fb2 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -5,7 +5,6 @@ # See LICENSE for license information. """Enums for e2e transformer""" - import torch import torch.distributed from torch.utils.cpp_extension import IS_HIP_EXTENSION @@ -53,7 +52,6 @@ def __missing__(self, key): return value raise KeyError(key) - TE_DType_To_Torch = Custom_DType_Dict({ tex.DType.kByte: torch.uint8, tex.DType.kInt32: torch.int32, diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 01909be5c..dfbcc4c89 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -14,13 +14,14 @@ import torch from torch.distributed._tensor import DTensor -from torch.utils.cpp_extension import IS_HIP_EXTENSION - import transformer_engine_torch as tex from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch.quantized_tensor import QuantizedTensor -from transformer_engine.pytorch.utils import is_fp8_fnuz from .multi_tensor_apply import multi_tensor_applier +# pylint: disable=wrong-import-order,ungrouped-imports +from torch.utils.cpp_extension import IS_HIP_EXTENSION +from transformer_engine.pytorch.utils import is_fp8_fnuz +# pylint: enable=wrong-import-order,ungrouped-imports def get_fp8_meta(fp8_tensor): diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index cf9bdc6a5..f66b9e211 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -7,11 +7,11 @@ """Pure Python base classes for quantization.""" from __future__ import annotations - +import os +from typing import Optional, Tuple, Iterable, Any, Dict, Union import abc -import math import warnings -from typing import Any, Dict, Iterable, Optional, Tuple, Union +import math import torch from torch.utils._pytree import tree_map diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index c76f8d32c..a80dcacb8 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -130,7 +130,8 @@ def run(self): return super().run() if FORCE_BUILD: - return super().run() + super().run() + return wheel_url, wheel_filename = get_wheel_url() print("Guessing wheel URL: ", wheel_url) @@ -149,11 +150,12 @@ def run(self): wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") print("Raw wheel path", wheel_path) os.rename(wheel_filename, wheel_path) - return None + return except (urllib.error.HTTPError, urllib.error.URLError): print("Precompiled wheel not found. Building from source...") # If the wheel could not be downloaded, build from source - return super().run() + super().run() + return if __name__ == "__main__": diff --git a/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py b/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py index b17ce4ce6..3974cc476 100644 --- a/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py +++ b/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py @@ -28,7 +28,7 @@ class FSDPAGTensor(torch.Tensor): - """Tensor subclass carrying FSDP metadata for quantized all-gather.""" + """A wrapper subclass for stateful FSDP transport""" @staticmethod def __new__(cls, elem: torch.Tensor, **kwargs): diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index aa5c71f33..a57a92d94 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -7,11 +7,10 @@ """Mixin class holding data specific for MXFP8Tensor""" from __future__ import annotations +import os from typing import Optional, Dict, Any, Tuple from collections.abc import Iterable import math -import os - import torch from torch.utils.cpp_extension import IS_HIP_EXTENSION diff --git a/transformer_engine/pytorch/triton_kernels/grouped_gemm.py b/transformer_engine/pytorch/triton_kernels/grouped_gemm.py index cef3ec9c3..c47ea23ba 100644 --- a/transformer_engine/pytorch/triton_kernels/grouped_gemm.py +++ b/transformer_engine/pytorch/triton_kernels/grouped_gemm.py @@ -11,7 +11,6 @@ from .gmm.gmm_wrapper import gmm, ptgmm - def general_grouped_gemm_triton( # pylint: disable=unused-argument A: List[torch.Tensor], B: List[torch.Tensor], diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 1d2463011..ca5073f94 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -484,7 +484,6 @@ def is_fp8_fnuz() -> bool: """True when using FP8 FNUZ dtypes (ROCm FP8 path).""" return IS_HIP_EXTENSION and get_device_compute_capability() == (9, 4) - get_torch_float8_e4m3_type = lambda: torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn get_torch_float8_e5m2_type = lambda: torch.float8_e5m2fnuz if is_fp8_fnuz() else torch.float8_e5m2