Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 1 addition & 40 deletions tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,6 @@ def test_fused_adam_fp8_master_weights_no_meta(recipe_name):
"""
recipe = get_recipe_from_string(recipe_name)

if recipe_name in ("MXFP8BlockScaling", "Float8BlockScaling", "NVFP4BlockScaling"):
pytest.xfail(
f"{recipe_name}: FSDP2 all-gather hooks for block-scaling QuantizedTensor "
"subclasses fail when parameters are initialized on CUDA. "
"Use device='meta' + reset_parameters() after sharding."
)

world_size, device = _get_dist_info()

model = _build_model(fp8_init=True, recipe=recipe, use_meta_device=False)
Expand Down Expand Up @@ -604,12 +597,6 @@ def test_safetensors_fp32_export(recipe_name):
- Saved tensor shapes match expected (unsharded) shapes
"""
recipe = get_recipe_from_string(recipe_name)
if recipe_name == "MXFP8BlockScaling":
pytest.xfail(
"MXFP8BlockScaling: FusedAdam CUDA kernel does not support "
"MXFP8 quantized tensors, causing illegal memory access. "
"Fixed by https://github.com/NVIDIA/TransformerEngine/pull/2789."
)

from safetensors.torch import load_file, save_file
from torch.distributed.checkpoint.state_dict import (
Expand Down Expand Up @@ -692,40 +679,14 @@ def test_dcp_output_parity(recipe_name, async_save):
"""
recipe = get_recipe_from_string(recipe_name)

if recipe_name == "MXFP8BlockScaling":
pytest.xfail(
"MXFP8BlockScaling: FusedAdam CUDA kernel does not support "
"MXFP8 quantized tensors, causing illegal memory access: "
"/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh:92 in function "
"multi_tensor_apply: CUDA Error: an illegal memory access was encountered. "
"Fixed by https://github.com/NVIDIA/TransformerEngine/pull/2789."
)

if recipe_name == "NVFP4BlockScaling":
pytest.xfail(
"NVFP4BlockScaling: DCP load_state_dict triggers reset_sharded_param() "
"which calls data_ptr() on NVFP4Tensor wrapper subclass with invalid storage"
)

if (
recipe_name == "Float8BlockScaling"
and not async_save
and torch.cuda.get_device_capability()[0] == 12
):
if recipe_name == "Float8BlockScaling" and torch.cuda.get_device_capability()[0] == 12:
pytest.xfail(
"Float8BlockScaling is failing on SM120 with RuntimeError: "
"transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu:534 "
"in function quantize_transpose_vector_blockwise: Assertion failed: pow2_scale. On "
"Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, which "
"requires using power of two scaling factors."
)
if recipe_name == "Float8BlockScaling" and async_save:
pytest.xfail(
"Float8BlockScaling: async DCP save/load round-trip produces different model "
"outputs — quantization metadata (scales) is not correctly persisted through "
"async distributed checkpointing. On SM120, additionally fails with pow2_scale "
"assertion in quantize_transpose_vector_blockwise."
)

import torch.distributed.checkpoint as dcp

Expand Down
13 changes: 0 additions & 13 deletions tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,19 +380,6 @@ def test_distributed(recipe_name, fp8_init, sharding_dims, layer_type):
"data and scale_inv into a single buffer in pre_all_gather, split in post."
)

if recipe_name == "Float8BlockScaling" and fp8_init:
pytest.xfail(
"Float8BlockScaling + fp8_init: scale inverse padding is not handled "
"correctly during FSDP2 all-gather slice ops."
)
if recipe_name == "NVFP4BlockScaling" and fp8_init and layer_type == "TransformerLayer":
pytest.xfail(
"NVFP4BlockScaling + fp8_init + TransformerLayer: "
"_check_fp8_fsdp2_allgather numerical error compounds across multiple "
"linear layers in the transformer block (up to ~1e-2 max abs diff). "
"LayerNormLinear passes with relaxed tolerances. "
"NVFP4 + FSDP2 training is validated by run_fsdp2_fused_adam.py."
)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

Expand Down
70 changes: 70 additions & 0 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,78 @@
from transformer_engine.pytorch.tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor import NVFP4Tensor
from transformer_engine.pytorch.tensor.float8_tensor import (
_make_float8_tensor_in_reduce_ex,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import (
_make_mxfp8_tensor_in_reduce_ex,
)
from transformer_engine.pytorch.tensor.nvfp4_tensor import (
_make_nvfp4_tensor_in_reduce_ex,
)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
_make_float8_blockwise_tensor_in_reduce_ex,
)

try:
torch._dynamo.config.error_on_nested_jit_trace = False
except AttributeError:
pass # error_on_nested_jit_trace was added in PyTorch 2.2.0


# Allow QuantizedTensor subclasses (and the metadata they pickle) to
# round-trip through ``torch.load(weights_only=True)``. DCP async-staging
# writes a torch.save / torch.load step internally, so without this the
# default safe-unpickler rejects our custom classes.
#
# The ``_make_*_in_reduce_ex`` reconstructors are defined as module-level
# functions (not classmethods) so they pickle as a single ``GLOBAL`` opcode
# rather than a ``(getattr, (cls, name))`` reduction. Their ``fp8_dtype`` /
# ``fp4_dtype`` arguments are passed as plain ``int`` values (converted back
# to the pybind11 ``transformer_engine_torch.DType`` enum on reconstruction)
# and ``Quantizer.__getstate__`` similarly serializes its embedded ``dtype``
# as an ``int``. Together these keep the pickle stream free of pybind11-enum
# reductions and bound-classmethod references, so we don't need to allow-list
# ``builtins.getattr`` or the enum type itself for ``weights_only=True``.
try:
from torch.serialization import add_safe_globals

add_safe_globals(
[
# Wrapper subclasses
QuantizedTensor,
Float8Tensor,
MXFP8Tensor,
NVFP4Tensor,
Float8BlockwiseQTensor,
# Storage mixins (used during pickling of internal-only tensors)
QuantizedTensorStorage,
Float8TensorStorage,
MXFP8TensorStorage,
NVFP4TensorStorage,
Float8BlockwiseQTensorStorage,
# Quantizer types embedded in metadata
Quantizer,
Float8Quantizer,
Float8CurrentScalingQuantizer,
MXFP8Quantizer,
NVFP4Quantizer,
Float8BlockQuantizer,
# __reduce_ex__ reconstructors (module-level functions).
_make_float8_tensor_in_reduce_ex,
_make_mxfp8_tensor_in_reduce_ex,
_make_nvfp4_tensor_in_reduce_ex,
_make_float8_blockwise_tensor_in_reduce_ex,
]
)
except (ImportError, AttributeError):
import warnings as _warnings

_warnings.warn(
"transformer_engine: torch.serialization.add_safe_globals is "
"unavailable on this PyTorch version (added in 2.4). DCP "
"checkpointing of QuantizedTensor weights with FSDP2 will not "
"work; upgrade to PyTorch >= 2.4 to enable it.",
RuntimeWarning,
stacklevel=2,
)
5 changes: 4 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
Expand Down Expand Up @@ -1466,7 +1467,9 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
raise RuntimeError("Weight quantizer has not been initialized")
quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
quantizer.internal = False
if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer):
if is_dtensor and isinstance(
quantizer, (Float8CurrentScalingQuantizer, NVFP4Quantizer)
):
device_mesh = dtensor_param.device_mesh
amax_reduction_group = (
device_mesh.get_group(mesh_dim="shard")
Expand Down
88 changes: 83 additions & 5 deletions transformer_engine/pytorch/quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,34 @@ def __init__(self, *, rowwise: bool, columnwise: bool) -> None:
self.internal = False
self.optimize_for_gemm = False

def __getstate__(self):
"""Custom pickling.

FP8/FP4 quantizer subclasses store ``self.dtype`` as a
``transformer_engine_torch.DType`` (pybind11 enum). Pybind11
enums reduce as ``(getattr, (Enum, name))``, which would force
callers using ``torch.load(weights_only=True)`` (e.g. DCP
async-staging) to allow-list ``builtins.getattr``. Serialize
``dtype`` as an ``int`` here so the pickle stream stays free of
those enum reductions. Subclass overrides should call
``super().__getstate__()`` rather than ``self.__dict__.copy()``
to preserve this behavior.
"""
from transformer_engine_torch import DType as _TE_DType

state = self.__dict__.copy()
if isinstance(state.get("dtype"), _TE_DType):
state["dtype"] = int(state["dtype"])
return state

def __setstate__(self, state):
"""Reconstruct ``dtype`` from its serialized ``int`` form."""
from transformer_engine_torch import DType as _TE_DType

if isinstance(state.get("dtype"), int):
state["dtype"] = _TE_DType(state["dtype"])
self.__dict__.update(state)

def __repr__(self):
return (
f"{self.__class__.__name__}("
Expand Down Expand Up @@ -529,9 +557,26 @@ def half(self) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return self.dequantize(dtype=torch.float16)

def cpu(self, memory_format=torch.preserve_format) -> torch.Tensor:
def cpu(self, memory_format=torch.preserve_format) -> QuantizedTensor:
"""Move tensor to CPU while preserving the QuantizedTensor type.

Routes through ``aten._to_copy.default`` so the subclass-preserving
handler in ``__torch_dispatch__`` runs (rather than dequantizing).

"""
# pylint: disable=missing-function-docstring
return self.dequantize().cpu(memory_format=memory_format)
return self.to(device=torch.device("cpu"), memory_format=memory_format)

def untyped_storage(self) -> torch.UntypedStorage:
"""Return an empty UntypedStorage on the tensor's device.

``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real
backing storage of its own; the actual bytes live in the inner
buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are
an implementation detail of the quantization scheme. Need to define
this method to avoid DCP staging errors with FSDP2.
"""
return torch.UntypedStorage(0, device=self.device)
Comment on lines +570 to +579
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Empty storage breaks shared-storage detection in existing callers

QuantizedTensor.untyped_storage() now returns a freshly allocated zero-byte storage every call. Code in module/_common.py:128 compares tensors[0].untyped_storage().nbytes() against expected size to decide between a no-op view and an out-of-place torch.cat. With 0 bytes returned, that condition is always true, silently disabling the in-place fast path for any QuantizedTensor through ConcatMerge.forward. More critically, utils.py:403-412 in SplitAlongDim.backward uses data_ptr() for noop detection — if all zero-size CUDA allocations return data_ptr() == 0, every QuantizedTensor pair incorrectly appears co-located, setting noop_ok = True and crashing on ret.set_() against a 0-byte storage.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, while I don't think we use QuantizedTensors in the SplitAlongDim ever, the concat sounds plausible to be hit.


def expand_as(self, other: torch.Tensor) -> torch.Tensor:
# pylint: disable=missing-function-docstring
Expand Down Expand Up @@ -585,6 +630,34 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
dst.copy_(src)
return None

# _to_copy op (used by .to(device=...), .cpu(), DCP staging).
# Preserve the QuantizedTensor subclass and move all internal
# buffers (data, scales, etc.) to the requested device.
if func == torch.ops.aten._to_copy.default:
tensor = args[0]
kw = dict(kwargs) if kwargs else {}
target_device = kw.get("device", tensor.device) or tensor.device
target_device = torch.device(target_device)
target_dtype = kw.get("dtype", tensor.dtype) or tensor.dtype
pin_memory = bool(kw.get("pin_memory", False))
non_blocking = bool(kw.get("non_blocking", False))

new_metadata = {}
for key, value in tensor.get_metadata().items():
if isinstance(value, torch.Tensor):
value = value.to(device=target_device, non_blocking=non_blocking)
if pin_memory and target_device.type == "cpu":
value = value.pin_memory()
new_metadata[key] = value
new_metadata["fake_dtype"] = target_dtype
return type(tensor)(
shape=tensor.shape,
dtype=target_dtype,
requires_grad=tensor.requires_grad,
device=target_device,
**new_metadata,
)

# View op
if func == torch.ops.aten.view.default:
raise NotImplementedError("{cls.__name__} class does not support tensor views")
Expand Down Expand Up @@ -725,14 +798,19 @@ def make_like(
"""Create new quantized tensor

By default, new tensor has the same attributes and underlying
data. This function is intended to create view of tensors.

data. This function is intended to create a view of ``tensor``,
"""
shape = shape if shape is not None else tensor.shape
dtype = dtype if dtype is not None else tensor.dtype
kwargs = tensor.get_metadata()
kwargs["fake_dtype"] = dtype
return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs)
return cls(
shape=shape,
dtype=dtype,
requires_grad=requires_grad,
device=tensor.device,
**kwargs,
)

def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor:
"""Create `QuantizedTensor` with given nominal dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def forward(
kwargs = tensor.get_metadata()
for key, val in init_kwargs.items():
kwargs[key] = val
kwargs["device"] = tensor.device
return type(tensor)(tensor.shape, tensor.dtype, **kwargs)

@staticmethod
Expand Down
Loading
Loading