You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
I have been pushing the implementation forward using several hacks to uncover edge cases. Below are the key issues identified and their current status:
Details: I’ve implemented a temporary fix in PR #5890, but this is not a long-term solution. Sharding propagation needs a fundamental rework to establish better defaults and cleaner logic.
Details: Updated getCommunicationInfo to support multi-dimensional sharding. The current logic reuses haveDifferentShardings to identify inconsistencies between input and output TensorView objects.
Technical Debt: Per #3987, haveDifferentShardings is currently bottlenecked by the expensive ExpressionSimplifier. We need to transition this to be IdModel-based in a future iteration.
Details: A generated transpose kernel is hitting misaligned memory access errors. This occurs during the transposition between the local Einsum and the downstream ReduceScatter. For context, this transposition was introduced by ReorderShardedAxisPass to ensure the scattered axis of the ReduceScatter is allocated outermost.
4. Performance Bottleneck: AllGather memory
Area: Communication Optimization
Details: The current naive AllGather preceding the Einsum is functional but consumes too much memory for AlphaFold3 workloads due to long sequence lengths.
Proposed Fix: We need to implement stream-parallelization to enable:
Ring-based AllGather (with Swizzle), or
Broadcast-based communication (without Swizzle). AFAICT, fast broadcast requires multicasting and therefore symmetric memory.
The change from continue to NVF_THROW in the sharding validation loop (line 393) could be overly strict. This throws an error when a parallel type is not sharded, but the original continue behavior suggests this might be a valid case. Need to verify this doesn't break legitimate use cases where some parallel types aren't sharded.
if (!haveDifferentShardings(producer, consumer, {pt})) {
continue;
}
IterDomain* p_loop_did = getOrDefault(producer_pt_to_did, pt);
IterDomain* c_loop_did = getOrDefault(consumer_pt_to_did, pt);
if (p_loop_did == nullptr && c_loop_did == nullptr) {
// Not sharded on this parallel typeNVF_THROW("Not sharded on this parallel type: ", pt);
}
Added input/output size validation in getCommunicationInfo (lines 349-354) assumes exactly 1 input and 1 output. This should be verified against all supported operation types (LoadStoreOp, ReductionOp, SqueezeOp) to ensure this assumption holds for all cases.
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
I have been pushing the implementation forward using several hacks to uncover edge cases. Below are the key issues identified and their current status:
1. Sharding Propagation Rework
2. Multi-Dimensional Sharding &
getCommunicationInfogetCommunicationInfoto support multi-dimensional sharding. The current logic reuseshaveDifferentShardingsto identify inconsistencies between input and outputTensorViewobjects.haveDifferentShardingsis currently bottlenecked by the expensiveExpressionSimplifier. We need to transition this to be IdModel-based in a future iteration.3. Misaligned Memory Access in Transpose Kernels
ReorderShardedAxisPassto ensure the scattered axis of theReduceScatteris allocated outermost.4. Performance Bottleneck: AllGather memory
AllGatherpreceding the Einsum is functional but consumes too much memory for AlphaFold3 workloads due to long sequence lengths.cc @DejunL