Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ MLASCALL
ArmKleidiAI::MlasDynamicQGemmBatch(
const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape,
const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams,
const size_t BatchSize,
const size_t BatchN,
MLAS_THREADPOOL* ThreadPool
) {

Expand All @@ -112,7 +112,7 @@ ArmKleidiAI::MlasDynamicQGemmBatch(
size_t m_step = qgemm_gemm.ukernel.get_m_step();
size_t n_step = qgemm_gemm.ukernel.get_n_step();

if (BatchSize == 0 || Shape.M == 0 || Shape.N == 0 || Shape.K == 0) {
if (BatchN == 0 || Shape.M == 0 || Shape.N == 0 || Shape.K == 0) {
return;
}

Expand All @@ -123,7 +123,7 @@ ArmKleidiAI::MlasDynamicQGemmBatch(
MLAS_THROW_EX(std::runtime_error, "Dynamic QGEMM requires valid DataParams.");
}

for (size_t batch_idx = 0; batch_idx < BatchSize; ++batch_idx) {
for (size_t batch_idx = 0; batch_idx < BatchN; ++batch_idx) {
const auto& params = DataParams[batch_idx];

if (params.A == nullptr) {
Expand Down Expand Up @@ -151,24 +151,24 @@ ArmKleidiAI::MlasDynamicQGemmBatch(
const size_t LhsPackedStride = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr);
std::byte* LhsPackedData = nullptr;

if (g_kai_tls_qgemm.lhs_packed.capacity() < LhsPackedStride * BatchSize) {
if (g_kai_tls_qgemm.lhs_packed.capacity() < LhsPackedStride * BatchN) {

g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchSize);
g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchN);
}
g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchSize);
g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchN);
LhsPackedData = g_kai_tls_qgemm.lhs_packed.data();

// Per-batch table of LHS base pointers.
if (g_kai_tls_qgemm.lhs_base_table.capacity() < BatchSize) {
if (g_kai_tls_qgemm.lhs_base_table.capacity() < BatchN) {

g_kai_tls_qgemm.lhs_base_table.reserve(BatchSize);
g_kai_tls_qgemm.lhs_base_table.reserve(BatchN);
}
g_kai_tls_qgemm.lhs_base_table.resize(BatchSize);
g_kai_tls_qgemm.lhs_base_table.resize(BatchN);
// Capture the shared batch table pointer so worker threads use the same backing storage.
const std::byte** tls_lhs_base = g_kai_tls_qgemm.lhs_base_table.data();
// B batches require no packing.
// We have already decided the matmul variant we are using before having values for M, N, and K.
MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) {
MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t batch_idx) {

std::byte* lhs = nullptr;
if (DataParams[batch_idx].Workspace && DataParams[batch_idx].WorkspaceSize >= LhsPackedStride) {
Expand All @@ -184,7 +184,7 @@ ArmKleidiAI::MlasDynamicQGemmBatch(

// Tile iteration dimensions.
std::array<size_t, 3> dim;
dim[0] = BatchSize; // B
dim[0] = BatchN; // B
dim[1] = MlasDivRoundup(Shape.M, m_step); // M
dim[2] = MlasDivRoundup(Shape.N, n_step); // N

Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/test/contrib_ops/attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -521,8 +521,12 @@ TEST(ContribOpAttentionTest, AttentionBatch1_Float16) {
3.154296875, 0.1082763671875, 4.25, 5.6484375,
3.970703125, 0.072998046875, 4.25, 5.6484375};

// WebGPU Attention does not support mask_index input.
constexpr bool disable_webgpu = true;
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads, true);
batch_size, sequence_length, hidden_size, number_of_heads, true /*use_float16*/,
false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, 0,
false /*disable_cpu*/, false /*disable_cuda*/, false /*disable_dml*/, disable_webgpu);
}

TEST(ContribOpAttentionTest, AttentionBatch2) {
Expand Down
41 changes: 22 additions & 19 deletions onnxruntime/test/mlas/unittest/test_conv2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

#include "test_util.h"

#if defined(MLAS_TARGET_AMD64)
#if defined(MLAS_TARGET_AMD64) || defined(USE_KLEIDIAI)
#include "core/mlas/lib/mlasi.h"
#endif

Expand Down Expand Up @@ -304,6 +304,25 @@ class MlasConv2DTest : public MlasTestBase {

MlasConv2DTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {}

void TestKleidiAiLhsPointerCacheByPadGrowth() {
#if defined(USE_KLEIDIAI)
if (GetMlasPlatform().MlasConvPrepareOverride == nullptr) {
return;
}

// KleidiAI Conv caches LHS indirection tables by the per-thread padding
// buffer address. Grow the padding buffer, then reuse the original shape so
// stale pad pointers in the old cache entry would corrupt the result.
for (int i = 0; i < 4; ++i) {
Test(1, 1, 64, 11, 11, 32, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1);
Test(1, 1, 320, 11, 11, 32, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1);
Test(1, 1, 64, 11, 11, 32, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1);
}
#else
return;
#endif
}

#if defined(MLAS_TARGET_AMD64)
void TestMobileClipAvx512DispatchSelection(size_t GroupCount,
size_t InputHeight,
Expand Down Expand Up @@ -705,24 +724,7 @@ class MlasConv2DTest : public MlasTestBase {
}
}

//
// Regression test: exercise a KleidiAI Conv2D path when KleidiAI is enabled.
// See https://github.com/microsoft/onnxruntime/issues/26669.
//
// The KleidiAI implementation uses an internal per-thread padding buffer for out-of-bounds pixels
// when constructing the LHS indirection table. Historically, if the buffer was too small for a later
// convolution (larger CI), resizing could invalidate cached indirection pointers and lead to
// non-deterministic corruption.
//
// This sequence forces pad-buffer growth by running a smaller-CI convolution followed by a larger-CI
// convolution (with padding to ensure pad pointers are used), then runs the smaller-CI convolution again.
// Repeat a few times to increase the likelihood of triggering a reallocation and verify the path.
//
for (int i = 0; i < 4; ++i) {
Test(1, 1, 64, 11, 11, 32, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1); // smaller CI
Test(1, 1, 320, 11, 11, 32, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1); // larger CI forces pad buffer growth
Test(1, 1, 64, 11, 11, 32, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1); // sanity: back to smaller CI after growth
}
TestKleidiAiLhsPointerCacheByPadGrowth();
}

void ExecuteShort(void) override {
Expand All @@ -735,5 +737,6 @@ class MlasConv2DTest : public MlasTestBase {
TestMobileClipBetaActivationRegression(128, 32, 32);
TestMobileClipBetaActivationRegression(256, 16, 16);
TestBatchedConv3DWorkingBufferUsesThreadTileSize();
TestKleidiAiLhsPointerCacheByPadGrowth();
}
};