Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions test/samples/FlashAttention/compile_and_run/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# FlashAttention compile and benchmark

This directory contains two PTO FlashAttention variants:

- `fa_140tflops.pto`
- `fa_patched_s1_256_q3072_s0_8192.pto`

## Requirements

- Run inside the configured Ascend/CANN container environment.
- `ptoas` and `bisheng` must already be available in `PATH`.
- `/sources/pto-isa/include` must exist.
- Python benchmark requires `torch_npu==2.9.0`.

## Compile

From this directory, run:

```bash
bash compile_flashattention.sh
```

This builds:

- `/tmp/fa_140tflops.so`
- `/tmp/compiler_team_fa.so`

## Benchmark

After compiling, run:

```bash
python3 benchmark_flashattention.py
```

The benchmark compares both PTO kernels against `torch_npu.npu_fused_infer_attention_score`, checks correctness against both fp32 reference attention and torch_npu output, and reports latency, TFLOP/s, and speedup.
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright (c) 2026 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

import ctypes
import math

import torch
import torch_npu # noqa: F401

KERNELS = [
("fa_140tflops", "/tmp/fa_140tflops.so", 524288, True),
("patched", "/tmp/compiler_team_fa.so", 229376, False),
]
DEVICE = "npu:0"
WARMUP_ITERS = 10
BENCH_ITERS = 100
NUM_CUBE_CORES = 24
RTOL = 1e-3
ATOL = 1e-3

Q_ROWS = 3072
HEAD = 128
S1_TOTAL = 8192
NUM_Q_BLOCKS = Q_ROWS // 32


def load_lib(lib_path, pass_shape):
lib = ctypes.CDLL(lib_path)
argtypes = [
ctypes.c_uint32,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
]
if pass_shape:
argtypes += [ctypes.c_int64, ctypes.c_int64]
lib.call_kernel.argtypes = argtypes
lib.call_kernel.restype = None
return lib


def ptr(t):
return ctypes.c_void_p(t.data_ptr())


def fused_attention(q_bsh, k_bsh, v_bsh):
scale = 1.0 / math.sqrt(q_bsh.shape[-1])
out, _ = torch_npu.npu_fused_infer_attention_score(
q_bsh,
k_bsh,
v_bsh,
num_heads=1,
input_layout="BSH",
scale=scale,
next_tokens=65535,
)
return out


def fa_reference(q, k, v):
scale = 1.0 / math.sqrt(q.shape[1])
scores = q.float() @ k.float().T * scale
return torch.softmax(scores, dim=-1) @ v.float()


def run_pto_kernel(lib, pass_shape, block_dim, gm, q, k, v, o):
stream = torch.npu.current_stream()._as_parameter_
args = [block_dim, stream, ptr(gm), ptr(q), ptr(k), ptr(v), ptr(o)]
if pass_shape:
args += [q.shape[0], k.shape[0]]
lib.call_kernel(*args)


def check_close(out_pto, out_fp32, out_torch_npu):
max_err_fp32 = (out_pto - out_fp32).abs().max().item()
max_err_torch_npu = (out_pto - out_torch_npu).abs().max().item()
try:
torch.testing.assert_close(out_pto, out_fp32, rtol=RTOL, atol=ATOL)
torch.testing.assert_close(out_pto, out_torch_npu, rtol=RTOL, atol=ATOL)
return "PASSED", max_err_fp32, max_err_torch_npu
except AssertionError:
return "FAILED", max_err_fp32, max_err_torch_npu


def bench(fn):
for _ in range(WARMUP_ITERS):
fn()
torch.npu.synchronize()

start = torch.npu.Event(enable_timing=True)
end = torch.npu.Event(enable_timing=True)
start.record()
for _ in range(BENCH_ITERS):
fn()
end.record()
torch.npu.synchronize()
return start.elapsed_time(end) * 1000.0 / BENCH_ITERS


def main():
device = torch.device(DEVICE)
block_dim = min(NUM_Q_BLOCKS, NUM_CUBE_CORES)
flops = 4 * Q_ROWS * HEAD * S1_TOTAL

torch.manual_seed(0)
q = torch.randn((Q_ROWS, HEAD), dtype=torch.float16, device=device)
k = torch.randn((S1_TOTAL, HEAD), dtype=torch.float16, device=device)
v = torch.randn((S1_TOTAL, HEAD), dtype=torch.float16, device=device)
q_bsh = q.unsqueeze(0)
k_bsh = k.unsqueeze(0)
v_bsh = v.unsqueeze(0)

def run_torch_npu():
fused_attention(q_bsh, k_bsh, v_bsh)

out_torch_npu = fused_attention(q_bsh, k_bsh, v_bsh).squeeze(0).float().cpu()
out_fp32 = fa_reference(q, k, v).float().cpu()
torch.npu.synchronize()

torch_npu_us = bench(run_torch_npu)
torch_npu_tflops = flops / (torch_npu_us * 1e-6) / 1e12

print(
f"PTO FA variants vs torch_npu fused attention: Q={Q_ROWS} S1={S1_TOTAL} H={HEAD} "
f"blockDim={block_dim}"
)
print(f" torch_npu: {torch_npu_us:8.2f} us {torch_npu_tflops:7.3f} TFLOP/s")

for name, lib_path, gm_elems_per_block, pass_shape in KERNELS:
lib = load_lib(lib_path, pass_shape)
gm = torch.zeros(
(gm_elems_per_block * block_dim,), dtype=torch.float32, device=device
)
o = torch.zeros((Q_ROWS, HEAD), dtype=torch.float32, device=device)

def run_pto():
run_pto_kernel(lib, pass_shape, block_dim, gm, q, k, v, o)

# Correctness check against torch_npu fused attention.
gm.zero_()
o.zero_()
run_pto()
torch.npu.synchronize()
out_pto = o.float().cpu()
correctness, max_err_fp32, max_err_torch_npu = check_close(
out_pto, out_fp32, out_torch_npu
)

pto_us = bench(run_pto)
pto_tflops = flops / (pto_us * 1e-6) / 1e12
print(
f" {name:12s}: {pto_us:8.2f} us {pto_tflops:7.3f} TFLOP/s "
f"speedup={torch_npu_us / pto_us:.2f}x {correctness} "
f"max_err(fp32={max_err_fp32:.3e}, torch_npu={max_err_torch_npu:.3e})"
)


if __name__ == "__main__":
main()
29 changes: 29 additions & 0 deletions test/samples/FlashAttention/compile_and_run/caller.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2026 Huawei Technologies Co., Ltd.
// This program is free software, you can redistribute it and/or modify it under the terms and conditions of
// CANN Open Software License Agreement Version 2.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.

#include <cstdint>

#ifndef KERNEL_CPP
#error "KERNEL_CPP must be defined at compile time."
#endif

extern "C" int rtGetC2cCtrlAddr(uint64_t *ctrlAddr, uint32_t *ctrlLen);

#include KERNEL_CPP

extern "C" void call_kernel(uint32_t blockDim, void *stream, uint8_t *gmSlotBuffer, uint8_t *q, uint8_t *k, uint8_t *v,
uint8_t *o)
{
void *fftsAddr = nullptr;
uint32_t fftsLen = 0;
(void)rtGetC2cCtrlAddr(reinterpret_cast<uint64_t *>(&fftsAddr), &fftsLen);
(void)fftsLen;
Comment on lines +24 to +25
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The return value of rtGetC2cCtrlAddr is ignored. If this runtime call fails, fftsAddr will remain nullptr, which can lead to a crash or undefined behavior when passed to the kernel. It is safer to check the return code and handle the failure.

    if (rtGetC2cCtrlAddr(reinterpret_cast<uint64_t *>(&fftsAddr), &fftsLen) != 0) {
        return;
    }


call_both<<<blockDim, nullptr, stream>>>((__gm__ int64_t *)fftsAddr, (__gm__ float *)gmSlotBuffer,
(__gm__ half *)q, (__gm__ half *)k, (__gm__ half *)v, (__gm__ float *)o);
}
30 changes: 30 additions & 0 deletions test/samples/FlashAttention/compile_and_run/caller_140tflops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2026 Huawei Technologies Co., Ltd.
// This program is free software, you can redistribute it and/or modify it under the terms and conditions of
// CANN Open Software License Agreement Version 2.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.

#include <cstdint>

#ifndef KERNEL_CPP
#error "KERNEL_CPP must be defined at compile time."
#endif

extern "C" int rtGetC2cCtrlAddr(uint64_t *ctrlAddr, uint32_t *ctrlLen);

#include KERNEL_CPP

extern "C" void call_kernel(uint32_t blockDim, void *stream, uint8_t *gmSlotBuffer, uint8_t *q, uint8_t *k, uint8_t *v,
uint8_t *o, int64_t s0, int64_t s1)
{
void *fftsAddr = nullptr;
uint32_t fftsLen = 0;
(void)rtGetC2cCtrlAddr(reinterpret_cast<uint64_t *>(&fftsAddr), &fftsLen);
(void)fftsLen;
Comment on lines +24 to +25
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The return value of rtGetC2cCtrlAddr is ignored. If this runtime call fails, fftsAddr will remain nullptr, which can lead to a crash or undefined behavior when passed to the kernel. It is safer to check the return code and handle the failure.

    if (rtGetC2cCtrlAddr(reinterpret_cast<uint64_t *>(&fftsAddr), &fftsLen) != 0) {
        return;
    }


call_both<<<blockDim, nullptr, stream>>>((__gm__ int64_t *)fftsAddr, (__gm__ float *)gmSlotBuffer,
(__gm__ half *)gmSlotBuffer, (__gm__ half *)q, (__gm__ half *)k,
(__gm__ half *)v, (__gm__ float *)o, s0, s1);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env bash
# Copyright (c) 2026 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

set -euo pipefail

cd "$(dirname "${BASH_SOURCE[0]}")"

ptoas --pto-arch=a3 --pto-level=level3 --enable-insert-sync \
fa_patched_s1_256_q3072_s0_8192.pto \
>/tmp/compiler_team_fa.cpp
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using hardcoded paths in /tmp for intermediate build artifacts and shared libraries can lead to conflicts or permission issues in multi-user environments. Consider using a local build directory or allowing the output path to be configurable via an environment variable.


bisheng \
-I/sources/pto-isa/include \
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The include path /sources/pto-isa/include is hardcoded, which makes the script dependent on a specific environment setup. Consider using an environment variable (e.g., $PTO_ISA_INCLUDE) to make the script more portable.

-fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \
-Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \
-xcce -Xhost-start -Xhost-end \
-mllvm -cce-aicore-stack-size=0x8000 \
-mllvm -cce-aicore-function-stack-size=0x8000 \
-mllvm -cce-aicore-record-overflow=true \
-mllvm -cce-aicore-addr-transform \
-mllvm -cce-aicore-dcci-insert-for-scalar=false \
-cce-enable-mix \
--npu-arch=dav-2201 -DMEMORY_BASE \
-std=gnu++17 \
-DKERNEL_CPP="\"/tmp/compiler_team_fa.cpp\"" \
"caller.cpp" \
-o /tmp/compiler_team_fa.so

ptoas --pto-arch=a3 --enable-insert-sync \
fa_140tflops.pto \
>/tmp/fa_140tflops.cpp

bisheng \
-I/sources/pto-isa/include \
-fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \
-Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \
-xcce -Xhost-start -Xhost-end \
-mllvm -cce-aicore-stack-size=0x8000 \
-mllvm -cce-aicore-function-stack-size=0x8000 \
-mllvm -cce-aicore-record-overflow=true \
-mllvm -cce-aicore-addr-transform \
-mllvm -cce-aicore-dcci-insert-for-scalar=false \
-cce-enable-mix \
--npu-arch=dav-2201 -DMEMORY_BASE \
-std=gnu++17 \
-DKERNEL_CPP="\"/tmp/fa_140tflops.cpp\"" \
"caller_140tflops.cpp" \
-o /tmp/fa_140tflops.so
Loading
Loading