Refactor: extract compute_task_fanin / register_task_outputs from submit_task#725
Open
ChaoWao wants to merge 1 commit intohw-native-sys:mainfrom
Open
Refactor: extract compute_task_fanin / register_task_outputs from submit_task#725ChaoWao wants to merge 1 commit intohw-native-sys:mainfrom
ChaoWao wants to merge 1 commit intohw-native-sys:mainfrom
Conversation
…mit_task
Pull the per-tensor dependency-computation portion of submit_task out
into a header-only template in pto_dep_compute.h, shared between the
a5 and a2a3 tensormap_and_ringbuffer runtimes (they used to keep
identical inline copies).
Two functions:
compute_task_fanin(inputs, tensor_map, in_manual_scope, emit)
— Runs STEP 3 of the original submit_task: per-tensor creator
retention (Step A) + tensormap lookup for INPUT/INOUT (Step B).
For each producer it identifies, calls back into the user-supplied
``emit`` (a single template lambda — not std::function — so it
inlines all the way through and adds zero hot-path overhead vs
the previous inline code).
register_task_outputs(inputs, task_id, tensor_map, in_manual_scope)
— Runs STEP 4: tensormap.insert for INOUT and OUTPUT_EXISTING.
STEP 1 (the explicit_dep loop) is intentionally left at the runtime
call site. Its ``last_task_alive`` shortcut + unchecked slot lookup is
subtly different from STEP 3's slot-reuse check; unifying them would
either require two emit semantics or marginally change behavior in
transients. The replay path (future PR) will handle STEP 1 with a
trivial one-line emit loop of its own.
The refactor is bit-equivalent: same control flow, same conditionals,
same calls into PTO2TensorMap, same ``append_fanin_or_fail`` semantics
through the runtime emit lambda. Verified on a5sim (mixed_example,
spmd_starvation, spmd_basic) and a2a3sim (test_l3_dependency).
DepInputs view struct decouples Arg's user-facing accessors (which
carry launch_spec / scalars / has_error etc. that dep computation does
not need) from the dep-computation interface, so future replay code
can build a DepInputs from captured trace data without faking an Arg.
Adds two const accessors:
- TensorTagMixin::tag_data() — exposes the per-tensor tag array start
- Arg::explicit_deps_data() — exposes the explicit_deps storage start
There was a problem hiding this comment.
Code Review
This pull request introduces a new header file, pto_dep_compute.h, which provides template-based primitives for dependency computation. It refactors the submit_task method in PTO2OrchestratorState to utilize these new functions, improving code modularity. Additionally, helper methods explicit_deps_data and tag_data were added to the Arg and TensorTagMixin structures to support this refactoring. I have no feedback to provide as there were no review comments.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Pull the per-tensor dependency-computation logic of
submit_taskoutinto a small header-only template in
pto_dep_compute.h, shared betweenthe a5 and a2a3
tensormap_and_ringbufferruntimes (they had identicalinline copies). Bit-equivalent refactor — no behavior change.
This is groundwork for an upcoming
dep_genfeature (offline replay ofthe orchestrator's dependency-discovery logic to capture a complete
dep graph without the race window in the in-record
fanout[]snapshot).Both submit_task and the future replay path will share the same
compute_task_faninbody, so dep computation has a single source oftruth.
What moved
compute_task_faninsubmit_task(per-tensor creator retention + tensormap lookup for INPUT/INOUT)register_task_outputssubmit_task(tensormap.insert for INOUT/OUTPUT_EXISTING)DepInputsview struct decouplesArg's user-facing API from thedep-computation interface so the future replay code can hand-roll a
DepInputswithout faking anArg.What did not move (intentional)
STEP 1 (the explicit_dep loop) stays inline at the runtime call site.
Its
last_task_aliveshort-circuit + unchecked slot lookup is subtlydifferent from STEP 3's slot-reuse check; unifying them would either
need two emit semantics or change transient behavior. The future replay
path will handle STEP 1 in a one-line emit loop of its own — duplication
is small and structural.
Performance
Emitis a template parameter, notstd::function. Both runtime(lambda capturing
fanin_builder+sm_header) and future replay(lambda capturing edge vector) instantiate at the call site and inline
all the way through. Do not replace with
std::function— that wouldadd ~5 ns per call in the orch hot path. The header has a comment
warning future maintainers about this.
Cosmetic adds
TensorTagMixin::tag_data()insrc/common/task_interface/task_args.hArg::explicit_deps_data()insrc/a{5,2a3}/runtime/tensormap_and_ringbuffer/runtime/pto_types.hBoth are simple
const T*accessors over already-public POD storage,needed to build a
DepInputsview without per-element copies.Testing
pre-commit run --files <changed>(clang-tidy, clang-format,cpplint, headers all pass on a5 and a2a3)
mixed_example(20 tasks, mixed AIC+AIV, RTOL=1e-3 goldencheck) PASSED
spmd_starvation(252 tasks scheduled stress) PASSEDspmd_basic(basic dispatch) PASSEDtest_l3_dependency(5 tasks dep test) PASSED