Skip to content

feat: Optimize memory footprint of long-context training via fused kernel and chunking#4312

Open
terminator123 wants to merge 1 commit into
NVIDIA:mainfrom
021ai:chunk_fused_cross_entropy
Open

feat: Optimize memory footprint of long-context training via fused kernel and chunking#4312
terminator123 wants to merge 1 commit into
NVIDIA:mainfrom
021ai:chunk_fused_cross_entropy

Conversation

@terminator123
Copy link
Copy Markdown

What does this PR do ?

Introduces a fused CrossEntropy kernel and output chunking strategy to reduce the peak memory consumption of logits during long-context training.

Technical Details

This PR addresses the high VRAM usage bottleneck in large-scale training by targeting the logits tensor memory footprint.

  • Fused Kernel: Utilizes the Liger-Kernel's fused CrossEntropy implementation to reduce intermediate memory overhead.

  • Output Chunking: Implements an output chunking mechanism where the model's output is processed in blocks.

  • Memory-Specific Optimization: The peak memory is reduced by a factor proportional to the number of chunks (1/N ). The more chunks the output is divided into, the lower the peak memory .

@terminator123 terminator123 requested review from a team as code owners April 15, 2026 07:12
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@svcnvidia-nemo-ci svcnvidia-nemo-ci marked this pull request as draft April 15, 2026 07:12
@github-actions
Copy link
Copy Markdown
Contributor

This PR has been automatically converted to draft because all PRs must start as drafts.

When you are ready for review, click Ready for Review to begin the review process. This will:

  1. Add the oncall reviewer (optional reviewer)
  2. Add required review teams based on your changes

See the contribution guide for more details.

@terminator123 terminator123 marked this pull request as ready for review April 16, 2026 06:55
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team April 16, 2026 06:55
@chtruong814 chtruong814 added the needs-follow-up Issue needs follow-up label Apr 17, 2026
@Phlip79
Copy link
Copy Markdown
Member

Phlip79 commented Apr 17, 2026

We are in the process of developing this same feature: #2206. @Jianbing-D can you please take a look at this?

@chtruong814 chtruong814 removed the needs-follow-up Issue needs follow-up label Apr 17, 2026
@Jianbing-D
Copy link
Copy Markdown

Jianbing-D commented Apr 20, 2026

Hi @terminator123,

We have similar feature already merged to dev branch. #2256

And regarding your PR, here are some questions:

  1. Is there any measurement numbers regarding your feature? like latency of fwd pass and bwd pass, as well as storage. Like what we did here: [Dev] Feature: linear cross entropy fusion #2256
  2. Your kernels are written with OAI triton, but that library failed to achieve good performance on Blackwell GPUs. If you could provide any perf numbers, that would be great for us to determine whether your kernels are good enough.
  3. Your feature seems not support reduction=none. Please correct me if I understand it wrong. If this feature doesn't support reduction=none, how could the users handle token masking and padding, where invalid tokens shall have zero as grad, whilst valid tokens shall have valid grad values.

@chtruong814 chtruong814 added the waiting-on-customer Waiting on the original author to respond label Apr 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request waiting-on-customer Waiting on the original author to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants