Skip to content

feat: add more activation func#329

Open
RuibinCheung wants to merge 3 commits into
mainfrom
feat/zhangrb/add_act_func
Open

feat: add more activation func#329
RuibinCheung wants to merge 3 commits into
mainfrom
feat/zhangrb/add_act_func

Conversation

@RuibinCheung
Copy link
Copy Markdown
Collaborator

Description

Add a comprehensive set of fused activation functions (GELU, GEGLU, SwiGLU, Quick-GEGLU) with Triton kernel implementations, aligned with Megatron-LM's fused_bias_gelu, fused_bias_geglu, and fused_bias_swiglu interfaces. All functions support optional bias, per-token weights, and row_mask with mask computation integrated directly into Triton kernels.

Also fixes existing bugs in the geglu_bwd_kernel (incorrect gradient computation) and with_mask kernel variants (out-of-bounds memory access).

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Code refactoring

Changes

New Triton Kernels

  • bias_gelu_kernel.py: GELU with optional bias and row_mask (fwd/bwd, mask/no-mask variants)
  • bias_geglu_kernel.py: GEGLU with optional bias and row_mask
  • bias_swiglu_kernel.py: SwiGLU with optional bias and row_mask
  • gelu_kernel.py: GELU with per-token weights (probs) and row_mask
  • quick_geglu_kernel.py: Quick-GEGLU (sigmoid approximation) with optional bias, weights, linear_offset, and row_mask. Uses HAS_BIAS/HAS_WEIGHTS constexpr flags to avoid kernel duplication.

New Triton Utilities

  • gelu.py: Added quick_gelu and quick_gelu_bwd JIT primitives for sigmoid-approximated GELU

New Python Wrappers (pytorch/kernels/activation/)

  • gelu_impl.py: bias_gelu_fwd / bias_gelu_bwd
  • quick_geglu_impl.py: quick_geglu_fwd / quick_geglu_bwd with full bias/weights/row_mask dispatch
  • Merged bias_geglu_impl.py into geglu_impl.py and bias_swiglu_impl.py into swiglu_impl.py

Public API (pytorch/ops/activation.py)

Aligned with Megatron-LM interface naming:

API Description
bias_gelu_impl(input, bias, row_mask) GELU with optional bias
bias_geglu_impl(input, bias, row_mask, clamp_value) GEGLU with optional bias
bias_swiglu_impl(input, bias, row_mask, clamp_value) SwiGLU with optional bias
weighted_bias_geglu_impl(input, bias, weights, row_mask, clamp_value) Weighted GEGLU with optional bias (dispatches to 4 autograd Functions)
weighted_bias_swiglu_impl(input, bias, weights, row_mask, clamp_value) Weighted SwiGLU with optional bias
weighted_bias_quick_geglu_impl(input, bias, weights, row_mask, linear_offset, clamp_value) Weighted Quick-GEGLU (sigmoid approx) with optional bias

Bug Fixes

  • geglu_bwd_kernel gradient bug: Fixed grad_up computation that had an extra grad_out factor (grad_out² instead of grad_out). The bug was masked by the old test using torch.ones_like as gradient.
  • with_mask kernel OOB access: Added mask=row_mask to scalar loads (probs, weights) in all with_mask kernel variants across geglu_kernel.py, swiglu_kernel.py, gelu_kernel.py, and quick_geglu_kernel.py. Without mask, BLOCK_SIZE > num_tokens caused out-of-bounds GPU memory access.
  • Uninitialized output for masked rows: Changed torch.emptytorch.zeros for output and gradient buffers when row_mask is present, preventing NaN/garbage in masked-out rows.
  • linear_offset type error: Changed from torch.tensor(linear_offset, ...) to float(linear_offset) to avoid Triton treating it as a pointer instead of a scalar.

Refactoring

  • Renamed *_with_probsweighted_* to align with Megatron naming (probsweights)
  • Split GLUWithProbs into separate WeightedSwiGLUFunction / WeightedGeGLUFunction classes
  • Extracted _validate_weight_and_row_mask and _clamp_glu_input as shared helpers
  • Normalized all *_impl functions to follow a consistent pattern (shape handling, validation, dispatch)

Tests

  • Rewrote test_activation.py with comprehensive tests for all new APIs
  • Tests cover: forward precision, backward gradient correctness, optional bias, optional weights, optional row_mask, multiple dtypes (fp16/bf16), multiple sizes
  • Reference implementations use fp32 intermediate computation to match Triton kernel precision

Checklist:

  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copilot AI review requested due to automatic review settings May 7, 2026 11:13
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

安全审查完成:未发现需要复核的历史自动化线程。本次发现 3 个高可信内存安全问题,均来自新增 activation API 对调用方可控 tensor 边界约束不足,可能导致 GPU 内存越界读取或未初始化内存泄露。未修改代码。

Open in Web View Automation 

Sent by Cursor Automation: Find vulnerabilities


x = tl.load(x_ptr + row_idx * stride_x_token + col_off, mask=mask).to(compute_type)
if HAS_BIAS:
b = tl.load(bias_ptr + col_off, mask=col_mask).to(compute_type)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

严重性:高。bias 进入这里前没有校验 device/dtype/numel,HAS_BIAS 只取决于非 None;短 CUDA bias 会被按 stride_x_token 通过 tl.load(bias_ptr + col_off) 越界读取,并把 GPU 内存混入输出/梯度。同类 bias kernel 也需要在 launch 前强制 bias 形状精确匹配。

row_mask: Optional[torch.Tensor],
):
if weights is not None:
assert input.size(0) == weights.size(0), "first dimension of input and weights must be the same"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

严重性:高。这里仅校验 weights.size(0) 和 dtype,torch.empty(N, 0, device='cuda', dtype=torch.float32) 也能通过;随后 weights.view(-1) 被传给 Triton kernel 按 token tl.load,会越过 allocation 读取 GPU 内存并影响输出/梯度。请校验 weights 为 CUDA 且 numel()==N(如 [N]/[N,1])。

row_mask: Optional[torch.Tensor] = None,
):
num_tokens, hidden_size = grad_out.size()
grad_x = torch.empty_like(x)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

严重性:高。当 row_mask 存在时,下游 quick_geglu_with_mask_bwd_kernel 只对 extra_mask != 0 的行写 grad_x;这里仍用 empty_like,被 mask 掉的行会作为 x.grad 返回未初始化 GPU 内存。请像其他 masked backward 一样在 row_mask is not None 时使用 zeros_like

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR expands Primus-Turbo’s fused activation support by adding Triton kernels and PyTorch wrappers for GELU/GEGLU/SwiGLU/Quick-GEGLU with optional bias, token-wise weights, and integrated row_mask handling, while also fixing correctness/OOB issues in existing masked kernels and GEGLU backward.

Changes:

  • Added new Triton kernels for bias-fused and Quick-GEGLU activations plus utility JIT primitives (quick_gelu).
  • Refactored/renamed activation APIs to align with Megatron-LM-style bias_* and weighted_bias_* interfaces.
  • Updated tests to cover the new activation APIs and adjusted attention test parameterization.

Reviewed changes

Copilot reviewed 15 out of 15 changed files in this pull request and generated 13 comments.

Show a summary per file
File Description
tests/pytorch/ops/test_attention.py Adjusts attention test parameterization for enable_sink.
tests/pytorch/ops/test_activation.py Rewrites activation tests to cover new bias_* / weighted_bias_* APIs.
primus_turbo/triton/utils/gelu.py Adds Triton JIT quick_gelu and its backward primitive.
primus_turbo/triton/activation/swiglu_kernel.py Fixes masked variants to avoid OOB scalar loads by masking probs loads.
primus_turbo/triton/activation/quick_geglu_kernel.py Introduces Quick-GEGLU Triton fwd/bwd kernels (masked/unmasked).
primus_turbo/triton/activation/gelu_kernel.py Adds weighted GELU Triton fwd/bwd kernels (masked/unmasked).
primus_turbo/triton/activation/geglu_kernel.py Fixes GEGLU backward gradient computation and masks scalar loads in masked kernels.
primus_turbo/triton/activation/bias_swiglu_kernel.py Adds bias-fused SwiGLU Triton fwd/bwd kernels (masked/unmasked).
primus_turbo/triton/activation/bias_gelu_kernel.py Adds bias-fused GELU Triton fwd/bwd kernels (masked/unmasked).
primus_turbo/triton/activation/bias_geglu_kernel.py Adds bias-fused GEGLU Triton fwd/bwd kernels (masked/unmasked).
primus_turbo/pytorch/ops/activation.py New public activation API + autograd Functions wiring to Triton kernels.
primus_turbo/pytorch/kernels/activation/swiglu_impl.py Refactors weighted SwiGLU wrappers and adds bias-SwiGLU wrapper dispatch.
primus_turbo/pytorch/kernels/activation/quick_geglu_impl.py Adds Quick-GEGLU wrapper dispatch for bias/weights/mask.
primus_turbo/pytorch/kernels/activation/gelu_impl.py Adds bias-GELU wrapper dispatch for mask/no-mask kernels.
primus_turbo/pytorch/kernels/activation/geglu_impl.py Refactors weighted GeGLU wrappers and adds bias-GeGLU wrapper dispatch.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +56 to +60
def backward(ctx, grad_output):
input, bias, row_mask = ctx.saved_tensors
grad_x = bias_swiglu_bwd(grad_output, input, bias, row_mask)
return grad_x, (grad_x if bias is not None else None), None

def backward(ctx, grad_output):
input, bias, row_mask = ctx.saved_tensors
grad_x = bias_gelu_bwd(grad_output, input, bias, row_mask)
return grad_x, (grad_x if bias is not None else None), None
Comment on lines +106 to +111
@staticmethod
def backward(ctx, grad_output):
input, bias, row_mask = ctx.saved_tensors
grad_x = bias_geglu_bwd(grad_output, input, bias, row_mask)
return grad_x, (grad_x if bias is not None else None), None

def backward(ctx, grad_output):
input, bias, row_mask = ctx.saved_tensors
grad_x, _ = quick_geglu_bwd(grad_output, input, ctx.linear_offset, bias=bias, row_mask=row_mask)
return grad_x, (grad_x if bias is not None else None), None, None
Comment on lines 260 to +268
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
assert grad_output.ndim == 2
def backward(ctx, grad_output):
input, bias, weights, row_mask = ctx.saved_tensors

x, probs, row_mask = ctx.saved_tensors
y = (input + bias).view(-1, input.size(-1))
w = weights.view(-1)
grad_y, grad_w = swiglu_bwd(grad_output, y, w, row_mask)
grad_y = grad_y.view(input.shape)
return grad_y, grad_y, grad_w.view(weights.shape), None
pid = tl.program_id(0)
compute_type = tl.float32
grad_x_dtype = grad_x_ptr.dtype.element_ty
tl.int64
Comment on lines +99 to +105
def test_bias_gelu_impl(num_tokens, hidden_size, dtype, has_bias, has_row_mask):
_reset_seed()
device = "cuda"

x = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype, requires_grad=True)
bias = torch.randn(hidden_size, device=device, dtype=dtype) if has_bias else None
row_mask = make_row_mask(num_tokens, device) if has_row_mask else None
Comment thread tests/pytorch/ops/test_attention.py Outdated
@pytest.mark.parametrize("config", test_cases)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("enable_sink", [False, True])
@pytest.mark.parametrize("enable_sink", [True])
Comment on lines +49 to +53
class BiasSwiGLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, bias, row_mask):
ctx.save_for_backward(input, bias, row_mask)
return bias_swiglu_fwd(input, bias, row_mask)
Comment on lines +197 to +205
class WeightedSwiGLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weights, row_mask):
ctx.save_for_backward(input, weights, row_mask)

x = input.view(-1, input.size(-1))
w = weights.view(-1)
return swiglu_fwd(x, w, row_mask)

Copilot AI review requested due to automatic review settings May 8, 2026 02:49
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 14 out of 14 changed files in this pull request and generated 23 comments.

Comment on lines +56 to +60
def backward(ctx, grad_output):
input, bias, row_mask = ctx.saved_tensors
grad_x = bias_swiglu_bwd(grad_output, input, bias, row_mask)
return grad_x, (grad_x if bias is not None else None), None

Comment on lines +93 to +98
@staticmethod
def backward(ctx, grad_output):
input, bias, row_mask = ctx.saved_tensors
grad_x = bias_gelu_bwd(grad_output, input, bias, row_mask)
return grad_x, (grad_x if bias is not None else None), None

Comment on lines +106 to +111
@staticmethod
def backward(ctx, grad_output):
input, bias, row_mask = ctx.saved_tensors
grad_x = bias_geglu_bwd(grad_output, input, bias, row_mask)
return grad_x, (grad_x if bias is not None else None), None

Comment on lines +123 to +124
grad_x, _ = quick_geglu_bwd(grad_output, input, ctx.linear_offset, bias=bias, row_mask=row_mask)
return grad_x, (grad_x if bias is not None else None), None, None
Comment on lines 260 to +268
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
assert grad_output.ndim == 2
def backward(ctx, grad_output):
input, bias, weights, row_mask = ctx.saved_tensors

x, probs, row_mask = ctx.saved_tensors
y = (input + bias).view(-1, input.size(-1))
w = weights.view(-1)
grad_y, grad_w = swiglu_bwd(grad_output, y, w, row_mask)
grad_y = grad_y.view(input.shape)
return grad_y, grad_y, grad_w.view(weights.shape), None
Comment on lines +82 to +85
grad_x = torch.empty_like(x)
has_bias = bias is not None
has_weights = weights is not None
grad_weights = torch.empty_like(weights) if has_weights else None
pid = tl.program_id(0)
compute_type = tl.float32
grad_x_dtype = grad_x_ptr.dtype.element_ty
tl.int64
g = tl.load(grad_out_ptr + row_idx * stride_grad_out_token + col_off, mask=mask).to(compute_type)

grad_down = g * gelu_up
grad_up = g * down * gelu_bwd_none(up, tl.full(up.shape, 1.0, dtype=compute_type))
Comment on lines +108 to +114
def test_bias_gelu_impl(num_tokens, hidden_size, dtype, has_bias, has_row_mask):
_reset_seed()
device = "cuda"

x = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype, requires_grad=True)
bias = torch.randn(hidden_size, device=device, dtype=dtype) if has_bias else None
row_mask = make_row_mask(num_tokens, device) if has_row_mask else None
Comment on lines +112 to +115
x = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype, requires_grad=True)
bias = torch.randn(hidden_size, device=device, dtype=dtype) if has_bias else None
row_mask = make_row_mask(num_tokens, device) if has_row_mask else None

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants