diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index e72cad9db1..a2863cba98 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -8,7 +8,6 @@ import pytest import torch from torch import nn -from torch.testing._internal.common_device_type import largeTensorTest import transformer_engine.pytorch as te from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8BlockScaling from transformer_engine.pytorch import MultiheadAttention, quantized_model_init, is_bf16_available @@ -1053,8 +1052,13 @@ def test_native(self): self.model_.load_state_dict(copy.deepcopy(self.model.state_dict())) - @largeTensorTest("60GB", "cuda") def test_large_tensor(self): + import gc + + gc.collect() + torch.cuda.empty_cache() + if torch.cuda.memory.mem_get_info()[0] < 60 * 1024**3: + pytest.skip("Insufficient available memory") t = torch.zeros(2359332864, dtype=torch.half, device="cuda") t2 = torch.zeros(2359332864, dtype=torch.half, device="cuda") grad = torch.randn_like(t) diff --git a/tests/pytorch/test_quantized_tensor.py b/tests/pytorch/test_quantized_tensor.py index 23ce93319b..526045e43e 100644 --- a/tests/pytorch/test_quantized_tensor.py +++ b/tests/pytorch/test_quantized_tensor.py @@ -28,6 +28,7 @@ import transformer_engine_torch as tex from references.ref_per_tensor_cs import ref_per_tensor_cs_cast +from utils import assert_close, quantization_tols # PyTorch tensor dtypes _dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] @@ -702,6 +703,63 @@ def test_shape_with_none_data( f"after setting data to None on {type(x_test).__name__}" ) + @pytest.mark.parametrize( + "quantization", + _quantization_list + (["nvfp4_2d"] if nvfp4_available else []), + ) + def test_update_nd_tensor( + self, + *, + quantization: str, + shape: Iterable[int] = (32, 4, 128), + dtype: torch.dtype = torch.bfloat16, + device: str = "cuda", + ) -> None: + """Check that an N-D quantized tensor can be updated.""" + + # Construct quantizer + if quantization == "fp8": + quantizer = Float8Quantizer( + scale=torch.ones(1, dtype=torch.float32, device=device).squeeze(), + amax=torch.zeros(1, dtype=torch.float32, device=device), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + elif quantization == "fp8_blockwise": + quantizer = Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + force_pow_2_scales=True, + amax_epsilon=0.0, + block_scaling_dim=1, + ) + elif quantization == "mxfp8": + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + elif quantization in ("nvfp4", "nvfp4_2d"): + quantizer = NVFP4Quantizer( + rowwise=True, + columnwise=True, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=(quantization == "nvfp4_2d"), + ) + quantization = "nvfp4" + else: + raise ValueError(f"Unknown quantization: {quantization}") + + # Construct quantized tensor + x = torch.randn(list(shape), dtype=dtype, device=device) + q_x = quantizer(x) + + # Update tensor + x_new = torch.randn(list(shape), dtype=dtype, device=device) + q_x.copy_(x_new) + + # Check results + assert q_x.shape == torch.Size(shape) + tols = quantization_tols(quantization) + assert_close(q_x, x_new, **tols) + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) class TestMXFP8Tensor: diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 32e44be2af..2ee18aaf57 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -113,6 +113,7 @@ def quantization_tols(name: str) -> dict[str, float]: "fp8", "fp8_delayed_scaling", "fp8_current_scaling", + "fp8_blockwise", "mxfp8", "mxfp8_block_scaling", ): diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index b06f6f5619..66bb2dc40e 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -26,6 +26,20 @@ std::vector convert_shape_back_from_fp4(const std::vector& shape return ret; } +std::array get_2d_dims(NVTEShape shape, bool transpose) { + if (!transpose) { + size_t flat_first = 1; + for (size_t i = 0; i + 1 < shape.ndim; ++i) flat_first *= shape.data[i]; + const size_t flat_last = shape.ndim > 0 ? shape.data[shape.ndim - 1] : 1; + return {flat_first, flat_last}; + } else { + const size_t flat_first = shape.ndim > 0 ? shape.data[0] : 1; + size_t flat_last = 1; + for (size_t i = 1; i < shape.ndim; ++i) flat_last *= shape.data[i]; + return {flat_first, flat_last}; + } +} + std::vector getTensorShape(const at::Tensor& t) { std::vector shape; for (auto s : t.sizes()) { diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 8f5b8294e8..35a459351b 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -45,6 +45,7 @@ #include #include +#include #include #include #include @@ -523,6 +524,21 @@ NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose); +// Flatten an N-D shape to 2D: {product(shape[:-1]), shape[-1]}. +// With transpose=true: {shape[0], product(shape[1:])}. +std::array get_2d_dims(NVTEShape shape, bool transpose = false); + +template +inline std::array get_2d_dims(const std::vector& shape, bool transpose = false) { + NVTEShape s{}; + s.ndim = shape.size(); + constexpr size_t max_ndim = sizeof(s.data) / sizeof(size_t); + NVTE_CHECK(s.ndim <= max_ndim, "Shape has too many dimensions (got ", s.ndim, ", max ", max_ndim, + ")."); + for (size_t i = 0; i < shape.size(); ++i) s.data[i] = static_cast(shape[i]); + return get_2d_dims(s, transpose); +} + // unpack the PhiloxCudaState into CUDA tensor void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 9e1f381bfe..00f4383ab6 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1243,14 +1243,8 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, // Create a wrapper for the columnwise output, as the rowwise output. Input is in transposed layout. TensorWrapper out_transpose(output_list[i].scaling_mode()); if (!is_empty_split) { - auto colwise_data_shape = out_columnwise_data.shape; - std::vector colwise_data_shape_2d; - colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); - size_t last_dim = 1; - for (size_t j = 1; j < colwise_data_shape.ndim; ++j) { - last_dim *= colwise_data_shape.data[j]; - } - colwise_data_shape_2d.push_back(last_dim); + auto [cw_first, cw_last] = get_2d_dims(out_columnwise_data.shape, true); + std::vector colwise_data_shape_2d = {cw_first, cw_last}; out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, static_cast(out_columnwise_data.dtype), diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 427eb7934e..9cb1fb7f54 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -41,10 +41,8 @@ bool is_low_precision(const DType type) { std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool transa, const NVTEShape& B_shape, const bool transb) { // Flatten outer dims to get 2D matrices - const size_t A0 = A_shape.ndim > 0 ? product(A_shape, 0, A_shape.ndim - 1) : 1; - const size_t A1 = A_shape.ndim > 0 ? A_shape.data[A_shape.ndim - 1] : 1; - const size_t B0 = B_shape.ndim > 0 ? product(B_shape, 0, B_shape.ndim - 1) : 1; - const size_t B1 = B_shape.ndim > 0 ? B_shape.data[B_shape.ndim - 1] : 1; + const auto [A0, A1] = get_2d_dims(A_shape); + const auto [B0, B1] = get_2d_dims(B_shape); // Check matrix dims NVTE_CHECK((transa ? A1 : A0) == (transb ? B0 : B1), "Invalid matrix dimensions for GEMM (A=(", diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 4887b59c28..c3dec944e4 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -80,8 +80,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Tensor dimensions const auto shape = nvte_shape_to_vector(input_nvte.shape()); - const auto outer_size = product(shape) / shape.back(); - const auto inner_size = shape.back(); + const auto [outer_size, inner_size] = get_2d_dims(shape); // Tensors to save for backward pass at::Tensor mu_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); @@ -320,8 +319,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Tensor dimensions const auto shape = nvte_shape_to_vector(input_nvte.shape()); - const auto outer_size = product(shape) / shape.back(); - const auto inner_size = shape.back(); + const auto [outer_size, inner_size] = get_2d_dims(shape); // Tensors to save for backward pass at::Tensor rsigma_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index 7f7f8a4351..193aed29e6 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -301,22 +301,8 @@ at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapp "Input tensor must be a block scaling tensor"); // Get tensor data - NVTEBasicTensor data; - size_t data_flat_first_dim = 1; - size_t data_flat_last_dim = 1; - if (rowwise) { - data = input.get_rowwise_data(); - for (size_t i = 0; i < data.shape.ndim - 1; ++i) { - data_flat_first_dim *= data.shape.data[i]; - } - data_flat_last_dim = data.shape.data[data.shape.ndim - 1]; - } else { - data = input.get_columnwise_data(); - data_flat_first_dim = data.shape.data[0]; - for (size_t i = 1; i < data.shape.ndim; ++i) { - data_flat_last_dim *= data.shape.data[i]; - } - } + NVTEBasicTensor data = rowwise ? input.get_rowwise_data() : input.get_columnwise_data(); + const auto [data_flat_first_dim, data_flat_last_dim] = get_2d_dims(data.shape, !rowwise); NVTEShape data_shape{}; data_shape.data[0] = data_flat_first_dim; data_shape.data[1] = data_flat_last_dim; diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index aaa27a104a..0318978195 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -29,8 +29,7 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional 0 ? product(shape) / shape.back() : 1; - const size_t N = shape.size() > 0 ? shape.back() : 1; + const auto [M, N] = get_2d_dims(shape); // Output tensor at::Tensor out; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 8f2de325ae..2b29f260e7 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1342,13 +1342,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve // Tensor dimensions const std::vector shape_int64(shape.begin(), shape.end()); - size_t flat_first_dim = 1; - if (shape.size() > 0) { - for (size_t i = 0; i < shape.size() - 1; ++i) { - flat_first_dim *= shape[i]; - } - } - const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0, "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); @@ -1736,13 +1730,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve // Tensor dimensions const std::vector shape_int64(shape.begin(), shape.end()); - size_t flat_first_dim = 1; - if (shape.size() > 0) { - for (size_t i = 0; i < shape.size() - 1; ++i) { - flat_first_dim *= shape[i]; - } - } - const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, "First dim for NVFP4 must be divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, @@ -2040,24 +2028,18 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( // Tensor dimensions, shape means original shape std::vector shape; - if (columnwise_data) { - shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); - if (rowwise_data) { - auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); - NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, - ") and column-wise data (shape=", shape, ") do not match"); - } - } else { // Already checked columnwise_data_tensor == true + if (rowwise_data) { shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); - } - - size_t flat_first_dim = 1; - if (shape.size() > 0) { - for (size_t i = 0; i < shape.size() - 1; ++i) { - flat_first_dim *= shape[i]; + if (columnwise_data) { + auto col_shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); + NVTE_CHECK(get_2d_dims(shape) == get_2d_dims(col_shape), "NVFP4 row-wise data (shape=", shape, + ") and column-wise data (shape=", col_shape, ") do not match"); } + } else { + shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); } - const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + + const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); @@ -2205,24 +2187,16 @@ void NVFP4Quantizer::quantize_with_rht_unfused_helper( // NOTE: should already be populated. auto out_columnwise_amax = out.get_columnwise_amax(); + // Flatten column-wise data shape to 2D to avoid problems when + // converting between FP4 tensor shape and byte tensor shape + // (involves dividing last dim by 2). + auto [flat_first_dim, flat_last_dim] = get_2d_dims(out_columnwise_data.shape, true); + std::vector colwise_data_shape_2d = {flat_first_dim, flat_last_dim}; + // Create a wrapper for the columnwise output, as the rowwise output. // The reason is due to the input `rht_output_t` is already in the transposed layout. // Thus, we only need a rowwise quantization to generate the columnwise output. TensorWrapper out_transpose(out.scaling_mode()); - // Note: since we are faking columnwise tensor into rowwise, the flat first dim check will fail - // need to convert the shape to 2D here - auto colwise_data_shape = out_columnwise_data.shape; - std::vector colwise_data_shape_2d; - // shape could be [512, 32, 64], that's actually 512, 32, 128 because 2 FP4 take 1 byte - // the 2D shape should be [512, 32*128], but columnwise data shape expect last dim to be halved again - // so the multiple 2 get cancelled out - colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); - size_t last_dim = 1; - for (size_t i = 1; i < colwise_data_shape.ndim; ++i) { - last_dim *= colwise_data_shape.data[i]; - } - colwise_data_shape_2d.push_back(last_dim); - out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, static_cast(out_columnwise_data.dtype), colwise_data_shape_2d); @@ -2234,7 +2208,6 @@ void NVFP4Quantizer::quantize_with_rht_unfused_helper( out_columnwise_amax.shape); // Invoking fallback RHT kernel unfused. - NVTE_SCOPED_GIL_RELEASE({ // Perform the RHT(input.t), and write to rht_output_cpp.columnwise. nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0, @@ -2483,13 +2456,7 @@ void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out std::vector NVFP4Quantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { - size_t numel = 1; - for (auto s : shape) { - numel *= s; - } - - auto last_dim = shape.back(); - auto flat_first_dim = numel / last_dim; + const auto [flat_first_dim, last_dim] = get_2d_dims(shape); NVTE_CHECK(last_dim % NVFP4_BLOCK_SIZE == 0, "Last dim for NVFP4 must be divisible by ", NVFP4_BLOCK_SIZE, " (got dim=", last_dim, ")");