-
Notifications
You must be signed in to change notification settings - Fork 77
add extra domain restriction to transpose scheduler #5884
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
!test --diff |
|
Review updated until commit 2be55a2 Description
|
| Relevant files | |||
|---|---|---|---|
| Bug fix |
| ||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Logic validation
|
Test failures (partial, pipeline still running)
-
(Medium, 1)
Scalar numerical mismatches in Thunder nanoGPT autograd tests (test_networks.py)Test Name GB200 Source thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32 ❌ -
(Medium, 1)
Shape mismatch in Thunder inplace alias update (CUDA)Test Name GB200 Source thunder.tests.test_update_aliases.test_higher_order_inplace_alias_update_nvfuser_cuda_thunder.dtypes.float32 ❌
|
!test |
|
!test |
|
!test |
Greptile OverviewGreptile SummaryThis PR adds an additional domain restriction to the transpose scheduler to improve performance for broadcast-only fusions. The key change prevents the transpose scheduler from being selected when two reference tensors have all their allocation domains mapped to each other, which indicates the grouping is due to broadcast rather than actual transposition. Key Changes:
Performance Impact: Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Scheduler as Scheduler Selection
participant TDM as TransposeDomainMap
participant CA as ComputeAtMap
Scheduler->>TDM: hasAtLeastTwoValidGroups(fusion)
TDM->>TDM: groupInputsOutputsByInnerDim()
alt Less than 2 groups
TDM-->>Scheduler: false (not transpose)
else 2+ groups found
TDM->>TDM: findReferenceFor(group1)
TDM->>TDM: findReferenceFor(group2)
alt No valid references
TDM-->>Scheduler: false (not transpose)
else Valid references found
TDM->>TDM: getMappedAllocDimIn(ref1, innermost2)
alt No mapped dimension
TDM-->>Scheduler: false (not transpose)
else Mapped dimension exists
TDM->>TDM: getMaybeAllocationDomain() for both refs
TDM->>CA: Check if all domains mapped
CA-->>TDM: all_mapped result
alt all_mapped == true
TDM->>TDM: Check for broadcast dimensions
TDM->>TDM: Validate assumption (NVF_ERROR)
TDM-->>Scheduler: false (broadcast case, use pointwise)
else all_mapped == false
TDM-->>Scheduler: true (valid transpose)
end
end
end
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments
csrc/scheduler/tools/domain_map.cpp
Outdated
| return false; | ||
| } | ||
|
|
||
| // For grouping caused by permutation, the corresponding loop domains should |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allocation domains instead of loop domains?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, revised.
csrc/scheduler/tools/domain_map.cpp
Outdated
| } | ||
|
|
||
| // For grouping caused by permutation, the corresponding loop domains should | ||
| // not be all mapped to each other. If they are, it means the two groups are |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add an example pattern here as a comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added.
// For example, in TransposeTest.NoTransposeMaverick17B, two inputs
// are tv0[i0, i1] and tv1[i2, b3] where i0/i2 and i1/b3 are mapped to each
// other. However, tv0 and tv1 are in two different groups because of the
// broadcast. In this case, we should use the pointwise scheduler instead of
// the transpose scheduler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments
naoyam
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| // are tv0[i0, i1] and tv1[i2, b3] where i0/i2 and i1/b3 are mapped to each | ||
| // other. However, tv0 and tv1 are in two different groups because of the | ||
| // broadcast. In this case, we should use the pointwise scheduler instead of | ||
| // the transpose scheduler. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm ok with this but feel like the original grouping should be changed instead.
IIUC, the grouping identifies whether there exist two differing groups. The added logic here basically overrides the grouping in the case with broadcast. I think it would be clearer if the initial grouping itself should consider the case and does not produce two groups.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point. This check is currently inside hasAtLeastTwoValidGroups(). The function filters out invalid groups. We could consider inlining this logic into groupInputsOutputsByInnerDim(), but keeping the functions separate is also reasonable.
To solve issue #5883 for better performance