From 56ff4c6331d0cd8dd1ea86c9dff38ea06b6599ed Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 8 May 2026 17:16:26 -0700 Subject: [PATCH 1/2] [PyTorch] Remove internal PyTorch testing helper (#2969) * Remove internal PyTorch testing helper Signed-off-by: Tim Moon * Review suggestion from @greptile-apps Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_fused_optimizer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index e72cad9db1..a2863cba98 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -8,7 +8,6 @@ import pytest import torch from torch import nn -from torch.testing._internal.common_device_type import largeTensorTest import transformer_engine.pytorch as te from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8BlockScaling from transformer_engine.pytorch import MultiheadAttention, quantized_model_init, is_bf16_available @@ -1053,8 +1052,13 @@ def test_native(self): self.model_.load_state_dict(copy.deepcopy(self.model.state_dict())) - @largeTensorTest("60GB", "cuda") def test_large_tensor(self): + import gc + + gc.collect() + torch.cuda.empty_cache() + if torch.cuda.memory.mem_get_info()[0] < 60 * 1024**3: + pytest.skip("Insufficient available memory") t = torch.zeros(2359332864, dtype=torch.half, device="cuda") t2 = torch.zeros(2359332864, dtype=torch.half, device="cuda") grad = torch.randn_like(t) From 0e289534985c95192d2e48e6d2447f7c53feff2f Mon Sep 17 00:00:00 2001 From: Zhang Haitao Date: Sat, 9 May 2026 08:44:08 +0800 Subject: [PATCH 2/2] Fix nvfp4 convert_and_update_tensor shape check (#2670) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix nvfp4 convert_and_update_tensor shape check Signed-off-by: 乙划 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add headers and check 2D shapes Signed-off-by: 乙划 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Przemyslaw Tredak * add unittest and doctring Signed-off-by: 乙划 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [PyTorch] Fix NVFP4 shape check for N-D tensors in convert_and_update_tensor Introduce get_2d_dims() in common.h/cpp to flatten an N-D shape to 2D dims (flat_first, flat_last), replacing the ad-hoc compressShapeTo2D helper from the contributor PR. The helper takes NVTEShape as its core argument (stack-allocated) with a header-only vector overload, and supports a transpose flag for the shape[1:] flattening direction. Use get_2d_dims in NVFP4Quantizer::convert_and_update_tensor to compare row-wise and column-wise shapes under 2D equivalence — fixing a false mismatch when the logical shape is 3D (columnwise data is always stored 2D). Also restructure the if-block to treat row-wise data as the ground truth when present. Fixes #2607 Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Tim Moon * [PyTorch] Add test for updating N-D quantized tensors via copy_ Replaces the NVFP4-only test_nvfp4_3d_shape_quantization in test_nvfp4_quantize_exact.py with a broader test_update_nd_tensor in TestQuantizedTensor that covers all quantization formats. The test constructs an N-D quantized tensor, updates it with copy_, and checks both shape preservation and numerical accuracy. The "nvfp4_2d" variant is appended to the parametrize list inline to cover both NVFP4 quantization modes without affecting the shared _quantization_list. Also adds "fp8_blockwise" to quantization_tols in utils.py. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Tim Moon * [PyTorch] Propagate get_2d_dims helper across C++ extensions Replace ad-hoc loops computing flat_first_dim/flat_last_dim and equivalent product(shape)/shape.back() patterns in quantizer.cpp, cast.cpp, gemm.cpp, normalization.cpp, swizzle.cpp, and transpose.cpp. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Tim Moon * [PyTorch] Drop intermediate variable in cast.cpp get_2d_dims call Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Check that tensor shape is not too large Suggestion from @greptile-apps. Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: 乙划 Signed-off-by: Przemyslaw Tredak Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: 乙划 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by: Tim Moon Co-authored-by: Claude Sonnet 4.6 Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_quantized_tensor.py | 58 ++++++++++++++++ tests/pytorch/utils.py | 1 + transformer_engine/pytorch/csrc/common.cpp | 14 ++++ transformer_engine/pytorch/csrc/common.h | 16 +++++ .../pytorch/csrc/extensions/cast.cpp | 10 +-- .../pytorch/csrc/extensions/gemm.cpp | 6 +- .../pytorch/csrc/extensions/normalization.cpp | 6 +- .../pytorch/csrc/extensions/swizzle.cpp | 18 +---- .../pytorch/csrc/extensions/transpose.cpp | 3 +- transformer_engine/pytorch/csrc/quantizer.cpp | 69 +++++-------------- 10 files changed, 116 insertions(+), 85 deletions(-) diff --git a/tests/pytorch/test_quantized_tensor.py b/tests/pytorch/test_quantized_tensor.py index 23ce93319b..526045e43e 100644 --- a/tests/pytorch/test_quantized_tensor.py +++ b/tests/pytorch/test_quantized_tensor.py @@ -28,6 +28,7 @@ import transformer_engine_torch as tex from references.ref_per_tensor_cs import ref_per_tensor_cs_cast +from utils import assert_close, quantization_tols # PyTorch tensor dtypes _dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] @@ -702,6 +703,63 @@ def test_shape_with_none_data( f"after setting data to None on {type(x_test).__name__}" ) + @pytest.mark.parametrize( + "quantization", + _quantization_list + (["nvfp4_2d"] if nvfp4_available else []), + ) + def test_update_nd_tensor( + self, + *, + quantization: str, + shape: Iterable[int] = (32, 4, 128), + dtype: torch.dtype = torch.bfloat16, + device: str = "cuda", + ) -> None: + """Check that an N-D quantized tensor can be updated.""" + + # Construct quantizer + if quantization == "fp8": + quantizer = Float8Quantizer( + scale=torch.ones(1, dtype=torch.float32, device=device).squeeze(), + amax=torch.zeros(1, dtype=torch.float32, device=device), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + elif quantization == "fp8_blockwise": + quantizer = Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + force_pow_2_scales=True, + amax_epsilon=0.0, + block_scaling_dim=1, + ) + elif quantization == "mxfp8": + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + elif quantization in ("nvfp4", "nvfp4_2d"): + quantizer = NVFP4Quantizer( + rowwise=True, + columnwise=True, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=(quantization == "nvfp4_2d"), + ) + quantization = "nvfp4" + else: + raise ValueError(f"Unknown quantization: {quantization}") + + # Construct quantized tensor + x = torch.randn(list(shape), dtype=dtype, device=device) + q_x = quantizer(x) + + # Update tensor + x_new = torch.randn(list(shape), dtype=dtype, device=device) + q_x.copy_(x_new) + + # Check results + assert q_x.shape == torch.Size(shape) + tols = quantization_tols(quantization) + assert_close(q_x, x_new, **tols) + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) class TestMXFP8Tensor: diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 32e44be2af..2ee18aaf57 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -113,6 +113,7 @@ def quantization_tols(name: str) -> dict[str, float]: "fp8", "fp8_delayed_scaling", "fp8_current_scaling", + "fp8_blockwise", "mxfp8", "mxfp8_block_scaling", ): diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index b06f6f5619..66bb2dc40e 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -26,6 +26,20 @@ std::vector convert_shape_back_from_fp4(const std::vector& shape return ret; } +std::array get_2d_dims(NVTEShape shape, bool transpose) { + if (!transpose) { + size_t flat_first = 1; + for (size_t i = 0; i + 1 < shape.ndim; ++i) flat_first *= shape.data[i]; + const size_t flat_last = shape.ndim > 0 ? shape.data[shape.ndim - 1] : 1; + return {flat_first, flat_last}; + } else { + const size_t flat_first = shape.ndim > 0 ? shape.data[0] : 1; + size_t flat_last = 1; + for (size_t i = 1; i < shape.ndim; ++i) flat_last *= shape.data[i]; + return {flat_first, flat_last}; + } +} + std::vector getTensorShape(const at::Tensor& t) { std::vector shape; for (auto s : t.sizes()) { diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 8f5b8294e8..35a459351b 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -45,6 +45,7 @@ #include #include +#include #include #include #include @@ -523,6 +524,21 @@ NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose); +// Flatten an N-D shape to 2D: {product(shape[:-1]), shape[-1]}. +// With transpose=true: {shape[0], product(shape[1:])}. +std::array get_2d_dims(NVTEShape shape, bool transpose = false); + +template +inline std::array get_2d_dims(const std::vector& shape, bool transpose = false) { + NVTEShape s{}; + s.ndim = shape.size(); + constexpr size_t max_ndim = sizeof(s.data) / sizeof(size_t); + NVTE_CHECK(s.ndim <= max_ndim, "Shape has too many dimensions (got ", s.ndim, ", max ", max_ndim, + ")."); + for (size_t i = 0; i < shape.size(); ++i) s.data[i] = static_cast(shape[i]); + return get_2d_dims(s, transpose); +} + // unpack the PhiloxCudaState into CUDA tensor void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 9e1f381bfe..00f4383ab6 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1243,14 +1243,8 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, // Create a wrapper for the columnwise output, as the rowwise output. Input is in transposed layout. TensorWrapper out_transpose(output_list[i].scaling_mode()); if (!is_empty_split) { - auto colwise_data_shape = out_columnwise_data.shape; - std::vector colwise_data_shape_2d; - colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); - size_t last_dim = 1; - for (size_t j = 1; j < colwise_data_shape.ndim; ++j) { - last_dim *= colwise_data_shape.data[j]; - } - colwise_data_shape_2d.push_back(last_dim); + auto [cw_first, cw_last] = get_2d_dims(out_columnwise_data.shape, true); + std::vector colwise_data_shape_2d = {cw_first, cw_last}; out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, static_cast(out_columnwise_data.dtype), diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 427eb7934e..9cb1fb7f54 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -41,10 +41,8 @@ bool is_low_precision(const DType type) { std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool transa, const NVTEShape& B_shape, const bool transb) { // Flatten outer dims to get 2D matrices - const size_t A0 = A_shape.ndim > 0 ? product(A_shape, 0, A_shape.ndim - 1) : 1; - const size_t A1 = A_shape.ndim > 0 ? A_shape.data[A_shape.ndim - 1] : 1; - const size_t B0 = B_shape.ndim > 0 ? product(B_shape, 0, B_shape.ndim - 1) : 1; - const size_t B1 = B_shape.ndim > 0 ? B_shape.data[B_shape.ndim - 1] : 1; + const auto [A0, A1] = get_2d_dims(A_shape); + const auto [B0, B1] = get_2d_dims(B_shape); // Check matrix dims NVTE_CHECK((transa ? A1 : A0) == (transb ? B0 : B1), "Invalid matrix dimensions for GEMM (A=(", diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 4887b59c28..c3dec944e4 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -80,8 +80,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Tensor dimensions const auto shape = nvte_shape_to_vector(input_nvte.shape()); - const auto outer_size = product(shape) / shape.back(); - const auto inner_size = shape.back(); + const auto [outer_size, inner_size] = get_2d_dims(shape); // Tensors to save for backward pass at::Tensor mu_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); @@ -320,8 +319,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Tensor dimensions const auto shape = nvte_shape_to_vector(input_nvte.shape()); - const auto outer_size = product(shape) / shape.back(); - const auto inner_size = shape.back(); + const auto [outer_size, inner_size] = get_2d_dims(shape); // Tensors to save for backward pass at::Tensor rsigma_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index 7f7f8a4351..193aed29e6 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -301,22 +301,8 @@ at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapp "Input tensor must be a block scaling tensor"); // Get tensor data - NVTEBasicTensor data; - size_t data_flat_first_dim = 1; - size_t data_flat_last_dim = 1; - if (rowwise) { - data = input.get_rowwise_data(); - for (size_t i = 0; i < data.shape.ndim - 1; ++i) { - data_flat_first_dim *= data.shape.data[i]; - } - data_flat_last_dim = data.shape.data[data.shape.ndim - 1]; - } else { - data = input.get_columnwise_data(); - data_flat_first_dim = data.shape.data[0]; - for (size_t i = 1; i < data.shape.ndim; ++i) { - data_flat_last_dim *= data.shape.data[i]; - } - } + NVTEBasicTensor data = rowwise ? input.get_rowwise_data() : input.get_columnwise_data(); + const auto [data_flat_first_dim, data_flat_last_dim] = get_2d_dims(data.shape, !rowwise); NVTEShape data_shape{}; data_shape.data[0] = data_flat_first_dim; data_shape.data[1] = data_flat_last_dim; diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index aaa27a104a..0318978195 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -29,8 +29,7 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional 0 ? product(shape) / shape.back() : 1; - const size_t N = shape.size() > 0 ? shape.back() : 1; + const auto [M, N] = get_2d_dims(shape); // Output tensor at::Tensor out; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 8f2de325ae..2b29f260e7 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1342,13 +1342,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve // Tensor dimensions const std::vector shape_int64(shape.begin(), shape.end()); - size_t flat_first_dim = 1; - if (shape.size() > 0) { - for (size_t i = 0; i < shape.size() - 1; ++i) { - flat_first_dim *= shape[i]; - } - } - const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0, "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); @@ -1736,13 +1730,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve // Tensor dimensions const std::vector shape_int64(shape.begin(), shape.end()); - size_t flat_first_dim = 1; - if (shape.size() > 0) { - for (size_t i = 0; i < shape.size() - 1; ++i) { - flat_first_dim *= shape[i]; - } - } - const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, "First dim for NVFP4 must be divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, @@ -2040,24 +2028,18 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( // Tensor dimensions, shape means original shape std::vector shape; - if (columnwise_data) { - shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); - if (rowwise_data) { - auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); - NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, - ") and column-wise data (shape=", shape, ") do not match"); - } - } else { // Already checked columnwise_data_tensor == true + if (rowwise_data) { shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); - } - - size_t flat_first_dim = 1; - if (shape.size() > 0) { - for (size_t i = 0; i < shape.size() - 1; ++i) { - flat_first_dim *= shape[i]; + if (columnwise_data) { + auto col_shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); + NVTE_CHECK(get_2d_dims(shape) == get_2d_dims(col_shape), "NVFP4 row-wise data (shape=", shape, + ") and column-wise data (shape=", col_shape, ") do not match"); } + } else { + shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); } - const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + + const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); @@ -2205,24 +2187,16 @@ void NVFP4Quantizer::quantize_with_rht_unfused_helper( // NOTE: should already be populated. auto out_columnwise_amax = out.get_columnwise_amax(); + // Flatten column-wise data shape to 2D to avoid problems when + // converting between FP4 tensor shape and byte tensor shape + // (involves dividing last dim by 2). + auto [flat_first_dim, flat_last_dim] = get_2d_dims(out_columnwise_data.shape, true); + std::vector colwise_data_shape_2d = {flat_first_dim, flat_last_dim}; + // Create a wrapper for the columnwise output, as the rowwise output. // The reason is due to the input `rht_output_t` is already in the transposed layout. // Thus, we only need a rowwise quantization to generate the columnwise output. TensorWrapper out_transpose(out.scaling_mode()); - // Note: since we are faking columnwise tensor into rowwise, the flat first dim check will fail - // need to convert the shape to 2D here - auto colwise_data_shape = out_columnwise_data.shape; - std::vector colwise_data_shape_2d; - // shape could be [512, 32, 64], that's actually 512, 32, 128 because 2 FP4 take 1 byte - // the 2D shape should be [512, 32*128], but columnwise data shape expect last dim to be halved again - // so the multiple 2 get cancelled out - colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); - size_t last_dim = 1; - for (size_t i = 1; i < colwise_data_shape.ndim; ++i) { - last_dim *= colwise_data_shape.data[i]; - } - colwise_data_shape_2d.push_back(last_dim); - out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, static_cast(out_columnwise_data.dtype), colwise_data_shape_2d); @@ -2234,7 +2208,6 @@ void NVFP4Quantizer::quantize_with_rht_unfused_helper( out_columnwise_amax.shape); // Invoking fallback RHT kernel unfused. - NVTE_SCOPED_GIL_RELEASE({ // Perform the RHT(input.t), and write to rht_output_cpp.columnwise. nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0, @@ -2483,13 +2456,7 @@ void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out std::vector NVFP4Quantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { - size_t numel = 1; - for (auto s : shape) { - numel *= s; - } - - auto last_dim = shape.back(); - auto flat_first_dim = numel / last_dim; + const auto [flat_first_dim, last_dim] = get_2d_dims(shape); NVTE_CHECK(last_dim % NVFP4_BLOCK_SIZE == 0, "Last dim for NVFP4 must be divisible by ", NVFP4_BLOCK_SIZE, " (got dim=", last_dim, ")");