-
Notifications
You must be signed in to change notification settings - Fork 49
Benchmark flash attention DSL 140tflops and ptoas smoke test version #635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| # FlashAttention compile and benchmark | ||
|
|
||
| This directory contains two PTO FlashAttention variants: | ||
|
|
||
| - `fa_140tflops.pto` | ||
| - `fa_patched_s1_256_q3072_s0_8192.pto` | ||
|
|
||
| ## Requirements | ||
|
|
||
| - Run inside the configured Ascend/CANN container environment. | ||
| - `ptoas` and `bisheng` must already be available in `PATH`. | ||
| - `/sources/pto-isa/include` must exist. | ||
| - Python benchmark requires `torch_npu==2.9.0`. | ||
|
|
||
| ## Compile | ||
|
|
||
| From this directory, run: | ||
|
|
||
| ```bash | ||
| bash compile_flashattention.sh | ||
| ``` | ||
|
|
||
| This builds: | ||
|
|
||
| - `/tmp/fa_140tflops.so` | ||
| - `/tmp/compiler_team_fa.so` | ||
|
|
||
| ## Benchmark | ||
|
|
||
| After compiling, run: | ||
|
|
||
| ```bash | ||
| python3 benchmark_flashattention.py | ||
| ``` | ||
|
|
||
| The benchmark compares both PTO kernels against `torch_npu.npu_fused_infer_attention_score`, checks correctness against both fp32 reference attention and torch_npu output, and reports latency, TFLOP/s, and speedup. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,167 @@ | ||
| # Copyright (c) 2026 Huawei Technologies Co., Ltd. | ||
| # This program is free software, you can redistribute it and/or modify it under the terms and conditions of | ||
| # CANN Open Software License Agreement Version 2.0 (the "License"). | ||
| # Please refer to the License for details. You may not use this file except in compliance with the License. | ||
| # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, | ||
| # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. | ||
| # See LICENSE in the root of the software repository for the full text of the License. | ||
|
|
||
| import ctypes | ||
| import math | ||
|
|
||
| import torch | ||
| import torch_npu # noqa: F401 | ||
|
|
||
| KERNELS = [ | ||
| ("fa_140tflops", "/tmp/fa_140tflops.so", 524288, True), | ||
| ("patched", "/tmp/compiler_team_fa.so", 229376, False), | ||
| ] | ||
| DEVICE = "npu:0" | ||
| WARMUP_ITERS = 10 | ||
| BENCH_ITERS = 100 | ||
| NUM_CUBE_CORES = 24 | ||
| RTOL = 1e-3 | ||
| ATOL = 1e-3 | ||
|
|
||
| Q_ROWS = 3072 | ||
| HEAD = 128 | ||
| S1_TOTAL = 8192 | ||
| NUM_Q_BLOCKS = Q_ROWS // 32 | ||
|
|
||
|
|
||
| def load_lib(lib_path, pass_shape): | ||
| lib = ctypes.CDLL(lib_path) | ||
| argtypes = [ | ||
| ctypes.c_uint32, | ||
| ctypes.c_void_p, | ||
| ctypes.c_void_p, | ||
| ctypes.c_void_p, | ||
| ctypes.c_void_p, | ||
| ctypes.c_void_p, | ||
| ctypes.c_void_p, | ||
| ] | ||
| if pass_shape: | ||
| argtypes += [ctypes.c_int64, ctypes.c_int64] | ||
| lib.call_kernel.argtypes = argtypes | ||
| lib.call_kernel.restype = None | ||
| return lib | ||
|
|
||
|
|
||
| def ptr(t): | ||
| return ctypes.c_void_p(t.data_ptr()) | ||
|
|
||
|
|
||
| def fused_attention(q_bsh, k_bsh, v_bsh): | ||
| scale = 1.0 / math.sqrt(q_bsh.shape[-1]) | ||
| out, _ = torch_npu.npu_fused_infer_attention_score( | ||
| q_bsh, | ||
| k_bsh, | ||
| v_bsh, | ||
| num_heads=1, | ||
| input_layout="BSH", | ||
| scale=scale, | ||
| next_tokens=65535, | ||
| ) | ||
| return out | ||
|
|
||
|
|
||
| def fa_reference(q, k, v): | ||
| scale = 1.0 / math.sqrt(q.shape[1]) | ||
| scores = q.float() @ k.float().T * scale | ||
| return torch.softmax(scores, dim=-1) @ v.float() | ||
|
|
||
|
|
||
| def run_pto_kernel(lib, pass_shape, block_dim, gm, q, k, v, o): | ||
| stream = torch.npu.current_stream()._as_parameter_ | ||
| args = [block_dim, stream, ptr(gm), ptr(q), ptr(k), ptr(v), ptr(o)] | ||
| if pass_shape: | ||
| args += [q.shape[0], k.shape[0]] | ||
| lib.call_kernel(*args) | ||
|
|
||
|
|
||
| def check_close(out_pto, out_fp32, out_torch_npu): | ||
| max_err_fp32 = (out_pto - out_fp32).abs().max().item() | ||
| max_err_torch_npu = (out_pto - out_torch_npu).abs().max().item() | ||
| try: | ||
| torch.testing.assert_close(out_pto, out_fp32, rtol=RTOL, atol=ATOL) | ||
| torch.testing.assert_close(out_pto, out_torch_npu, rtol=RTOL, atol=ATOL) | ||
| return "PASSED", max_err_fp32, max_err_torch_npu | ||
| except AssertionError: | ||
| return "FAILED", max_err_fp32, max_err_torch_npu | ||
|
|
||
|
|
||
| def bench(fn): | ||
| for _ in range(WARMUP_ITERS): | ||
| fn() | ||
| torch.npu.synchronize() | ||
|
|
||
| start = torch.npu.Event(enable_timing=True) | ||
| end = torch.npu.Event(enable_timing=True) | ||
| start.record() | ||
| for _ in range(BENCH_ITERS): | ||
| fn() | ||
| end.record() | ||
| torch.npu.synchronize() | ||
| return start.elapsed_time(end) * 1000.0 / BENCH_ITERS | ||
|
|
||
|
|
||
| def main(): | ||
| device = torch.device(DEVICE) | ||
| block_dim = min(NUM_Q_BLOCKS, NUM_CUBE_CORES) | ||
| flops = 4 * Q_ROWS * HEAD * S1_TOTAL | ||
|
|
||
| torch.manual_seed(0) | ||
| q = torch.randn((Q_ROWS, HEAD), dtype=torch.float16, device=device) | ||
| k = torch.randn((S1_TOTAL, HEAD), dtype=torch.float16, device=device) | ||
| v = torch.randn((S1_TOTAL, HEAD), dtype=torch.float16, device=device) | ||
| q_bsh = q.unsqueeze(0) | ||
| k_bsh = k.unsqueeze(0) | ||
| v_bsh = v.unsqueeze(0) | ||
|
|
||
| def run_torch_npu(): | ||
| fused_attention(q_bsh, k_bsh, v_bsh) | ||
|
|
||
| out_torch_npu = fused_attention(q_bsh, k_bsh, v_bsh).squeeze(0).float().cpu() | ||
| out_fp32 = fa_reference(q, k, v).float().cpu() | ||
| torch.npu.synchronize() | ||
|
|
||
| torch_npu_us = bench(run_torch_npu) | ||
| torch_npu_tflops = flops / (torch_npu_us * 1e-6) / 1e12 | ||
|
|
||
| print( | ||
| f"PTO FA variants vs torch_npu fused attention: Q={Q_ROWS} S1={S1_TOTAL} H={HEAD} " | ||
| f"blockDim={block_dim}" | ||
| ) | ||
| print(f" torch_npu: {torch_npu_us:8.2f} us {torch_npu_tflops:7.3f} TFLOP/s") | ||
|
|
||
| for name, lib_path, gm_elems_per_block, pass_shape in KERNELS: | ||
| lib = load_lib(lib_path, pass_shape) | ||
| gm = torch.zeros( | ||
| (gm_elems_per_block * block_dim,), dtype=torch.float32, device=device | ||
| ) | ||
| o = torch.zeros((Q_ROWS, HEAD), dtype=torch.float32, device=device) | ||
|
|
||
| def run_pto(): | ||
| run_pto_kernel(lib, pass_shape, block_dim, gm, q, k, v, o) | ||
|
|
||
| # Correctness check against torch_npu fused attention. | ||
| gm.zero_() | ||
| o.zero_() | ||
| run_pto() | ||
| torch.npu.synchronize() | ||
| out_pto = o.float().cpu() | ||
| correctness, max_err_fp32, max_err_torch_npu = check_close( | ||
| out_pto, out_fp32, out_torch_npu | ||
| ) | ||
|
|
||
| pto_us = bench(run_pto) | ||
| pto_tflops = flops / (pto_us * 1e-6) / 1e12 | ||
| print( | ||
| f" {name:12s}: {pto_us:8.2f} us {pto_tflops:7.3f} TFLOP/s " | ||
| f"speedup={torch_npu_us / pto_us:.2f}x {correctness} " | ||
| f"max_err(fp32={max_err_fp32:.3e}, torch_npu={max_err_torch_npu:.3e})" | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| // Copyright (c) 2026 Huawei Technologies Co., Ltd. | ||
| // This program is free software, you can redistribute it and/or modify it under the terms and conditions of | ||
| // CANN Open Software License Agreement Version 2.0 (the "License"). | ||
| // Please refer to the License for details. You may not use this file except in compliance with the License. | ||
| // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, | ||
| // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. | ||
| // See LICENSE in the root of the software repository for the full text of the License. | ||
|
|
||
| #include <cstdint> | ||
|
|
||
| #ifndef KERNEL_CPP | ||
| #error "KERNEL_CPP must be defined at compile time." | ||
| #endif | ||
|
|
||
| extern "C" int rtGetC2cCtrlAddr(uint64_t *ctrlAddr, uint32_t *ctrlLen); | ||
|
|
||
| #include KERNEL_CPP | ||
|
|
||
| extern "C" void call_kernel(uint32_t blockDim, void *stream, uint8_t *gmSlotBuffer, uint8_t *q, uint8_t *k, uint8_t *v, | ||
| uint8_t *o) | ||
| { | ||
| void *fftsAddr = nullptr; | ||
| uint32_t fftsLen = 0; | ||
| (void)rtGetC2cCtrlAddr(reinterpret_cast<uint64_t *>(&fftsAddr), &fftsLen); | ||
| (void)fftsLen; | ||
|
|
||
| call_both<<<blockDim, nullptr, stream>>>((__gm__ int64_t *)fftsAddr, (__gm__ float *)gmSlotBuffer, | ||
| (__gm__ half *)q, (__gm__ half *)k, (__gm__ half *)v, (__gm__ float *)o); | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| // Copyright (c) 2026 Huawei Technologies Co., Ltd. | ||
| // This program is free software, you can redistribute it and/or modify it under the terms and conditions of | ||
| // CANN Open Software License Agreement Version 2.0 (the "License"). | ||
| // Please refer to the License for details. You may not use this file except in compliance with the License. | ||
| // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, | ||
| // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. | ||
| // See LICENSE in the root of the software repository for the full text of the License. | ||
|
|
||
| #include <cstdint> | ||
|
|
||
| #ifndef KERNEL_CPP | ||
| #error "KERNEL_CPP must be defined at compile time." | ||
| #endif | ||
|
|
||
| extern "C" int rtGetC2cCtrlAddr(uint64_t *ctrlAddr, uint32_t *ctrlLen); | ||
|
|
||
| #include KERNEL_CPP | ||
|
|
||
| extern "C" void call_kernel(uint32_t blockDim, void *stream, uint8_t *gmSlotBuffer, uint8_t *q, uint8_t *k, uint8_t *v, | ||
| uint8_t *o, int64_t s0, int64_t s1) | ||
| { | ||
| void *fftsAddr = nullptr; | ||
| uint32_t fftsLen = 0; | ||
| (void)rtGetC2cCtrlAddr(reinterpret_cast<uint64_t *>(&fftsAddr), &fftsLen); | ||
| (void)fftsLen; | ||
|
Comment on lines
+24
to
+25
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The return value of if (rtGetC2cCtrlAddr(reinterpret_cast<uint64_t *>(&fftsAddr), &fftsLen) != 0) {
return;
} |
||
|
|
||
| call_both<<<blockDim, nullptr, stream>>>((__gm__ int64_t *)fftsAddr, (__gm__ float *)gmSlotBuffer, | ||
| (__gm__ half *)gmSlotBuffer, (__gm__ half *)q, (__gm__ half *)k, | ||
| (__gm__ half *)v, (__gm__ float *)o, s0, s1); | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| #!/usr/bin/env bash | ||
| # Copyright (c) 2026 Huawei Technologies Co., Ltd. | ||
| # This program is free software, you can redistribute it and/or modify it under the terms and conditions of | ||
| # CANN Open Software License Agreement Version 2.0 (the "License"). | ||
| # Please refer to the License for details. You may not use this file except in compliance with the License. | ||
| # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, | ||
| # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. | ||
| # See LICENSE in the root of the software repository for the full text of the License. | ||
|
|
||
| set -euo pipefail | ||
|
|
||
| cd "$(dirname "${BASH_SOURCE[0]}")" | ||
|
|
||
| ptoas --pto-arch=a3 --pto-level=level3 --enable-insert-sync \ | ||
| fa_patched_s1_256_q3072_s0_8192.pto \ | ||
| >/tmp/compiler_team_fa.cpp | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| bisheng \ | ||
| -I/sources/pto-isa/include \ | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ | ||
| -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ | ||
| -xcce -Xhost-start -Xhost-end \ | ||
| -mllvm -cce-aicore-stack-size=0x8000 \ | ||
| -mllvm -cce-aicore-function-stack-size=0x8000 \ | ||
| -mllvm -cce-aicore-record-overflow=true \ | ||
| -mllvm -cce-aicore-addr-transform \ | ||
| -mllvm -cce-aicore-dcci-insert-for-scalar=false \ | ||
| -cce-enable-mix \ | ||
| --npu-arch=dav-2201 -DMEMORY_BASE \ | ||
| -std=gnu++17 \ | ||
| -DKERNEL_CPP="\"/tmp/compiler_team_fa.cpp\"" \ | ||
| "caller.cpp" \ | ||
| -o /tmp/compiler_team_fa.so | ||
|
|
||
| ptoas --pto-arch=a3 --enable-insert-sync \ | ||
| fa_140tflops.pto \ | ||
| >/tmp/fa_140tflops.cpp | ||
|
|
||
| bisheng \ | ||
| -I/sources/pto-isa/include \ | ||
| -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ | ||
| -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ | ||
| -xcce -Xhost-start -Xhost-end \ | ||
| -mllvm -cce-aicore-stack-size=0x8000 \ | ||
| -mllvm -cce-aicore-function-stack-size=0x8000 \ | ||
| -mllvm -cce-aicore-record-overflow=true \ | ||
| -mllvm -cce-aicore-addr-transform \ | ||
| -mllvm -cce-aicore-dcci-insert-for-scalar=false \ | ||
| -cce-enable-mix \ | ||
| --npu-arch=dav-2201 -DMEMORY_BASE \ | ||
| -std=gnu++17 \ | ||
| -DKERNEL_CPP="\"/tmp/fa_140tflops.cpp\"" \ | ||
| "caller_140tflops.cpp" \ | ||
| -o /tmp/fa_140tflops.so | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return value of
rtGetC2cCtrlAddris ignored. If this runtime call fails,fftsAddrwill remainnullptr, which can lead to a crash or undefined behavior when passed to the kernel. It is safer to check the return code and handle the failure.