Skip to content

Refactor: extract compute_task_fanin / register_task_outputs from submit_task#725

Open
ChaoWao wants to merge 1 commit intohw-native-sys:mainfrom
ChaoWao:refactor/extract-compute-task-deps
Open

Refactor: extract compute_task_fanin / register_task_outputs from submit_task#725
ChaoWao wants to merge 1 commit intohw-native-sys:mainfrom
ChaoWao:refactor/extract-compute-task-deps

Conversation

@ChaoWao
Copy link
Copy Markdown
Collaborator

@ChaoWao ChaoWao commented May 9, 2026

Summary

Pull the per-tensor dependency-computation logic of submit_task out
into a small header-only template in pto_dep_compute.h, shared between
the a5 and a2a3 tensormap_and_ringbuffer runtimes (they had identical
inline copies). Bit-equivalent refactor — no behavior change.

This is groundwork for an upcoming dep_gen feature (offline replay of
the orchestrator's dependency-discovery logic to capture a complete
dep graph without the race window in the in-record fanout[] snapshot).
Both submit_task and the future replay path will share the same
compute_task_fanin body, so dep computation has a single source of
truth.

What moved

New entry point Original location
compute_task_fanin STEP 3 in submit_task (per-tensor creator retention + tensormap lookup for INPUT/INOUT)
register_task_outputs STEP 4 in submit_task (tensormap.insert for INOUT/OUTPUT_EXISTING)

DepInputs view struct decouples Arg's user-facing API from the
dep-computation interface so the future replay code can hand-roll a
DepInputs without faking an Arg.

What did not move (intentional)

STEP 1 (the explicit_dep loop) stays inline at the runtime call site.
Its last_task_alive short-circuit + unchecked slot lookup is subtly
different from STEP 3's slot-reuse check; unifying them would either
need two emit semantics or change transient behavior. The future replay
path will handle STEP 1 in a one-line emit loop of its own — duplication
is small and structural.

Performance

Emit is a template parameter, not std::function. Both runtime
(lambda capturing fanin_builder + sm_header) and future replay
(lambda capturing edge vector) instantiate at the call site and inline
all the way through. Do not replace with std::function — that would
add ~5 ns per call in the orch hot path. The header has a comment
warning future maintainers about this.

Cosmetic adds

  • TensorTagMixin::tag_data() in src/common/task_interface/task_args.h
  • Arg::explicit_deps_data() in
    src/a{5,2a3}/runtime/tensormap_and_ringbuffer/runtime/pto_types.h

Both are simple const T* accessors over already-public POD storage,
needed to build a DepInputs view without per-element copies.

Testing

  • pre-commit run --files <changed> (clang-tidy, clang-format,
    cpplint, headers all pass on a5 and a2a3)
  • a5sim: mixed_example (20 tasks, mixed AIC+AIV, RTOL=1e-3 golden
    check) PASSED
  • a5sim: spmd_starvation (252 tasks scheduled stress) PASSED
  • a5sim: spmd_basic (basic dispatch) PASSED
  • a2a3sim: test_l3_dependency (5 tasks dep test) PASSED
  • Linux CI green

…mit_task

Pull the per-tensor dependency-computation portion of submit_task out
into a header-only template in pto_dep_compute.h, shared between the
a5 and a2a3 tensormap_and_ringbuffer runtimes (they used to keep
identical inline copies).

Two functions:

  compute_task_fanin(inputs, tensor_map, in_manual_scope, emit)
    — Runs STEP 3 of the original submit_task: per-tensor creator
      retention (Step A) + tensormap lookup for INPUT/INOUT (Step B).
      For each producer it identifies, calls back into the user-supplied
      ``emit`` (a single template lambda — not std::function — so it
      inlines all the way through and adds zero hot-path overhead vs
      the previous inline code).

  register_task_outputs(inputs, task_id, tensor_map, in_manual_scope)
    — Runs STEP 4: tensormap.insert for INOUT and OUTPUT_EXISTING.

STEP 1 (the explicit_dep loop) is intentionally left at the runtime
call site. Its ``last_task_alive`` shortcut + unchecked slot lookup is
subtly different from STEP 3's slot-reuse check; unifying them would
either require two emit semantics or marginally change behavior in
transients. The replay path (future PR) will handle STEP 1 with a
trivial one-line emit loop of its own.

The refactor is bit-equivalent: same control flow, same conditionals,
same calls into PTO2TensorMap, same ``append_fanin_or_fail`` semantics
through the runtime emit lambda. Verified on a5sim (mixed_example,
spmd_starvation, spmd_basic) and a2a3sim (test_l3_dependency).

DepInputs view struct decouples Arg's user-facing accessors (which
carry launch_spec / scalars / has_error etc. that dep computation does
not need) from the dep-computation interface, so future replay code
can build a DepInputs from captured trace data without faking an Arg.

Adds two const accessors:
  - TensorTagMixin::tag_data() — exposes the per-tensor tag array start
  - Arg::explicit_deps_data() — exposes the explicit_deps storage start
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new header file, pto_dep_compute.h, which provides template-based primitives for dependency computation. It refactors the submit_task method in PTO2OrchestratorState to utilize these new functions, improving code modularity. Additionally, helper methods explicit_deps_data and tag_data were added to the Arg and TensorTagMixin structures to support this refactoring. I have no feedback to provide as there were no review comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant