Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 55 additions & 15 deletions transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""This module provides predefined FP8 recipes."""
from __future__ import annotations
import abc
import os
from enum import Enum
from typing import Any, Literal, Optional, Union, Callable, NamedTuple
Expand Down Expand Up @@ -60,6 +61,16 @@ class MMParams:

use_split_accumulator: bool = True

def __post_init__(self) -> None:
object.__setattr__(
self,
"_cached_repr",
f"MMParams(use_split_accumulator={self.use_split_accumulator})",
)

def __repr__(self) -> str:
return self._cached_repr


@dataclass(frozen=True)
class QParams:
Expand All @@ -76,21 +87,50 @@ class QParams:
stochastic_rounding: bool = False
fp4_2d_quantization: bool = False

def __repr__(self) -> str:
return (
def __post_init__(self) -> None:
object.__setattr__(
self,
"_cached_repr",
f"Qparams(\npower_2_scale={self.power_2_scale},\n"
f"amax_epsilon={self.amax_epsilon},\n"
f"random_hadamard_transform={self.random_hadamard_transform},\n"
f"stochastic_rounding={self.stochastic_rounding},\n"
f"fp4_2d_quantization={self.fp4_2d_quantization}\n)"
f"fp4_2d_quantization={self.fp4_2d_quantization}\n)",
)

def __repr__(self) -> str:
return self._cached_repr


class Recipe:
"""
Base recipe class.
"""

# Cached string representation. Lazily populated by ``__repr__`` in
# subclasses and invalidated by ``__setattr__`` whenever any attribute
# changes. This makes repeated ``str(recipe)`` calls much cheaper
_cached_repr: Optional[str] = None

def __setattr__(self, name: str, value: Any) -> None:
# Invalidate the cached repr on any attribute mutation.
if name != "_cached_repr":
object.__setattr__(self, "_cached_repr", None)
object.__setattr__(self, name, value)

@abc.abstractmethod
def _make_repr(self) -> str:
"""Build the string representation for this recipe.

Subclasses must override this method. The result is cached by
``__repr__`` and reused until any attribute is mutated.
"""

def __repr__(self) -> str:
if self._cached_repr is None:
self._cached_repr = self._make_repr()
return self._cached_repr

@classmethod
def nvfp4(cls):
"""Whether the given recipe is NVFP4 1D block scaling."""
Expand Down Expand Up @@ -127,7 +167,7 @@ def custom(cls):
return issubclass(cls, CustomRecipe)


@dataclass()
@dataclass(repr=False)
class DelayedScaling(Recipe):
"""
Use the delayed scaling factor strategy. Use scale factor from previous
Expand Down Expand Up @@ -227,7 +267,7 @@ def __post_init__(self) -> None:
self.backward_override is None
), "Delayed scaling only supports backward_override=None."

def __repr__(self) -> str:
def _make_repr(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"margin={self.margin}, "
Expand All @@ -240,7 +280,7 @@ def __repr__(self) -> str:
)


@dataclass()
@dataclass(repr=False)
class Float8CurrentScaling(Recipe):
"""
Use the per-tensor current scaling factor strategy.
Expand Down Expand Up @@ -275,7 +315,7 @@ def __post_init__(self) -> None:
self.backward_override in _BACKWARD_OVERRIDES
), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'."

def __repr__(self) -> str:
def _make_repr(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"format={str(self.fp8_format).split('.')[1]}, "
Expand All @@ -291,7 +331,7 @@ def __repr__(self) -> str:
)


@dataclass()
@dataclass(repr=False)
class MXFP8BlockScaling(Recipe):
"""
Use the MXFP8 scaling factor strategy.
Expand Down Expand Up @@ -333,7 +373,7 @@ def __post_init__(self) -> None:
self.backward_override in _BACKWARD_OVERRIDES
), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'."

def __repr__(self) -> str:
def _make_repr(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"margin={self.margin}, "
Expand All @@ -342,7 +382,7 @@ def __repr__(self) -> str:
)


@dataclass()
@dataclass(repr=False)
class Float8BlockScaling(Recipe):
"""
Use block-wise scaling for FP8 tensors.
Expand Down Expand Up @@ -414,7 +454,7 @@ def __post_init__(self) -> None:
self.backward_override in _BACKWARD_OVERRIDES
), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'."

def __repr__(self) -> str:
def _make_repr(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"format={str(self.fp8_format).split('.')[1]}, "
Expand All @@ -433,7 +473,7 @@ def __repr__(self) -> str:
)


@dataclass()
@dataclass(repr=False)
class NVFP4BlockScaling(Recipe):
"""
Use the NVFP4 scaling strategy.
Expand Down Expand Up @@ -531,7 +571,7 @@ def __post_init__(self) -> None:
fp4_2d_quantization=False,
)

def __repr__(self) -> str:
def _make_repr(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"fp4_format={str(self.fp4_format).split('.')[1]}, "
Expand All @@ -546,7 +586,7 @@ def __repr__(self) -> str:
)


@dataclass()
@dataclass(repr=False)
class CustomRecipe(Recipe):
"""
Custom recipe that allows users to provide quantizer factories.
Expand Down Expand Up @@ -608,7 +648,7 @@ def __post_init__(self) -> None:
self.backward_override in _BACKWARD_OVERRIDES
), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'."

def __repr__(self) -> str:
def _make_repr(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"qfactory={self.qfactory}, "
Expand Down
97 changes: 63 additions & 34 deletions transformer_engine/pytorch/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,9 +686,8 @@ def reduce_and_update_fp8_tensors(
amax_history, scale, get_fp8_max(recipe, forward), recipe
)

@classmethod
@staticmethod
def get_unique_autocast_key(
cls,
recipe: Optional[Recipe] = None,
group: Optional[dist_group_type] = None,
):
Expand All @@ -697,7 +696,11 @@ def get_unique_autocast_key(
Object identity is sufficient since autocast contexts never outlive a single
training session.
"""
return str((str(recipe), id(group) if group is not None else None))
recipe_repr = recipe.__dict__.get("_cached_repr") if recipe is not None else None
if recipe_repr is None:
recipe_repr = str(recipe)
group_id = id(group) if group is not None else None
return f"recipe={recipe_repr},group={group_id}"

@classmethod
def autocast_enter(
Expand Down Expand Up @@ -911,14 +914,13 @@ def quantized_model_init(
qstate.high_precision_init_val = _high_precision_init_val


@contextmanager
def fp8_autocast(
enabled: bool = True,
calibrating: bool = False,
fp8_recipe: Optional[Recipe] = None,
fp8_group: Optional[dist_group_type] = None,
_graph: bool = False,
) -> None:
) -> "autocast":
"""
.. warning::

Expand All @@ -934,25 +936,16 @@ def fp8_autocast(
stacklevel=2,
)

# Call new implementation.
with autocast(
return autocast(
enabled=enabled,
calibrating=calibrating,
recipe=fp8_recipe,
amax_reduction_group=fp8_group,
_graph=_graph,
):
yield
)


@contextmanager
def autocast(
enabled: bool = True,
calibrating: bool = False,
recipe: Optional["Recipe"] = None,
amax_reduction_group: Optional["dist_group_type"] = None,
_graph: bool = False,
) -> None:
class autocast:
"""
Context manager for quantization schemes like FP8 or FP4.

Expand Down Expand Up @@ -991,24 +984,60 @@ def autocast(
are reduced at the end of each training step.
"""

if enabled:
check_recipe_support(recipe)

# Save current state so we always restore it on exit.
fp8_state = FP8GlobalStateManager.get_autocast_state()

FP8GlobalStateManager.autocast_enter(
enabled=enabled,
calibrating=calibrating,
fp8_recipe=recipe,
fp8_group=amax_reduction_group,
_graph=_graph,
# Class-based context manager (instead of ``@contextmanager`` from contextlib)
# to avoid overheads.
__slots__ = (
"_enabled",
"_calibrating",
"_recipe",
"_amax_reduction_group",
"_graph",
"_fp8_state",
)
try:
yield
finally:
FP8GlobalStateManager.set_autocast_state(fp8_state)
FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph)

def __init__(
self,
enabled: bool = True,
calibrating: bool = False,
recipe: Optional["Recipe"] = None,
amax_reduction_group: Optional["dist_group_type"] = None,
_graph: bool = False,
) -> None:
self._enabled = enabled
self._calibrating = calibrating
self._recipe = recipe
self._amax_reduction_group = amax_reduction_group
self._graph = _graph
self._fp8_state = None

def __enter__(self) -> "autocast":
# Disallow nested re-entry of the same instance.
if self._fp8_state is not None:
raise RuntimeError(
"autocast context manager cannot be entered more than once concurrently"
)
if self._enabled:
check_recipe_support(self._recipe)
# Save current state so we always restore it on exit.
self._fp8_state = FP8GlobalStateManager.get_autocast_state()
FP8GlobalStateManager.autocast_enter(
enabled=self._enabled,
calibrating=self._calibrating,
fp8_recipe=self._recipe,
fp8_group=self._amax_reduction_group,
_graph=self._graph,
)
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
try:
FP8GlobalStateManager.set_autocast_state(self._fp8_state)
FP8GlobalStateManager.autocast_exit(self._enabled, _graph=self._graph)
finally:
# Clear the saved state so the instance can be entered again
# sequentially (and so a failure inside the restore path does not
# permanently mark the instance as "active").
self._fp8_state = None


def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
Expand Down
Loading