Skip to content

[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938

Open
hxbai wants to merge 7 commits into
NVIDIA:mainfrom
hxbai:swiglu_offset
Open

[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938
hxbai wants to merge 7 commits into
NVIDIA:mainfrom
hxbai:swiglu_offset

Conversation

@hxbai
Copy link
Copy Markdown
Contributor

@hxbai hxbai commented Apr 28, 2026

Description

The previous ClampedSwiGLU follows GPT-OSS, which hard-coded the offset 1.0.
DeepSeek-V4 uses ClampedSwiGLU without alpha and offset.
This PR makes the offset of ClampedSwiGLU configurable to support DeepSeek-V4.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 28, 2026

Greptile Summary

This PR makes the linear (gate) component offset of ClampedSwiGLU configurable, enabling DeepSeek-V4 support where offset should be 0.0 instead of the GPT-OSS hard-coded 1.0. A new v2 C API (nvte_clamped_swiglu_v2 / nvte_clamped_dswiglu_v2) is introduced to avoid breaking the public ABI.

  • Core struct change: ClampedSwiGLUParam gains a glu_linear_offset field (default 1.0f), propagated through all CUDA kernels (vectorized_pointwise.h, gated_fp8.cuh, gated_mxfp8.cuh) and both the forward and backward passes.
  • API versioning: The existing nvte_clamped_swiglu / nvte_clamped_dswiglu functions are preserved with hard-coded 1.0f and marked deprecated; new _v2 variants expose the offset parameter, and all Python/JAX/PyTorch binding layers are updated to call the v2 functions.
  • Fusion guard: fuse_grouped_mlp_ops in _common.py now blocks MXFP8 grouped-MLP fusion when glu_linear_offset != 1.0, correctly falling back to the non-fused path for the new variant.

Confidence Score: 5/5

Safe to merge. All code paths — FP32, FP8, MXFP8, PyTorch ops, JAX FFI, and the Python fallback — consistently propagate the new offset, and the default of 1.0 preserves existing behavior.

The change is well-scoped and consistent across all backends. Default values of glu_linear_offset = 1.0 in both C++ (ClampedSwiGLUParam) and Python/JAX layers guarantee backward compatibility. Old public C API symbols are preserved unchanged with the new semantics living in _v2 variants. The fusion guard in _common.py prevents the MXFP8 grouped-MLP fused kernel from being invoked with a non-default offset it was never designed to handle. Forward and backward CUDA math is correct: gate_in = clamp(x_linear) + offset is the right multiplier in the backward gradient for the activation branch. Tests cover both offset=1.0 and offset=0.0 for both ClampedSwiGLU and ScaledClampedQGeGLU.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/common/include/transformer_engine/activation.h Adds new v2 public API functions with glu_linear_offset, marks old functions deprecated, preserving ABI compatibility.
transformer_engine/common/util/math.h Adds glu_linear_offset = 1.0f field to ClampedSwiGLUParam with correct default preserving backward compatibility.
transformer_engine/common/activation/swiglu.cu Old functions hardcode offset to 1.0f; new v2 functions pass configurable offset through ClampedSwiGLUParam.
transformer_engine/common/util/vectorized_pointwise.h Both forward and backward kernel paths updated from hardcoded + 1 to + p.glu_linear_offset; math is correct.
transformer_engine/common/cast/fp8/gated_fp8.cuh Replaces hardcoded + 1 with + p.glu_linear_offset in FP8 quantization kernel; backward gradient derivation is correct.
transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh Both forward kernel paths (non-interleaved and interleaved) updated to use p.glu_linear_offset consistently.
transformer_engine/pytorch/ops/_common.py Fusion guard correctly blocks MXFP8 grouped-MLP fusion when glu_linear_offset != 1.0, preventing incorrect fused behavior.
transformer_engine/pytorch/ops/basic/swiglu.py Adds glu_linear_offset parameter to ClampedSwiGLU and ScaledClampedQGeGLU with correct default and propagation.
transformer_engine/jax/cpp_extensions/activation.py JAX ClampedSwigluParams dataclass updated with new field; custom hash and FFI lowering dict both include glu_linear_offset.
transformer_engine/jax/csrc/extensions.h ClampedSwigluConfig struct and XLA FFI struct attr decoding updated to include the new glu_linear_offset member.
transformer_engine/pytorch/module/layernorm_mlp.py Python-level fallback _clamped_swiglu reads glu_linear_offset from activation_params dict with correct default of 1.0.
tests/pytorch/test_fusible_ops.py Tests parametrized over glu_linear_offset of 1.0 and 0.0 for both test_clamped_swiglu and test_scaled_clamped_qgeglu.

Reviews (8): Last reviewed commit: "fix test" | Re-trigger Greptile

Comment on lines 339 to 341
* \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0).
* \param[in] stream CUDA stream used for the operation.
*/
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Breaking public C API change

nvte_clamped_swiglu and nvte_clamped_dswiglu are public symbols declared in a versioned public header. Inserting glu_linear_offset before cudaStream_t is an ABI-breaking change: any external binary or shared library compiled against the old header will silently pass the stream pointer as the offset and a garbage value as the stream, leading to undefined behavior at runtime rather than a clean compile error if called via a pre-compiled library. This should be acknowledged as a breaking change in the PR checklist, and — if this library follows semantic versioning or a compatibility guarantee — a deprecation/transition path or version bump is needed.

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fused op for grouped MLP is hard-coded for GPT-OSS, so we should make sure not to fuse if glu_linear_offset != 1:

elif isinstance(window[1], ScaledClampedQGeGLU) and (
abs(window[1]._clamped.alpha - 1.702) > 0.001
or not _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu()
):

@timmoon10
Copy link
Copy Markdown
Collaborator

/te-ci

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@hxbai hxbai marked this pull request as draft April 29, 2026 00:28
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@hxbai hxbai marked this pull request as ready for review April 29, 2026 01:01

void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
cudaStream_t stream) {
float glu_linear_offset, cudaStream_t stream) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define new APIs named nvte_clamped_swiglu_v2 and nvte_clamped_dswiglu_v2
and deprecate this API here to not break backward compatibility?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewrited this part

vthumbe1503 and others added 3 commits May 6, 2026 11:38
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator

/te-ci

hxbai added 2 commits May 12, 2026 15:13
Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants