Skip to content

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Jan 29, 2026

I have been pushing the implementation forward using several hacks to uncover edge cases. Below are the key issues identified and their current status:

1. Sharding Propagation Rework

  • Issue: NVIDIA/Fuser#5901
  • Status: Workaround applied.
  • Details: I’ve implemented a temporary fix in PR #5890, but this is not a long-term solution. Sharding propagation needs a fundamental rework to establish better defaults and cleaner logic.

2. Multi-Dimensional Sharding & getCommunicationInfo

  • Status: Implementation in progress.
  • Details: Updated getCommunicationInfo to support multi-dimensional sharding. The current logic reuses haveDifferentShardings to identify inconsistencies between input and output TensorView objects.
  • Technical Debt: Per #3987, haveDifferentShardings is currently bottlenecked by the expensive ExpressionSimplifier. We need to transition this to be IdModel-based in a future iteration.

3. Misaligned Memory Access in Transpose Kernels

  • Repro: Commit Link
  • Details: A generated transpose kernel is hitting misaligned memory access errors. This occurs during the transposition between the local Einsum and the downstream ReduceScatter. For context, this transposition was introduced by ReorderShardedAxisPass to ensure the scattered axis of the ReduceScatter is allocated outermost.

4. Performance Bottleneck: AllGather memory

  • Area: Communication Optimization
  • Details: The current naive AllGather preceding the Einsum is functional but consumes too much memory for AlphaFold3 workloads due to long sequence lengths.
  • Proposed Fix: We need to implement stream-parallelization to enable:
    • Ring-based AllGather (with Swizzle), or
    • Broadcast-based communication (without Swizzle). AFAICT, fast broadcast requires multicasting and therefore symmetric memory.

cc @DejunL

@github-actions
Copy link

github-actions bot commented Jan 29, 2026

Review updated until commit 54bbceb

Description

  • Add 3D sharding support for AlphaFold3 triangle updates with multi-GPU

  • Implement new test cases for triangle updates (incoming/outgoing directions)

  • Add transpose test with 3D device mesh sharding

  • Improve error handling and validation in communication lowering

  • Add debug logging for fusion transforms after sharding propagation

Changes walkthrough

Relevant files
Tests
test_alphafold3.py
AlphaFold3 triangle updates multi-GPU test                             

tests/python/multidevice/test_alphafold3.py

  • Add comprehensive test for AlphaFold3 triangle updates
  • Support both incoming and outgoing direction triangle updates
  • Implement 3D device mesh sharding (dp_size x cp_size x cp_size)
  • Test layer normalization, gating, and einsum operations
  • Validate multi-GPU execution with proper tensor sharding
  • +225/-0 
    test_multidevice.py
    Multi-device transpose and pointwise tests                             

    tests/python/multidevice/test_multidevice.py

  • Add transpose test with 3D device mesh sharding
  • Update pointwise test to define mesh after FusionDefinition
  • Implement complex tensor reshaping and allocation domain settings
  • +58/-1   
    Bug fix
    lower_to_communication.cpp
    Communication lowering error handling improvements             

    csrc/host_ir/lower_to_communication.cpp

  • Replace NVF_ERROR with NVF_ERROR_EQ for rank validation
  • Add input/output size validation in getCommunicationInfo
  • Improve error messages and add sharding validation
  • Update includes for better modularity
  • +24/-18 
    Enhancement
    propagate_shardings.cpp
    Debug logging for sharding propagation                                     

    csrc/preseg_passes/propagate_shardings.cpp

  • Add debug logging for fusion transforms after pass execution
  • Include pre-segmenter logging when enabled
  • +7/-0     
    fusion_segmenter.cpp
    Segmented fusion print method update                                         

    csrc/fusion_segmenter.cpp

  • Change print() to use completeFusion()->print() instead of printMath()
  • Update debug output formatting with proper line breaks
  • +4/-3     
    Miscellaneous
    base.h
    Header cleanup                                                                                     

    csrc/base.h

    • Remove unused #include header
    +0/-1     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Error Handling Robustness

    The change from continue to NVF_THROW in the sharding validation loop (line 393) could be overly strict. This throws an error when a parallel type is not sharded, but the original continue behavior suggests this might be a valid case. Need to verify this doesn't break legitimate use cases where some parallel types aren't sharded.

    if (!haveDifferentShardings(producer, consumer, {pt})) {
      continue;
    }
    
    IterDomain* p_loop_did = getOrDefault(producer_pt_to_did, pt);
    IterDomain* c_loop_did = getOrDefault(consumer_pt_to_did, pt);
    
    if (p_loop_did == nullptr && c_loop_did == nullptr) {
      // Not sharded on this parallel type
      NVF_THROW("Not sharded on this parallel type: ", pt);
    }
    Communication Info Validation

    Added input/output size validation in getCommunicationInfo (lines 349-354) assumes exactly 1 input and 1 output. This should be verified against all supported operation types (LoadStoreOp, ReductionOp, SqueezeOp) to ensure this assumption holds for all cases.

    NVF_ERROR_EQ(
        e->inputs().size(), 1, "Expected 1 input, but got ", e->toString());
    auto* producer = e->inputs().at(0)->as<TensorView>();
    NVF_ERROR_EQ(
        e->outputs().size(), 1, "Expected 1 output, but got ", e->toString());
    auto* consumer = e->outputs().at(0)->as<TensorView>();

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant