Enable graph trainer passes for agnostic accelerators#2968
Conversation
|
The following ciflow label(s) have been added but CI has not been triggered yet because the workflows are awaiting approval:
Once a maintainer approves the workflows (scroll to the bottom of the PR page), the corresponding CI jobs will be triggered automatically. Please ping one of the reviewers if you do not have access to approve and run workflows. |
| ngpu=8, | ||
| disabled=_JIT_AOT_DISABLED, | ||
| ), | ||
| ),_`` |
| import torch._inductor.fx_passes.node_runtime_estimation as node_runtime_estimation | ||
| import torch._inductor.fx_passes.overlap_scheduling as overlap_scheduling | ||
|
|
||
| original_estimate_roofline_runtime_ms = getattr( | ||
| overlap_scheduling, "estimate_roofline_runtime_ms", None | ||
| ) | ||
| original_estimate_runtime_analytical = getattr( | ||
| overlap_scheduling, "estimate_runtime_analytical", None | ||
| ) | ||
| original_log_compute_estimations = getattr( | ||
| node_runtime_estimation, "_log_compute_estimations", None | ||
| ) | ||
| scheduler_cls = getattr(overlap_scheduling, "OverlapScheduler", None) | ||
| original_align = None | ||
| if scheduler_cls is not None: | ||
| original_align = getattr( | ||
| scheduler_cls, | ||
| "_align_compute_nodes_runtime_estimations_across_all_distributed_ranks", | ||
| None, | ||
| ) | ||
|
|
||
| try: | ||
| if original_estimate_roofline_runtime_ms is not None: | ||
| overlap_scheduling.estimate_roofline_runtime_ms = lambda node: 1e-3 | ||
| if original_estimate_runtime_analytical is not None: | ||
| overlap_scheduling.estimate_runtime_analytical = lambda node: 1e-3 | ||
| if original_log_compute_estimations is not None: | ||
| node_runtime_estimation._log_compute_estimations = ( | ||
| lambda compute_nodes, benchmarked_estimations, analytical_estimations: None | ||
| ) | ||
| if scheduler_cls is not None and original_align is not None: | ||
| scheduler_cls._align_compute_nodes_runtime_estimations_across_all_distributed_ranks = ( | ||
| lambda self: None | ||
| ) |
There was a problem hiding this comment.
This seems much more like a change that should happen upstream in OverlapScheduler, and patching it from here is brittle. So consider making this change there.
If not, please all least extract all these out into a helper, it's okay to assume the OverlapScheduler will contain those methods since refactors upstream will break this code anyways
There was a problem hiding this comment.
Agreed; this belongs in overlap scheduler rather than patched in torchtitan. I removed the duplicated inline patching and created a PR to torch. Once it's merged, I will ping for further review.
This PR keeps the existing CUDA path and adds a non-CUDA fallbacks for graph trainer passes to avoid scheduler estimation/alignment issues.