Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tests/pytorch/test_fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pytest
import torch
from torch import nn
from torch.testing._internal.common_device_type import largeTensorTest
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8BlockScaling
from transformer_engine.pytorch import MultiheadAttention, quantized_model_init, is_bf16_available
Expand Down Expand Up @@ -1053,8 +1052,13 @@ def test_native(self):

self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))

@largeTensorTest("60GB", "cuda")
def test_large_tensor(self):
import gc

gc.collect()
torch.cuda.empty_cache()
if torch.cuda.memory.mem_get_info()[0] < 60 * 1024**3:
pytest.skip("Insufficient available memory")
t = torch.zeros(2359332864, dtype=torch.half, device="cuda")
t2 = torch.zeros(2359332864, dtype=torch.half, device="cuda")
grad = torch.randn_like(t)
Expand Down
58 changes: 58 additions & 0 deletions tests/pytorch/test_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import transformer_engine_torch as tex

from references.ref_per_tensor_cs import ref_per_tensor_cs_cast
from utils import assert_close, quantization_tols

# PyTorch tensor dtypes
_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16]
Expand Down Expand Up @@ -702,6 +703,63 @@ def test_shape_with_none_data(
f"after setting data to None on {type(x_test).__name__}"
)

@pytest.mark.parametrize(
"quantization",
_quantization_list + (["nvfp4_2d"] if nvfp4_available else []),
)
def test_update_nd_tensor(
self,
*,
quantization: str,
shape: Iterable[int] = (32, 4, 128),
dtype: torch.dtype = torch.bfloat16,
device: str = "cuda",
) -> None:
"""Check that an N-D quantized tensor can be updated."""

# Construct quantizer
if quantization == "fp8":
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
elif quantization == "fp8_blockwise":
quantizer = Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=True,
force_pow_2_scales=True,
amax_epsilon=0.0,
block_scaling_dim=1,
)
elif quantization == "mxfp8":
quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
elif quantization in ("nvfp4", "nvfp4_2d"):
quantizer = NVFP4Quantizer(
rowwise=True,
columnwise=True,
with_rht=False,
with_post_rht_amax=False,
with_2d_quantization=(quantization == "nvfp4_2d"),
)
quantization = "nvfp4"
else:
raise ValueError(f"Unknown quantization: {quantization}")

# Construct quantized tensor
x = torch.randn(list(shape), dtype=dtype, device=device)
q_x = quantizer(x)

# Update tensor
x_new = torch.randn(list(shape), dtype=dtype, device=device)
q_x.copy_(x_new)

# Check results
assert q_x.shape == torch.Size(shape)
tols = quantization_tols(quantization)
assert_close(q_x, x_new, **tols)


@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
class TestMXFP8Tensor:
Expand Down
1 change: 1 addition & 0 deletions tests/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def quantization_tols(name: str) -> dict[str, float]:
"fp8",
"fp8_delayed_scaling",
"fp8_current_scaling",
"fp8_blockwise",
"mxfp8",
"mxfp8_block_scaling",
):
Expand Down
14 changes: 14 additions & 0 deletions transformer_engine/pytorch/csrc/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@ std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape
return ret;
}

std::array<size_t, 2> get_2d_dims(NVTEShape shape, bool transpose) {
if (!transpose) {
size_t flat_first = 1;
for (size_t i = 0; i + 1 < shape.ndim; ++i) flat_first *= shape.data[i];
const size_t flat_last = shape.ndim > 0 ? shape.data[shape.ndim - 1] : 1;
return {flat_first, flat_last};
} else {
const size_t flat_first = shape.ndim > 0 ? shape.data[0] : 1;
size_t flat_last = 1;
for (size_t i = 1; i < shape.ndim; ++i) flat_last *= shape.data[i];
return {flat_first, flat_last};
}
}

std::vector<size_t> getTensorShape(const at::Tensor& t) {
std::vector<size_t> shape;
for (auto s : t.sizes()) {
Expand Down
16 changes: 16 additions & 0 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include <transformer_engine/utils.h>

#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <array>
#include <cassert>
#include <cstring>
#include <iostream>
Expand Down Expand Up @@ -523,6 +524,21 @@ NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);

std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape, bool transpose);

// Flatten an N-D shape to 2D: {product(shape[:-1]), shape[-1]}.
// With transpose=true: {shape[0], product(shape[1:])}.
std::array<size_t, 2> get_2d_dims(NVTEShape shape, bool transpose = false);

template <typename T>
inline std::array<size_t, 2> get_2d_dims(const std::vector<T>& shape, bool transpose = false) {
NVTEShape s{};
s.ndim = shape.size();
constexpr size_t max_ndim = sizeof(s.data) / sizeof(size_t);
NVTE_CHECK(s.ndim <= max_ndim, "Shape has too many dimensions (got ", s.ndim, ", max ", max_ndim,
").");
for (size_t i = 0; i < shape.size(); ++i) s.data[i] = static_cast<size_t>(shape[i]);
return get_2d_dims(s, transpose);
}

// unpack the PhiloxCudaState into CUDA tensor
void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr);

Expand Down
10 changes: 2 additions & 8 deletions transformer_engine/pytorch/csrc/extensions/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1243,14 +1243,8 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input,
// Create a wrapper for the columnwise output, as the rowwise output. Input is in transposed layout.
TensorWrapper out_transpose(output_list[i].scaling_mode());
if (!is_empty_split) {
auto colwise_data_shape = out_columnwise_data.shape;
std::vector<size_t> colwise_data_shape_2d;
colwise_data_shape_2d.push_back(colwise_data_shape.data[0]);
size_t last_dim = 1;
for (size_t j = 1; j < colwise_data_shape.ndim; ++j) {
last_dim *= colwise_data_shape.data[j];
}
colwise_data_shape_2d.push_back(last_dim);
auto [cw_first, cw_last] = get_2d_dims(out_columnwise_data.shape, true);
std::vector<size_t> colwise_data_shape_2d = {cw_first, cw_last};

out_transpose.set_rowwise_data(out_columnwise_data.data_ptr,
static_cast<DType>(out_columnwise_data.dtype),
Expand Down
6 changes: 2 additions & 4 deletions transformer_engine/pytorch/csrc/extensions/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@ bool is_low_precision(const DType type) {
std::vector<size_t> getGemmOutputShape(const NVTEShape& A_shape, const bool transa,
const NVTEShape& B_shape, const bool transb) {
// Flatten outer dims to get 2D matrices
const size_t A0 = A_shape.ndim > 0 ? product(A_shape, 0, A_shape.ndim - 1) : 1;
const size_t A1 = A_shape.ndim > 0 ? A_shape.data[A_shape.ndim - 1] : 1;
const size_t B0 = B_shape.ndim > 0 ? product(B_shape, 0, B_shape.ndim - 1) : 1;
const size_t B1 = B_shape.ndim > 0 ? B_shape.data[B_shape.ndim - 1] : 1;
const auto [A0, A1] = get_2d_dims(A_shape);
const auto [B0, B1] = get_2d_dims(B_shape);

// Check matrix dims
NVTE_CHECK((transa ? A1 : A0) == (transb ? B0 : B1), "Invalid matrix dimensions for GEMM (A=(",
Expand Down
6 changes: 2 additions & 4 deletions transformer_engine/pytorch/csrc/extensions/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe

// Tensor dimensions
const auto shape = nvte_shape_to_vector(input_nvte.shape());
const auto outer_size = product(shape) / shape.back();
const auto inner_size = shape.back();
const auto [outer_size, inner_size] = get_2d_dims(shape);

// Tensors to save for backward pass
at::Tensor mu_py = at::empty({static_cast<int64_t>(outer_size)}, at::CUDA(at::kFloat));
Expand Down Expand Up @@ -320,8 +319,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w

// Tensor dimensions
const auto shape = nvte_shape_to_vector(input_nvte.shape());
const auto outer_size = product(shape) / shape.back();
const auto inner_size = shape.back();
const auto [outer_size, inner_size] = get_2d_dims(shape);

// Tensors to save for backward pass
at::Tensor rsigma_py = at::empty({static_cast<int64_t>(outer_size)}, at::CUDA(at::kFloat));
Expand Down
18 changes: 2 additions & 16 deletions transformer_engine/pytorch/csrc/extensions/swizzle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,22 +301,8 @@ at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapp
"Input tensor must be a block scaling tensor");

// Get tensor data
NVTEBasicTensor data;
size_t data_flat_first_dim = 1;
size_t data_flat_last_dim = 1;
if (rowwise) {
data = input.get_rowwise_data();
for (size_t i = 0; i < data.shape.ndim - 1; ++i) {
data_flat_first_dim *= data.shape.data[i];
}
data_flat_last_dim = data.shape.data[data.shape.ndim - 1];
} else {
data = input.get_columnwise_data();
data_flat_first_dim = data.shape.data[0];
for (size_t i = 1; i < data.shape.ndim; ++i) {
data_flat_last_dim *= data.shape.data[i];
}
}
NVTEBasicTensor data = rowwise ? input.get_rowwise_data() : input.get_columnwise_data();
const auto [data_flat_first_dim, data_flat_last_dim] = get_2d_dims(data.shape, !rowwise);
NVTEShape data_shape{};
data_shape.data[0] = data_flat_first_dim;
data_shape.data[1] = data_flat_last_dim;
Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/pytorch/csrc/extensions/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor
transpose_shape_int64.push_back(shape[i]);
}
}
const size_t M = shape.size() > 0 ? product(shape) / shape.back() : 1;
const size_t N = shape.size() > 0 ? shape.back() : 1;
const auto [M, N] = get_2d_dims(shape);

// Output tensor
at::Tensor out;
Expand Down
69 changes: 18 additions & 51 deletions transformer_engine/pytorch/csrc/quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1342,13 +1342,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve

// Tensor dimensions
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
size_t flat_first_dim = 1;
if (shape.size() > 0) {
for (size_t i = 0; i < shape.size() - 1; ++i) {
flat_first_dim *= shape[i];
}
}
const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1;
const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape);
NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0,
"MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE,
" (got shape=", shape, ")");
Expand Down Expand Up @@ -1736,13 +1730,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve

// Tensor dimensions
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
size_t flat_first_dim = 1;
if (shape.size() > 0) {
for (size_t i = 0; i < shape.size() - 1; ++i) {
flat_first_dim *= shape[i];
}
}
const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1;
const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape);
NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, "First dim for NVFP4 must be divisible by ",
NVFP4_BLOCK_SIZE, " (got shape=", shape, ")");
NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0,
Expand Down Expand Up @@ -2040,24 +2028,18 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(

// Tensor dimensions, shape means original shape
std::vector<size_t> shape;
if (columnwise_data) {
shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true);
if (rowwise_data) {
auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false);
NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape,
") and column-wise data (shape=", shape, ") do not match");
}
} else { // Already checked columnwise_data_tensor == true
if (rowwise_data) {
shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false);
}

size_t flat_first_dim = 1;
if (shape.size() > 0) {
for (size_t i = 0; i < shape.size() - 1; ++i) {
flat_first_dim *= shape[i];
if (columnwise_data) {
auto col_shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true);
NVTE_CHECK(get_2d_dims(shape) == get_2d_dims(col_shape), "NVFP4 row-wise data (shape=", shape,
") and column-wise data (shape=", col_shape, ") do not match");
}
} else {
shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true);
}
const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1;

const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape);
const bool row_scaled_nvfp4 = this->row_scaled_nvfp4;
if (row_scaled_nvfp4) {
NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage.");
Expand Down Expand Up @@ -2205,24 +2187,16 @@ void NVFP4Quantizer::quantize_with_rht_unfused_helper(
// NOTE: should already be populated.
auto out_columnwise_amax = out.get_columnwise_amax();

// Flatten column-wise data shape to 2D to avoid problems when
// converting between FP4 tensor shape and byte tensor shape
// (involves dividing last dim by 2).
auto [flat_first_dim, flat_last_dim] = get_2d_dims(out_columnwise_data.shape, true);
std::vector<size_t> colwise_data_shape_2d = {flat_first_dim, flat_last_dim};

// Create a wrapper for the columnwise output, as the rowwise output.
// The reason is due to the input `rht_output_t` is already in the transposed layout.
// Thus, we only need a rowwise quantization to generate the columnwise output.
TensorWrapper out_transpose(out.scaling_mode());
// Note: since we are faking columnwise tensor into rowwise, the flat first dim check will fail
// need to convert the shape to 2D here
auto colwise_data_shape = out_columnwise_data.shape;
std::vector<size_t> colwise_data_shape_2d;
// shape could be [512, 32, 64], that's actually 512, 32, 128 because 2 FP4 take 1 byte
// the 2D shape should be [512, 32*128], but columnwise data shape expect last dim to be halved again
// so the multiple 2 get cancelled out
colwise_data_shape_2d.push_back(colwise_data_shape.data[0]);
size_t last_dim = 1;
for (size_t i = 1; i < colwise_data_shape.ndim; ++i) {
last_dim *= colwise_data_shape.data[i];
}
colwise_data_shape_2d.push_back(last_dim);

out_transpose.set_rowwise_data(out_columnwise_data.data_ptr,
static_cast<DType>(out_columnwise_data.dtype),
colwise_data_shape_2d);
Expand All @@ -2234,7 +2208,6 @@ void NVFP4Quantizer::quantize_with_rht_unfused_helper(
out_columnwise_amax.shape);

// Invoking fallback RHT kernel unfused.

NVTE_SCOPED_GIL_RELEASE({
// Perform the RHT(input.t), and write to rht_output_cpp.columnwise.
nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0,
Expand Down Expand Up @@ -2483,13 +2456,7 @@ void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out

std::vector<size_t> NVFP4Quantizer::get_scale_shape(const std::vector<size_t>& shape,
bool columnwise) const {
size_t numel = 1;
for (auto s : shape) {
numel *= s;
}

auto last_dim = shape.back();
auto flat_first_dim = numel / last_dim;
const auto [flat_first_dim, last_dim] = get_2d_dims(shape);

NVTE_CHECK(last_dim % NVFP4_BLOCK_SIZE == 0, "Last dim for NVFP4 must be divisible by ",
NVFP4_BLOCK_SIZE, " (got dim=", last_dim, ")");
Expand Down
Loading