[NPU] optimize cross_entropy for ASCEND NPU#1206
[NPU] optimize cross_entropy for ASCEND NPU#1206sunyi0505 wants to merge 1 commit intolinkedin:mainfrom
Conversation
|
|
@Tcc0403 This PR is ready for review. |
| X_offsets = i + tl.arange(0, BLOCK_SIZE) | ||
| tl.store(dX_ptr + dX_ptr_offset + X_offsets, 0.0, mask=X_offsets < n_cols) | ||
| else: | ||
| if DERIVE_LSE_FROM_LOSS: |
There was a problem hiding this comment.
if not HAS_LSE: # derive from loss if not present
| ignore_index, | ||
| reduction: tl.constexpr, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| HAS_GRAD_OUTPUT_VECTOR: tl.constexpr, |
There was a problem hiding this comment.
Isn't it equivalent to reduction=="none"?
There was a problem hiding this comment.
HAS_GRAD_OUTPUT_VECTOR: When False, grad_output is a 0‑dimensional scalar (common with reduction="mean" or "sum").When True, grad_output is a 1‑dimensional vector (common with reduction="none", one gradient per row).
Inside the kernel, this determines which branch the code takes:
True: Read per row: grad_output[row_idx]
False: Read the same scalar: grad_output[0] (broadcast to all rows)
There was a problem hiding this comment.
Thanks for the explanation. If I understand correctly, we can reuse reduction for branching conditions. We can comment grad_output is a vector if reduction!="none" next to it.
There was a problem hiding this comment.
The branching here is not for choosing whether grad_output is a per-row vector or a scalar; instead, it selects which backward kernel to use.
has_grad_output_vector only indicates whether the incoming upstream grad_output is one scalar per token, or a single scalar for the entire loss. This is orthogonal to whether class weights, label smoothing, softcap, or z-loss (lse_square_scale) are enabled.
Only plain cross-entropy — with no class weights, no softcap, no label smoothing, and no z-loss term — is eligible for this optimized lightweight kernel. Once any of these features are enabled, the backward formulation changes entirely, and we must fall back to the general liger_cross_entropy_backward_kernel.
If we instead branch externally on reduction:
When reduction="mean" and class weights are present: has_grad_output_vector is False, causing an incorrect branch dispatch. Entering the no-weight kernel in this case produces wrong gradients.
When reduction="none" and weight / label smoothing are enabled: has_grad_output_vector is True. Routing solely based on this flag into the no-weight kernel also yields incorrect gradients.
Therefore, reduction cannot replace use_no_weight_fast_path: the former does not imply whether the backward pass can be implemented with plain cross-entropy logic.
|
The execution method for |
b306d1d to
10d66bd
Compare
ff9e663 to
8ef1421
Compare
Co-authored-by: zheliuyu <zheliuyu@users.noreply.github.com>





Summary
Refactors Ascend
cross_entropyfrom a single fused forward+backward kernel to dedicated forward and backward Triton kernels, with device-side reduction stats and an optimized no-weight backward path.Motivation
debug_barriercoupling between backward-style writes and loss computation..item()before kernel launch by passing precomputed inverse scales via a smallce_statsbuffer on device.Key changes
liger_cross_entropy_forward_kernel+liger_cross_entropy_backward_kernel/liger_cross_entropy_backward_kernel_no_weight.ce_stats:[inv_n_scale, inv_sum_weight_scale, weight_sum]built without.item()for launch-friendly execution on NPU.get_no_weight_fast_path_block_size;eviction_policy="evict_first"on logits loads; hoistls_epsto host.ce_stats, conditionalloss_1dvslse,derive_lse_from_loss; applygrad_outputinside backward kernels.Testing
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence