Skip to content

[NPU] optimize cross_entropy for ASCEND NPU#1206

Open
sunyi0505 wants to merge 1 commit intolinkedin:mainfrom
sunyi0505:main
Open

[NPU] optimize cross_entropy for ASCEND NPU#1206
sunyi0505 wants to merge 1 commit intolinkedin:mainfrom
sunyi0505:main

Conversation

@sunyi0505
Copy link
Copy Markdown
Contributor

@sunyi0505 sunyi0505 commented Apr 28, 2026

Summary

Refactors Ascend cross_entropy from a single fused forward+backward kernel to dedicated forward and backward Triton kernels, with device-side reduction stats and an optimized no-weight backward path.

Motivation

  • Avoid overwriting logits with gradients in forward and remove debug_barrier coupling between backward-style writes and loss computation.
  • Reduce host sync from .item() before kernel launch by passing precomputed inverse scales via a small ce_stats buffer on device.
  • Improve locality by assigning contiguous row ranges per program instead of interleaved grid-stride scheduling.

Key changes

  • Split kernels: liger_cross_entropy_forward_kernel + liger_cross_entropy_backward_kernel / liger_cross_entropy_backward_kernel_no_weight.
  • ce_stats: [inv_n_scale, inv_sum_weight_scale, weight_sum] built without .item() for launch-friendly execution on NPU.
  • Optional LSE skip: For fp32 logits and the simple CE path (no weight, smoothing, softcap, z-loss), skip a dedicated LSE tensor and derive LSE in the no-weight backward kernel from per-row loss + target logit where applicable.
  • Tuning: Tiered block sizes / get_no_weight_fast_path_block_size; eviction_policy="evict_first" on logits loads; hoist ls_eps to host.
  • Autograd: Save ce_stats, conditional loss_1d vs lse, derive_lse_from_loss; apply grad_output inside backward kernels.

Testing

  • Existing Ascend / transformers cross-entropy tests should cover numerical parity for enabled dtypes and feature flags.
  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@sunyi0505
Copy link
Copy Markdown
Contributor Author

sunyi0505 commented Apr 28, 2026

  1. Graph Split: Forward / Backward Separation

    • Old version: A single liger_cross_entropy_kernel uses HAS_GRADIENTS to both write the loss and write gradients back in-place into X_ptr inside the same kernel, relying on tl.debug_barrier() to ensure read/write order.
    • New version:
      • liger_cross_entropy_forward_kernel only performs forward computation (loss, optional z-loss, accuracy, predicted tokens, optional LSE).
      • liger_cross_entropy_backward_kernel and liger_cross_entropy_backward_kernel_no_weight are dedicated to computing gradients, writing them into a separate dX without overwriting the logits.
  2. Grid & Memory Access: Contiguous Row Chunks vs Interleaved Stride

    • Old version: Uses pid interleaved traversal of rows with stride num_progs (grid-stride loop).
    • New version: Uses contiguous row intervals by row_chunk = ceil(n_rows / num_progs). The comments indicate this improves MTE memory locality on Ascend and reduces repeated address calculations like row_idx * stride. Inner loops move along rows using pointer arithmetic (+1).
  3. Statistics ce_stats: Avoid .item() Before Launch

    • Old version: Passes n_non_ignore, sum_non_ignore_weight, weight_sum via .item() to the kernel, which easily introduces host synchronization on NPU.
    • New version: _make_ce_stats_buffer constructs [inv_n_scale, inv_sum_weight_scale, weight_sum] on the device. The kernel loads ce_stats_ptr + ... via tl.load(), without synchronously fetching scalars from Python.
  4. No-Weight Common Path: Dedicated Backward Kernel + Optional LSE Omission

    • The new version uses backward_kernel_no_weight when there is no weight / softcap / label smoothing / z-loss, and uses get_no_weight_fast_path_block_size (larger block size for small vocab) to optimize bandwidth.
    • _skip_lse_buffer_for_backward: Under the simplified conditions above with fp32 logits, the forward pass can skip allocating and writing a separate fp32 LSE buffer. The backward pass restores LSE from loss_row and x_y via DERIVE_LSE_FROM_LOSS, reducing memory usage and forward stores (bf16/fp16 logits still retain LSE to avoid numerical issues and test mismatches).
  5. Forward No Longer Uses Logits as Gradient Buffer

    • Old version: Ignored classes would zero out the entire row in X to achieve gradient semantics.
    • New version: The forward pass only writes loss, etc. Gradients are only written to grad_input in the backward kernel, making the semantics clearer and facilitating the above LSE omission and fast path.
  6. Constants & Minor Optimizations

    • ls_eps: label_smoothing / V is computed in Python and passed in; the kernel no longer divides by n_cols each row.
    • eviction_policy="evict_first": Added eviction policy for large-block logits loads (targeting Ascend).
    • get_optimal_block_size: When no gradients are needed in the forward pass, uses segmented defaults based on vocab size (e.g., ≤32k → 1024, etc.), then falls back to UB tiling. The default fallback is increased from 2048 to 4096 (when no tiling is used).
  7. Other Behavior / API

    • Token accuracy uses tensor operations like (argmax_idx == y).to(tl.float32). Argmax across block merging is also adjusted.
    • Removes dependency on element_mul_kernel in backward: scalar/vector grad_output is multiplied inside the backward kernel.
    • Autograd saves ce_stats, derive_lse_from_loss, and saves loss_1d instead of LSE when needed.

foward:
image
backward:
image
full:
image
no_grad-forward:
image
ut:
image

@sunyi0505
Copy link
Copy Markdown
Contributor Author

@Tcc0403 This PR is ready for review.

Comment thread src/liger_kernel/ops/backends/_ascend/ops/cross_entropy.py Outdated
Comment thread src/liger_kernel/ops/backends/_ascend/ops/cross_entropy.py Outdated
Comment thread src/liger_kernel/ops/backends/_ascend/ops/cross_entropy.py Outdated
Comment thread src/liger_kernel/ops/backends/_ascend/ops/cross_entropy.py Outdated
Comment thread src/liger_kernel/ops/backends/_ascend/ops/cross_entropy.py Outdated
X_offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(dX_ptr + dX_ptr_offset + X_offsets, 0.0, mask=X_offsets < n_cols)
else:
if DERIVE_LSE_FROM_LOSS:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not HAS_LSE:  # derive from loss if not present

ignore_index,
reduction: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HAS_GRAD_OUTPUT_VECTOR: tl.constexpr,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it equivalent to reduction=="none"?

Copy link
Copy Markdown
Contributor Author

@sunyi0505 sunyi0505 Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HAS_GRAD_OUTPUT_VECTOR: When False, grad_output is a 0‑dimensional scalar (common with reduction="mean" or "sum").When True, grad_output is a 1‑dimensional vector (common with reduction="none", one gradient per row).
Inside the kernel, this determines which branch the code takes:
True: Read per row: grad_output[row_idx]
False: Read the same scalar: grad_output[0] (broadcast to all rows)

Copy link
Copy Markdown
Collaborator

@Tcc0403 Tcc0403 May 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. If I understand correctly, we can reuse reduction for branching conditions. We can comment grad_output is a vector if reduction!="none" next to it.

Copy link
Copy Markdown
Contributor Author

@sunyi0505 sunyi0505 May 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The branching here is not for choosing whether grad_output is a per-row vector or a scalar; instead, it selects which backward kernel to use.
has_grad_output_vector only indicates whether the incoming upstream grad_output is one scalar per token, or a single scalar for the entire loss. This is orthogonal to whether class weights, label smoothing, softcap, or z-loss (lse_square_scale) are enabled.
Only plain cross-entropy — with no class weights, no softcap, no label smoothing, and no z-loss term — is eligible for this optimized lightweight kernel. Once any of these features are enabled, the backward formulation changes entirely, and we must fall back to the general liger_cross_entropy_backward_kernel.
If we instead branch externally on reduction:
When reduction="mean" and class weights are present: has_grad_output_vector is False, causing an incorrect branch dispatch. Entering the no-weight kernel in this case produces wrong gradients.
When reduction="none" and weight / label smoothing are enabled: has_grad_output_vector is True. Routing solely based on this flag into the no-weight kernel also yields incorrect gradients.
Therefore, reduction cannot replace use_no_weight_fast_path: the former does not imply whether the backward pass can be implemented with plain cross-entropy logic.

Comment thread src/liger_kernel/ops/backends/_ascend/ops/cross_entropy.py
Comment thread src/liger_kernel/ops/backends/_ascend/ops/cross_entropy.py Outdated
@zheliuyu
Copy link
Copy Markdown
Contributor

The execution method for benchmark_cross_entropy has changed recently. Could you update the code and re-run the performance test?

@sunyi0505 sunyi0505 force-pushed the main branch 4 times, most recently from b306d1d to 10d66bd Compare April 29, 2026 08:39
@sunyi0505 sunyi0505 requested a review from Tcc0403 April 29, 2026 08:55
@sunyi0505 sunyi0505 force-pushed the main branch 3 times, most recently from ff9e663 to 8ef1421 Compare May 6, 2026 01:57
Co-authored-by: zheliuyu <zheliuyu@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants