feat: add more activation func#329
Conversation
|
|
||
| x = tl.load(x_ptr + row_idx * stride_x_token + col_off, mask=mask).to(compute_type) | ||
| if HAS_BIAS: | ||
| b = tl.load(bias_ptr + col_off, mask=col_mask).to(compute_type) |
There was a problem hiding this comment.
严重性:高。bias 进入这里前没有校验 device/dtype/numel,HAS_BIAS 只取决于非 None;短 CUDA bias 会被按 stride_x_token 通过 tl.load(bias_ptr + col_off) 越界读取,并把 GPU 内存混入输出/梯度。同类 bias kernel 也需要在 launch 前强制 bias 形状精确匹配。
| row_mask: Optional[torch.Tensor], | ||
| ): | ||
| if weights is not None: | ||
| assert input.size(0) == weights.size(0), "first dimension of input and weights must be the same" |
There was a problem hiding this comment.
严重性:高。这里仅校验 weights.size(0) 和 dtype,torch.empty(N, 0, device='cuda', dtype=torch.float32) 也能通过;随后 weights.view(-1) 被传给 Triton kernel 按 token tl.load,会越过 allocation 读取 GPU 内存并影响输出/梯度。请校验 weights 为 CUDA 且 numel()==N(如 [N]/[N,1])。
| row_mask: Optional[torch.Tensor] = None, | ||
| ): | ||
| num_tokens, hidden_size = grad_out.size() | ||
| grad_x = torch.empty_like(x) |
There was a problem hiding this comment.
严重性:高。当 row_mask 存在时,下游 quick_geglu_with_mask_bwd_kernel 只对 extra_mask != 0 的行写 grad_x;这里仍用 empty_like,被 mask 掉的行会作为 x.grad 返回未初始化 GPU 内存。请像其他 masked backward 一样在 row_mask is not None 时使用 zeros_like。
There was a problem hiding this comment.
Pull request overview
This PR expands Primus-Turbo’s fused activation support by adding Triton kernels and PyTorch wrappers for GELU/GEGLU/SwiGLU/Quick-GEGLU with optional bias, token-wise weights, and integrated row_mask handling, while also fixing correctness/OOB issues in existing masked kernels and GEGLU backward.
Changes:
- Added new Triton kernels for bias-fused and Quick-GEGLU activations plus utility JIT primitives (
quick_gelu). - Refactored/renamed activation APIs to align with Megatron-LM-style
bias_*andweighted_bias_*interfaces. - Updated tests to cover the new activation APIs and adjusted attention test parameterization.
Reviewed changes
Copilot reviewed 15 out of 15 changed files in this pull request and generated 13 comments.
Show a summary per file
| File | Description |
|---|---|
tests/pytorch/ops/test_attention.py |
Adjusts attention test parameterization for enable_sink. |
tests/pytorch/ops/test_activation.py |
Rewrites activation tests to cover new bias_* / weighted_bias_* APIs. |
primus_turbo/triton/utils/gelu.py |
Adds Triton JIT quick_gelu and its backward primitive. |
primus_turbo/triton/activation/swiglu_kernel.py |
Fixes masked variants to avoid OOB scalar loads by masking probs loads. |
primus_turbo/triton/activation/quick_geglu_kernel.py |
Introduces Quick-GEGLU Triton fwd/bwd kernels (masked/unmasked). |
primus_turbo/triton/activation/gelu_kernel.py |
Adds weighted GELU Triton fwd/bwd kernels (masked/unmasked). |
primus_turbo/triton/activation/geglu_kernel.py |
Fixes GEGLU backward gradient computation and masks scalar loads in masked kernels. |
primus_turbo/triton/activation/bias_swiglu_kernel.py |
Adds bias-fused SwiGLU Triton fwd/bwd kernels (masked/unmasked). |
primus_turbo/triton/activation/bias_gelu_kernel.py |
Adds bias-fused GELU Triton fwd/bwd kernels (masked/unmasked). |
primus_turbo/triton/activation/bias_geglu_kernel.py |
Adds bias-fused GEGLU Triton fwd/bwd kernels (masked/unmasked). |
primus_turbo/pytorch/ops/activation.py |
New public activation API + autograd Functions wiring to Triton kernels. |
primus_turbo/pytorch/kernels/activation/swiglu_impl.py |
Refactors weighted SwiGLU wrappers and adds bias-SwiGLU wrapper dispatch. |
primus_turbo/pytorch/kernels/activation/quick_geglu_impl.py |
Adds Quick-GEGLU wrapper dispatch for bias/weights/mask. |
primus_turbo/pytorch/kernels/activation/gelu_impl.py |
Adds bias-GELU wrapper dispatch for mask/no-mask kernels. |
primus_turbo/pytorch/kernels/activation/geglu_impl.py |
Refactors weighted GeGLU wrappers and adds bias-GeGLU wrapper dispatch. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def backward(ctx, grad_output): | ||
| input, bias, row_mask = ctx.saved_tensors | ||
| grad_x = bias_swiglu_bwd(grad_output, input, bias, row_mask) | ||
| return grad_x, (grad_x if bias is not None else None), None | ||
|
|
| def backward(ctx, grad_output): | ||
| input, bias, row_mask = ctx.saved_tensors | ||
| grad_x = bias_gelu_bwd(grad_output, input, bias, row_mask) | ||
| return grad_x, (grad_x if bias is not None else None), None |
| @staticmethod | ||
| def backward(ctx, grad_output): | ||
| input, bias, row_mask = ctx.saved_tensors | ||
| grad_x = bias_geglu_bwd(grad_output, input, bias, row_mask) | ||
| return grad_x, (grad_x if bias is not None else None), None | ||
|
|
| def backward(ctx, grad_output): | ||
| input, bias, row_mask = ctx.saved_tensors | ||
| grad_x, _ = quick_geglu_bwd(grad_output, input, ctx.linear_offset, bias=bias, row_mask=row_mask) | ||
| return grad_x, (grad_x if bias is not None else None), None, None |
| @staticmethod | ||
| def backward(ctx, grad_output: torch.Tensor): | ||
| assert grad_output.ndim == 2 | ||
| def backward(ctx, grad_output): | ||
| input, bias, weights, row_mask = ctx.saved_tensors | ||
|
|
||
| x, probs, row_mask = ctx.saved_tensors | ||
| y = (input + bias).view(-1, input.size(-1)) | ||
| w = weights.view(-1) | ||
| grad_y, grad_w = swiglu_bwd(grad_output, y, w, row_mask) | ||
| grad_y = grad_y.view(input.shape) | ||
| return grad_y, grad_y, grad_w.view(weights.shape), None |
| pid = tl.program_id(0) | ||
| compute_type = tl.float32 | ||
| grad_x_dtype = grad_x_ptr.dtype.element_ty | ||
| tl.int64 |
| def test_bias_gelu_impl(num_tokens, hidden_size, dtype, has_bias, has_row_mask): | ||
| _reset_seed() | ||
| device = "cuda" | ||
|
|
||
| x = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype, requires_grad=True) | ||
| bias = torch.randn(hidden_size, device=device, dtype=dtype) if has_bias else None | ||
| row_mask = make_row_mask(num_tokens, device) if has_row_mask else None |
| @pytest.mark.parametrize("config", test_cases) | ||
| @pytest.mark.parametrize("causal", [True, False]) | ||
| @pytest.mark.parametrize("enable_sink", [False, True]) | ||
| @pytest.mark.parametrize("enable_sink", [True]) |
| class BiasSwiGLUFunction(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, input, bias, row_mask): | ||
| ctx.save_for_backward(input, bias, row_mask) | ||
| return bias_swiglu_fwd(input, bias, row_mask) |
| class WeightedSwiGLUFunction(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, input, weights, row_mask): | ||
| ctx.save_for_backward(input, weights, row_mask) | ||
|
|
||
| x = input.view(-1, input.size(-1)) | ||
| w = weights.view(-1) | ||
| return swiglu_fwd(x, w, row_mask) | ||
|
|
| def backward(ctx, grad_output): | ||
| input, bias, row_mask = ctx.saved_tensors | ||
| grad_x = bias_swiglu_bwd(grad_output, input, bias, row_mask) | ||
| return grad_x, (grad_x if bias is not None else None), None | ||
|
|
| @staticmethod | ||
| def backward(ctx, grad_output): | ||
| input, bias, row_mask = ctx.saved_tensors | ||
| grad_x = bias_gelu_bwd(grad_output, input, bias, row_mask) | ||
| return grad_x, (grad_x if bias is not None else None), None | ||
|
|
| @staticmethod | ||
| def backward(ctx, grad_output): | ||
| input, bias, row_mask = ctx.saved_tensors | ||
| grad_x = bias_geglu_bwd(grad_output, input, bias, row_mask) | ||
| return grad_x, (grad_x if bias is not None else None), None | ||
|
|
| grad_x, _ = quick_geglu_bwd(grad_output, input, ctx.linear_offset, bias=bias, row_mask=row_mask) | ||
| return grad_x, (grad_x if bias is not None else None), None, None |
| @staticmethod | ||
| def backward(ctx, grad_output: torch.Tensor): | ||
| assert grad_output.ndim == 2 | ||
| def backward(ctx, grad_output): | ||
| input, bias, weights, row_mask = ctx.saved_tensors | ||
|
|
||
| x, probs, row_mask = ctx.saved_tensors | ||
| y = (input + bias).view(-1, input.size(-1)) | ||
| w = weights.view(-1) | ||
| grad_y, grad_w = swiglu_bwd(grad_output, y, w, row_mask) | ||
| grad_y = grad_y.view(input.shape) | ||
| return grad_y, grad_y, grad_w.view(weights.shape), None |
| grad_x = torch.empty_like(x) | ||
| has_bias = bias is not None | ||
| has_weights = weights is not None | ||
| grad_weights = torch.empty_like(weights) if has_weights else None |
| pid = tl.program_id(0) | ||
| compute_type = tl.float32 | ||
| grad_x_dtype = grad_x_ptr.dtype.element_ty | ||
| tl.int64 |
| g = tl.load(grad_out_ptr + row_idx * stride_grad_out_token + col_off, mask=mask).to(compute_type) | ||
|
|
||
| grad_down = g * gelu_up | ||
| grad_up = g * down * gelu_bwd_none(up, tl.full(up.shape, 1.0, dtype=compute_type)) |
| def test_bias_gelu_impl(num_tokens, hidden_size, dtype, has_bias, has_row_mask): | ||
| _reset_seed() | ||
| device = "cuda" | ||
|
|
||
| x = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype, requires_grad=True) | ||
| bias = torch.randn(hidden_size, device=device, dtype=dtype) if has_bias else None | ||
| row_mask = make_row_mask(num_tokens, device) if has_row_mask else None |
| x = torch.randn(num_tokens, hidden_size, device=device, dtype=dtype, requires_grad=True) | ||
| bias = torch.randn(hidden_size, device=device, dtype=dtype) if has_bias else None | ||
| row_mask = make_row_mask(num_tokens, device) if has_row_mask else None | ||
|
|


Description
Add a comprehensive set of fused activation functions (GELU, GEGLU, SwiGLU, Quick-GEGLU) with Triton kernel implementations, aligned with Megatron-LM's
fused_bias_gelu,fused_bias_geglu, andfused_bias_swigluinterfaces. All functions support optionalbias, per-tokenweights, androw_maskwith mask computation integrated directly into Triton kernels.Also fixes existing bugs in the
geglu_bwd_kernel(incorrect gradient computation) andwith_maskkernel variants (out-of-bounds memory access).Type of change
Changes
New Triton Kernels
bias_gelu_kernel.py: GELU with optional bias and row_mask (fwd/bwd, mask/no-mask variants)bias_geglu_kernel.py: GEGLU with optional bias and row_maskbias_swiglu_kernel.py: SwiGLU with optional bias and row_maskgelu_kernel.py: GELU with per-token weights (probs) and row_maskquick_geglu_kernel.py: Quick-GEGLU (sigmoid approximation) with optional bias, weights,linear_offset, and row_mask. UsesHAS_BIAS/HAS_WEIGHTSconstexpr flags to avoid kernel duplication.New Triton Utilities
gelu.py: Addedquick_geluandquick_gelu_bwdJIT primitives for sigmoid-approximated GELUNew Python Wrappers (
pytorch/kernels/activation/)gelu_impl.py:bias_gelu_fwd/bias_gelu_bwdquick_geglu_impl.py:quick_geglu_fwd/quick_geglu_bwdwith fullbias/weights/row_maskdispatchbias_geglu_impl.pyintogeglu_impl.pyandbias_swiglu_impl.pyintoswiglu_impl.pyPublic API (
pytorch/ops/activation.py)Aligned with Megatron-LM interface naming:
bias_gelu_impl(input, bias, row_mask)bias_geglu_impl(input, bias, row_mask, clamp_value)bias_swiglu_impl(input, bias, row_mask, clamp_value)weighted_bias_geglu_impl(input, bias, weights, row_mask, clamp_value)weighted_bias_swiglu_impl(input, bias, weights, row_mask, clamp_value)weighted_bias_quick_geglu_impl(input, bias, weights, row_mask, linear_offset, clamp_value)Bug Fixes
geglu_bwd_kernelgradient bug: Fixedgrad_upcomputation that had an extragrad_outfactor (grad_out²instead ofgrad_out). The bug was masked by the old test usingtorch.ones_likeas gradient.with_maskkernel OOB access: Addedmask=row_maskto scalar loads (probs,weights) in allwith_maskkernel variants acrossgeglu_kernel.py,swiglu_kernel.py,gelu_kernel.py, andquick_geglu_kernel.py. Without mask,BLOCK_SIZE > num_tokenscaused out-of-bounds GPU memory access.torch.empty→torch.zerosfor output and gradient buffers whenrow_maskis present, preventing NaN/garbage in masked-out rows.linear_offsettype error: Changed fromtorch.tensor(linear_offset, ...)tofloat(linear_offset)to avoid Triton treating it as a pointer instead of a scalar.Refactoring
*_with_probs→weighted_*to align with Megatron naming (probs→weights)GLUWithProbsinto separateWeightedSwiGLUFunction/WeightedGeGLUFunctionclasses_validate_weight_and_row_maskand_clamp_glu_inputas shared helpers*_implfunctions to follow a consistent pattern (shape handling, validation, dispatch)Tests
test_activation.pywith comprehensive tests for all new APIsChecklist: