Skip to content

[pull] main from NVIDIA:main#602

Merged
pull[bot] merged 3 commits into
phu0ngng:mainfrom
NVIDIA:main
May 11, 2026
Merged

[pull] main from NVIDIA:main#602
pull[bot] merged 3 commits into
phu0ngng:mainfrom
NVIDIA:main

Conversation

@pull
Copy link
Copy Markdown

@pull pull Bot commented May 11, 2026

See Commits and Changes for more details.


Created by pull[bot] (v2.0.0-alpha.4)

Can you help keep this open source service alive? 💖 Please sponsor : )

negvet and others added 3 commits May 11, 2026 19:26
* Enable semantic roles emitted by module/op and comsumed by custom recipe state

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Update quantization factories

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Fix tests

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Swap tensor:module

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Better naming

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Introduce QuantizerRole frozen data class instead of a string

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Shrink module_type vocabulary

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix numerics exact test

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Set defaults, make custom recipe forward compatible

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* remove position from QuantizerRole

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Set good defaults

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Resolve naming: make every module/op distinguishable via name

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Configure output/grad_input roles, defaults to None

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Remove is_gemm()

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Enable base recipes via CustomRecipe and quantization factories

Signed-off-by: Evgeny <etsykunov@gmail.com>

* Add factory example - NVFP4 for Linear, MXFP8 for GroupedLinear

Signed-off-by: Evgeny <etsykunov@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix custom recipe test

Signed-off-by: Evgeny <etsykunov@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Test fine-grained quantization targets

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add quantizer roles for attention (attn is wip)

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Enable statful recipes in the Custom recipe - Delayed Scaling support

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Fix save_original_input for custom delayed scaling

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Enable custom recipe for attn

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Make boundary role setting more explicit in MHA

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Make dpa role setting more intuitive

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Docstring for get_quantizer_roles() in base module

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Fix lint

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Restrict None roles

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Linter

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Minor fixes

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Test debug tools compat

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix pylint

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* fix test

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Fix lint

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Constructor takes roles kwarg + test fix

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Constructor takes roles kwarg + test fix (quantization.py)

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Fix attention: MXFP8, w/o CP

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Add test custom recipe

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Make Float8BlockScalingRecipeState and NVFP4BlockScalingRecipeState aware about QuantizerRole, dispatch on that + positional fallback if get_quantizer_roles() is not defined by the module/op

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix linter

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Fix CI

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Preserve delayed scaling state (buffers) when rebuild is triggered

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Fix test, minor

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix distributed tests

Signed-off-by: Evgeny <etsykunov@nvidia.com>

---------

Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny Tsykunov <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Evgeny <etsykunov@gmail.com>
…permute (#2907)

* [Common][PyTorch] Fix int32 overflow and -1 sentinel handling in moe_permute

Two independent bugs in transformer_engine/common/permutation/permutation.cu
and the PyTorch extension caller reproduce on main (264da2b) and v2.13:

1. int32 overflow in moe_unpermute_kernel and moe_permute_kernel.
   `source_token * num_cols` and `source_row * num_cols` are computed with
   int, so for long-sequence MoE workloads where num_out_tokens * num_cols
   reaches 2**31 (e.g. 2**18 tokens x 2**13 hidden), the pointer offset
   wraps and the kernel either reads garbage or raises
   `an illegal memory access was encountered`.
   Widening source_token, source_row and dest_row to int64_t inside the
   kernels keeps the index arithmetic in 64 bits without changing any
   public types.

2. Incorrect handling of -1 sentinels in the routing indices.
   Libraries such as DeepEP (and any expert-parallel mask that sets
   non-local (token, slot) pairs to -1) feed a routing_map that contains
   -1 entries. `cub::DeviceRadixSort::SortPairs` is signed ascending, so
   those sentinels land at the HEAD of sorted_row_id, not the tail.
   moe_permute_row_map currently writes -1 only for idx >= num_out_tokens
   and reads the sentinel prefix as if it were a valid sorted id,
   producing bogus row_id_map writes (for instance
   `source_row / topK == 0, source_row % topK == -1`).

   The caller now advances sorted_row_id_ptr past the num_minus_ones
   prefix and pre-fills row_id_map with -1 via torch::full, so the
   kernel only processes the valid suffix and never dereferences a
   sentinel.  The launcher's grid switches from num_rows*topK blocks
   to num_out_tokens blocks to match the new valid range.

No behaviour change on happy-path routing_map (no -1, no overflow).
Reproducers:

- 8-token, topK=2 routing_map with -1 masking: max |TE - ref| = 4.5e0
  on bf16 with current main; 0.0 with this patch.
- num_tokens=2**18+1, num_cols=2**13, topK=1: current main raises
  CUDA illegal memory access at permutation.cu:252; with this patch
  it succeeds.

Signed-off-by: Jingyi Xi <flotherxi@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Guard against invalid num_out_tokens in moe_permute_fwd

Add an NVTE_CHECK that num_out_tokens <= num_tokens * topK and cast
num_minus_ones to size_t before the pointer advance, so a negative
num_minus_ones (from an invalid num_out_tokens) cannot silently wrap
into a huge pointer offset.

Signed-off-by: Jingyi Xi <flotherxi@gmail.com>

* Switch radix sort keys to uint32_t to fix -1 sentinel ordering

The MoE permute path was correct for the existing capacity-drop convention
(drops encoded as a large positive expert id, sorted to the tail by the
signed cub::DeviceRadixSort), but it broke for callers that mark dropped
(token, slot) pairs with -1 (expert-parallel rank masking, e.g. DeepEP).
With signed sort the -1 sentinels land at the HEAD of sorted_row_id, while
moe_permute_row_map's `idx >= num_out_tokens` branch assumes drops are at
the tail.

Reinterpret the keys as uint32_t inside nvte_device_radix_sort_pairs so
-1 (= UINT_MAX) sorts to the tail, unifying the EP-mask case with the
existing capacity-drop convention. The kernel and caller sides are
unchanged - this is a one-place fix that makes both drop conventions
land in the existing drop branch.

Also widen the loop-carried indices in moe_unpermute_kernel and
moe_permute_kernel to int64_t (`source_token`, `source_row`, `dest_row`)
to keep `row * num_cols` in 64 bits. We hit this on DeepSeek-V3 long-
context training (hidden = 7168, topK = 8): once `num_out_tokens *
num_cols` reaches 2**31 the int product wraps and the kernel either
silently corrupts rows or raises CUDA `illegal memory access`.

Signed-off-by: Jingyi Xi <flotherxi@gmail.com>

* Widen num_rows * topK products in moe_permute_row_map for consistency

Per reviewer feedback in #2907, promote the
int * int multiplications in moe_permute_row_map and its launcher to
int64_t. These are not the overflow path this PR was originally
fixing (DeepSeek-V3 long-context hits row * num_cols, where num_cols
is the hidden dim ~ 7-8k), and num_rows * topK only crosses 2**31 at
unrealistic per-rank token counts (>= 268M at topK=8). The change is
purely defensive but keeps the index arithmetic in this kernel
consistent with the int64_t source_token / source_row / dest_row
widening already applied to moe_unpermute_kernel and moe_permute_kernel.

Signed-off-by: Jingyi Xi <flotherxi@gmail.com>

---------

Signed-off-by: Jingyi Xi <flotherxi@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Teddy Do <tdophung@nvidia.com>
* [PyTorch] Linear: minor cleanups for compile-friendliness

Three small refactors that make the module easier to reason about
and pave the way for the dataclass / saved-tensor refactors:

- Add a TensorOrQuantized type alias (Union[Tensor, QuantizedTensorStorage])
  used pervasively in helper signatures.
- Hoist the conditional bias argument into a local linear_bias_tensor
  variable instead of an inline expression at the linear_fn() call site.
- Only forward self.wgrad_store into the autograd Function when it is
  actually active (delay_wgrad_compute() is True); pass None otherwise so
  the autograd graph does not carry an unused Python object.

Pure rename / hoisting; no behavioural change.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>

* [PyTorch] Linear: pack forward/backward state into dataclasses

Replace the loosely typed ``non_tensor_args`` tuple and the ad-hoc
``ctx.<attr>`` plumbing with two dataclasses, ``LinearFwdArgs`` and
``LinearBwdArgs``, that act as the single argument to every helper
in the forward/backward pipeline.

What changes:

* ``LinearFwdArgs`` carries the (positional) tensors ``weight``, ``inp``
  and ``bias`` plus all quantizers, ``requires_grad`` flags, the cached
  ``weight_workspace`` and every former ``non_tensor_args`` knob.
  ``_Linear.forward`` still takes ``weight/inp/bias`` as positional
  Tensor inputs so autograd tracks them, then immediately re-attaches
  them to ``fwd_args`` so every downstream helper has a single-argument
  signature.
* ``LinearBwdArgs`` mirrors that on the backward side: it owns the
  saved tensors (``inputmat``, ``weight_fp8``, ``saved_weight``,
  ``bias``), the per-call quantizers, every flag previously stored
  directly on ``ctx`` and a ``setup_saved_tensors(saved_tensors,
  tensor_objects)`` helper that rehydrates the saved-tensor fields.
* ``ctx.backward_objects = bwd_args`` is now the single attribute the
  autograd context needs (besides ``saved_tensors``/``tensor_objects``).
* ``weight_workspace`` is no longer a positional Tensor arg of the
  autograd Function; it is read from ``fwd_args.weight_workspace`` and
  the freshly produced workspace is returned alongside ``out`` so the
  module can refresh its cache without autograd tracking the cache.
* ``prepare_for_saving`` now lives at the autograd boundary in
  ``_Linear.forward``; ``_linear_setup_ctx`` only returns the merged
  list of tensors that should be saved.
* ``grad_output_preprocess`` is invoked with ``bwd_args`` directly
  (it is duck-typed on the same attribute names) so backward never
  reaches into ``ctx.<attr>`` for non-tensor state.

Behaviour preserved (verified numerically against ``torch.nn.Linear``
and on FP8 + workspace-cache paths).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>

* [PyTorch] Linear: deduplicate saved tensors that alias forward inputs

When ``saved_inputmat is inp``, ``wt_save is weight`` or ``bias`` is the
exact bias passed in, there is no point asking ``prepare_for_saving`` to
serialize the same Python object twice. Make ``_linear_forward_impl``
emit ``None`` in those slots (and a parallel ``saved_tensor_aliases``
tuple in ``ctx_attrs`` describing which slot points where), and have
``_linear_setup_ctx`` rebuild the tuple with the original references
before handing it to ``prepare_for_saving``.

Saves a Python ref per alias in eager and, more importantly, keeps the
forward helper from "returning" a tensor that aliases its own inputs --
a pattern ``torch.compile`` would otherwise need to reason about when
the helper is wrapped in an opaque op.

Numerically equivalent (validated against ``torch.nn.Linear`` and on a
multi-iteration FP8 path with workspace caching).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [PyTorch] Linear: tighten LinearFwdArgs/BwdArgs and trim ctx_attrs

Follow-up cleanups on top of the dataclass refactor:

* Sort ``LinearFwdArgs`` / ``LinearBwdArgs`` fields into labelled groups
  (tensors, requires_grad flags, quantizers, dtype/numerical config,
  parallelism, userbuffers, FSDP, wgrad scheduling, misc) and mirror that
  ordering in their construction sites.
* Add ``slots=True`` to both dataclasses so typos in
  ``fwd_args.X`` / ``bwd_args.X`` raise ``AttributeError`` immediately
  instead of silently creating a new attribute.
* Inline single-use ``args.X`` aliases in ``_linear_forward_impl``
  (``weight_workspace``, ``fp8_calibration``, ``tp_size``,
  ``tensor_parallel``, ``cache_weight``, ``skip_fp8_weight_update``,
  ``custom``, ``backward_input_needs_gather``) so the prelude only keeps
  aliases that are actually reused.
* Shrink ``ctx_attrs`` to ``{fsdp_shapes, saved_tensor_aliases}``:
  ``weight_quantizer`` is re-derived in ``_linear_setup_ctx`` from
  ``fwd_args.weight`` (matching the resolution done in forward),
  ``is_fsdp2`` already lives on ``fwd_args``, and ``owns_input`` is
  equivalent to ``saved_tensor_aliases[0] != "inp"``.
* Replace ``setup_saved_tensors(saved_tensors, tensor_objects)`` with
  ``setup_saved_tensors(ctx)`` backed by ``restore_from_func_ctx``,
  matching ``layernorm_mlp`` / ``layernorm_linear`` /
  ``grouped_linear`` and dropping the manual
  ``ctx.tensor_objects = None`` cleanup.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [PyTorch] tests: snapshot backward ctx state from LinearBwdArgs

After packing the Linear backward state into ``LinearBwdArgs`` the
attributes the test was reading (``backward_override``, ``fp8``,
``grad_output_quantizer``, ``reduce_and_update_bwd_fp8_tensors``) no
longer live directly on ``grad_fn``. Read them from
``grad_fn.backward_objects`` when present, falling back to ``grad_fn``
for the linear-like modules that have not been refactored yet
(``layernorm_linear``, ``ops_linear``).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [PyTorch] Linear: add docstrings to LinearFwdArgs / LinearBwdArgs

Restore the one-line class docstrings dropped during the field
reorganization so pylint stops warning about C0115.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>

* [PyTorch] Linear: drop ctx.backward_objects after backward

Saved tensors, quantizers, weakrefs and main_grad closures referenced
from LinearBwdArgs survived until ctx GC, extending peak GPU memory
under retain_graph=True. Null out ctx.backward_objects right after
_linear_backward so they are released as soon as backward returns.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

---------

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@pull pull Bot locked and limited conversation to collaborators May 11, 2026
@pull pull Bot added the ⤵️ pull label May 11, 2026
@pull pull Bot merged commit 282b4fb into phu0ngng:main May 11, 2026
@pull pull Bot had a problem deploying to github-pages May 11, 2026 22:33 Failure
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants