Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1809,6 +1809,7 @@ def test_interleaved_swiglu(self):
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantize_forward", (False, True))
@pytest.mark.parametrize("quantize_backward", (False, True))
@pytest.mark.parametrize("glu_linear_offset", (1.0, 0.0))
def test_clamped_swiglu(
self,
*,
Expand All @@ -1819,6 +1820,7 @@ def test_clamped_swiglu(
quantization: Optional[str],
quantize_forward: bool,
quantize_backward: bool,
glu_linear_offset: float,
limit: float = 0.75,
alpha: float = 1.702,
):
Expand Down Expand Up @@ -1861,7 +1863,7 @@ def test_clamped_swiglu(
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
y_ref = out_glu * (x_linear + 1)
y_ref = out_glu * (x_linear + glu_linear_offset)
y_ref.backward(dy_ref)

# Implementation with fusible operation
Expand All @@ -1872,6 +1874,7 @@ def test_clamped_swiglu(
te_ops.ClampedSwiGLU(
limit=limit,
alpha=alpha,
glu_linear_offset=glu_linear_offset,
glu_interleave_size=glu_interleave_size,
),
te_ops.Quantize(forward=quantize_forward, backward=False),
Expand Down Expand Up @@ -1901,6 +1904,7 @@ def test_interleaved_clamped_swiglu(self):
quantize_forward=False,
quantize_backward=False,
glu_interleave_size=32,
glu_linear_offset=1.0,
)

@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
Expand Down Expand Up @@ -2492,6 +2496,7 @@ def test_interleaved_scaled_swiglu(self):
@pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128)))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("scales_requires_grad", (False, True))
@pytest.mark.parametrize("glu_linear_offset", (1.0, 0.0))
def test_scaled_clamped_qgeglu(
self,
*,
Expand All @@ -2501,6 +2506,7 @@ def test_scaled_clamped_qgeglu(
device: torch.device = "cuda",
input_requires_grad: bool,
scales_requires_grad: bool,
glu_linear_offset: float,
limit: float = 7.0,
alpha: float = 1.702,
) -> None:
Expand Down Expand Up @@ -2545,7 +2551,7 @@ def test_scaled_clamped_qgeglu(
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
y = out_glu * (x_linear + 1)
y = out_glu * (x_linear + glu_linear_offset)
y_ref = scales_ref.unsqueeze(-1) * y
if input_requires_grad or scales_requires_grad:
y_ref.backward(dy_ref)
Expand All @@ -2554,6 +2560,7 @@ def test_scaled_clamped_qgeglu(
glu_interleave_size=glu_interleave_size,
limit=limit,
alpha=alpha,
glu_linear_offset=glu_linear_offset,
)
y_test = op(x_test, scales_test)
if input_requires_grad or scales_requires_grad:
Expand All @@ -2572,6 +2579,7 @@ def test_interleaved_scaled_clamped_qgeglu(self):
glu_interleave_size=32,
input_requires_grad=True,
scales_requires_grad=True,
glu_linear_offset=1.0,
)


Expand Down Expand Up @@ -3589,6 +3597,13 @@ def test_grouped_mlp(
activation: str,
) -> None:
"""GroupedLinear + ScaledSwiGLU / ScaledClampedQGeGLU + GroupedLinear"""
if os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0") == "0" and (
single_grouped_weight or single_grouped_bias
):
pytest.skip(
"single_grouped_weight/single_grouped_bias requires"
" NVTE_GROUPED_LINEAR_SINGLE_PARAM=1"
)

# Split sizes
split_sizes = [split_alignment * (i) for i in range(group_size)]
Expand Down
24 changes: 22 additions & 2 deletions transformer_engine/common/activation/swiglu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,35 @@ void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit,
cudaStream_t stream) {
NVTE_API_CALL(nvte_clamped_swiglu);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha};
// Preserve original behavior: linear (gate) component offset is hard-coded to 1.0f.
ClampedSwiGLUParam param = {limit, alpha, /*glu_linear_offset=*/1.0f};
gated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>>(input, output, param, stream);
}

void nvte_clamped_swiglu_v2(const NVTETensor input, NVTETensor output, float limit, float alpha,
float glu_linear_offset, cudaStream_t stream) {
NVTE_API_CALL(nvte_clamped_swiglu_v2);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset};
gated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>>(input, output, param, stream);
}

void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
float limit, float alpha, cudaStream_t stream) {
NVTE_API_CALL(nvte_clamped_dswiglu);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha};
// Preserve original behavior: linear (gate) component offset is hard-coded to 1.0f.
ClampedSwiGLUParam param = {limit, alpha, /*glu_linear_offset=*/1.0f};
dgated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>, clamped_dsilu<fp32, fp32>>(
grad, input, output, param, stream);
}

void nvte_clamped_dswiglu_v2(const NVTETensor grad, const NVTETensor input, NVTETensor output,
float limit, float alpha, float glu_linear_offset,
cudaStream_t stream) {
NVTE_API_CALL(nvte_clamped_dswiglu_v2);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset};
dgated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>, clamped_dsilu<fp32, fp32>>(
grad, input, output, param, stream);
}
5 changes: 2 additions & 3 deletions transformer_engine/common/cast/fp8/gated_fp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
bool dgate_elt = true; // gating is ideally an identity function
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1;
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit;
gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset;
}

if constexpr (IS_BWD) {
Expand Down
10 changes: 4 additions & 6 deletions transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float after_gate_elt;
bool dgate_elt = true; // gating is ideally an identity function
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit;
gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset;
}
if constexpr (IS_BWD) {
float grad_elt = static_cast<float>(in_grad_sh[shmem_offset_colwise]);
Expand Down Expand Up @@ -510,9 +509,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float after_gate_elt;
bool dgate_elt = true;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit;
gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset;
}
if constexpr (IS_BWD) {
float grad_elt = static_cast<float>(in_grad.data.elt[e]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,11 @@ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);

/*! \brief Computes the gated Swish activation of the input used in GPT OSS.
*
* \deprecated This function has been deprecated in favor of nvte_clamped_swiglu_v2,
* which exposes a configurable offset for the linear (gate) component.
* This API is preserved for backward compatibility and is equivalent to
* calling nvte_clamped_swiglu_v2 with glu_linear_offset = 1.0.
*
* See https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
* This Gated activation has two differences compared to the original SwiGLU
Expand All @@ -341,6 +346,28 @@ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
cudaStream_t stream);

/*! \brief Computes the gated Swish activation of the input used in GPT OSS, with a configurable
* offset for the linear (gate) component after clamping.
*
* See https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
* This Gated activation has two differences compared to the original SwiGLU
* 1. Both gate and pre-activations are clipped based on parameter limit.
* 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired
* by original GELU paper https://arxiv.org/pdf/1606.08415
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes Act(input[N, :H]) x (input[N, H:] + glu_linear_offset)
* \param[in] limit Clipping limits for gate and pre-activation.
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
* \param[in] glu_linear_offset Offset added to the linear component after clamping (typically 1.0).
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_clamped_swiglu_v2(const NVTETensor input, NVTETensor output, float limit, float alpha,
float glu_linear_offset, cudaStream_t stream);

/*! \brief Computes the gated ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
Expand Down Expand Up @@ -399,6 +426,11 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp
cudaStream_t stream);

/*! \brief Computes the gradient of gated Swish activation of the input used in GPT OSS.
*
* \deprecated This function has been deprecated in favor of nvte_clamped_dswiglu_v2,
* which exposes a configurable offset for the linear (gate) component.
* This API is preserved for backward compatibility and is equivalent to
* calling nvte_clamped_dswiglu_v2 with glu_linear_offset = 1.0.
*
* https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
* This activation has two differences compared to the original SwiGLU
Expand All @@ -418,6 +450,29 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp
void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
float limit, float alpha, cudaStream_t stream);

/*! \brief Computes the gradient of gated Swish activation of the input used in GPT OSS, with a
* configurable offset for the linear (gate) component after clamping.
*
* https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
* This activation has two differences compared to the original SwiGLU
* 1. Both gate and pre-activations are clipped based on parameter limit.
* 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired
* by original GELU paper https://arxiv.org/pdf/1606.08415
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] limit Clipping limits for gate and pre-activation.
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
* \param[in] glu_linear_offset Offset added to the linear component after clamping (typically 1.0).
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_clamped_dswiglu_v2(const NVTETensor grad, const NVTETensor input, NVTETensor output,
float limit, float alpha, float glu_linear_offset,
cudaStream_t stream);

/*! \brief Computes the gated ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/common/util/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ struct Empty {};

struct ClampedSwiGLUParam {
float limit;
float alpha = 1.702f; // Default value for QuickGELU
float alpha = 1.702f; // Default value for QuickGELU
float glu_linear_offset = 1.0f; // Offset added to the linear (gate) component after clamping
};

template <typename OType, typename IType>
Expand Down
8 changes: 3 additions & 5 deletions transformer_engine/common/util/vectorized_pointwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,8 @@ __launch_bounds__(unary_kernel_threads) __global__
ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]);

if constexpr (std::is_same<Param, ClampedSwiGLUParam>::value) {
// Clamp the gated value and add 1 at the end
ComputeType limit = p.limit;
val2 = std::min(std::max(-limit, val2), limit) + 1;
val2 = std::min(std::max(-limit, val2), limit) + p.glu_linear_offset;
}
ComputeType temp = static_cast<ComputeType>(Activation(val, p) * val2);
if (requires_amax) {
Expand Down Expand Up @@ -542,10 +541,9 @@ __launch_bounds__(unary_kernel_threads) __global__
bool dgate_in = true;

if constexpr (std::is_same<Param, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
const ComputeType limit = p.limit;
dgate_in = gate_in <= limit && gate_in >= -limit; // Derivative of clamp
gate_in = std::min(std::max(-limit, gate_in), limit) + 1.0f;
dgate_in = gate_in <= limit && gate_in >= -limit;
gate_in = std::min(std::max(-limit, gate_in), limit) + p.glu_linear_offset;
}

ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in;
Expand Down
15 changes: 9 additions & 6 deletions transformer_engine/jax/cpp_extensions/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,15 @@ class ClampedSwigluParams:

limit: float = 7.0
alpha: float = 1.702
glu_linear_offset: float = 1.0

def __hash__(self):
"""Custom hash function to ensure dataclass is hashable for jax jit to work.

Returns:
int: Hash value of the dataclass instance.
"""
return hash((self.limit, self.alpha))
return hash((self.limit, self.alpha, self.glu_linear_offset))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update one of the tests here to use a non-default value of glu_linear_offset?


def to_ffi_lowering_dict(self):
"""Convert the activation parameters to a dictionary format for FFI lowering.
Expand All @@ -80,7 +81,11 @@ def to_ffi_lowering_dict(self):
dict: A dictionary representation of the activation parameters consumable by
XLA FFI bindings for activation functions.
"""
return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)}
return {
"limit": np.float32(self.limit),
"alpha": np.float32(self.alpha),
"glu_linear_offset": np.float32(self.glu_linear_offset),
}


@dataclass(frozen=True)
Expand Down Expand Up @@ -121,11 +126,9 @@ def _convert_to_activation_function(fn_or_string, act_params: ActivationParams):
if fn_or_string == "linear":
return lambda x: x
if fn_or_string == "clamped_linear":
# This function is used for ClampedSwiGLU
# used in GPT OSS where the gates are not only clamped
# but also shifted by +1
limit = act_params.clamped_swiglu.limit
return lambda x: jnp.clip(x, min=-limit, max=limit) + 1
offset = act_params.clamped_swiglu.glu_linear_offset
return lambda x: jnp.clip(x, min=-limit, max=limit) + offset
if fn_or_string == "quick_gelu":
return lambda x: jax.nn.sigmoid(1.702 * x) * x
if fn_or_string == "squared_relu":
Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/jax/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ namespace jax {
struct ClampedSwigluConfig {
float limit;
float alpha;
float glu_linear_offset;
};

struct ActivationConfig {
Expand Down Expand Up @@ -208,7 +209,8 @@ pybind11::tuple GetTopkWorkspaceSizes(int batch_size, int seq_len, int k);

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig,
::xla::ffi::StructMember<float>("limit"),
::xla::ffi::StructMember<float>("alpha"));
::xla::ffi::StructMember<float>("alpha"),
::xla::ffi::StructMember<float>("glu_linear_offset"));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a default value for users on HLO from a previous version? Would glu_linear_offset=1 be the same as the current behavior on main?


XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(
transformer_engine::jax::ActivationConfig,
Expand Down
Loading
Loading