Skip to content

Enable graph trainer passes for agnostic accelerators#2968

Draft
jemitche1 wants to merge 12 commits into
pytorch:mainfrom
jemitche1:jerome_m/fix/graph-xpu-compile-passes
Draft

Enable graph trainer passes for agnostic accelerators#2968
jemitche1 wants to merge 12 commits into
pytorch:mainfrom
jemitche1:jerome_m/fix/graph-xpu-compile-passes

Conversation

@jemitche1
Copy link
Copy Markdown

@jemitche1 jemitche1 commented Apr 14, 2026

This PR keeps the existing CUDA path and adds a non-CUDA fallbacks for graph trainer passes to avoid scheduler estimation/alignment issues.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 14, 2026
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 14, 2026

The following ciflow label(s) have been added but CI has not been triggered yet because the workflows are awaiting approval:

  • ciflow/8gpu

Once a maintainer approves the workflows (scroll to the bottom of the PR page), the corresponding CI jobs will be triggered automatically. Please ping one of the reviewers if you do not have access to approve and run workflows.

@jemitche1 jemitche1 changed the title graph trainer compile passes for xpu Enable graph trainer passes accelerator agnostic Apr 14, 2026
@jemitche1 jemitche1 changed the title Enable graph trainer passes accelerator agnostic Enable graph trainer passes for agnostic accelerators Apr 14, 2026
@jemitche1 jemitche1 marked this pull request as ready for review April 26, 2026 23:51
@jemitche1 jemitche1 marked this pull request as draft April 27, 2026 01:17
@jemitche1 jemitche1 marked this pull request as ready for review April 28, 2026 20:12
ngpu=8,
disabled=_JIT_AOT_DISABLED,
),
),_``
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo?

Comment on lines +244 to +277
import torch._inductor.fx_passes.node_runtime_estimation as node_runtime_estimation
import torch._inductor.fx_passes.overlap_scheduling as overlap_scheduling

original_estimate_roofline_runtime_ms = getattr(
overlap_scheduling, "estimate_roofline_runtime_ms", None
)
original_estimate_runtime_analytical = getattr(
overlap_scheduling, "estimate_runtime_analytical", None
)
original_log_compute_estimations = getattr(
node_runtime_estimation, "_log_compute_estimations", None
)
scheduler_cls = getattr(overlap_scheduling, "OverlapScheduler", None)
original_align = None
if scheduler_cls is not None:
original_align = getattr(
scheduler_cls,
"_align_compute_nodes_runtime_estimations_across_all_distributed_ranks",
None,
)

try:
if original_estimate_roofline_runtime_ms is not None:
overlap_scheduling.estimate_roofline_runtime_ms = lambda node: 1e-3
if original_estimate_runtime_analytical is not None:
overlap_scheduling.estimate_runtime_analytical = lambda node: 1e-3
if original_log_compute_estimations is not None:
node_runtime_estimation._log_compute_estimations = (
lambda compute_nodes, benchmarked_estimations, analytical_estimations: None
)
if scheduler_cls is not None and original_align is not None:
scheduler_cls._align_compute_nodes_runtime_estimations_across_all_distributed_ranks = (
lambda self: None
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems much more like a change that should happen upstream in OverlapScheduler, and patching it from here is brittle. So consider making this change there.

If not, please all least extract all these out into a helper, it's okay to assume the OverlapScheduler will contain those methods since refactors upstream will break this code anyways

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed; this belongs in overlap scheduler rather than patched in torchtitan. I removed the duplicated inline patching and created a PR to torch. Once it's merged, I will ping for further review.

@jemitche1 jemitche1 marked this pull request as draft May 7, 2026 23:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants