From 4197bee78341999183108c9d09d2cb88346196dc Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 13 May 2026 03:55:02 +0000 Subject: [PATCH 1/4] all changes in Signed-off-by: Varun Thumbe --- .../fsdp2_tests/run_fsdp2_fused_adam.py | 36 ----------- .../fsdp2_tests/run_fsdp2_model.py | 13 ---- transformer_engine/pytorch/__init__.py | 57 +++++++++++++++++ transformer_engine/pytorch/module/base.py | 5 +- .../pytorch/quantized_tensor.py | 63 +++++++++++++++++-- .../pytorch/tensor/_quantization_helpers.py | 1 + .../pytorch/tensor/float8_blockwise_tensor.py | 14 ----- .../pytorch/tensor/float8_tensor.py | 15 +++-- .../tensor/storage/float8_tensor_storage.py | 5 -- 9 files changed, 128 insertions(+), 81 deletions(-) diff --git a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py index ecda481ed9..92cebf2b53 100644 --- a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py +++ b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py @@ -228,13 +228,6 @@ def test_fused_adam_fp8_master_weights_no_meta(recipe_name): """ recipe = get_recipe_from_string(recipe_name) - if recipe_name in ("MXFP8BlockScaling", "Float8BlockScaling", "NVFP4BlockScaling"): - pytest.xfail( - f"{recipe_name}: FSDP2 all-gather hooks for block-scaling QuantizedTensor " - "subclasses fail when parameters are initialized on CUDA. " - "Use device='meta' + reset_parameters() after sharding." - ) - world_size, device = _get_dist_info() model = _build_model(fp8_init=True, recipe=recipe, use_meta_device=False) @@ -604,12 +597,6 @@ def test_safetensors_fp32_export(recipe_name): - Saved tensor shapes match expected (unsharded) shapes """ recipe = get_recipe_from_string(recipe_name) - if recipe_name == "MXFP8BlockScaling": - pytest.xfail( - "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " - "MXFP8 quantized tensors, causing illegal memory access. " - "Fixed by https://github.com/NVIDIA/TransformerEngine/pull/2789." - ) from safetensors.torch import load_file, save_file from torch.distributed.checkpoint.state_dict import ( @@ -692,24 +679,8 @@ def test_dcp_output_parity(recipe_name, async_save): """ recipe = get_recipe_from_string(recipe_name) - if recipe_name == "MXFP8BlockScaling": - pytest.xfail( - "MXFP8BlockScaling: FusedAdam CUDA kernel does not support " - "MXFP8 quantized tensors, causing illegal memory access: " - "/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh:92 in function " - "multi_tensor_apply: CUDA Error: an illegal memory access was encountered. " - "Fixed by https://github.com/NVIDIA/TransformerEngine/pull/2789." - ) - - if recipe_name == "NVFP4BlockScaling": - pytest.xfail( - "NVFP4BlockScaling: DCP load_state_dict triggers reset_sharded_param() " - "which calls data_ptr() on NVFP4Tensor wrapper subclass with invalid storage" - ) - if ( recipe_name == "Float8BlockScaling" - and not async_save and torch.cuda.get_device_capability()[0] == 12 ): pytest.xfail( @@ -719,13 +690,6 @@ def test_dcp_output_parity(recipe_name, async_save): "Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, which " "requires using power of two scaling factors." ) - if recipe_name == "Float8BlockScaling" and async_save: - pytest.xfail( - "Float8BlockScaling: async DCP save/load round-trip produces different model " - "outputs — quantization metadata (scales) is not correctly persisted through " - "async distributed checkpointing. On SM120, additionally fails with pow2_scale " - "assertion in quantize_transpose_vector_blockwise." - ) import torch.distributed.checkpoint as dcp diff --git a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py index 6342e63e75..9383355fcc 100644 --- a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py +++ b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py @@ -380,19 +380,6 @@ def test_distributed(recipe_name, fp8_init, sharding_dims, layer_type): "data and scale_inv into a single buffer in pre_all_gather, split in post." ) - if recipe_name == "Float8BlockScaling" and fp8_init: - pytest.xfail( - "Float8BlockScaling + fp8_init: scale inverse padding is not handled " - "correctly during FSDP2 all-gather slice ops." - ) - if recipe_name == "NVFP4BlockScaling" and fp8_init and layer_type == "TransformerLayer": - pytest.xfail( - "NVFP4BlockScaling + fp8_init + TransformerLayer: " - "_check_fp8_fsdp2_allgather numerical error compounds across multiple " - "linear layers in the transformer block (up to ~1e-2 max abs diff). " - "LayerNormLinear passes with relaxed tolerances. " - "NVFP4 + FSDP2 training is validated by run_fsdp2_fused_adam.py." - ) torch.manual_seed(42) torch.cuda.manual_seed(42) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index d145cf0a21..6284e96cef 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -91,3 +91,60 @@ torch._dynamo.config.error_on_nested_jit_trace = False except AttributeError: pass # error_on_nested_jit_trace was added in PyTorch 2.2.0 + + +# Allow QuantizedTensor subclasses (and the metadata they pickle) to +# round-trip through ``torch.load(weights_only=True)``. DCP async-staging +# writes a torch.save / torch.load step internally, so without this the +# default safe-unpickler rejects our custom classes. +try: + from torch.serialization import add_safe_globals + from transformer_engine_torch import DType as _TE_DType + + add_safe_globals( + [ + # Wrapper subclasses + QuantizedTensor, + Float8Tensor, + MXFP8Tensor, + NVFP4Tensor, + Float8BlockwiseQTensor, + # Storage mixins (used during pickling of internal-only tensors) + QuantizedTensorStorage, + Float8TensorStorage, + MXFP8TensorStorage, + NVFP4TensorStorage, + Float8BlockwiseQTensorStorage, + # Quantizer types embedded in metadata + Quantizer, + Float8Quantizer, + Float8CurrentScalingQuantizer, + MXFP8Quantizer, + NVFP4Quantizer, + Float8BlockQuantizer, + # __reduce_ex__ constructors (bound classmethods). + Float8Tensor._make_in_reduce_ex, + MXFP8Tensor._make_in_reduce_ex, + NVFP4Tensor._make_in_reduce_ex, + Float8BlockwiseQTensor._make_in_reduce_ex, + # The pickle stream produced by ``__reduce_ex__`` references + # the pybind11 enum ``transformer_engine_torch.DType`` (e.g. + # the ``fp8_dtype`` argument) and uses ``builtins.getattr`` to + # resolve both the enum members and the bound-classmethod + # ``_make_in_reduce_ex`` callables above. Both must be + # allow-listed for ``torch.load(weights_only=True)`` (used + # internally by DCP async-staging) to accept the stream. + _TE_DType, + getattr, + ] + ) +except (ImportError, AttributeError): + import warnings as _warnings + _warnings.warn( + "transformer_engine: torch.serialization.add_safe_globals is " + "unavailable on this PyTorch version (added in 2.4). DCP " + "checkpointing of QuantizedTensor weights with FSDP2 will not " + "work; upgrade to PyTorch >= 2.4 to enable it.", + RuntimeWarning, + stacklevel=2, + ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index e6bedee0c0..873c980579 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -41,6 +41,7 @@ from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage @@ -1466,7 +1467,9 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: raise RuntimeError("Weight quantizer has not been initialized") quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) quantizer.internal = False - if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer): + if is_dtensor and isinstance( + quantizer, (Float8CurrentScalingQuantizer, NVFP4Quantizer) + ): device_mesh = dtensor_param.device_mesh amax_reduction_group = ( device_mesh.get_group(mesh_dim="shard") diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index a7722f777e..1470f5eca2 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -529,9 +529,26 @@ def half(self) -> torch.Tensor: # pylint: disable=missing-function-docstring return self.dequantize(dtype=torch.float16) - def cpu(self, memory_format=torch.preserve_format) -> torch.Tensor: + def cpu(self, memory_format=torch.preserve_format) -> QuantizedTensor: + """Move tensor to CPU while preserving the QuantizedTensor type. + + Routes through ``aten._to_copy.default`` so the subclass-preserving + handler in ``__torch_dispatch__`` runs (rather than dequantizing). + + """ # pylint: disable=missing-function-docstring - return self.dequantize().cpu(memory_format=memory_format) + return self.to(device=torch.device("cpu"), memory_format=memory_format) + + def untyped_storage(self) -> torch.UntypedStorage: + """Return an empty UntypedStorage on the tensor's device. + + ``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real + backing storage of its own; the actual bytes live in the inner + buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are + an implementation detail of the quantization scheme. Need to define + this method to avoid DCP staging errors with FSDP2. + """ + return torch.UntypedStorage(0, device=self.device) def expand_as(self, other: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -585,6 +602,34 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): dst.copy_(src) return None + # _to_copy op (used by .to(device=...), .cpu(), DCP staging). + # Preserve the QuantizedTensor subclass and move all internal + # buffers (data, scales, etc.) to the requested device. + if func == torch.ops.aten._to_copy.default: + tensor = args[0] + kw = dict(kwargs) if kwargs else {} + target_device = kw.get("device", tensor.device) or tensor.device + target_device = torch.device(target_device) + target_dtype = kw.get("dtype", tensor.dtype) or tensor.dtype + pin_memory = bool(kw.get("pin_memory", False)) + non_blocking = bool(kw.get("non_blocking", False)) + + new_metadata = {} + for key, value in tensor.get_metadata().items(): + if isinstance(value, torch.Tensor): + value = value.to(device=target_device, non_blocking=non_blocking) + if pin_memory and target_device.type == "cpu": + value = value.pin_memory() + new_metadata[key] = value + new_metadata["fake_dtype"] = target_dtype + return type(tensor)( + shape=tensor.shape, + dtype=target_dtype, + requires_grad=tensor.requires_grad, + device=target_device, + **new_metadata, + ) + # View op if func == torch.ops.aten.view.default: raise NotImplementedError("{cls.__name__} class does not support tensor views") @@ -725,14 +770,24 @@ def make_like( """Create new quantized tensor By default, new tensor has the same attributes and underlying - data. This function is intended to create view of tensors. + data. This function is intended to create a view of ``tensor``, + so the new tensor lives on the same device. To move quantized + data across devices use ``.to(device=...)`` / ``.cpu()`` / + ``aten._to_copy.default`` instead, which actually copies the + inner buffers. """ shape = shape if shape is not None else tensor.shape dtype = dtype if dtype is not None else tensor.dtype kwargs = tensor.get_metadata() kwargs["fake_dtype"] = dtype - return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs) + return cls( + shape=shape, + dtype=dtype, + requires_grad=requires_grad, + device=tensor.device, + **kwargs, + ) def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor: """Create `QuantizedTensor` with given nominal dtype diff --git a/transformer_engine/pytorch/tensor/_quantization_helpers.py b/transformer_engine/pytorch/tensor/_quantization_helpers.py index ba3407e13b..56cf503630 100644 --- a/transformer_engine/pytorch/tensor/_quantization_helpers.py +++ b/transformer_engine/pytorch/tensor/_quantization_helpers.py @@ -61,6 +61,7 @@ def forward( kwargs = tensor.get_metadata() for key, val in init_kwargs.items(): kwargs[key] = val + kwargs["device"] = tensor.device return type(tensor)(tensor.shape, tensor.dtype, **kwargs) @staticmethod diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 914397b9b6..0beca32a18 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -389,20 +389,6 @@ def reshape(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring return _ReshapeFunc.apply(self, shape) - def untyped_storage(self) -> torch.UntypedStorage: - """Return the underlying UntypedStorage of the FP8 data. - - Note that FP8 block-scaled tensor may involve multiple - buffers: row-wise FP8 data, row-wise scales, column-wise FP8 - data, column-wise scales. The UntypedStorage of the row-wise - FP8 data is returned if it exists, and otherwise the - UntypedStorage of the column-wise FP8 data. - - """ - data = self._rowwise_data if self._rowwise_data is not None else self._columnwise_data - if data is not None: - return data.untyped_storage() - return torch.UntypedStorage(0, device=self.device) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index ed6091c85b..ddee20b32e 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -1007,16 +1007,14 @@ def _make_in_reduce_ex( ) def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling to remove references to FP8 metadata objects + """Custom pickling to remove references to FP8 metadata objects. - CPU Float8Tensors are serialized as dequantized plain tensors - for compatibility with torch.load(weights_only=True), which is - used by DCP async save staging. + Always serializes the underlying FP8 buffers (no dequantization + fallback for CPU tensors) so that DCP async-staging round-trips + preserve bitwise-identical data. ``Float8Tensor`` is registered + with ``torch.serialization.add_safe_globals`` to keep + ``torch.load(weights_only=True)`` compatibility. """ - data_is_cpu = self._data is not None and self._data.is_cpu - transpose_is_cpu = self._transpose is not None and self._transpose.is_cpu - if data_is_cpu or transpose_is_cpu: - return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol) return ( Float8Tensor._make_in_reduce_ex, (self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape), @@ -1177,3 +1175,4 @@ def backward( ) -> Tuple[Optional[torch.Tensor], ...]: # pylint: disable=missing-function-docstring return grad.reshape(ctx.shape), None + diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index de7f8f58e2..3a72ec5d1a 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -139,11 +139,6 @@ def get_metadata(self) -> Dict[str, Any]: "fp8_dtype": self._fp8_dtype, "data_transpose": self._transpose, "quantizer": self._quantizer, - "device": ( - self._data.device - if self._data is not None - else (self._transpose.device if self._transpose is not None else None) - ), "fake_dtype": self._dtype, } From 8496440690f401d5efadf87bd29c60d349847536 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 May 2026 04:01:37 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py | 5 +---- transformer_engine/pytorch/__init__.py | 1 + transformer_engine/pytorch/tensor/float8_blockwise_tensor.py | 1 - transformer_engine/pytorch/tensor/float8_tensor.py | 1 - 4 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py index 92cebf2b53..1abb49e98c 100644 --- a/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py +++ b/tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py @@ -679,10 +679,7 @@ def test_dcp_output_parity(recipe_name, async_save): """ recipe = get_recipe_from_string(recipe_name) - if ( - recipe_name == "Float8BlockScaling" - and torch.cuda.get_device_capability()[0] == 12 - ): + if recipe_name == "Float8BlockScaling" and torch.cuda.get_device_capability()[0] == 12: pytest.xfail( "Float8BlockScaling is failing on SM120 with RuntimeError: " "transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu:534 " diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 6284e96cef..0c019c209c 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -140,6 +140,7 @@ ) except (ImportError, AttributeError): import warnings as _warnings + _warnings.warn( "transformer_engine: torch.serialization.add_safe_globals is " "unavailable on this PyTorch version (added in 2.4). DCP " diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 0beca32a18..6928543f43 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -389,7 +389,6 @@ def reshape(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring return _ReshapeFunc.apply(self, shape) - @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index ddee20b32e..c7b69581a0 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -1175,4 +1175,3 @@ def backward( ) -> Tuple[Optional[torch.Tensor], ...]: # pylint: disable=missing-function-docstring return grad.reshape(ctx.shape), None - From dde83666cade2cb8b5686224d770eb58769352b2 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 13 May 2026 04:02:10 +0000 Subject: [PATCH 3/4] simplify Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/quantized_tensor.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 1470f5eca2..71d59c529f 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -771,11 +771,6 @@ def make_like( By default, new tensor has the same attributes and underlying data. This function is intended to create a view of ``tensor``, - so the new tensor lives on the same device. To move quantized - data across devices use ``.to(device=...)`` / ``.cpu()`` / - ``aten._to_copy.default`` instead, which actually copies the - inner buffers. - """ shape = shape if shape is not None else tensor.shape dtype = dtype if dtype is not None else tensor.dtype From 9cdfe7a3078881d6c11c096b129a4bc196347f0f Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 13 May 2026 04:19:13 +0000 Subject: [PATCH 4/4] address review comment Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/__init__.py | 42 ++++++---- .../pytorch/quantized_tensor.py | 28 +++++++ .../pytorch/tensor/float8_blockwise_tensor.py | 69 ++++++++-------- .../pytorch/tensor/float8_tensor.py | 54 +++++++------ .../pytorch/tensor/mxfp8_tensor.py | 67 ++++++++-------- .../pytorch/tensor/nvfp4_tensor.py | 79 ++++++++++--------- 6 files changed, 192 insertions(+), 147 deletions(-) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 0c019c209c..3da552833b 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -86,6 +86,18 @@ from transformer_engine.pytorch.tensor import MXFP8Tensor from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor from transformer_engine.pytorch.tensor import NVFP4Tensor +from transformer_engine.pytorch.tensor.float8_tensor import ( + _make_float8_tensor_in_reduce_ex, +) +from transformer_engine.pytorch.tensor.mxfp8_tensor import ( + _make_mxfp8_tensor_in_reduce_ex, +) +from transformer_engine.pytorch.tensor.nvfp4_tensor import ( + _make_nvfp4_tensor_in_reduce_ex, +) +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + _make_float8_blockwise_tensor_in_reduce_ex, +) try: torch._dynamo.config.error_on_nested_jit_trace = False @@ -97,9 +109,18 @@ # round-trip through ``torch.load(weights_only=True)``. DCP async-staging # writes a torch.save / torch.load step internally, so without this the # default safe-unpickler rejects our custom classes. +# +# The ``_make_*_in_reduce_ex`` reconstructors are defined as module-level +# functions (not classmethods) so they pickle as a single ``GLOBAL`` opcode +# rather than a ``(getattr, (cls, name))`` reduction. Their ``fp8_dtype`` / +# ``fp4_dtype`` arguments are passed as plain ``int`` values (converted back +# to the pybind11 ``transformer_engine_torch.DType`` enum on reconstruction) +# and ``Quantizer.__getstate__`` similarly serializes its embedded ``dtype`` +# as an ``int``. Together these keep the pickle stream free of pybind11-enum +# reductions and bound-classmethod references, so we don't need to allow-list +# ``builtins.getattr`` or the enum type itself for ``weights_only=True``. try: from torch.serialization import add_safe_globals - from transformer_engine_torch import DType as _TE_DType add_safe_globals( [ @@ -122,20 +143,11 @@ MXFP8Quantizer, NVFP4Quantizer, Float8BlockQuantizer, - # __reduce_ex__ constructors (bound classmethods). - Float8Tensor._make_in_reduce_ex, - MXFP8Tensor._make_in_reduce_ex, - NVFP4Tensor._make_in_reduce_ex, - Float8BlockwiseQTensor._make_in_reduce_ex, - # The pickle stream produced by ``__reduce_ex__`` references - # the pybind11 enum ``transformer_engine_torch.DType`` (e.g. - # the ``fp8_dtype`` argument) and uses ``builtins.getattr`` to - # resolve both the enum members and the bound-classmethod - # ``_make_in_reduce_ex`` callables above. Both must be - # allow-listed for ``torch.load(weights_only=True)`` (used - # internally by DCP async-staging) to accept the stream. - _TE_DType, - getattr, + # __reduce_ex__ reconstructors (module-level functions). + _make_float8_tensor_in_reduce_ex, + _make_mxfp8_tensor_in_reduce_ex, + _make_nvfp4_tensor_in_reduce_ex, + _make_float8_blockwise_tensor_in_reduce_ex, ] ) except (ImportError, AttributeError): diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 71d59c529f..f261a2ddaa 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -254,6 +254,34 @@ def __init__(self, *, rowwise: bool, columnwise: bool) -> None: self.internal = False self.optimize_for_gemm = False + def __getstate__(self): + """Custom pickling. + + FP8/FP4 quantizer subclasses store ``self.dtype`` as a + ``transformer_engine_torch.DType`` (pybind11 enum). Pybind11 + enums reduce as ``(getattr, (Enum, name))``, which would force + callers using ``torch.load(weights_only=True)`` (e.g. DCP + async-staging) to allow-list ``builtins.getattr``. Serialize + ``dtype`` as an ``int`` here so the pickle stream stays free of + those enum reductions. Subclass overrides should call + ``super().__getstate__()`` rather than ``self.__dict__.copy()`` + to preserve this behavior. + """ + from transformer_engine_torch import DType as _TE_DType + + state = self.__dict__.copy() + if isinstance(state.get("dtype"), _TE_DType): + state["dtype"] = int(state["dtype"]) + return state + + def __setstate__(self, state): + """Reconstruct ``dtype`` from its serialized ``int`` form.""" + from transformer_engine_torch import DType as _TE_DType + + if isinstance(state.get("dtype"), int): + state["dtype"] = _TE_DType(state["dtype"]) + self.__dict__.update(state) + def __repr__(self): return ( f"{self.__class__.__name__}(" diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 6928543f43..c70c30c4ae 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -473,49 +473,17 @@ def contiguous( return self raise ValueError("Float8BlockwiseQTensor does not support different memory formats!") - @classmethod - def _make_in_reduce_ex( - cls, - shape: torch.Size, - rowwise_data: torch.Tensor, - rowwise_scale_inv: torch.Tensor, - columnwise_data: torch.Tensor, - columnwise_scale_inv: torch.Tensor, - fp8_dtype: TE_DType, - dtype: torch.dtype, - quantizer: Quantizer, - is_2D_scaled: bool, - data_format: Any = None, # pylint: disable=unused-argument - ) -> Float8BlockwiseQTensor: - """Build Float8BlockwiseQTensor, for use in __reduce__ - - __reduce_ex__ assumes object constructor has positional - arguments. - - """ - return Float8BlockwiseQTensor( - shape=shape, - rowwise_data=rowwise_data, - rowwise_scale_inv=rowwise_scale_inv, - fp8_dtype=fp8_dtype, - columnwise_data=columnwise_data, - columnwise_scale_inv=columnwise_scale_inv, - dtype=dtype, - quantizer=quantizer, - is_2D_scaled=is_2D_scaled, - ) - def __reduce_ex__(self, protocol: int) -> tuple: """Custom pickling to remove references to FP8 metadata objects""" return ( - Float8BlockwiseQTensor._make_in_reduce_ex, + _make_float8_blockwise_tensor_in_reduce_ex, ( self.shape, self._rowwise_data, self._rowwise_scale_inv, self._columnwise_data, self._columnwise_scale_inv, - self._fp8_dtype, + int(self._fp8_dtype), self.dtype, self._quantizer, self._is_2D_scaled, @@ -709,6 +677,39 @@ def fsdp_post_all_gather( return out, all_gather_outputs +def _make_float8_blockwise_tensor_in_reduce_ex( + shape: torch.Size, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + fp8_dtype: int, + dtype: torch.dtype, + quantizer: Quantizer, + is_2D_scaled: bool, + data_format: Any = None, # pylint: disable=unused-argument +) -> Float8BlockwiseQTensor: + """Reconstruct a ``Float8BlockwiseQTensor`` from ``__reduce_ex__``. + + Defined at module level so the pickle stream uses a single + ``GLOBAL`` opcode rather than the ``(getattr, (cls, name))`` + reduction that bound classmethods produce. ``fp8_dtype`` is passed + as an ``int`` and converted back to the pybind11 ``TE_DType`` enum + here. + """ + return Float8BlockwiseQTensor( + shape=shape, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + fp8_dtype=TE_DType(fp8_dtype), + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + dtype=dtype, + quantizer=quantizer, + is_2D_scaled=is_2D_scaled, + ) + + class _ViewFunc(torch.autograd.Function): """View function diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index c7b69581a0..d81458b3c4 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -287,7 +287,7 @@ def __init__( def __getstate__(self): """Exclude unpicklable process group from serialized state.""" - state = self.__dict__.copy() + state = super().__getstate__() state["amax_reduction_group"] = None return state @@ -983,29 +983,6 @@ def is_cpu(self): return self._transpose.is_cpu raise RuntimeError("Both data and transpose are None") - @classmethod - def _make_in_reduce_ex( - cls, - data: torch.Tensor, - fp8_dtype: TE_DType, - fp8_scale_inv: torch.Tensor, - dtype: torch.dtype, - shape: torch.shape, - ) -> Float8Tensor: - """Build Float8Tensor, for use in __reduce__ - - __reduce_ex__ assumes object constructor has positional - arguments. - - """ - return Float8Tensor( - data=data, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - dtype=dtype, - shape=shape, - ) - def __reduce_ex__(self, protocol: int) -> tuple: """Custom pickling to remove references to FP8 metadata objects. @@ -1016,8 +993,8 @@ def __reduce_ex__(self, protocol: int) -> tuple: ``torch.load(weights_only=True)`` compatibility. """ return ( - Float8Tensor._make_in_reduce_ex, - (self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape), + _make_float8_tensor_in_reduce_ex, + (self._data, int(self._fp8_dtype), self._scale_inv, self.dtype, self.shape), ) def _get_data(self) -> Float8Tensor: @@ -1083,6 +1060,31 @@ def _set_data(self, tensor: torch.Tensor) -> None: data = property(_get_data, _set_data) +def _make_float8_tensor_in_reduce_ex( + data: torch.Tensor, + fp8_dtype: int, + fp8_scale_inv: torch.Tensor, + dtype: torch.dtype, + shape: torch.Size, +) -> Float8Tensor: + """Reconstruct a ``Float8Tensor`` from its ``__reduce_ex__`` payload. + + Defined at module level (not as a classmethod) so the pickle stream + references it via a single ``GLOBAL`` opcode rather than the + ``(getattr, (cls, name))`` reduction that bound classmethods/static + methods produce. ``fp8_dtype`` is passed as an ``int`` and converted + back to the pybind11 ``TE_DType`` enum here so the pickle stream + stays free of enum reductions as well. + """ + return Float8Tensor( + data=data, + fp8_dtype=TE_DType(fp8_dtype), + fp8_scale_inv=fp8_scale_inv, + dtype=dtype, + shape=shape, + ) + + class _ViewFunc(torch.autograd.Function): """View function diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 5cab519c79..96f4b9554a 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -760,47 +760,16 @@ def fsdp_post_all_gather( out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage) return out, all_gather_outputs - @classmethod - def _make_in_reduce_ex( - cls, - rowwise_data: torch.Tensor, - rowwise_scale_inv: torch.Tensor, - columnwise_data: torch.Tensor, - columnwise_scale_inv: torch.Tensor, - fp8_dtype: TE_DType, - dtype: torch.dtype, - shape: torch.shape, - quantizer: Optional[Quantizer] = None, - with_gemm_swizzled_scales: bool = False, - ) -> MXFP8Tensor: - """Build MXFP8Tensor, for use in __reduce__ - - __reduce_ex__ assumes object constructor has positional - arguments. - - """ - return MXFP8Tensor( - rowwise_data=rowwise_data, - rowwise_scale_inv=rowwise_scale_inv, - fp8_dtype=fp8_dtype, - columnwise_data=columnwise_data, - columnwise_scale_inv=columnwise_scale_inv, - dtype=dtype, - shape=shape, - quantizer=quantizer, - with_gemm_swizzled_scales=with_gemm_swizzled_scales, - ) - def __reduce_ex__(self, protocol: int) -> tuple: """Custom pickling""" return ( - MXFP8Tensor._make_in_reduce_ex, + _make_mxfp8_tensor_in_reduce_ex, ( self._rowwise_data, self._rowwise_scale_inv, self._columnwise_data, self._columnwise_scale_inv, - self._fp8_dtype, + int(self._fp8_dtype), self.dtype, self.shape, self._quantizer, @@ -896,6 +865,38 @@ def is_cuda(self): raise RuntimeError("MXFP8Tensor has no data!") +def _make_mxfp8_tensor_in_reduce_ex( + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + fp8_dtype: int, + dtype: torch.dtype, + shape: torch.Size, + quantizer: Optional[Quantizer] = None, + with_gemm_swizzled_scales: bool = False, +) -> MXFP8Tensor: + """Reconstruct an ``MXFP8Tensor`` from its ``__reduce_ex__`` payload. + + Defined at module level so the pickle stream uses a single + ``GLOBAL`` opcode rather than the ``(getattr, (cls, name))`` + reduction that bound classmethods produce. ``fp8_dtype`` is passed + as an ``int`` and converted back to the pybind11 ``TE_DType`` enum + here. + """ + return MXFP8Tensor( + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + fp8_dtype=TE_DType(fp8_dtype), + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + dtype=dtype, + shape=shape, + quantizer=quantizer, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, + ) + + class _ViewFunc(torch.autograd.Function): """View function diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 285a7f030a..0ec472c592 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -165,7 +165,7 @@ def __init__( def __getstate__(self): """Exclude unpicklable process group from serialized state.""" - state = self.__dict__.copy() + state = super().__getstate__() state["amax_reduction_group"] = None return state @@ -820,46 +820,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Default case return super().__torch_dispatch__(func, types, args, kwargs) - @classmethod - def _make_in_reduce_ex( - cls, - shape: torch.Size, - rowwise_data: torch.Tensor, - rowwise_scale_inv: torch.Tensor, - columnwise_data: torch.Tensor, - columnwise_scale_inv: torch.Tensor, - amax_rowwise: torch.Tensor, - amax_columnwise: torch.Tensor, - fp4_dtype: TE_DType, - dtype: torch.dtype, - quantizer: Quantizer, - with_gemm_swizzled_scales: bool = False, - ) -> NVFP4Tensor: - """Build NVFP4Tensor, for use in __reduce__ - - __reduce_ex__ assumes object constructor has positional - arguments. - - """ - return NVFP4Tensor( - shape=shape, - dtype=dtype, - fp4_dtype=fp4_dtype, - rowwise_data=rowwise_data, - rowwise_scale_inv=rowwise_scale_inv, - columnwise_data=columnwise_data, - columnwise_scale_inv=columnwise_scale_inv, - amax_rowwise=amax_rowwise, - amax_columnwise=amax_columnwise, - quantizer=quantizer, - requires_grad=False, - with_gemm_swizzled_scales=with_gemm_swizzled_scales, - ) - def __reduce_ex__(self, protocol: int) -> tuple: """Custom pickling""" return ( - NVFP4Tensor._make_in_reduce_ex, + _make_nvfp4_tensor_in_reduce_ex, ( self.shape, self._rowwise_data, @@ -868,7 +832,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self._columnwise_scale_inv, self._amax_rowwise, self._amax_columnwise, - self._fp4_dtype, + int(self._fp4_dtype), self.dtype, self._quantizer, self._with_gemm_swizzled_scales, @@ -965,6 +929,43 @@ def is_cuda(self): raise RuntimeError("NVFP4Tensor has no data!") +def _make_nvfp4_tensor_in_reduce_ex( + shape: torch.Size, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + amax_rowwise: torch.Tensor, + amax_columnwise: torch.Tensor, + fp4_dtype: int, + dtype: torch.dtype, + quantizer: Quantizer, + with_gemm_swizzled_scales: bool = False, +) -> NVFP4Tensor: + """Reconstruct an ``NVFP4Tensor`` from its ``__reduce_ex__`` payload. + + Defined at module level so the pickle stream uses a single + ``GLOBAL`` opcode rather than the ``(getattr, (cls, name))`` + reduction that bound classmethods produce. ``fp4_dtype`` is passed + as an ``int`` and converted back to the pybind11 ``TE_DType`` enum + here. + """ + return NVFP4Tensor( + shape=shape, + dtype=dtype, + fp4_dtype=TE_DType(fp4_dtype), + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + quantizer=quantizer, + requires_grad=False, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, + ) + + class _ViewFunc(torch.autograd.Function): """View function