Skip to content

frontend: aic/aiv_initialize_pipe rejects local_slot_num on gm_slot_tensor form, forcing LocalSlotNum=8 #649

@chenshengxin2026

Description

@chenshengxin2026

Summary

pto.aic_initialize_pipe / pto.aiv_initialize_pipe reject local_slot_num when used with the address-based gm_slot_tensor operand:

'pto.aic_initialize_pipe' op globaltensor pipe init does not use 'local_slot_num'
'pto.aiv_initialize_pipe' op globaltensor pipe init does not use 'local_slot_num'

This makes it impossible for kernels to author the LocalSlotNum=2 template instantiation that the runtime headers default to and that the manual flash-attention kernel uses. Without an IR-level override, ptoas's address-based pipe lowering emits TPipe<..., 8, 8, false> (LocalSlotNum=SlotNum) because buildTPipeTokenFromInitOp in lib/PTO/Transforms/PTOToEmitC.cpp falls back to initOp.getSlotNum() when no local_slot_num attribute is present.

Symptom

Manual code in kernels/manual/common/flash_atten/fa_performance_kernel.cpp declares its QK/PV/P pipes as:

TPipe<FlagID, DirType, SlotSize, 8, 2, false> qkPipe;  // SlotNum=8, LocalSlotNum=2

ptoas-generated code from a kernel using gm_slot_tensor instead emits:

TPipe<..., 1024, 8, 8, false>(...);   // LocalSlotNum=8
TFREE<TPipe<..., 1024, 8, 8, false>, GlobalTensor<...>, ...>(pipe, tensor);

Because LocalSlotNum is part of the C++ template type, the two pipes are different types even though they refer to the same hardware FIFO. The 8/8 form inflates the FFTS/event multiplex (8 local × 8 global per pipe instead of 2 × 8) and overruns --enable-insert-sync's 8-event pool at long sequences. In flash-attention, this manifests as kernel timeouts at S1>=4096 on a3, even when address-based TFREE is otherwise correct.

Root cause (current state)

  1. verifyFrontendInitCommon in lib/PTO/IR/PTO.cpp (around line 10680) explicitly rejects local_slot_num on the gm_slot_tensor branch:
    if (op.getLocalSlotNumAttr())
      return op.emitOpError(
          "globaltensor pipe init does not use 'local_slot_num'");
  2. InitializeL2G2LPipeOp::verify (around line 11440) further enforces that local_slot_num requires local_addr, which is empty for the gm_slot_tensor form.
  3. createFrontendPipe in lib/PTO/Transforms/PTOLowerFrontendPipeOpsPass.cpp (around line 128) hard-codes IntegerAttr{} for local_slot_num when lowering the gm_slot_tensor branch, so even if the verifiers were relaxed, the attribute could not flow through.
  4. buildTPipeTokenFromInitOp in lib/PTO/Transforms/PTOToEmitC.cpp (around line 628) falls back to initOp.getSlotNum() when the attribute is missing.

Reproducer

module {
  func.func @cube_kernel(%gm_slot_buffer : !pto.ptr<f32>)
      attributes {pto.kernel_kind = #pto.kernel_kind<cube>} {
    %c1 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %gm_slots = pto.make_tensor_view %gm_slot_buffer,
      shape = [%c16, %c16], strides = [%c16, %c1]
      : !pto.tensor_view<16x16xf32>
    pto.aic_initialize_pipe {id = 0, dir_mask = 1, slot_size = 1024, local_slot_num = 2}
      (gm_slot_tensor = %gm_slots : !pto.tensor_view<16x16xf32>)
    func.return
  }
}
$ ptoas --pto-arch=a3 reproducer.mlir
error: 'pto.aic_initialize_pipe' op globaltensor pipe init does not use 'local_slot_num'

Proposed fix

Allow local_slot_num on the gm_slot_tensor form so kernels can opt into the manual-aligned LocalSlotNum=2 template. Concretely:

  1. Frontend verifier (lib/PTO/IR/PTO.cpp, verifyFrontendInitCommon): drop the blanket rejection. Validate the attribute the same way the legacy gm_slot_buffer branch does (must be > 0, must be ≤ 8 for dir_mask=1/2 or ≤ 4 for dir_mask=3).
  2. InitializeL2G2LPipeOp verifier: allow local_slot_num on the no-local_addr path when gm_addr is a tensor_view, i.e. the gm_slot_tensor form. The bounds check (> 0, <= slot_num) still runs.
  3. Lowering (lib/PTO/Transforms/PTOLowerFrontendPipeOpsPass.cpp, createFrontendPipe gm_slot_tensor branch): propagate initOp.getLocalSlotNumAttr() into the InitializeL2G2LPipeOp it creates instead of IntegerAttr{}.
  4. Lit test: add test/lit/pto/tpush_tpop_globaltensor_local_slot_num_a3.pto exercising local_slot_num=2 on both aic/aiv and asserting the lowered TPipe<..., 8, 2, ...> template.

A draft PR with these changes is opened against this repository and linked below.

Open question — should the default change too?

The C++ TPipe template defaults to LocalSlotNum=2 (see include/pto/npu/a2a3/TPush.hpp):

template <uint8_t FlagID, uint8_t DirType, uint32_t SlotSize, uint32_t SlotNum, uint32_t LocalSlotNum = 2, ...>

ptoas currently defaults to SlotNum (=8) on both the legacy and gm_slot_tensor paths when the attribute is absent. Aligning ptoas's default with the C++ template default would silently fix the flash-attention case without requiring kernels to write local_slot_num=2, but it touches existing lit-test expectations under test/lit/pto/tpush_tpop_globaltensor_*.pto, test/lit/pto/tpush_tpop_emitc.pto, etc.

I'm filing the verifier-relaxation PR first because it is non-breaking and unblocks the downstream kernel. The default-change discussion can happen separately.

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions