From 923422fea2556b7c19f3b4ef7166173a612aa0d7 Mon Sep 17 00:00:00 2001 From: Cathal Lawlor Date: Tue, 12 May 2026 13:26:49 +0100 Subject: [PATCH 1/8] Add Arm64 fp16 MLAS and KleidiAI support Add native fp16 HGEMM and halfconv IMATMUL support for Arm64 MLAS/KleidiAI, including CPU fp16 MatMul/Gemm exposure, prepack handling, backend selector routing, and focused test coverage. Also include Apple Arm64 fp16 build enablement, insert-cast mixed-EP coverage, and fixes for fp16 routing edge cases such as zero-K halfgemm handling and backend-native packed-B contracts. Signed-off-by: Cathal Lawlor --- cmake/onnxruntime_mlas.cmake | 67 +- .../onnxruntime_session_options_config_keys.h | 18 + onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc | 3 +- onnxruntime/core/framework/session_state.cc | 11 +- onnxruntime/core/mlas/inc/mlas.h | 124 ++- onnxruntime/core/mlas/lib/halfconv.cpp | 158 +++ onnxruntime/core/mlas/lib/halfgemm.cpp | 142 ++- onnxruntime/core/mlas/lib/halfgemm.h | 67 +- .../core/mlas/lib/halfgemm_kernel_neon.cpp | 2 +- .../core/mlas/lib/kai_ukernel_interface.cpp | 34 + .../core/mlas/lib/kai_ukernel_interface.h | 18 +- .../mlas/lib/kleidiai/convolve_kleidiai.cpp | 6 +- .../mlas/lib/kleidiai/halfconv_kleidiai.cpp | 999 ++++++++++++++++++ .../mlas/lib/kleidiai/halfgemm_kleidiai.cpp | 278 +++++ .../core/mlas/lib/kleidiai/mlasi_kleidiai.h | 136 +++ .../core/mlas/lib/kleidiai/qgemm_kleidiai.cpp | 22 +- onnxruntime/core/mlas/lib/mlasi.h | 93 ++ onnxruntime/core/mlas/lib/platform.cpp | 9 +- .../core/optimizer/insert_cast_transformer.cc | 468 +++++++- .../core/optimizer/insert_cast_transformer.h | 48 +- .../providers/cpu/cpu_execution_provider.cc | 20 + .../core/providers/cpu/fp16/fp16_conv.cc | 133 ++- onnxruntime/core/providers/cpu/math/gemm.cc | 20 +- onnxruntime/core/providers/cpu/math/matmul.cc | 165 ++- onnxruntime/core/providers/cpu/math/matmul.h | 24 + onnxruntime/core/session/inference_session.cc | 26 +- onnxruntime/core/util/math_cpu.cc | 29 + .../framework/insert_cast_transformer_test.cc | 692 ++++++++++++ .../test/mlas/unittest/test_conv2d.cpp | 58 + .../test/mlas/unittest/test_halfgemm.cpp | 620 +++++++++++ .../test/mlas/unittest/test_halfgemm.h | 144 ++- onnxruntime/test/mlas/unittest/test_util.h | 2 + .../test/providers/cpu/math/gemm_test.cc | 46 + .../test/providers/cpu/math/matmul_test.cc | 63 ++ .../test/providers/cpu/nn/conv_fp16_test.cc | 314 +++++- 35 files changed, 4908 insertions(+), 151 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/halfconv.cpp create mode 100644 onnxruntime/core/mlas/lib/kleidiai/halfconv_kleidiai.cpp create mode 100644 onnxruntime/core/mlas/lib/kleidiai/halfgemm_kleidiai.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 71558bf4d201a..f15247e12bc11 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -23,6 +23,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/threading.cpp ${MLAS_SRC_DIR}/sgemm.cpp ${MLAS_SRC_DIR}/halfgemm.cpp + ${MLAS_SRC_DIR}/halfconv.cpp ${MLAS_SRC_DIR}/qgemm.cpp ${MLAS_SRC_DIR}/qdwconv.cpp ${MLAS_SRC_DIR}/convolve.cpp @@ -302,6 +303,8 @@ function(setup_kleidiai) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/kai_ukernel_interface.cpp ${MLAS_SRC_DIR}/kleidiai/sgemm_kleidiai.cpp + ${MLAS_SRC_DIR}/kleidiai/halfgemm_kleidiai.cpp + ${MLAS_SRC_DIR}/kleidiai/halfconv_kleidiai.cpp ${MLAS_SRC_DIR}/kleidiai/sbgemm_kleidiai.cpp ${MLAS_SRC_DIR}/kleidiai/convolve_kleidiai.cpp ${MLAS_SRC_DIR}/kleidiai/qgemm_kleidiai.cpp @@ -356,7 +359,7 @@ function (setup_arm_neon_nchwc) ${MLAS_SRC_DIR}/aarch64/SconvNchwcKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeon.S - ) + ) endif() list(APPEND mlas_private_compile_definitions MLAS_USE_ARM_NEON_NCHWC) set(mlas_private_compile_definitions ${mlas_private_compile_definitions} PARENT_SCOPE) @@ -536,8 +539,7 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") - if (NOT APPLE) - set(mlas_platform_srcs + set(mlas_platform_srcs ${mlas_platform_srcs} ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S @@ -563,37 +565,36 @@ else() ${MLAS_SRC_DIR}/gelu_neon_fp16.h ${MLAS_SRC_DIR}/gelu_neon_fp16.cpp ) - if (onnxruntime_USE_ARM_NEON_NCHWC) - list(APPEND mlas_platform_srcs - ${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeonBf16.S - ${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S - ${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S - ) - endif() - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16_8bit.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - if (onnxruntime_USE_ARM_NEON_NCHWC) - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - endif() - set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/gelu_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + if (onnxruntime_USE_ARM_NEON_NCHWC AND NOT APPLE) + list(APPEND mlas_platform_srcs + ${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeonBf16.S + ${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S + ${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S + ) + endif() + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16_8bit.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + if (onnxruntime_USE_ARM_NEON_NCHWC AND NOT APPLE) + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") endif() + set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/gelu_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") if(ONNXRUNTIME_MLAS_MULTI_ARCH) onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs}) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 9d61165927d8c..2b1d4d8701563 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -462,6 +462,24 @@ static const char* const kOrtSessionOptionsEpContextModelExternalInitializersFil // - "1": Gemm FastMath mode is enabled. static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; +// Enables opt-in CPU EP float16 execution for supported ops. +// This allows InsertCastTransformer to preserve eligible fp16 CPU nodes instead of inserting fp16<->fp32 casts. +// The fp32 fallback heuristic below is still enabled by default and may keep selected shapes on the fp32 path when +// native fp16 is not expected to be profitable for the active MLAS backend. +// Option values: +// - "0": CPU fp16 is not explicitly enabled, and default cast/fallback behavior is used. [DEFAULT] +// - "1": Enable CPU fp16 preservation for the current opt-in scope. +static const char* const kOrtSessionOptionsEnableCpuFp16 = "session.enable_cpu_fp16"; + +// Controls the CPU fp16 -> fp32 fallback heuristic used when session.enable_cpu_fp16 is "1". +// The heuristic keeps native CPU fp16 only for cases expected to be profitable, such as supported GEMV-like shapes or +// constant-RHS MatMul when the active MLAS backend reports a native packed-B path. Other eligible nodes are assigned to +// CPU through fp32 casts when a valid CPU fp32 kernel exists. This is the recommended/default CPU fp16 mode. +// Option values: +// - "0": Do not use the heuristic; preserve native CPU fp16 more broadly for eligible CPU fp16 kernels. +// - "1": Use the fp32 fallback heuristic. [DEFAULT] +static const char* const kOrtSessionOptionsCpuFp16UseFp32FallbackHeuristic = "session.cpu_fp16_use_fp32_fallback_heuristic"; + // Use LUT (Lookup Table) based GEMM for quantized models when available. // Option values: // - "0": Do not use LUT based GEMM. [DEFAULT] diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc index c3194f1c57ba7..08fab6d2c110a 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc @@ -551,7 +551,8 @@ Status MoE::ComputeGEMM(const MLFloat16* A, const MLFloat16* B, MLFlo params.ldb = static_cast(N); } - MlasHalfGemmBatch(static_cast(M), static_cast(N), static_cast(K), 1, ¶ms, nullptr); + MlasHalfGemmBatch(static_cast(M), static_cast(N), static_cast(K), 1, ¶ms, nullptr, + &mlas_backend_kernel_selector_config_); return Status::OK(); } diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 6ef2319c1d3f4..69cfe16c20ff5 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -517,13 +517,10 @@ Status SessionState::PrepackConstantInitializedTensors( is_packed, &weights_to_be_filled_in)); - if (is_packed) { - // BUG CHECK: Ensure that the kernel has filled in the pre-packed weight - // to be cached if the weight was pre-packed - ORT_ENFORCE(weights_to_be_filled_in.buffers_.size() > 0, - "The kernel corresponding to the node ", node.Name(), - " doesn't have an implementation that can cache computed pre-packed weights"); - + // Some kernels pre-pack for their own session but intentionally do not + // expose those buffers for sharing, for example when the packed layout is + // backend-specific and the shared container has no layout metadata. + if (is_packed && !weights_to_be_filled_in.buffers_.empty()) { const auto& op_type = node.OpType(); // Sanity check diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index ddb9daa5e244b..1552c15db2e6d 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -92,15 +92,7 @@ Module Name: #if (!defined(_MSC_VER)) || (_MSC_VER >= 1930) #if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) -#if !defined(__APPLE__) -// Had to temporary disable fp16 under APPLE ARM64, as compiling -// the source files require a hardware specific compilation flag. -// When building an universial binary for APPLE, this flag would -// cause trouble for x64 target. - #define MLAS_F16VEC_INTRINSICS_SUPPORTED - -#endif // #endif // ARM64 #endif // Visual Studio 16 or earlier does not support fp16 intrinsic @@ -906,6 +898,7 @@ struct MLAS_CONV_PARAMETERS { float Beta; MLAS_CONV_ALGORITHM Algorithm; ptrdiff_t ThreadCount; + bool InputOutputChannelsLast; const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr; const void* PackedFilter = nullptr; size_t PackedFilterGroupStride = 0; @@ -1898,11 +1891,20 @@ struct MLAS_HALF_GEMM_DATA_PARAMS { const MLAS_HALF_GEMM_POSTPROCESSOR* OutputProcessor = nullptr; bool AIsfp32 = false; /**< matrix A is fp32, needs to be casted into fp16*/ bool BIsfp32 = false; /**< matrix B is fp32, needs to be casted into fp16*/ + bool BIsPacked = false; /**< matrix B is pre-packed by MlasHalfGemmPackB/MlasHalfGemmConvertPackB */ + /** + * Matrix B uses a backend-specific direct-consumption packed layout. + * When true, B must be produced by MlasHalfGemmNativePackB, ldb must be 0, + * Bias must be nullptr, and OutputProcessor must be nullptr. + */ + bool BIsBackendNativePacked = false; }; /** * @brief Half precision Batched GEMM: C = A * B + Bias - * Either A or B can be fp32 or fp16 + * Either A or B can be fp32 or fp16. + * Backend-native packed B is a constrained direct-consumption layout + * and does not support runtime Bias or OutputProcessor. * * Note: We only support uniform batching, so shapes and types of the * input must be same across all parameter blocks. @@ -1913,6 +1915,7 @@ struct MLAS_HALF_GEMM_DATA_PARAMS { * @param[in] BatchN number of batches * @param[inout] DataParams An array (size BatchN) of parameter blocks * @param[in] ThreadPool + * @param[in] BackendKernelSelectorConfig Optional backend selector config * @return */ void @@ -1923,7 +1926,8 @@ MlasHalfGemmBatch( const size_t K, const size_t BatchN, const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr ); /** @@ -1964,6 +1968,106 @@ MlasHalfGemmPackB( void* PackedB ); +/** + * @brief For half precision GEMM, returns the size of a backend-native + * direct-consumption packing buffer for right hand side B. + * + * The returned layout is only valid when MlasHalfGemmBatch consumes + * MLAS_HALF_GEMM_DATA_PARAMS with BIsBackendNativePacked set to true, ldb set + * to 0, Bias set to nullptr, and OutputProcessor set to nullptr. Returns 0 + * when no backend-native format is available for the given transpose/shape/config. + */ +size_t +MLASCALL +MlasHalfGemmNativePackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr + ); + +/** + * @brief For half precision GEMM, pack right hand side B into the + * backend-native direct-consumption format. + * + * Returns false when the backend-native format is unavailable for the given + * transpose/shape/config. The resulting buffer must be passed to + * MlasHalfGemmBatch with ldb set to 0, BIsBackendNativePacked set to true, + * Bias set to nullptr, and OutputProcessor set to nullptr. + */ +bool +MLASCALL +MlasHalfGemmNativePackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const MLAS_FP16* B, + size_t ldb, + void* PackedB, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr + ); + +bool +MLASCALL +MlasHalfConvPrepare( + MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + bool InputOutputChannelsLast, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr); + +bool +MLASCALL +MlasHalfConv( + const MLAS_CONV_PARAMETERS* Parameters, + const MLAS_FP16* Input, + const MLAS_FP16* Filter, + // When true, Filter must point to a packed weights+bias buffer produced by + // MlasHalfConvPackWeightsAndBias and Bias must be nullptr. + bool FilterAndBiasArePacked, + const MLAS_FP16* Bias, + MLAS_FP16* WorkingBuffer, + MLAS_FP16* Output, + MLAS_THREADPOOL* ThreadPool); + +size_t +MLASCALL +MlasHalfConvPackWeightsAndBiasSize( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr); + +bool +MLASCALL +MlasHalfConvPackWeightsAndBias( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape, + const MLAS_FP16* Filter, + // Optional bias to bake into the packed direct-consumption buffer. + const MLAS_FP16* Bias, + void* PackedWeightsAndBias, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr); + /** * @brief For half precision GEMM, convert the float matrix B * to half precision and pack it into a packing buffer diff --git a/onnxruntime/core/mlas/lib/halfconv.cpp b/onnxruntime/core/mlas/lib/halfconv.cpp new file mode 100644 index 0000000000000..22fa48dbc2874 --- /dev/null +++ b/onnxruntime/core/mlas/lib/halfconv.cpp @@ -0,0 +1,158 @@ +// +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +/*++ + +Module Name: + + halfconv.cpp + +Abstract: + + This module implements public dispatch wrappers for optional half precision + convolution backends. + +--*/ + +#include "mlasi.h" + +bool +MLASCALL +MlasHalfConvPrepare( + MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + bool InputOutputChannelsLast, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig + ) +{ + if (GetMlasPlatform().MlasHalfConvPrepareOverride == nullptr) { + return false; + } + + return GetMlasPlatform().MlasHalfConvPrepareOverride( + Parameters, + Dimensions, + BatchCount, + GroupCount, + InputChannels, + InputShape, + KernelShape, + DilationShape, + Padding, + StrideShape, + OutputShape, + FilterCount, + Activation, + WorkingBufferSize, + Beta, + InputOutputChannelsLast, + ThreadPool, + BackendKernelSelectorConfig); +} + +bool +MLASCALL +MlasHalfConv( + const MLAS_CONV_PARAMETERS* Parameters, + const MLAS_FP16* Input, + const MLAS_FP16* Filter, + bool FilterAndBiasArePacked, + const MLAS_FP16* Bias, + MLAS_FP16* WorkingBuffer, + MLAS_FP16* Output, + MLAS_THREADPOOL* ThreadPool + ) +{ + if (GetMlasPlatform().MlasHalfConvOverride == nullptr) { + return false; + } + + if (FilterAndBiasArePacked && Bias != nullptr) { + return false; + } + + return GetMlasPlatform().MlasHalfConvOverride( + Parameters, + Input, + Filter, + FilterAndBiasArePacked, + Bias, + WorkingBuffer, + Output, + ThreadPool); +} + +size_t +MLASCALL +MlasHalfConvPackWeightsAndBiasSize( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig + ) +{ + if (BackendKernelSelectorConfig != nullptr && !BackendKernelSelectorConfig->use_kleidiai) { + return 0; + } + + if (GetMlasPlatform().MlasHalfConvPackWeightsAndBiasSizeOverride == nullptr) { + return 0; + } + + return GetMlasPlatform().MlasHalfConvPackWeightsAndBiasSizeOverride( + FilterCount, + InputChannels, + KernelShape, + DilationShape); +} + +bool +MLASCALL +MlasHalfConvPackWeightsAndBias( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape, + const MLAS_FP16* Filter, + const MLAS_FP16* Bias, + void* PackedWeightsAndBias, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig + ) +{ + if (BackendKernelSelectorConfig != nullptr && !BackendKernelSelectorConfig->use_kleidiai) { + return false; + } + + if (GetMlasPlatform().MlasHalfConvPackWeightsAndBiasOverride == nullptr) { + return false; + } + + return GetMlasPlatform().MlasHalfConvPackWeightsAndBiasOverride( + FilterCount, + InputChannels, + KernelShape, + DilationShape, + Filter, + Bias, + PackedWeightsAndBias, + ThreadPool); +} diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp index 66a335665d024..e7fe787b7f7dd 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -19,9 +19,38 @@ Module Name: #include "mlas_float16.h" #include "halfgemm.h" +#if defined(USE_KLEIDIAI) +#include "kleidiai/mlasi_kleidiai.h" +#endif #include +static void +MlasHalfGemmZeroKBatch( + const size_t M, + const size_t N, + const size_t BatchN, + const MLAS_HALF_GEMM_DATA_PARAMS* DataParams + ) +{ + for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { + const auto* Data = &DataParams[gemm_i]; + auto* C = reinterpret_cast(Data->C); + const auto* Bias = reinterpret_cast(Data->Bias); + const size_t ldc = Data->ldc; + + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + C[m * ldc + n] = Bias == nullptr ? MLAS_FP16::FromBits(0) : Bias[n]; + } + } + + if (Data->OutputProcessor != nullptr) { + Data->OutputProcessor->Process(Data->C, 0, 0, M, N, ldc); + } + } +} + bool MLASCALL MlasFp16AccelerationSupported() { @@ -41,9 +70,41 @@ MlasHalfGemmBatch( const size_t K, const size_t BatchN, const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig ) { + MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig); + + if (BatchN == 0 || M == 0 || N == 0) { + return; + } + + if (K == 0) { + MlasHalfGemmZeroKBatch(M, N, BatchN, DataParams); + return; + } + +#if defined(USE_KLEIDIAI) + if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) && + GetMlasPlatform().MlasHalfGemmBatchOverride != nullptr && + GetMlasPlatform().MlasHalfGemmBatchOverride( + M, N, K, BatchN, DataParams, ThreadPool, BackendKernelSelectorConfig)) { + return; + } +#endif + + // BIsPacked denotes the generic MLAS halfgemm packed-B layout and can be + // consumed here. Backend-native packed layouts are separate and must not + // silently fall through to the generic kernels. + for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { + if (DataParams[gemm_i].BIsBackendNativePacked) { + MLAS_THROW_EX( + std::runtime_error, + "backend-native halfgemm packed B is not supported by generic MLAS halfgemm"); + } + } + const MLAS_HALFGEMM_DISPATCH* dispatch = MlasHalfGemmGetDispatch(); MLAS_HALFGEMM_OPERATION* operation = dispatch->Operation; @@ -128,11 +189,24 @@ MlasHalfGemmPackBSize( // No packing routine provided return 0; } - const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); - const size_t BytesRequired = N * AlignedK * FP16_SIZE + padding; + size_t aligned_k_input = 0; + if (MlasTryAddSizeT(K, PackedK - 1, &aligned_k_input)) { + return 0; + } + const size_t AlignedK = aligned_k_input & ~(PackedK - 1); + + size_t BytesRequired = 0; + if (MlasTryMultiplySizeT(N, AlignedK, &BytesRequired) || + MlasTryMultiplySizeT(BytesRequired, FP16_SIZE, &BytesRequired) || + MlasTryAddSizeT(BytesRequired, padding, &BytesRequired)) { + return 0; + } const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); - const size_t AlignedBytesRequired = - (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); + size_t aligned_bytes_input = 0; + if (MlasTryAddSizeT(BytesRequired, BufferAlignment - 1, &aligned_bytes_input)) { + return 0; + } + const size_t AlignedBytesRequired = aligned_bytes_input & ~(BufferAlignment - 1); return AlignedBytesRequired; } @@ -150,6 +224,62 @@ MlasHalfGemmPackB( dispatch->CopyPackBRoutine((_mlas_fp16_*)PackedB, (const _mlas_fp16_*)B, ldb, N, K); } +size_t +MLASCALL +MlasHalfGemmNativePackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig + ) +{ +#if defined(USE_KLEIDIAI) + if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) && + GetMlasPlatform().MlasHalfGemmPackBSizeOverride != nullptr) { + return GetMlasPlatform().MlasHalfGemmPackBSizeOverride(TransA, TransB, N, K); + } +#else + MLAS_UNREFERENCED_PARAMETER(TransA); + MLAS_UNREFERENCED_PARAMETER(TransB); + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig); +#endif + return 0; +} + +bool +MLASCALL +MlasHalfGemmNativePackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const MLAS_FP16* B, + size_t ldb, + void* PackedB, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig + ) +{ +#if defined(USE_KLEIDIAI) + if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) && + GetMlasPlatform().MlasHalfGemmPackBOverride != nullptr) { + return GetMlasPlatform().MlasHalfGemmPackBOverride(TransA, TransB, N, K, B, ldb, PackedB); + } +#else + MLAS_UNREFERENCED_PARAMETER(TransA); + MLAS_UNREFERENCED_PARAMETER(TransB); + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(B); + MLAS_UNREFERENCED_PARAMETER(ldb); + MLAS_UNREFERENCED_PARAMETER(PackedB); + MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig); +#endif + return false; +} + void MLASCALL MlasHalfGemmConvertPackB( @@ -574,7 +704,7 @@ MlasGemmBatch( const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault = { MlasHalfGemmOperation, - nullptr, + MlasHalfGemmCopyPackB, MlasHalfGemmConvertPackB, MLAS_HALF_GEMM_KERNEL_DEFAULT::PackedK, MLAS_HALF_GEMM_KERNEL_DEFAULT::KernelMaxM, diff --git a/onnxruntime/core/mlas/lib/halfgemm.h b/onnxruntime/core/mlas/lib/halfgemm.h index 529db48f58e6f..1eaedbb1ccef0 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.h +++ b/onnxruntime/core/mlas/lib/halfgemm.h @@ -34,12 +34,53 @@ Module Name: #include #include +#include #include #include "mlasi.h" #include "mlas_float16.h" +MLAS_FORCEINLINE +bool +MlasTryMultiplySizeT( + size_t a, + size_t b, + size_t* out + ) +{ +#if defined(__has_builtin) +#if __has_builtin(__builtin_mul_overflow) + return __builtin_mul_overflow(a, b, out); +#endif +#endif + if (b != 0 && a > (std::numeric_limits::max)() / b) { + return true; + } + *out = a * b; + return false; +} + +MLAS_FORCEINLINE +bool +MlasTryAddSizeT( + size_t a, + size_t b, + size_t* out + ) +{ +#if defined(__has_builtin) +#if __has_builtin(__builtin_add_overflow) + return __builtin_add_overflow(a, b, out); +#endif +#endif + if (a > (std::numeric_limits::max)() - b) { + return true; + } + *out = a + b; + return false; +} + /** * @brief Define the default striding parameters for * the half precision gemm operation @@ -71,12 +112,26 @@ MlasHalfGemmCopyPackB( size_t CountK ) { - MLAS_UNREFERENCED_PARAMETER(D); - MLAS_UNREFERENCED_PARAMETER(B); - MLAS_UNREFERENCED_PARAMETER(ldb); - MLAS_UNREFERENCED_PARAMETER(CountN); - MLAS_UNREFERENCED_PARAMETER(CountK); - // No packing needed by default + if (ldb == CountN) { + size_t bytes_to_copy = 0; + ORT_ENFORCE( + !MlasTryMultiplySizeT(CountK, CountN, &bytes_to_copy) && + !MlasTryMultiplySizeT(bytes_to_copy, sizeof(_mlas_fp16_), &bytes_to_copy), + "MlasHalfGemmCopyPackB size overflow"); + std::memcpy(D, B, bytes_to_copy); + return; + } + + size_t row_bytes = 0; + ORT_ENFORCE( + !MlasTryMultiplySizeT(CountN, sizeof(_mlas_fp16_), &row_bytes), + "MlasHalfGemmCopyPackB row size overflow"); + while (CountK > 0) { + std::memcpy(D, B, row_bytes); + B += ldb; + D += CountN; + CountK--; + } } /** diff --git a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp index d7f5a90b00589..f9bb3da1c5b57 100644 --- a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp @@ -179,7 +179,7 @@ MlasHalfGemmKernel( const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchNeon = { MlasHalfGemmOperation, - nullptr, + MlasHalfGemmCopyPackB, MlasHalfGemmConvertPackB, MLAS_HALF_GEMM_KERNEL_NEON::PackedK, MLAS_HALF_GEMM_KERNEL_NEON::KernelMaxM, diff --git a/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp b/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp index 6ee80594c6b49..a352fc85abbd1 100644 --- a/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp +++ b/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp @@ -28,6 +28,7 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h" // IMATMUL #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h" +#include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.h" // SME2 kernels // GEMM/QGEMM/SBGEMM @@ -40,6 +41,11 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h" // IMATMUL #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h" +#include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" + +// FP16 HGEMM kernels +#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.h" #if defined(ENABLE_QMX_KERNELS) // QMX kernels (optional) @@ -229,6 +235,12 @@ const KaiF32IMatmulKernel imatmul_conv_sme = const KaiF32IMatmulKernel imatmul_conv_sme2 = KAI_WRAP_UKERNEL_RUN_IMATMUL_PACKED_7(imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa); +const KaiF16IMatmulKernel imatmul_f16_conv_sme = + KAI_WRAP_UKERNEL_RUN_IMATMUL_PACKED_7(imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa); + +const KaiF16IMatmulKernel imatmul_f16_conv_sme2 = + KAI_WRAP_UKERNEL_RUN_IMATMUL_PACKED_7(imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa); + const KaiBF16SBgemmKernel sbgemm_gemm_sme2 = KAI_WRAP_UKERNEL_RUN_MATMUL_11(matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa); @@ -246,6 +258,12 @@ const KaiDynamicQGemmKernel qgemm_gemm_sme = const KaiDynamicQGemmKernel qgemm_gemm_sme2 = KAI_WRAP_UKERNEL_RUN_MATMUL_11(matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa); +const KaiF16HgemmKernel hgemm_sme = + KAI_WRAP_UKERNEL_RUN_MATMUL_10_LHS_OFFSET(matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla); + +const KaiF16HgemmKernel hgemm_sme2 = + KAI_WRAP_UKERNEL_RUN_MATMUL_10_LHS_PACKED_OFFSET(matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot); + #if defined(ENABLE_QMX_KERNELS) @@ -350,6 +368,14 @@ const KaiF32IMatmulKernel& GetKleidiAIF32IMatmulUKernel() { } } +const KaiF16IMatmulKernel& GetKleidiAIF16IMatmulUKernel() { + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) { + return imatmul_f16_conv_sme2; + } else { + return imatmul_f16_conv_sme; + } +} + const KaiDynamicQGemmKernel& GetKleidiAIQGemmUKernel() { if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) { return qgemm_gemm_sme2; @@ -372,3 +398,11 @@ const KaiBF16SBgemmKernel& GetKleidiAISBGemmUKernel() { // Currently only SME2 variant exists for bfloat16/SBGEMM kernel return sbgemm_gemm_sme2; } + +const KaiF16HgemmKernel& GetKleidiAIHgemmUKernel() { + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) { + return hgemm_sme2; + } else { + return hgemm_sme; + } +} diff --git a/onnxruntime/core/mlas/lib/kai_ukernel_interface.h b/onnxruntime/core/mlas/lib/kai_ukernel_interface.h index 155ecf1762b3b..9b42e30a969d0 100644 --- a/onnxruntime/core/mlas/lib/kai_ukernel_interface.h +++ b/onnxruntime/core/mlas/lib/kai_ukernel_interface.h @@ -18,6 +18,10 @@ #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h" +#include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h" + +#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p_interface.h" + // Wrapper type that carries a stable "name" alongside the KAI ukernel interface. // This avoids needing to infer which underlying microkernel was selected from a function pointer. template @@ -41,8 +45,14 @@ using KaiDynamicQGemmKernel = KaiMatmulKernel; +// Wrapper for FP16 IMATMUL kernels used by the KleidiAI convolution implementation. +using KaiF16IMatmulKernel = KaiMatmulKernel; + using KaiBF16SBgemmKernel = KaiMatmulKernel; +// Wrapper for FP16 HGEMM kernels producing FP16 output. +using KaiF16HgemmKernel = KaiMatmulKernel; + // Returns the selected Qnbit GEMM ukernel based on runtime CPU capabilities. const KaiQnbitGemmKernel& GetKleidiAIGemmUKernel(); @@ -61,5 +71,11 @@ const KaiF32SgemvKernel& GetKleidiAISGemvUKernel(); // Returns the selected FP32 IMATMUL ukernel used by the KleidiAI convolution implementation. const KaiF32IMatmulKernel& GetKleidiAIF32IMatmulUKernel(); +// Returns the selected FP16 IMATMUL ukernel used by the KleidiAI convolution implementation. +const KaiF16IMatmulKernel& GetKleidiAIF16IMatmulUKernel(); + // Returns the selected BF16 SBGEMM ukernel used by the KleidiAI based on runtime CPU capabilities. -const KaiBF16SBgemmKernel& GetKleidiAISBGemmUKernel(); \ No newline at end of file +const KaiBF16SBgemmKernel& GetKleidiAISBGemmUKernel(); + +// Returns the selected FP16 HGEMM ukernel based on runtime CPU capabilities. +const KaiF16HgemmKernel& GetKleidiAIHgemmUKernel(); diff --git a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp index cca4f5a19c417..1b8b80ed496d9 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp @@ -480,14 +480,12 @@ static std::unique_ptr LhsPackImageDataSme(const size_t ci, const s 1, 1 }; - auto& lhs_ptrs_cache = lhs_ptrs_cache_by_pad[cur_pad_ptr]; - std::shared_ptr lhs_ptrs; - if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) { + if (auto found = lhs_ptrs_cache_by_pad[cur_pad_ptr].find(key); found != lhs_ptrs_cache_by_pad[cur_pad_ptr].end()) { lhs_ptrs = found->second; } else { lhs_ptrs = LhsPtrFill(ci, ih, iw, kh, kw, sh, sw, padding, &pad_ptr[0]); - lhs_ptrs_cache[key] = lhs_ptrs; + lhs_ptrs_cache_by_pad[cur_pad_ptr][key] = lhs_ptrs; } MultiThreadedLHSPackSme(ThreadPool, ci, m, kh, kw, &lhs_ptrs[0], &lhs[0], activation_src, &pad_ptr[0]); diff --git a/onnxruntime/core/mlas/lib/kleidiai/halfconv_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/halfconv_kleidiai.cpp new file mode 100644 index 0000000000000..2a0efe3aab32f --- /dev/null +++ b/onnxruntime/core/mlas/lib/kleidiai/halfconv_kleidiai.cpp @@ -0,0 +1,999 @@ +// +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include +#include +#include +#include +#include +#include + +#include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h" +#include "kai_ukernel_interface.h" +#include "mlasi_kleidiai.h" + +namespace +{ + +// Cache-oriented heuristics for the chunked IMATMUL path. These bound the +// packed LHS working set so we only split when the single-pass pack/compute +// footprint is large enough that chunking can improve locality. +constexpr size_t AutomaticMaximumLhsChunkBytes = 2 * 1024 * 1024; +constexpr size_t MinimumLhsPackedBytesForAutomaticChunking = 8 * 1024 * 1024; +constexpr size_t MinimumEffectiveKForAutomaticChunking = 1024; +// Small filter counts typically do not amortize the extra chunking overhead. +constexpr size_t MaximumFilterCountForAutomaticChunking = 32; + +bool +TryComputeKernelSize( + size_t dilation, + size_t kernel, + size_t& dilated_kernel +) +{ + if (dilation == 0 || kernel == 0) { + return false; + } + + size_t scaled_kernel = 0; + if (mul_overflow_size_t_builtin(dilation, kernel, &scaled_kernel) || scaled_kernel < (dilation - 1)) { + return false; + } + + dilated_kernel = scaled_kernel - (dilation - 1); + return true; +} + +bool +TryComputeConvOutSize( + size_t input, + size_t kernel, + size_t padding, + size_t stride, + size_t& output +) +{ + output = 0; + if (stride == 0) { + return false; + } + + size_t total_padding = 0; + size_t padded_input = 0; + if (mul_overflow_size_t_builtin(padding, 2, &total_padding) || + !TryAddSize(input, total_padding, padded_input) || + padded_input < kernel) { + return true; + } + + output = ((padded_input - kernel) / stride) + 1; + return true; +} + +bool +TryComputeOutputSize( + size_t input_height, + size_t input_width, + size_t kernel_height, + size_t kernel_width, + size_t padding_height, + size_t padding_width, + size_t stride_height, + size_t stride_width, + size_t& output_height, + size_t& output_width, + size_t& output_size +) +{ + if (!TryComputeConvOutSize(input_height, kernel_height, padding_height, stride_height, output_height) || + !TryComputeConvOutSize(input_width, kernel_width, padding_width, stride_width, output_width) || + mul_overflow_size_t_builtin(output_height, output_width, &output_size)) { + return false; + } + + return true; +} + +bool +TryComputeOutputSize( + size_t input_height, + size_t input_width, + size_t kernel_height, + size_t kernel_width, + size_t padding_height, + size_t padding_width, + size_t stride_height, + size_t stride_width, + size_t& output_size +) +{ + size_t output_height = 0; + size_t output_width = 0; + return TryComputeOutputSize( + input_height, + input_width, + kernel_height, + kernel_width, + padding_height, + padding_width, + stride_height, + stride_width, + output_height, + output_width, + output_size); +} + +size_t +SelectMaximumLhsChunkBytes( + size_t full_lhs_size, + size_t filter_count, + size_t effective_k +) +{ + // The bounded-LHS path is most useful when LHS packing is cache-hostile and + // there are few output channels, so full-LHS reuse across N tiles is limited. + if (full_lhs_size >= MinimumLhsPackedBytesForAutomaticChunking && + filter_count <= MaximumFilterCountForAutomaticChunking && + effective_k >= MinimumEffectiveKForAutomaticChunking) { + return AutomaticMaximumLhsChunkBytes; + } + + return 0; +} + +bool +IsPaddingSymmetric2D(const MLAS_CONV_PARAMETERS* parameters) +{ + return parameters->Padding[0] == parameters->Padding[1] && + parameters->Padding[0] == parameters->Padding[2] && + parameters->Padding[0] == parameters->Padding[3]; +} + +bool +CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* parameters) +{ + if (parameters == nullptr) { + return false; + } + + if (parameters->BackendKernelSelectorConfig != nullptr && + !parameters->BackendKernelSelectorConfig->use_kleidiai) { + KLEIDIAI_DEBUG_LOG("User explicitly disabled KleidiAI, returning false from MlasHalfConv."); + return false; + } + + if ((parameters->Dimensions != 2) || + (parameters->BatchCount != 1) || + (parameters->GroupCount != 1) || + (parameters->Beta != 0.0f) || + !(parameters->Activation == nullptr || + parameters->Activation->ActivationKind == MlasIdentityActivation) || + !IsPaddingSymmetric2D(parameters)) { + KLEIDIAI_DEBUG_LOG("MlasHalfConv capability check failed: unsupported configuration."); + return false; + } + + size_t d_kh = 0; + size_t d_kw = 0; + size_t output_size = 0; + if (!TryComputeKernelSize(parameters->DilationShape[0], parameters->KernelShape[0], d_kh) || + !TryComputeKernelSize(parameters->DilationShape[1], parameters->KernelShape[1], d_kw) || + !TryComputeOutputSize( + parameters->InputShape[0], + parameters->InputShape[1], + d_kh, + d_kw, + parameters->Padding[0], + parameters->Padding[1], + parameters->StrideShape[0], + parameters->StrideShape[1], + output_size)) { + return false; + } + + if (output_size == 0 || + output_size != parameters->OutputSize || + parameters->InputChannels == 0 || + parameters->FilterCount == 0 || + parameters->FilterCount == 1 || + parameters->KernelShape[0] < 3 || + parameters->KernelShape[1] < 3) { + KLEIDIAI_DEBUG_LOG("MlasHalfConv capability check failed: shape/heuristic gating."); + return false; + } + + return true; +} + +bool +GetPackedFilterSize( + size_t filter_count, + size_t input_channels, + const int64_t* kernel_shape, + const int64_t* dilation_shape, + size_t* packed_size +) +{ + if (packed_size == nullptr || + kernel_shape == nullptr || + dilation_shape == nullptr || + filter_count <= 1 || + input_channels == 0 || + kernel_shape[0] < 3 || + kernel_shape[1] < 3 || + dilation_shape[0] <= 0 || + dilation_shape[1] <= 0) { + return false; + } + + size_t d_kh = 0; + size_t d_kw = 0; + size_t k_chunk_count = 0; + if (!TryComputeKernelSize(static_cast(dilation_shape[0]), static_cast(kernel_shape[0]), d_kh) || + !TryComputeKernelSize(static_cast(dilation_shape[1]), static_cast(kernel_shape[1]), d_kw) || + mul_overflow_size_t_builtin(d_kh, d_kw, &k_chunk_count)) { + return false; + } + + *packed_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + filter_count, k_chunk_count, input_channels + ); + return *packed_size != 0; +} + +bool +NchwToNhwc( + const MLAS_FP16* input, + size_t channels, + size_t height, + size_t width, + std::vector& output +) +{ + size_t element_count = 0; + if (mul_overflow_size_t_builtin(channels, height, &element_count) || + mul_overflow_size_t_builtin(element_count, width, &element_count)) { + return false; + } + output.resize(element_count); + + if (input == nullptr) { + return false; + } + + for (size_t c = 0; c < channels; ++c) { + for (size_t h = 0; h < height; ++h) { + for (size_t w = 0; w < width; ++w) { + output[(h * width + w) * channels + c] = input[(c * height + h) * width + w]; + } + } + } + + return true; +} + +bool +FillIndirectionTable( + const MLAS_FP16* input_nhwc, + const MLAS_FP16* pad, + size_t input_channels, + size_t input_height, + size_t input_width, + size_t kernel_height, + size_t kernel_width, + size_t stride_height, + size_t stride_width, + size_t padding, + size_t output_size, + std::vector& indirection +) +{ + const auto& imatmul = GetKleidiAIF16IMatmulUKernel(); + const size_t m_step = imatmul.ukernel.get_m_step(); + size_t lhs_ptrs_k = 0; + if (mul_overflow_size_t_builtin(kernel_height, kernel_width, &lhs_ptrs_k)) { + return false; + } + + size_t lhs_ptrs_m = 0; + if (mul_overflow_size_t_builtin(m_step, MlasDivRoundup(output_size, m_step), &lhs_ptrs_m)) { + return false; + } + + size_t table_size = 0; + if (mul_overflow_size_t_builtin(lhs_ptrs_k, lhs_ptrs_m, &table_size)) { + return false; + } + indirection.resize(table_size); + + std::fill(indirection.begin(), indirection.end(), pad); + + auto ptr_offset = [lhs_ptrs_k, m_step](size_t k, size_t m) { + return ((m / m_step) * lhs_ptrs_k * m_step) + (k * m_step) + (m % m_step); + }; + + auto pixel_ptr = [=](size_t h, size_t w) -> const void* { + if (h < padding || w < padding) { + return pad; + } + + h -= padding; + w -= padding; + if (h >= input_height || w >= input_width) { + return pad; + } + + return input_nhwc + (h * input_width + w) * input_channels; + }; + + size_t output_height = 0; + size_t output_width = 0; + size_t computed_output_size = 0; + if (!TryComputeOutputSize( + input_height, + input_width, + kernel_height, + kernel_width, + padding, + padding, + stride_height, + stride_width, + output_height, + output_width, + computed_output_size) || + computed_output_size != output_size) { + return false; + } + + size_t m = 0; + for (size_t oh = 0; oh < output_height; ++oh) { + for (size_t ow = 0; ow < output_width; ++ow, ++m) { + size_t k = 0; + const size_t input_base_h = oh * stride_height; + const size_t input_base_w = ow * stride_width; + for (size_t kh = 0; kh < kernel_height; ++kh) { + for (size_t kw = 0; kw < kernel_width; ++kw, ++k) { + indirection[ptr_offset(k, m)] = pixel_ptr(input_base_h + kh, input_base_w + kw); + } + } + } + } + + return m == output_size; +} + +bool +PrepareLhsInput( + const MLAS_CONV_PARAMETERS* parameters, + const MLAS_FP16* input, + size_t output_size, + std::vector& input_nhwc, + std::vector& pad, + std::vector& indirection +) +{ + size_t d_kh = 0; + size_t d_kw = 0; + if (!TryComputeKernelSize(parameters->DilationShape[0], parameters->KernelShape[0], d_kh) || + !TryComputeKernelSize(parameters->DilationShape[1], parameters->KernelShape[1], d_kw)) { + return false; + } + + const size_t input_channels = parameters->InputChannels; + + const MLAS_FP16* input_nhwc_data = input; + if (!parameters->InputOutputChannelsLast) { + if (!NchwToNhwc( + input, + input_channels, + parameters->InputShape[0], + parameters->InputShape[1], + input_nhwc + )) { + return false; + } + input_nhwc_data = input_nhwc.data(); + } + + pad.resize(input_channels); + std::fill(pad.begin(), pad.end(), MLAS_FP16::FromBits(0)); + + return FillIndirectionTable( + input_nhwc_data, + pad.data(), + input_channels, + parameters->InputShape[0], + parameters->InputShape[1], + d_kh, + d_kw, + parameters->StrideShape[0], + parameters->StrideShape[1], + parameters->Padding[0], + output_size, + indirection + ); +} + +bool +PackFilter( + size_t filter_count, + size_t input_channels, + const int64_t* kernel_shape, + const int64_t* dilation_shape, + const MLAS_FP16* filter, + const MLAS_FP16* bias, + void* packed_filter +) +{ + if (filter == nullptr || packed_filter == nullptr || + kernel_shape == nullptr || dilation_shape == nullptr || + filter_count <= 1 || input_channels == 0 || + kernel_shape[0] < 3 || kernel_shape[1] < 3 || + dilation_shape[0] <= 0 || dilation_shape[1] <= 0) { + return false; + } + + const size_t kernel_height = static_cast(kernel_shape[0]); + const size_t kernel_width = static_cast(kernel_shape[1]); + const size_t dilation_height = static_cast(dilation_shape[0]); + const size_t dilation_width = static_cast(dilation_shape[1]); + size_t d_kh = 0; + size_t d_kw = 0; + size_t k_chunk_count = 0; + if (!TryComputeKernelSize(dilation_height, kernel_height, d_kh) || + !TryComputeKernelSize(dilation_width, kernel_width, d_kw) || + mul_overflow_size_t_builtin(d_kh, d_kw, &k_chunk_count)) { + return false; + } + + size_t reordered_size = 0; + if (mul_overflow_size_t_builtin(k_chunk_count, input_channels, &reordered_size) || + mul_overflow_size_t_builtin(reordered_size, filter_count, &reordered_size)) { + return false; + } + + std::vector reordered_filter; + reordered_filter.resize(reordered_size); + std::fill(reordered_filter.begin(), reordered_filter.end(), MLAS_FP16::FromBits(0)); + + for (size_t oc = 0; oc < filter_count; ++oc) { + for (size_t ic = 0; ic < input_channels; ++ic) { + for (size_t kh = 0; kh < kernel_height; ++kh) { + for (size_t kw = 0; kw < kernel_width; ++kw) { + const size_t src = ((oc * input_channels + ic) * kernel_height + kh) * kernel_width + kw; + const size_t dk = ((kh * dilation_height) * d_kw + (kw * dilation_width)) * input_channels + ic; + reordered_filter[dk * filter_count + oc] = filter[src]; + } + } + } + } + + std::vector zero_bias; + const MLAS_FP16* bias_data = bias; + if (bias_data == nullptr) { + zero_bias.resize(filter_count); + std::fill(zero_bias.begin(), zero_bias.end(), MLAS_FP16::FromBits(0)); + bias_data = zero_bias.data(); + } + + KLEIDIAI_KERNEL_LOG("kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme" << " N=" << filter_count << " k_chunk_count=" << k_chunk_count << " k_chunk_length=" << input_channels << " rhs_stride_row=" << (filter_count * sizeof(MLAS_FP16))); + kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + filter_count, + k_chunk_count, + input_channels, + filter_count * sizeof(MLAS_FP16), + reordered_filter.data(), + bias_data, + packed_filter + ); + + return true; +} + +bool +ConvolveSme( + const MLAS_CONV_PARAMETERS* parameters, + const MLAS_FP16* input, + const MLAS_FP16* filter, + bool filter_and_bias_are_packed, + const MLAS_FP16* bias, + MLAS_FP16* working_buffer, + MLAS_FP16* output, + MLAS_THREADPOOL* thread_pool +) +{ + if (input == nullptr || filter == nullptr || output == nullptr || + (!parameters->InputOutputChannelsLast && working_buffer == nullptr)) { + return false; + } + + size_t d_kh = 0; + size_t d_kw = 0; + size_t output_size = 0; + if (!TryComputeKernelSize(parameters->DilationShape[0], parameters->KernelShape[0], d_kh) || + !TryComputeKernelSize(parameters->DilationShape[1], parameters->KernelShape[1], d_kw) || + !TryComputeOutputSize( + parameters->InputShape[0], + parameters->InputShape[1], + d_kh, + d_kw, + parameters->Padding[0], + parameters->Padding[1], + parameters->StrideShape[0], + parameters->StrideShape[1], + output_size)) { + return false; + } + + std::vector input_nhwc; + std::vector pad; + std::vector indirection; + if (!PrepareLhsInput(parameters, input, output_size, input_nhwc, pad, indirection)) { + return false; + } + + std::vector packed_filter_buffer; + const std::byte* packed_filter = reinterpret_cast(filter); + if (!filter_and_bias_are_packed) { + const std::array kernel_shape{ + static_cast(parameters->KernelShape[0]), + static_cast(parameters->KernelShape[1]) + }; + const std::array dilation_shape{ + static_cast(parameters->DilationShape[0]), + static_cast(parameters->DilationShape[1]) + }; + size_t packed_filter_size = 0; + if (!GetPackedFilterSize( + parameters->FilterCount, + parameters->InputChannels, + kernel_shape.data(), + dilation_shape.data(), + &packed_filter_size + )) { + return false; + } + packed_filter_buffer.resize(packed_filter_size); + if (!PackFilter( + parameters->FilterCount, + parameters->InputChannels, + kernel_shape.data(), + dilation_shape.data(), + filter, + bias, + packed_filter_buffer.data() + )) { + return false; + } + packed_filter = packed_filter_buffer.data(); + } + + const auto& imatmul = GetKleidiAIF16IMatmulUKernel(); + const size_t base_n_step = imatmul.ukernel.get_n_step(); + const size_t base_m_step = imatmul.ukernel.get_m_step(); + size_t n_step = base_n_step; + size_t m_step = base_m_step; + const size_t filter_count = parameters->FilterCount; + const size_t input_channels = parameters->InputChannels; + + std::array dim{ + MlasDivRoundup(output_size, m_step), + MlasDivRoundup(filter_count, n_step) + }; + + size_t tile_count = 0; + if (mul_overflow_size_t_builtin(dim[0], dim[1], &tile_count)) { + return false; + } + + const size_t required_tiles = std::min( + static_cast(MlasGetMaximumThreadCount(thread_pool)), + tile_count + ); + + if (required_tiles == 0) { + return false; + } + + const size_t original_dim0 = dim[0]; + const size_t original_dim1 = dim[1]; + size_t scaled_dim0 = 0; + size_t scaled_dim1 = 0; + if (mul_overflow_size_t_builtin(required_tiles, original_dim0, &scaled_dim0) || + mul_overflow_size_t_builtin(required_tiles, original_dim1, &scaled_dim1)) { + return false; + } + + dim[0] = MlasDivRoundup(scaled_dim0, tile_count); + dim[1] = MlasDivRoundup(scaled_dim1, tile_count); + + size_t new_m_step = 0; + size_t new_n_step = 0; + if (mul_overflow_size_t_builtin(m_step, MlasDivRoundup(MlasDivRoundup(output_size, dim[0]), m_step), &new_m_step) || + mul_overflow_size_t_builtin(n_step, MlasDivRoundup(MlasDivRoundup(filter_count, dim[1]), n_step), &new_n_step)) { + return false; + } + m_step = new_m_step; + n_step = new_n_step; + + dim[0] = MlasDivRoundup(output_size, m_step); + dim[1] = MlasDivRoundup(filter_count, n_step); + size_t finalized_tile_count = 0; + if (mul_overflow_size_t_builtin(dim[0], dim[1], &finalized_tile_count)) { + return false; + } + + const float clamp_min = -std::numeric_limits::infinity(); + const float clamp_max = std::numeric_limits::infinity(); + size_t dst_stride = 0; + if (mul_overflow_size_t_builtin(filter_count, sizeof(MLAS_FP16), &dst_stride)) { + return false; + } + + MLAS_FP16* destination = parameters->InputOutputChannelsLast ? output : working_buffer; + + size_t kernel_chunk_count = 0; + size_t effective_k = 0; + if (mul_overflow_size_t_builtin(d_kh, d_kw, &kernel_chunk_count) || + mul_overflow_size_t_builtin(kernel_chunk_count, input_channels, &effective_k)) { + return false; + } + + const size_t full_lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + output_size, kernel_chunk_count, input_channels + ); + if (full_lhs_size == 0) { + return false; + } + + const size_t maximum_lhs_chunk_bytes = SelectMaximumLhsChunkBytes( + full_lhs_size, + filter_count, + effective_k + ); + + if (maximum_lhs_chunk_bytes == 0 || full_lhs_size <= maximum_lhs_chunk_bytes) { + std::vector packed_lhs; + packed_lhs.resize(full_lhs_size); + + KLEIDIAI_KERNEL_LOG("kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme" + << " M=" << output_size + << " k_chunk_count=" << kernel_chunk_count + << " k_chunk_length=" << input_channels); + kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + output_size, + kernel_chunk_count, + input_channels, + indirection.data(), + 0, + pad.data(), + packed_lhs.data() + ); + + std::atomic ok{true}; + MlasTrySimpleParallel(thread_pool, static_cast(finalized_tile_count), [&](ptrdiff_t tid) { + if (!ok.load(std::memory_order_relaxed)) { + return; + } + + const size_t m_idx = (static_cast(tid) / dim[1]) * m_step; + const size_t n_idx = (static_cast(tid) % dim[1]) * n_step; + const size_t tile_m = std::min(m_step, output_size - m_idx); + const size_t tile_n = std::min(n_step, filter_count - n_idx); + + const std::byte* lhs_tile = + packed_lhs.data() + imatmul.ukernel.get_lhs_packed_offset(m_idx, kernel_chunk_count, input_channels); + const std::byte* rhs_tile = + packed_filter + imatmul.ukernel.get_rhs_packed_offset(n_idx, kernel_chunk_count, input_channels); + + size_t dst_elements = 0; + size_t dst_bytes = 0; + if (mul_overflow_size_t_builtin(m_idx, filter_count, &dst_elements) || + !TryAddSize(dst_elements, n_idx, dst_elements) || + mul_overflow_size_t_builtin(dst_elements, sizeof(MLAS_FP16), &dst_bytes)) { + ok.store(false, std::memory_order_relaxed); + return; + } + + std::byte* dst_tile = reinterpret_cast(destination) + dst_bytes; + + KLEIDIAI_KERNEL_LOG(imatmul.name << " M=" << tile_m + << " N=" << tile_n + << " k_chunk_count=" << kernel_chunk_count + << " k_chunk_length=" << input_channels); + imatmul.ukernel.run_imatmul( + tile_m, + tile_n, + kernel_chunk_count, + input_channels, + lhs_tile, + rhs_tile, + dst_tile, + dst_stride, + clamp_min, + clamp_max + ); + }); + + if (!ok.load(std::memory_order_relaxed)) { + return false; + } + } else { + const size_t bytes_per_m_step = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + base_m_step, kernel_chunk_count, input_channels + ); + if (bytes_per_m_step == 0) { + return false; + } + + const size_t max_m_steps_per_chunk = std::min( + MlasDivRoundup(output_size, base_m_step), + std::max(1, maximum_lhs_chunk_bytes / bytes_per_m_step) + ); + size_t lhs_chunk_m = 0; + if (mul_overflow_size_t_builtin(base_m_step, max_m_steps_per_chunk, &lhs_chunk_m)) { + return false; + } + + const size_t lhs_chunk_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + lhs_chunk_m, kernel_chunk_count, input_channels + ); + if (lhs_chunk_size == 0 || lhs_chunk_size > std::vector().max_size()) { + return false; + } + + const size_t m_chunk_count = MlasDivRoundup(output_size, lhs_chunk_m); + std::atomic ok{true}; + MlasTrySimpleParallel(thread_pool, static_cast(m_chunk_count), [&](ptrdiff_t tid) { + if (!ok.load(std::memory_order_relaxed)) { + return; + } + + const size_t global_m_idx = static_cast(tid) * lhs_chunk_m; + const size_t chunk_m = std::min(lhs_chunk_m, output_size - global_m_idx); + size_t indirection_offset = 0; + size_t indirection_tiles = 0; + if (mul_overflow_size_t_builtin(global_m_idx / base_m_step, kernel_chunk_count, &indirection_tiles) || + mul_overflow_size_t_builtin(indirection_tiles, base_m_step, &indirection_offset)) { + ok.store(false, std::memory_order_relaxed); + return; + } + + std::vector packed_lhs; + packed_lhs.resize(lhs_chunk_size); + + KLEIDIAI_KERNEL_LOG("kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme" + << " M=" << chunk_m + << " k_chunk_count=" << kernel_chunk_count + << " k_chunk_length=" << input_channels); + kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + chunk_m, + kernel_chunk_count, + input_channels, + indirection.data() + indirection_offset, + 0, + pad.data(), + packed_lhs.data() + ); + + for (size_t n_tile_idx = 0; n_tile_idx < dim[1]; ++n_tile_idx) { + const size_t n_idx = n_tile_idx * n_step; + const size_t tile_n = std::min(n_step, filter_count - n_idx); + const std::byte* rhs_tile = + packed_filter + imatmul.ukernel.get_rhs_packed_offset(n_idx, kernel_chunk_count, input_channels); + + size_t dst_elements = 0; + size_t dst_bytes = 0; + if (mul_overflow_size_t_builtin(global_m_idx, filter_count, &dst_elements) || + !TryAddSize(dst_elements, n_idx, dst_elements) || + mul_overflow_size_t_builtin(dst_elements, sizeof(MLAS_FP16), &dst_bytes)) { + ok.store(false, std::memory_order_relaxed); + return; + } + std::byte* dst_tile = reinterpret_cast(destination) + dst_bytes; + + KLEIDIAI_KERNEL_LOG(imatmul.name << " M=" << chunk_m + << " N=" << tile_n + << " k_chunk_count=" << kernel_chunk_count + << " k_chunk_length=" << input_channels); + imatmul.ukernel.run_imatmul( + chunk_m, + tile_n, + kernel_chunk_count, + input_channels, + packed_lhs.data(), + rhs_tile, + dst_tile, + dst_stride, + clamp_min, + clamp_max + ); + } + }); + + if (!ok.load(std::memory_order_relaxed)) { + return false; + } + } + + if (parameters->InputOutputChannelsLast) { + return true; + } + + MlasTranspose(working_buffer, output, output_size, filter_count, thread_pool); + return true; +} + +} // namespace + +bool + MLASCALL + ArmKleidiAI::MlasHalfConvPrepare( + MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + bool InputOutputChannelsLast, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig + ) +{ + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + if (Parameters == nullptr || + InputShape == nullptr || + KernelShape == nullptr || + DilationShape == nullptr || + Padding == nullptr || + StrideShape == nullptr || + OutputShape == nullptr || + WorkingBufferSize == nullptr || + Dimensions != 2) { + return false; + } + + if (BackendKernelSelectorConfig != nullptr && + !BackendKernelSelectorConfig->use_kleidiai) { + KLEIDIAI_DEBUG_LOG("User explicitly disabled KleidiAI, returning false from MlasHalfConvPrepare."); + return false; + } + + Parameters->BackendKernelSelectorConfig = BackendKernelSelectorConfig; + Parameters->Activation = Activation; + Parameters->Dimensions = Dimensions; + Parameters->BatchCount = BatchCount; + Parameters->GroupCount = GroupCount; + Parameters->InputChannels = InputChannels; + Parameters->FilterCount = FilterCount; + Parameters->Beta = Beta; + Parameters->InputOutputChannelsLast = InputOutputChannelsLast; + + size_t input_size = 1; + size_t output_size = 1; + size_t k = InputChannels; + size_t dilated_kernel_size = InputChannels; + for (size_t dim = 0; dim < Dimensions; ++dim) { + if (InputShape[dim] <= 0 || + OutputShape[dim] <= 0 || + KernelShape[dim] <= 0 || + DilationShape[dim] <= 0 || + StrideShape[dim] <= 0 || + Padding[dim] < 0 || + Padding[dim + Dimensions] < 0) { + return false; + } + + Parameters->InputShape[dim] = static_cast(InputShape[dim]); + Parameters->OutputShape[dim] = static_cast(OutputShape[dim]); + Parameters->KernelShape[dim] = static_cast(KernelShape[dim]); + Parameters->DilationShape[dim] = static_cast(DilationShape[dim]); + Parameters->Padding[dim] = static_cast(Padding[dim]); + Parameters->Padding[dim + Dimensions] = static_cast(Padding[dim + Dimensions]); + Parameters->StrideShape[dim] = static_cast(StrideShape[dim]); + + size_t dilated_kernel_dim = 0; + if (!TryComputeKernelSize(Parameters->DilationShape[dim], Parameters->KernelShape[dim], dilated_kernel_dim)) { + return false; + } + + if (mul_overflow_size_t_builtin(input_size, Parameters->InputShape[dim], &input_size) || + mul_overflow_size_t_builtin(output_size, Parameters->OutputShape[dim], &output_size) || + mul_overflow_size_t_builtin(k, Parameters->KernelShape[dim], &k) || + mul_overflow_size_t_builtin(dilated_kernel_size, dilated_kernel_dim, &dilated_kernel_size)) { + return false; + } + } + + Parameters->InputSize = input_size; + Parameters->OutputSize = output_size; + Parameters->K = k; + Parameters->ThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (!CheckCapabilitiesSme(Parameters)) { + return false; + } + + size_t working_elements = 0; + if (mul_overflow_size_t_builtin(Parameters->OutputSize, Parameters->FilterCount, &working_elements)) { + return false; + } + *WorkingBufferSize = Parameters->InputOutputChannelsLast ? 0 : working_elements; + return true; +} + +bool + MLASCALL + ArmKleidiAI::MlasHalfConv( + const MLAS_CONV_PARAMETERS* Parameters, + const MLAS_FP16* Input, + const MLAS_FP16* Filter, + bool FilterAndBiasArePacked, + const MLAS_FP16* Bias, + MLAS_FP16* WorkingBuffer, + MLAS_FP16* Output, + MLAS_THREADPOOL* ThreadPool + ) +{ + if (!CheckCapabilitiesSme(Parameters)) { + return false; + } + + return ConvolveSme(Parameters, Input, Filter, FilterAndBiasArePacked, Bias, WorkingBuffer, Output, ThreadPool); +} + +size_t + MLASCALL + ArmKleidiAI::MlasHalfConvPackWeightsAndBiasSize( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape + ) +{ + size_t packed_size = 0; + if (!GetPackedFilterSize(FilterCount, InputChannels, KernelShape, DilationShape, &packed_size)) { + return 0; + } + return packed_size; +} + +bool + MLASCALL + ArmKleidiAI::MlasHalfConvPackWeightsAndBias( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape, + const MLAS_FP16* Filter, + const MLAS_FP16* Bias, + void* PackedWeightsAndBias, + MLAS_THREADPOOL* ThreadPool + ) +{ + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + return PackFilter( + FilterCount, + InputChannels, + KernelShape, + DilationShape, + Filter, + Bias, + PackedWeightsAndBias + ); +} diff --git a/onnxruntime/core/mlas/lib/kleidiai/halfgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/halfgemm_kleidiai.cpp new file mode 100644 index 0000000000000..e93f1dbc485b1 --- /dev/null +++ b/onnxruntime/core/mlas/lib/kleidiai/halfgemm_kleidiai.cpp @@ -0,0 +1,278 @@ +// +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include +#include +#include +#include "mlas.h" + +#include "mlasi_kleidiai.h" + +#include "kai_ukernel_interface.h" + +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.h" + +namespace { +struct KaiHalfTlsBuffers { + std::vector lhs_converted; + std::vector rhs_converted; + std::vector bias_zero; + std::vector rhs_packed; +}; + +thread_local KaiHalfTlsBuffers g_kai_half_tls; + +template +bool TryResizeVector(std::vector& buffer, size_t size) { + if (size > buffer.max_size()) { + return false; + } + buffer.resize(size); + return true; +} + +static inline void ConvertFloatMatrixToHalf( + const float* src, + MLAS_FP16* dst, + size_t rows, + size_t cols, + size_t src_ld) { + for (size_t r = 0; r < rows; ++r) { + MlasConvertFloatToHalfBuffer(src + r * src_ld, dst + r * cols, cols); + } +} +} // namespace + +size_t +MLASCALL +ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K +) { + if (TransA != CblasNoTrans || TransB != CblasNoTrans || N == 0 || K == 0) { + return 0; + } + + return kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(N, K); +} + +bool +MLASCALL +ArmKleidiAI::MlasHalfGemmKleidiAIPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const MLAS_FP16* B, + size_t ldb, + void* PackedB +) { + if (TransA != CblasNoTrans || TransB != CblasNoTrans) { + return false; + } + + if (PackedB == nullptr || B == nullptr || N == 0 || K == 0) { + return false; + } + + const size_t packed_rhs_size = ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize(TransA, TransB, N, K); + if (packed_rhs_size == 0) { + return false; + } + + std::vector zero_bias(N, MLAS_FP16::FromBits(0)); + + size_t ldb_bytes = 0; + if (mul_overflow_size_t_builtin(ldb, sizeof(MLAS_FP16), &ldb_bytes)) { + return false; + } + + const auto& hgemm = GetKleidiAIHgemmUKernel(); + kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme( + 1, N, K, hgemm.ukernel.get_nr(), hgemm.ukernel.get_kr(), hgemm.ukernel.get_sr(), ldb_bytes, + B, + zero_bias.data(), + nullptr, + PackedB, + 0, + nullptr); + + return true; +} + +bool +MLASCALL +ArmKleidiAI::MlasHalfGemmBatch( + size_t M, + size_t N, + size_t K, + size_t BatchN, + const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig +) { + if (BatchN == 0 || M == 0 || N == 0) { + return true; + } + if (K == 0) { + return false; + } + if (DataParams == nullptr) { + return false; + } + + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig); + // Validate all batch entries up front so we never partially execute and then + // fall back (which would corrupt results for the already-written outputs). + for (size_t b = 0; b < BatchN; ++b) { + const auto& data = DataParams[b]; + if (data.OutputProcessor != nullptr) { + return false; + } + if (data.BIsBackendNativePacked && data.Bias != nullptr) { + return false; + } + if (data.BIsBackendNativePacked && data.ldb != 0) { + return false; + } + } + + const auto& hgemm = GetKleidiAIHgemmUKernel(); + const size_t n_step = hgemm.ukernel.get_n_step(); + const size_t nr = hgemm.ukernel.get_nr(); + const size_t kr = hgemm.ukernel.get_kr(); + const size_t sr = hgemm.ukernel.get_sr(); + KLEIDIAI_KERNEL_LOG(hgemm.name); + + const size_t packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(N, K); + if (packed_rhs_size == 0) { + return false; + } + + const float clamp_min = -std::numeric_limits::infinity(); + const float clamp_max = std::numeric_limits::infinity(); + + if (!TryResizeVector(g_kai_half_tls.rhs_packed, packed_rhs_size)) { + return false; + } + + for (size_t b = 0; b < BatchN; ++b) { + const auto& data = DataParams[b]; + + const MLAS_FP16* lhs_base = reinterpret_cast(data.A); + const MLAS_FP16* rhs_base = reinterpret_cast(data.B); + const std::byte* rhs_packed = nullptr; + size_t lhs_ld = data.lda; + size_t rhs_ld = data.ldb; + + if (data.AIsfp32) { + size_t lhs_elements = 0; + if (mul_overflow_size_t_builtin(M, K, &lhs_elements) || + !TryResizeVector(g_kai_half_tls.lhs_converted, lhs_elements)) { + return false; + } + ConvertFloatMatrixToHalf( + reinterpret_cast(data.A), + g_kai_half_tls.lhs_converted.data(), + M, K, data.lda); + lhs_base = g_kai_half_tls.lhs_converted.data(); + lhs_ld = K; + } + + if (data.BIsBackendNativePacked) { + rhs_packed = reinterpret_cast(data.B); + } else if (data.ldb == 0) { + // Prepacked B from MlasHalfGemmPackB/MlasHalfGemmConvertPackB. + // For the current default halfgemm dispatch this is a row-major + // fp16 KxN buffer with leading dimension N. It is not the native + // KleidiAI RHS-packed layout, so this path falls back to packing + // it into KleidiAI format before execution. + rhs_ld = N; + } else if (data.BIsfp32) { + size_t rhs_elements = 0; + if (mul_overflow_size_t_builtin(K, N, &rhs_elements) || + !TryResizeVector(g_kai_half_tls.rhs_converted, rhs_elements)) { + return false; + } + ConvertFloatMatrixToHalf( + reinterpret_cast(data.B), + g_kai_half_tls.rhs_converted.data(), + K, N, data.ldb); + rhs_base = g_kai_half_tls.rhs_converted.data(); + rhs_ld = N; + } + + if (rhs_packed == nullptr) { + auto* rhs_packed_buffer = g_kai_half_tls.rhs_packed.data(); + + size_t ldb_bytes = 0; + if (mul_overflow_size_t_builtin(rhs_ld, sizeof(MLAS_FP16), &ldb_bytes)) { + return false; + } + if (data.Bias == nullptr) { + if (!TryResizeVector(g_kai_half_tls.bias_zero, N)) { + return false; + } + std::fill(g_kai_half_tls.bias_zero.begin(), g_kai_half_tls.bias_zero.end(), MLAS_FP16::FromBits(0)); + } + + kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme( + 1, N, K, nr, kr, sr, ldb_bytes, + rhs_base, + data.Bias != nullptr ? data.Bias : g_kai_half_tls.bias_zero.data(), + nullptr, + rhs_packed_buffer, + 0, + nullptr); + rhs_packed = rhs_packed_buffer; + } + + size_t lda_bytes = 0; + if (mul_overflow_size_t_builtin(lhs_ld, sizeof(MLAS_FP16), &lda_bytes)) { + return false; + } + size_t dst_stride_bytes = 0; + if (mul_overflow_size_t_builtin(data.ldc, sizeof(MLAS_FP16), &dst_stride_bytes)) { + return false; + } + + MlasTrySimpleParallel(ThreadPool, static_cast(M), [&](ptrdiff_t m_idx) { + const size_t m = static_cast(m_idx); + const auto* lhs = lhs_base + m * lhs_ld; + auto* dst_row = data.C + m * data.ldc; + const auto* rhs_packed_base = rhs_packed; + // The selected KleidiAI HGEMM micro-kernel is 1xN by design. + // We execute one output row per call and parallelize over rows. + constexpr size_t kernel_m = 1; + for (size_t n_idx = 0; n_idx < N; n_idx += n_step) { + const size_t tile_n = std::min(n_step, N - n_idx); + const auto* rhs_tile = rhs_packed_base + hgemm.ukernel.get_rhs_packed_offset(n_idx, K); + auto* dst_tile = reinterpret_cast( + reinterpret_cast(dst_row) + + hgemm.ukernel.get_dst_offset(0, n_idx, dst_stride_bytes)); + + hgemm.ukernel.run_matmul( + kernel_m, + tile_n, + K, + lhs, + lda_bytes, + rhs_tile, + dst_tile, + dst_stride_bytes, + sizeof(MLAS_FP16), + clamp_min, + clamp_max); + } + }); + + } + + return true; +} diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h index 3c9f398ece887..207935b8a8bd8 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h +++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h @@ -7,6 +7,7 @@ #pragma once #include "../mlasi.h" +#include #include // Fix to ensure compatibility with MSVC build @@ -224,6 +225,102 @@ MlasConvSymmetricChannelsLast2DFloatPackW( size_t PackedFilterGroupStride, MLAS_THREADPOOL* ThreadPool ); + +bool +MLASCALL +MlasHalfGemmBatch( + size_t M, + size_t N, + size_t K, + size_t BatchN, + const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig + ); + +size_t +MLASCALL +MlasHalfGemmKleidiAIPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K + ); + +// Packs B into the native KleidiAI RHS-packed layout for the supported +// halfgemm configuration only. This differs from the generic MLAS halfgemm +// prepacked-B format produced by MlasHalfGemmPackB and +// MlasHalfGemmConvertPackB, so generic MLAS prepacked weights may need to be +// repacked into this layout before running the KleidiAI halfgemm path. +// Unsupported transpose combinations return false/0 so the caller can fall +// back to the generic MLAS path. +bool +MLASCALL +MlasHalfGemmKleidiAIPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const MLAS_FP16* B, + size_t ldb, + void* PackedB + ); + +bool +MLASCALL +MlasHalfConvPrepare(MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + bool InputOutputChannelsLast, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig); + +bool +MLASCALL +MlasHalfConv( + const MLAS_CONV_PARAMETERS* Parameters, + const MLAS_FP16* Input, + const MLAS_FP16* Filter, + bool FilterAndBiasArePacked, + const MLAS_FP16* Bias, + MLAS_FP16* WorkingBuffer, + MLAS_FP16* Output, + MLAS_THREADPOOL* ThreadPool + ); + +size_t +MLASCALL +MlasHalfConvPackWeightsAndBiasSize( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape + ); + +bool +MLASCALL +MlasHalfConvPackWeightsAndBias( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape, + const MLAS_FP16* Filter, + const MLAS_FP16* Bias, + void* PackedWeightsAndBias, + MLAS_THREADPOOL* ThreadPool + ); } /*++ @@ -259,3 +356,42 @@ inline bool mul_overflow_size_t_builtin(size_t a, size_t b, size_t* out) { if (out) *out = a * b; return false; } + +/*++ + +Routine Description: + + This routine adds two size_t values and returns the sum when no wraparound + occurs. Uses __builtin_add_overflow if available on the current system and + falls back to a default implementation otherwise. + +Arguments: + + a - Supplies the first number to be added. + + b - Supplies the second number to be added. + + out - Supplies a size_t reference which receives the result in success + cases. + +Return Value: + + Returns true if the operation was successful + Returns false if wraparound of size_t was detected + +--*/ +inline bool TryAddSize(size_t a, size_t b, size_t& out) +{ +#if defined(__has_builtin) +# if __has_builtin(__builtin_add_overflow) + return !__builtin_add_overflow(a, b, &out); +# endif +#endif + + if (a > (std::numeric_limits::max)() - b) { + return false; + } + + out = a + b; + return true; +} diff --git a/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp index 9e70b0742217a..2e5306d442476 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp @@ -101,7 +101,7 @@ MLASCALL ArmKleidiAI::MlasDynamicQGemmBatch( const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, - const size_t BatchSize, + const size_t BatchN, MLAS_THREADPOOL* ThreadPool ) { @@ -112,7 +112,7 @@ ArmKleidiAI::MlasDynamicQGemmBatch( size_t m_step = qgemm_gemm.ukernel.get_m_step(); size_t n_step = qgemm_gemm.ukernel.get_n_step(); - if (BatchSize == 0 || Shape.M == 0 || Shape.N == 0 || Shape.K == 0) { + if (BatchN == 0 || Shape.M == 0 || Shape.N == 0 || Shape.K == 0) { return; } @@ -123,7 +123,7 @@ ArmKleidiAI::MlasDynamicQGemmBatch( MLAS_THROW_EX(std::runtime_error, "Dynamic QGEMM requires valid DataParams."); } - for (size_t batch_idx = 0; batch_idx < BatchSize; ++batch_idx) { + for (size_t batch_idx = 0; batch_idx < BatchN; ++batch_idx) { const auto& params = DataParams[batch_idx]; if (params.A == nullptr) { @@ -151,24 +151,24 @@ ArmKleidiAI::MlasDynamicQGemmBatch( const size_t LhsPackedStride = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr); std::byte* LhsPackedData = nullptr; - if (g_kai_tls_qgemm.lhs_packed.capacity() < LhsPackedStride * BatchSize) { + if (g_kai_tls_qgemm.lhs_packed.capacity() < LhsPackedStride * BatchN) { - g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchSize); + g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchN); } - g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchSize); + g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchN); LhsPackedData = g_kai_tls_qgemm.lhs_packed.data(); // Per-batch table of LHS base pointers. - if (g_kai_tls_qgemm.lhs_base_table.capacity() < BatchSize) { + if (g_kai_tls_qgemm.lhs_base_table.capacity() < BatchN) { - g_kai_tls_qgemm.lhs_base_table.reserve(BatchSize); + g_kai_tls_qgemm.lhs_base_table.reserve(BatchN); } - g_kai_tls_qgemm.lhs_base_table.resize(BatchSize); + g_kai_tls_qgemm.lhs_base_table.resize(BatchN); // Capture the shared batch table pointer so worker threads use the same backing storage. const std::byte** tls_lhs_base = g_kai_tls_qgemm.lhs_base_table.data(); // B batches require no packing. // We have already decided the matmul variant we are using before having values for M, N, and K. - MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t batch_idx) { std::byte* lhs = nullptr; if (DataParams[batch_idx].Workspace && DataParams[batch_idx].WorkspaceSize >= LhsPackedStride) { @@ -184,7 +184,7 @@ ArmKleidiAI::MlasDynamicQGemmBatch( // Tile iteration dimensions. std::array dim; - dim[0] = BatchSize; // B + dim[0] = BatchN; // B dim[1] = MlasDivRoundup(Shape.M, m_step); // M dim[2] = MlasDivRoundup(Shape.N, n_step); // N diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 1193c1c5bbe27..ecc82d5f94755 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -967,6 +967,90 @@ bool void* PackedB); #endif +typedef +bool +(MLASCALL MLAS_HALF_GEMM_BATCH_OVERRIDE)( + size_t M, + size_t N, + size_t K, + size_t BatchN, + const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig); + +typedef +size_t +(MLASCALL MLAS_HALF_GEMM_PACK_B_SIZE_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K); + +typedef +bool +(MLASCALL MLAS_HALF_GEMM_PACK_B_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const MLAS_FP16* B, + size_t ldb, + void* PackedB); + +typedef +bool +(MLASCALL MLAS_HALF_CONV_PREPARE_OVERRIDE)( + MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + bool InputOutputChannelsLast, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig); + +typedef +bool +(MLASCALL MLAS_HALF_CONV_OVERRIDE)( + const MLAS_CONV_PARAMETERS* Parameters, + const MLAS_FP16* Input, + const MLAS_FP16* Filter, + bool FilterAndBiasArePacked, + const MLAS_FP16* Bias, + MLAS_FP16* WorkingBuffer, + MLAS_FP16* Output, + MLAS_THREADPOOL* ThreadPool); + +typedef +size_t +(MLASCALL MLAS_HALF_CONV_PACK_WEIGHTS_AND_BIAS_SIZE_OVERRIDE)( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape); + +typedef +bool +(MLASCALL MLAS_HALF_CONV_PACK_WEIGHTS_AND_BIAS_OVERRIDE)( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape, + const MLAS_FP16* Filter, + const MLAS_FP16* Bias, + void* PackedWeightsAndBias, + MLAS_THREADPOOL* ThreadPool); + extern "C" { #if defined(MLAS_TARGET_AMD64_IX86) @@ -1477,6 +1561,15 @@ struct MLAS_PLATFORM { MLAS_DYNAMIC_QGEMM_BATCH_OVERRIDE* MlasDynamicQGemmBatchOverride = nullptr; MLAS_DYNAMIC_QGEMM_PACK_B_SIZE_OVERRIDE* MlasDynamicQGemmPackBSizeOverride = nullptr; MLAS_DYNAMIC_QGEMM_PACK_B_OVERRIDE* MlasDynamicQGemmPackBOverride = nullptr; + // MLAS HalfGemm overrides + MLAS_HALF_GEMM_BATCH_OVERRIDE* MlasHalfGemmBatchOverride = nullptr; + MLAS_HALF_GEMM_PACK_B_SIZE_OVERRIDE* MlasHalfGemmPackBSizeOverride = nullptr; + MLAS_HALF_GEMM_PACK_B_OVERRIDE* MlasHalfGemmPackBOverride = nullptr; + // MLAS HalfConv overrides + MLAS_HALF_CONV_PREPARE_OVERRIDE* MlasHalfConvPrepareOverride = nullptr; + MLAS_HALF_CONV_OVERRIDE* MlasHalfConvOverride = nullptr; + MLAS_HALF_CONV_PACK_WEIGHTS_AND_BIAS_SIZE_OVERRIDE* MlasHalfConvPackWeightsAndBiasSizeOverride = nullptr; + MLAS_HALF_CONV_PACK_WEIGHTS_AND_BIAS_OVERRIDE* MlasHalfConvPackWeightsAndBiasOverride = nullptr; // MLAS Conv overrides MLAS_CONV_PREPARE_FLOAT_OVERRIDE* MlasConvPrepareOverride = nullptr; MLAS_CONV_FLOAT_OVERRIDE* MlasConvOverride = nullptr; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index de521de0c3785..d9e0262c6c477 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -674,13 +674,20 @@ Return Value: } #if defined(USE_KLEIDIAI) - if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME() || MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()){ this->MlasSGemmBatchOverride = ArmKleidiAI::MlasGemmBatch; this->MlasSGemmPackBSizeOverride = ArmKleidiAI::MlasGemmPackBSize; this->MlasSGemmPackBOverride = ArmKleidiAI::MlasGemmPackB; this->MlasDynamicQGemmBatchOverride = ArmKleidiAI::MlasDynamicQGemmBatch; this->MlasDynamicQGemmPackBSizeOverride = ArmKleidiAI::MlasDynamicQGemmPackBSize; this->MlasDynamicQGemmPackBOverride = ArmKleidiAI::MlasDynamicQGemmPackB; + this->MlasHalfGemmBatchOverride = ArmKleidiAI::MlasHalfGemmBatch; + this->MlasHalfGemmPackBSizeOverride = ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize; + this->MlasHalfGemmPackBOverride = ArmKleidiAI::MlasHalfGemmKleidiAIPackB; + this->MlasHalfConvPrepareOverride = ArmKleidiAI::MlasHalfConvPrepare; + this->MlasHalfConvOverride = ArmKleidiAI::MlasHalfConv; + this->MlasHalfConvPackWeightsAndBiasSizeOverride = ArmKleidiAI::MlasHalfConvPackWeightsAndBiasSize; + this->MlasHalfConvPackWeightsAndBiasOverride = ArmKleidiAI::MlasHalfConvPackWeightsAndBias; this->MlasConvPrepareOverride = ArmKleidiAI::MlasConvPrepare; this->MlasConvOverride = ArmKleidiAI::MlasConv; #if defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 7f960648de40e..4ccc3caaa08e1 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -4,27 +4,388 @@ #include "core/optimizer/insert_cast_transformer.h" #include "core/framework/data_types.h" #include "core/graph/graph_utils.h" +#include "core/mlas/inc/mlas.h" + +#include +#include +#include using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; namespace onnxruntime { +template +static bool IsTensorOfType(const NodeArg& node_arg) { + const auto* type_proto = node_arg.TypeAsProto(); + return node_arg.Exists() && + type_proto != nullptr && + DataTypeImpl::TypeFromProto(*type_proto) == DataTypeImpl::GetTensorType(); +} + +template +static bool HasTensorArgOfType(const NodeArgs& node_args) { + return std::any_of(node_args.cbegin(), node_args.cend(), + [](const NodeArg* node_arg) { + return node_arg != nullptr && IsTensorOfType(*node_arg); + }); +} + static bool IsMLFloat16Tensor(const NodeArg& node_arg) { - // Type() will return nullptr if node_arg.Exists() is true so don't need an additional check for that - return node_arg.Type() != nullptr && - DataTypeImpl::TypeFromProto(*node_arg.TypeAsProto()) == DataTypeImpl::GetTensorType(); + return IsTensorOfType(node_arg); } bool InsertCastTransformer::NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const { - // If the node's input is float16 and currently the node is not assigned to any EP - // we need to insert a cast to float, and put the node on CPU for default behavior. - // We don't cast a node with a subgraph as we'd need to do a lot more checking of the subgraph inputs - // (both explicit and implicit) and contents to determine if it was safe to do so. - // TODO: a better check is to check does the CPU kernel with float exist or not. + // Returns true when this input is an fp16 input to an unassigned node that is eligible + // for the cast-to-fp32 fallback path. + // + // Nodes with subgraphs are excluded because rewriting explicit and implicit subgraph + // inputs safely requires additional checks of the subgraph boundaries and contents. return node->GetExecutionProviderType().empty() && !node->ContainsSubgraph() && IsMLFloat16Tensor(*input); } +static bool HasFp16IO(const onnxruntime::Node& node) { + return HasTensorArgOfType(node.InputDefs()) || + HasTensorArgOfType(node.OutputDefs()); +} + +static bool HasFp16Input(const onnxruntime::Node& node) { + return HasTensorArgOfType(node.InputDefs()); +} + +static bool IsCpuFp16OptInPolicyOp(const onnxruntime::Node& node) { + // Standard-domain ops whose CPU fp16 kernels are governed by session.enable_cpu_fp16 + // and the fp32 fallback heuristic. Existing CPU fp16 kernels are intentionally not + // included in this opt-in/fallback policy. + return node.Domain().empty() && + (node.OpType() == "MatMul" || node.OpType() == "Gemm"); +} + +static std::optional DimValue(const ONNX_NAMESPACE::TensorShapeProto* shape, int dim_idx) { + if (!shape || dim_idx < 0 || dim_idx >= shape->dim_size()) { + return std::nullopt; + } + + const auto& dim = shape->dim(dim_idx); + if (!dim.has_dim_value()) { + return std::nullopt; + } + + return dim.dim_value(); +} + +static std::optional ProductOfDimsExceptLast(const ONNX_NAMESPACE::TensorShapeProto* shape) { + if (!shape || shape->dim_size() < 2) { + return std::nullopt; + } + + int64_t product = 1; + for (int i = 0; i < shape->dim_size() - 1; ++i) { + const auto dim = DimValue(shape, i); + if (!dim || *dim < 0) { + return std::nullopt; + } + + if (*dim != 0 && product > std::numeric_limits::max() / *dim) { + return std::nullopt; + } + product *= *dim; + } + + return product; +} + +static std::optional GetAttributeInt(const onnxruntime::Node& node, + const std::string& attr_name, + int64_t default_value) { + const auto& attrs = node.GetAttributes(); + const auto attr = attrs.find(attr_name); + if (attr == attrs.end()) { + return default_value; + } + + if (attr->second.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_INT) { + return std::nullopt; + } + + return attr->second.i(); +} + +struct CpuFp16GemmShape { + int64_t M; + int64_t N; + int64_t K; +}; + +static std::optional GetMatMulShapeForCpuFp16Heuristic(const onnxruntime::Node& node) { + const auto& inputs = node.InputDefs(); + if (inputs.size() < 2 || !inputs[0] || !inputs[1]) { + return std::nullopt; + } + + const auto* a_shape = inputs[0]->Shape(); + const auto* b_shape = inputs[1]->Shape(); + if (!a_shape || !b_shape || a_shape->dim_size() < 2 || b_shape->dim_size() < 2) { + return std::nullopt; + } + + const auto b_inner_dim = DimValue(b_shape, b_shape->dim_size() - 2); + const auto b_size_to_inner_dim = ProductOfDimsExceptLast(b_shape); + // Match MatMulComputeHelper: RHS shapes like [1, ..., 1, K, N] also flatten + // the left input because the leading RHS dims are only padding. + const bool flattens_left = a_shape->dim_size() >= b_shape->dim_size() && + b_inner_dim && b_size_to_inner_dim && + *b_size_to_inner_dim == *b_inner_dim; + + // Match MatMulComputeHelper: effectively 2D RHS flattens the left input, while + // genuinely batched RHS uses the per-GEMM row count. + auto M = flattens_left ? ProductOfDimsExceptLast(a_shape) + : DimValue(a_shape, a_shape->dim_size() - 2); + auto K = DimValue(a_shape, a_shape->dim_size() - 1); + auto N = DimValue(b_shape, b_shape->dim_size() - 1); + if (!M || !N || !K) { + return std::nullopt; + } + + return CpuFp16GemmShape{*M, *N, *K}; +} + +struct CpuFp16MatMulRhsShape { + int64_t N; + int64_t K; +}; + +static std::optional GetMatMulRhsShapeForCpuFp16Heuristic(const onnxruntime::Node& node) { + const auto& inputs = node.InputDefs(); + if (inputs.size() < 2 || !inputs[1]) { + return std::nullopt; + } + + const auto* b_shape = inputs[1]->Shape(); + if (!b_shape || b_shape->dim_size() != 2) { + return std::nullopt; + } + + auto K = DimValue(b_shape, b_shape->dim_size() - 2); + auto N = DimValue(b_shape, b_shape->dim_size() - 1); + if (!N || !K || *N < 0 || *K < 0) { + return std::nullopt; + } + + if (*K != 0 && *N > std::numeric_limits::max() / *K) { + return std::nullopt; + } + + return CpuFp16MatMulRhsShape{*N, *K}; +} + +static std::optional GetGemmShapeForCpuFp16Heuristic(const onnxruntime::Node& node) { + const auto& inputs = node.InputDefs(); + if (inputs.size() < 2 || !inputs[0] || !inputs[1]) { + return std::nullopt; + } + + const auto* a_shape = inputs[0]->Shape(); + const auto* b_shape = inputs[1]->Shape(); + if (!a_shape || !b_shape || a_shape->dim_size() != 2 || b_shape->dim_size() != 2) { + return std::nullopt; + } + + const auto trans_a = GetAttributeInt(node, "transA", 0); + const auto trans_b = GetAttributeInt(node, "transB", 0); + if (!trans_a || !trans_b) { + return std::nullopt; + } + + const auto M = DimValue(a_shape, *trans_a ? 1 : 0); + const auto K = DimValue(a_shape, *trans_a ? 0 : 1); + const auto N = DimValue(b_shape, *trans_b ? 0 : 1); + if (!M || !N || !K) { + return std::nullopt; + } + + return CpuFp16GemmShape{*M, *N, *K}; +} + +static bool ShouldKeepNativeCpuFp16ForMatMulOrGemm( + const Graph& graph, + const onnxruntime::Node& node, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG& mlas_backend_kernel_selector_config) { + constexpr int64_t kMaxNativeFp16GemvM = 4; + constexpr int64_t kMinNativeFp16ConstantMatMulNK = 512 * 1024; + constexpr int64_t kMinNativeFp16NK = 1024 * 1024; + + std::optional shape; + if (node.OpType() == "MatMul") { + const auto& inputs = node.InputDefs(); + const bool has_constant_rhs = + inputs.size() > 1 && inputs[1] && graph_utils::IsInitializer(graph, inputs[1]->Name(), true); + if (has_constant_rhs) { + const auto rhs_shape = GetMatMulRhsShapeForCpuFp16Heuristic(node); + if (!rhs_shape) { + return false; + } + + const auto nk = rhs_shape->N * rhs_shape->K; + if (nk < kMinNativeFp16ConstantMatMulNK) { + return false; + } + + return MlasHalfGemmNativePackBSize(CblasNoTrans, CblasNoTrans, + static_cast(rhs_shape->N), + static_cast(rhs_shape->K), + &mlas_backend_kernel_selector_config) != 0; + } + + shape = GetMatMulShapeForCpuFp16Heuristic(node); + } else if (node.OpType() == "Gemm") { + shape = GetGemmShapeForCpuFp16Heuristic(node); + } + + if (!shape || shape->M < 0 || shape->N < 0 || shape->K < 0) { + return false; + } + + if (shape->K != 0 && shape->N > std::numeric_limits::max() / shape->K) { + return false; + } + + const auto nk = shape->N * shape->K; + + if (node.OpType() == "MatMul") { + return shape->M <= kMaxNativeFp16GemvM && + nk >= kMinNativeFp16NK && + MlasHGemmSupported(CblasNoTrans, CblasNoTrans); + } + + if (shape->M > kMaxNativeFp16GemvM) { + return false; + } + + return nk >= kMinNativeFp16NK; +} + +static const IKernelTypeStrResolver& GetInsertCastKernelTypeStrResolver() { +#if !defined(ORT_MINIMAL_BUILD) + static const OpSchemaKernelTypeStrResolver resolver; +#else + static const KernelTypeStrResolver resolver; +#endif + return resolver; +} + +static bool BuildTypeConstraintMapForNode(const onnxruntime::Node& node, + bool replace_fp16_with_float, + InlinedHashMap& type_constraint_map) { + // Build the type-constraint map that kernel lookup uses for this node. + // + // ONNX kernel lookup is based on the operator schema's type variables (e.g. T, T1, T2) + // rather than directly on individual NodeArg names. For example, a schema may say that + // both inputs and the output are of type "T". To ask "does CPU have a kernel for this + // node as currently typed?" or "does CPU have a float32 fallback for this fp16 node?", + // we first need to resolve those schema type variables to concrete MLDataType values. + // + // When replace_fp16_with_float is false we record the node's current types as-is. + // When it is true we rewrite any float16 tensors to float in the constructed map so + // we can ask whether a valid float32 fallback kernel exists for the same operator. + const auto* schema = node.Op(); + if (!schema) { + return false; + } + + const TypeConstraintMap& type_schema = schema->typeConstraintMap(); + type_constraint_map.reserve(type_schema.size()); + + const auto SetTypeConstraint = [&](const std::string& type_str, const NodeArg* def) { + if (!def || !def->Exists()) { + return; + } + + TypeConstraintMap::const_iterator it = type_schema.find(type_str); + if (it == type_schema.end()) { + return; + } + + auto type = DataTypeImpl::TypeFromProto(*(def->TypeAsProto())); + if (replace_fp16_with_float && type == DataTypeImpl::GetTensorType()) { + type = DataTypeImpl::GetTensorType(); + } + + type_constraint_map[type_str] = type; + }; + + const auto& input_arg_counts = node.InputArgCount(); + const auto& input_defs = node.InputDefs(); + const auto& formal_inputs = schema->inputs(); + const size_t num_inputs = std::min(formal_inputs.size(), input_arg_counts.size()); + int input_idx_start = 0; + for (size_t formal_idx = 0; + formal_idx < num_inputs; + input_idx_start += input_arg_counts[formal_idx], formal_idx++) { + const auto& type_str = formal_inputs[formal_idx].GetTypeStr(); + // Variadic formal parameters can map to multiple actual inputs. For current CPU fp16 + // preservation/fallback decisions we only need one concrete binding for the schema type + // variable, so we take the first existing actual input for that formal parameter. + for (int input_idx = 0; input_idx < input_arg_counts[formal_idx]; input_idx++) { + const size_t idx = static_cast(input_idx_start) + static_cast(input_idx); + ORT_ENFORCE(idx < input_defs.size()); + const NodeArg* input_def = input_defs[idx]; + if (!input_def || !input_def->Exists()) { + continue; + } + + SetTypeConstraint(type_str, input_def); + break; + } + } + + const auto& output_defs = node.OutputDefs(); + const auto& formal_outputs = schema->outputs(); + const size_t num_outputs = std::min(formal_outputs.size(), output_defs.size()); + for (size_t idx = 0; idx < num_outputs; idx++) { + const auto& type_str = formal_outputs[idx].GetTypeStr(); + SetTypeConstraint(type_str, output_defs[idx]); + } + + return true; +} + +static bool HasCpuKernelForCurrentTypes( + const onnxruntime::Node& node, + const InlinedVector>& cpu_kernel_registries, + const logging::Logger& logger) { + const auto& resolver = GetInsertCastKernelTypeStrResolver(); + for (const KernelRegistry* cpu_kernel_registry : cpu_kernel_registries) { + if (KernelRegistry::HasImplementationOf(*cpu_kernel_registry, node, kCpuExecutionProvider, resolver, logger)) { + return true; + } + } + + return false; +} + +static bool HasCpuFloat32FallbackKernel( + const onnxruntime::Node& node, + const InlinedVector>& cpu_kernel_registries, + const logging::Logger& logger) { + InlinedHashMap type_constraint_map; + if (!BuildTypeConstraintMapForNode(node, true, type_constraint_map)) { + return false; + } + + for (const KernelRegistry* cpu_kernel_registry : cpu_kernel_registries) { + const KernelCreateInfo* kernel_create_info{}; + const auto lookup_status = cpu_kernel_registry->TryFindKernel( + kCpuExecutionProvider, node.OpType(), node.Domain(), + node.SinceVersion(), type_constraint_map, logger, &kernel_create_info); + if (lookup_status.IsOK() && kernel_create_info != nullptr) { + return true; + } + } + + return false; +} + onnxruntime::NodeArg* AddCastNode(onnxruntime::Graph& graph, onnxruntime::NodeArg* old_arg, TypeProto* new_type, @@ -48,16 +409,13 @@ onnxruntime::NodeArg* AddCastNode(onnxruntime::Graph& graph, // check if the node has an fp16 input but was not able to be assigned an execution provider. // we will need to add casts to/from fp32 around the node for it to be executed using the CPU EP. -static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node) { +static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node, + const InlinedVector>& cpu_kernel_registries, + const logging::Logger& logger) { bool not_assigned = node.GetExecutionProviderType().empty(); if (not_assigned) { - const auto& input_defs = node.InputDefs(); - bool has_fp16_input = std::any_of(input_defs.cbegin(), input_defs.cend(), - [](const NodeArg* input_def) { - return IsMLFloat16Tensor(*input_def); - }); - return has_fp16_input; + return HasFp16Input(node) && HasCpuFloat32FallbackKernel(node, cpu_kernel_registries, logger); } return false; @@ -85,7 +443,7 @@ static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node) { // // Return true if all the fp16 inputs and outputs are connected to nodes that will be cast to fp32. static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph, - const KernelRegistry& cpu_kernel_registry, + const InlinedVector>& cpu_kernel_registries, const logging::Logger& logger) { // we can check if it's an isolated fp16 node // if node has input coming from other nodes (only consuming graph inputs or initializers if it doesn't), @@ -162,7 +520,7 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: const int arg_idx = input_edge->GetDstArgIndex(); if (fp16_args.find(arg_idx) != fp16_args.end()) { // if the node producing our fp16 input does not need its input cast to fp32 we should run in fp16 - if (!NodeNeedsInputCastToFp32(input_edge->GetNode())) { + if (!NodeNeedsInputCastToFp32(input_edge->GetNode(), cpu_kernel_registries, logger)) { return false; } } @@ -202,7 +560,7 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: const int arg_idx = output_edge->GetSrcArgIndex(); if (fp16_args.find(arg_idx) != fp16_args.end()) { // if the node producing our fp16 input does not need its input cast to fp32 we should run in fp16 - if (!NodeNeedsInputCastToFp32(output_edge->GetNode())) { + if (!NodeNeedsInputCastToFp32(output_edge->GetNode(), cpu_kernel_registries, logger)) { return false; } } @@ -210,24 +568,32 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: // now all fp16 inputs and outputs would have a cast // make sure fp32 version of the kernel is available. - const KernelCreateInfo* kernel_create_info{}; - const auto lookup_status = cpu_kernel_registry.TryFindKernel( - kCpuExecutionProvider, node.OpType(), node.Domain(), - node.SinceVersion(), type_constraint_map, logger, &kernel_create_info); - if (lookup_status.IsOK() && kernel_create_info != nullptr) { - return true; + for (const KernelRegistry* cpu_kernel_registry : cpu_kernel_registries) { + const KernelCreateInfo* kernel_create_info{}; + const auto lookup_status = cpu_kernel_registry->TryFindKernel( + kCpuExecutionProvider, node.OpType(), node.Domain(), + node.SinceVersion(), type_constraint_map, logger, &kernel_create_info); + if (lookup_status.IsOK() && kernel_create_info != nullptr) { + return true; + } } } return false; } -static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry, - const logging::Logger& logger) { +static Status ForceSingleNodeCPUFloat16ToFloat32( + onnxruntime::Graph& graph, + const InlinedVector>& cpu_kernel_registries, + const logging::Logger& logger, + InlinedHashSet& forced_fp32_nodes) { for (auto& node : graph.Nodes()) { - if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry, logger)) { - // unassign the node so that NeedInsertCast will return true for it, forcing it to fp32 + if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registries, logger)) { + // Unassign the node so that NeedInsertCast will return true for it, forcing it to fp32. + // Track the node index as well so the later broad fp16-preservation logic does not + // immediately assign it back to CPU and undo the heuristic. node.SetExecutionProviderType(""); + forced_fp32_nodes.insert(node.Index()); } } @@ -504,8 +870,11 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { - if (force_cpu_fp32_) - ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_, logger)); + InlinedHashSet forced_fp32_nodes; + if (force_cpu_fp32_ && !cpu_kernel_registries_.empty()) { + ORT_RETURN_IF_ERROR( + ForceSingleNodeCPUFloat16ToFloat32(graph, cpu_kernel_registries_, logger, forced_fp32_nodes)); + } GraphViewer graph_viewer(graph); auto& order = graph_viewer.GetNodesInTopologicalOrder(); @@ -520,6 +889,43 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie if (!node) return Status(ONNXRUNTIME, INVALID_ARGUMENT); + if (!enable_cpu_fp16_ && + node->GetExecutionProviderType() == kCpuExecutionProvider && + IsCpuFp16OptInPolicyOp(*node) && + HasFp16Input(*node) && + !node->ContainsSubgraph()) { + node->SetExecutionProviderType(""); + } + + if (enable_cpu_fp16_ && + force_cpu_fp32_ && + IsCpuFp16OptInPolicyOp(*node) && + HasFp16Input(*node) && + !node->ContainsSubgraph() && + (node->GetExecutionProviderType().empty() || + node->GetExecutionProviderType() == kCpuExecutionProvider) && + !ShouldKeepNativeCpuFp16ForMatMulOrGemm(graph, *node, mlas_backend_kernel_selector_config_) && + HasCpuFloat32FallbackKernel(*node, cpu_kernel_registries_, logger)) { + // Current Arm fp16 paths are profitable for constant-RHS MatMul once native + // packed-B is available, and for large GEMV-like shapes. Keep Gemm conservative + // until MLAS native fp16 is consistently faster across its common shapes. + node->SetExecutionProviderType(""); + forced_fp32_nodes.insert(node->Index()); + } + + const bool has_fp16_io = !node->ContainsSubgraph() && HasFp16IO(*node); + const bool has_cpu_fp16_kernel = + has_fp16_io && HasCpuKernelForCurrentTypes(*node, cpu_kernel_registries_, logger); + + if (enable_cpu_fp16_ && + node->GetExecutionProviderType().empty() && + has_cpu_fp16_kernel && + forced_fp32_nodes.find(node->Index()) == forced_fp32_nodes.end()) { + // When CPU fp16 is enabled, assign any currently-unassigned fp16-capable node to CPU + // so it is preserved in fp16 instead of being routed through the fp32 cast fallback. + node->SetExecutionProviderType(kCpuExecutionProvider); + } + auto& inputs = node->MutableInputDefs(); std::map replacement_defs; bool casted = false; @@ -546,7 +952,7 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie if (casted) { // Set current node to run on the CPU execution provider - // Keep in mind that the EP will be empty because NeedInsertCast() already insures that + // Keep in mind that the EP will be empty because NeedInsertCast() already ensures that node->SetExecutionProviderType(kCpuExecutionProvider); // Some ONNX operators have an attribute `dtype` which define the output type for these operators diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.h b/onnxruntime/core/optimizer/insert_cast_transformer.h index 8be08d51585cf..abe2a40a687be 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.h +++ b/onnxruntime/core/optimizer/insert_cast_transformer.h @@ -3,11 +3,13 @@ #pragma once #include "core/common/common.h" +#include "core/common/inlined_containers.h" #include "core/graph/graph_viewer.h" #include "core/framework/op_kernel.h" #include "core/optimizer/graph_transformer.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/kernel_registry.h" +#include "core/mlas/inc/mlas.h" namespace onnxruntime { @@ -22,22 +24,52 @@ class InsertCastTransformer : public onnxruntime::GraphTransformer { * @brief Initializer * @param name for logging purpose * @param cpu_kernel_registry used to query whether an op node can be safely created + * @param enable_cpu_fp16 if true, allows CPU fp16 kernels to run without forcing fp32 casts + * @param force_cpu_fp32 if true, applies the CPU fp16 fallback heuristic, which may keep selected + * fp16 CPU nodes on the existing fp32 cast fallback path when native fp16 is + * not expected to be profitable for the active MLAS backend + * @param mlas_backend_kernel_selector_config + * active MLAS backend selector config. Used by the CPU fp16 heuristic to avoid + * preserving native fp16 paths that rely on backend-specific support, such as + * native packed-B, when that backend is disabled or unavailable. */ - InsertCastTransformer(const std::string& name, const KernelRegistry* cpu_kernel_registry) + InsertCastTransformer(const std::string& name, const KernelRegistry* cpu_kernel_registry, + bool enable_cpu_fp16 = false, bool force_cpu_fp32 = true, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config = nullptr) : onnxruntime::GraphTransformer(name), - cpu_kernel_registries_(cpu_kernel_registry), - force_cpu_fp32_(cpu_kernel_registry != nullptr) {} + cpu_kernel_registries_(cpu_kernel_registry != nullptr ? InlinedVector>{cpu_kernel_registry} + : InlinedVector>{}), + enable_cpu_fp16_(enable_cpu_fp16), + force_cpu_fp32_(!cpu_kernel_registries_.empty() && force_cpu_fp32), + mlas_backend_kernel_selector_config_(mlas_backend_kernel_selector_config != nullptr + ? *mlas_backend_kernel_selector_config + : MLAS_BACKEND_KERNEL_SELECTOR_CONFIG{}) {} + + InsertCastTransformer(const std::string& name, + InlinedVector> cpu_kernel_registries, + bool enable_cpu_fp16 = false, bool force_cpu_fp32 = true, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config = nullptr) + : onnxruntime::GraphTransformer(name), + cpu_kernel_registries_(std::move(cpu_kernel_registries)), + enable_cpu_fp16_(enable_cpu_fp16), + force_cpu_fp32_(!cpu_kernel_registries_.empty() && force_cpu_fp32), + mlas_backend_kernel_selector_config_(mlas_backend_kernel_selector_config != nullptr + ? *mlas_backend_kernel_selector_config + : MLAS_BACKEND_KERNEL_SELECTOR_CONFIG{}) {} private: Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; bool NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const; - const KernelRegistry* cpu_kernel_registries_; + const InlinedVector> cpu_kernel_registries_; - // Currently because we only have very few cpu kernels support float16, place those nodes on float16 - // will introduce many cast between fp32 and fp16, which will slow the execution. - // A better solution is to have a cost model to evaluate does it works to place the node on float16. - // Here for simplify, we only force the single-node-float16 sub-graph to float32 + const bool enable_cpu_fp16_; + + // Some CPU fp16 kernels are only profitable for specific shapes and backend capabilities. A broader cost model would + // be better; for now we use conservative checks for known slower cases and native packed-B availability. const bool force_cpu_fp32_; + + // Copied from session options so graph optimization makes the same backend-capability decisions that execution will. + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index fb854b19accfa..5476a9b210f65 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -177,11 +177,13 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 21, Atan); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Hardmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, float, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, double, MatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, MLFloat16, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, Softmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, Softmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, float, TopK); @@ -397,8 +399,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, double, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int64_t, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, float, @@ -575,6 +579,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, ScatterND); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, BitShift); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, BitShift); @@ -709,8 +714,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, #endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, MatMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Min); @@ -1745,6 +1752,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { float, Gemm)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc index 08dbc46213f65..95445c73a39ce 100644 --- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc +++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc @@ -9,9 +9,13 @@ #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +#include + +#include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/common/float16.h" #include "core/framework/op_kernel.h" +#include "core/providers/cpu/mlas_backend_kernel_selector_config_utils.h" #include "core/providers/cpu/nn/conv_attributes.h" #include "contrib_ops/cpu/fused_activation.h" @@ -46,6 +50,13 @@ class FusedConvFp16 final : public OpKernel { FusedConvFp16(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) { ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK()); channels_last_ = (info.GetKernelDef().OpName() == "NhwcFusedConv"); + SetupMlasBackendKernelSelectorFromConfigOptions(mlas_backend_kernel_selector_config_, info.GetConfigOptions()); + const auto& input_defs = info.node().InputDefs(); + has_bias_input_ = input_defs.size() >= 3 && input_defs[2]->Exists(); + const Tensor* bias = nullptr; + if (has_bias_input_ && info.TryGetConstantInput(2, &bias)) { + constant_B_ = bias; + } } Status Compute(OpKernelContext* context) const override; @@ -93,8 +104,13 @@ class FusedConvFp16 final : public OpKernel { } MLAS_ACTIVATION activation_; + MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config_; ConvAttributes conv_attrs_; + bool has_bias_input_{false}; + const Tensor* constant_B_{nullptr}; TensorShape W_shape_; + BufferUniquePtr packed_halfconv_weights_and_bias_buffer_; + size_t packed_halfconv_weights_and_bias_size_{0}; BufferUniquePtr packed_W_buffer_; size_t packed_W_size_{0}; bool is_W_packed_{false}; @@ -139,8 +155,53 @@ Status FusedConvFp16::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr const size_t kernel_dim = group_input_channels * kernel_size; bool share_prepacked_weights = (prepacked_weights != nullptr); + const bool has_valid_constant_halfconv_bias = + constant_B_ != nullptr && + !share_prepacked_weights && + constant_B_->Shape().NumDimensions() == 1 && + constant_B_->Shape()[0] == static_cast(output_channels); const bool is_depthwise_conv = (group_input_channels == 1 && group_output_channels == 1); + + if (group_count == 1 && + rank == 4 && + output_channels > 1 && + activation_.ActivationKind == MlasIdentityActivation && + (!has_bias_input_ || has_valid_constant_halfconv_bias)) { + std::array halfconv_kernel_shape{shape[2], shape[3]}; + std::array halfconv_dilations{1, 1}; + if (!conv_attrs_.dilations.empty()) { + halfconv_dilations = {conv_attrs_.dilations[0], conv_attrs_.dilations[1]}; + } + + packed_halfconv_weights_and_bias_size_ = MlasHalfConvPackWeightsAndBiasSize( + output_channels, + group_input_channels, + halfconv_kernel_shape.data(), + halfconv_dilations.data(), + &mlas_backend_kernel_selector_config_); + if (packed_halfconv_weights_and_bias_size_ != 0) { + auto* packed_halfconv_weights_and_bias = alloc->Alloc(packed_halfconv_weights_and_bias_size_); + BufferUniquePtr packed_halfconv_weights_and_bias_buffer( + packed_halfconv_weights_and_bias, BufferDeleter(alloc)); + const auto* bias_data = constant_B_ != nullptr ? constant_B_->Data() : nullptr; + if (MlasHalfConvPackWeightsAndBias( + output_channels, + group_input_channels, + halfconv_kernel_shape.data(), + halfconv_dilations.data(), + Wdata, + bias_data, + packed_halfconv_weights_and_bias, + nullptr, + &mlas_backend_kernel_selector_config_)) { + packed_halfconv_weights_and_bias_buffer_ = std::move(packed_halfconv_weights_and_bias_buffer); + } else { + packed_halfconv_weights_and_bias_size_ = 0; + } + } + } + // Don't pack the filter buffer if the MlasConvDepthwise path is used. if (!is_depthwise_conv) { packed_W_size_ = MlasHalfGemmPackBSize(group_output_channels, kernel_dim, false); @@ -185,6 +246,9 @@ Status FusedConvFp16::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr } if (share_prepacked_weights) { + // `packed_halfconv_weights_and_bias_buffer_` is session-local only. It encodes a bias + // choice even when the bias is null, so it must not be shared under the + // W-only prepack cache key. prepacked_weights->buffers_.push_back(nullptr); // packed_W_buffer_ is nullptr prepacked_weights->buffer_sizes_.push_back(0); } @@ -222,12 +286,13 @@ Status FusedConvFp16::UseSharedPrePackedBuffers(std::vector& pr used_shared_buffers = true; - if (prepacked_buffers.size() == 1) { // This means that only packed_W_ exists + if (prepacked_buffers.size() == 1) { // only packed_W_ exists packed_W_buffer_ = std::move(prepacked_buffers[0]); - } else if (prepacked_buffers.size() == 2) { // This means that only reordered_W_ exists - // Enforce that the first "placeholder" buffer is nullptr + } else if (prepacked_buffers.size() == 2) { // placeholder + reordered_W_ ORT_ENFORCE(prepacked_buffers[0].get() == nullptr); reordered_W_buffer_ = std::move(prepacked_buffers[1]); + } else { + ORT_ENFORCE(false, "Unexpected number of shared prepacked fp16 conv buffers."); } return Status::OK(); @@ -296,6 +361,63 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { AllocatorPtr alloc; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); + concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); + + const auto* Bdata = B != nullptr ? B->Data() : nullptr; + + if (Sum == nullptr && kernel_rank == 2) { + MLAS_CONV_PARAMETERS parameters; + + size_t working_buffer_size = 0; + if (MlasHalfConvPrepare(¶meters, + kernel_rank, + narrow(N), + narrow(conv_attrs_.group), + narrow(C / conv_attrs_.group), + input_shape.GetDims().data(), + kernel_shape.data(), + dilations.data(), + pads.data(), + strides.data(), + output_shape.GetDims().data(), + narrow(M / conv_attrs_.group), + &activation_, + &working_buffer_size, + 0.0f, + channels_last_, + thread_pool, + &mlas_backend_kernel_selector_config_)) { + const MLFloat16* halfconv_filter = nullptr; + const MLFloat16* halfconv_bias = Bdata; + bool halfconv_filter_and_bias_are_packed = false; + + if (packed_halfconv_weights_and_bias_buffer_ != nullptr) { + halfconv_filter = static_cast(packed_halfconv_weights_and_bias_buffer_.get()); + halfconv_filter_and_bias_are_packed = true; + halfconv_bias = nullptr; + } else if (W != nullptr) { + halfconv_filter = W->Data(); + } + + if (halfconv_filter != nullptr) { + auto* working_data = working_buffer_size > 0 + ? alloc->Alloc(sizeof(MLFloat16) * SafeInt(working_buffer_size)) + : nullptr; + BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc)); + + if (MlasHalfConv(¶meters, + X->Data(), + halfconv_filter, + halfconv_filter_and_bias_are_packed, + halfconv_bias, + static_cast(working_buffer.get()), + Y->MutableData(), + thread_pool)) { + return Status::OK(); + } + } + } + } // Handle the case of a dynamic weight filter. BufferUniquePtr reordered_W_buffer; @@ -337,7 +459,6 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { const int64_t col_buffer_size = kernel_dim * output_image_size; const auto* Xdata = X->Data(); - const auto* Bdata = B != nullptr ? B->Data() : nullptr; auto* Ydata = Y->MutableData(); const auto* sum_data = Sum != nullptr ? Sum->Data() : nullptr; @@ -387,8 +508,6 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { padding_data.resize(static_cast(C), MLFloat16()); } - concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); - /************************************* * Thread partition idea: we are essentially partition a GEMM A[M,K] x B[K,N]. * Here B contains the conv filters, which are usually not big, so we assume @@ -563,7 +682,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { static_cast(output_count), static_cast(group_output_channels), static_cast(kernel_dim), - 1, &gemm_params, nullptr); + 1, &gemm_params, nullptr, &mlas_backend_kernel_selector_config_); } } }; diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index c0da9aec1e1b1..6fd633619b956 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -206,15 +206,19 @@ void Gemm_MLFloat16(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, if (c_data == nullptr) beta = onnxruntime::MLFloat16::Zero; #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED - bool support_mlas = false; + bool support_mlas_bias = false; if (c_shape == nullptr) { - support_mlas = true; + support_mlas_bias = true; } else if (c_shape->NumDimensions() == 1 && (*c_shape)[0] == N) { - support_mlas = true; - } else if (c_shape->NumDimensions() == 2 && (((*c_shape)[0] == 1 && (*c_shape)[1] == N) || ((*c_shape)[0] == N && (*c_shape)[1] == 1))) { - support_mlas = true; + support_mlas_bias = true; + } else if (c_shape->NumDimensions() == 2 && + (((*c_shape)[0] == 1 && (*c_shape)[1] == N) || ((*c_shape)[0] == N && (*c_shape)[1] == 1))) { + support_mlas_bias = true; } - if (trans_a == CblasNoTrans && trans_b == CblasNoTrans && support_mlas && alpha.ToFloat() == 1.0 && beta.ToFloat() == 1.0) { + const bool use_mlas_no_bias = beta == onnxruntime::MLFloat16::Zero; + const bool use_mlas_bias = beta == onnxruntime::MLFloat16::One && support_mlas_bias; + if (trans_a == CblasNoTrans && trans_b == CblasNoTrans && + alpha == onnxruntime::MLFloat16::One && (use_mlas_no_bias || use_mlas_bias)) { MLAS_HALF_GEMM_DATA_PARAMS data; data.A = a_data; data.lda = K; @@ -222,10 +226,10 @@ void Gemm_MLFloat16(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, data.ldb = N; data.C = y_data; data.ldc = N; - if (c_shape != nullptr) { + if (use_mlas_bias && c_shape != nullptr) { data.Bias = c_data; } - MlasHalfGemmBatch(M, N, K, 1, &data, thread_pool); + MlasHalfGemmBatch(M, N, K, 1, &data, thread_pool, mlas_backend_kernel_selector_config); return; } #endif diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index ca91f46db93da..594b732683000 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -6,6 +6,8 @@ #include "core/providers/cpu/math/matmul_helper.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" +#include +#include namespace onnxruntime { @@ -23,6 +25,13 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), MatMul); +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 1, 8, + MLFloat16, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + // opset 9 supports more types ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( MatMul, @@ -40,6 +49,14 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), MatMul); +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 9, + 12, + MLFloat16, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( MatMul, 9, @@ -72,6 +89,13 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), MatMul); +ONNX_CPU_OPERATOR_TYPED_KERNEL( + MatMul, + 13, + MLFloat16, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + ONNX_CPU_OPERATOR_TYPED_KERNEL( MatMul, 13, @@ -106,9 +130,8 @@ Status MatMul::Compute(OpKernelContext* ctx) const { if (helper.K() == 0) { // When we have (M, 0, N) then the inputs are empty, but the output should // be filled out with zeros. - EigenMatrixMapRowMajor dest(y->MutableData(), - narrow(helper.M()), narrow(helper.N())); - dest.setZero(); + auto output_span = gsl::make_span(y->MutableData(), y->Shape().Size()); + std::fill(output_span.begin(), output_span.end(), T{}); return Status::OK(); } @@ -235,6 +258,45 @@ bool GemmPackBBfloat16(AllocatorPtr& alloc, } #endif +bool GemmPackBHalfNative(AllocatorPtr& alloc, + const Tensor& tensor_b, + IAllocatorUniquePtr& packed_b, + size_t& packed_b_size, + TensorShape& b_shape, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config) { + // Only handle the common case of a 2D weight matrix. Additional matrices + // could be handled by stacking the packed buffers. + if (tensor_b.Shape().NumDimensions() != 2) { + return false; + } + + b_shape = tensor_b.Shape(); + + const size_t K = static_cast(b_shape[0]); + const size_t N = static_cast(b_shape[1]); + + packed_b_size = MlasHalfGemmNativePackBSize(CblasNoTrans, CblasNoTrans, N, K, + mlas_backend_kernel_selector_config); + if (packed_b_size == 0) { + return false; + } + + packed_b = IAllocator::MakeUniquePtr(alloc, packed_b_size, true); + auto* packed_b_data = packed_b.get(); + memset(packed_b_data, 0, packed_b_size); + + if (!MlasHalfGemmNativePackB(CblasNoTrans, CblasNoTrans, N, K, + reinterpret_cast(tensor_b.Data()), N, packed_b_data, + mlas_backend_kernel_selector_config)) { + packed_b.reset(); + packed_b_size = 0; + b_shape = TensorShape(); + return false; + } + + return true; +} + Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { @@ -284,6 +346,103 @@ Status MatMul::UseSharedPrePackedBuffers(std::vector& pr return Status::OK(); } +Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + ORT_UNUSED_PARAMETER(prepacked_weights); + + if (input_idx == 1) { + size_t packed_b_size = 0; + is_packed = GemmPackBHalfNative(alloc, tensor, packed_b_, packed_b_size, b_shape_, + &mlas_backend_kernel_selector_config_); + // The native fp16 packed-B layout depends on the active MLAS backend selector. + // Keep it owned by this kernel until shared prepacked weights carry layout metadata. + } + + return Status::OK(); +} + +Status MatMul::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span prepacked_buffer_sizes, + int input_idx, + /*out*/ bool& used_shared_buffers) { + ORT_UNUSED_PARAMETER(prepacked_buffers); + ORT_UNUSED_PARAMETER(prepacked_buffer_sizes); + ORT_UNUSED_PARAMETER(input_idx); + + // Native fp16 packed-B buffers are backend-layout-specific. Decline shared + // buffers until the shared prepack cache can validate that layout. + used_shared_buffers = false; + return Status::OK(); +} + +Status MatMul::Compute(OpKernelContext* ctx) const { + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); + + const Tensor* a = ctx->Input(0); + const Tensor* b = packed_b_ ? nullptr : ctx->Input(1); + const auto& b_shape = b ? b->Shape() : b_shape_; + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape)); + Tensor* y = ctx->Output(0, helper.OutputShape()); + + if (y->Shape().Size() == 0) { + return Status::OK(); + } + + if (helper.K() == 0) { + auto output_span = gsl::make_span(y->MutableData(), y->Shape().Size()); + std::fill(output_span.begin(), output_span.end(), MLFloat16{}); + return Status::OK(); + } + + const auto* a_data = a->Data(); + const auto* b_data = b ? b->Data() : nullptr; + auto* y_data = y->MutableData(); + + const size_t max_len = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(false); + const size_t ldb = helper.Ldb(false); + + if (M <= 2 && packed_b_ == nullptr && MlasHGemmSupported(CblasNoTrans, CblasNoTrans)) { + const auto alpha = MLFloat16(1.0f); + const auto beta = MLFloat16(0.0f); + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = b_data + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = alpha.val; + data[i].beta = beta.val; + } + + MlasGemmBatch(CblasNoTrans, CblasNoTrans, M, N, K, data.data(), max_len, thread_pool); + return Status::OK(); + } + + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = packed_b_ ? packed_b_.get() : static_cast(b_data + helper.RightOffsets()[i]); + data[i].ldb = packed_b_ ? 0 : ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].BIsBackendNativePacked = static_cast(packed_b_); + } + + MlasHalfGemmBatch(M, N, K, max_len, data.data(), thread_pool, &mlas_backend_kernel_selector_config_); + return Status::OK(); +} + Status MatMul::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h index d1c0df19f924e..25d222818aaf6 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.h +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -106,4 +106,28 @@ class MatMul final : public OpKernel { #endif }; +template <> +class MatMul final : public OpKernel { + public: + MatMul(const OpKernelInfo& info) : OpKernel(info) { + SetupMlasBackendKernelSelectorFromConfigOptions(mlas_backend_kernel_selector_config_, info.GetConfigOptions()); + } + + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, + int input_idx, + /*out*/ bool& used_shared_buffers) override; + + Status Compute(OpKernelContext* context) const override; + + private: + TensorShape b_shape_; + IAllocatorUniquePtr packed_b_; + MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config_; +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 48b76d13f01ff..5dccfe9a7d989 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -66,6 +66,7 @@ #endif #include "core/providers/cpu/controlflow/utils.h" #include "core/providers/cpu/cpu_execution_provider.h" +#include "core/providers/cpu/mlas_backend_kernel_selector_config_utils.h" #include "core/session/abi_devices.h" #ifdef USE_DML // TODO: This is necessary for the workaround in TransformGraph #include "core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h" @@ -1715,14 +1716,23 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool const InlinedVector> kernel_regs = kernel_registry_manager_.GetKernelRegistriesByProviderType(kCpuExecutionProvider); - const KernelRegistry* cpu_regs = nullptr; - if (!kernel_regs.empty()) { - // NOTE: This assumes that CPU kernels are always at the n-1 index of kernel registries vector as per the design - // of GetKernelRegistriesByProviderType function. - cpu_regs = kernel_regs[kernel_regs.size() - 1]; - } - - InsertCastTransformer insert_cast_transformer{"CastFloat16Transformer", cpu_regs}; + const bool enable_cpu_fp16 = + session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableCpuFp16, "0") == "1"; + const bool use_cpu_fp16_fp32_fallback_heuristic = + session_options_.config_options.GetConfigOrDefault( + kOrtSessionOptionsCpuFp16UseFp32FallbackHeuristic, "1") == "1"; + const bool force_cpu_fp32 = !enable_cpu_fp16 || use_cpu_fp16_fp32_fallback_heuristic; + + // Keep InsertCastTransformer's CPU fp16 profitability checks aligned with execution-time MLAS backend selection. + MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config; + SetupMlasBackendKernelSelectorFromConfigOptions(mlas_backend_kernel_selector_config, + session_options_.config_options); + + InsertCastTransformer insert_cast_transformer{ + "CastFloat16Transformer", kernel_regs, + /*enable_cpu_fp16*/ enable_cpu_fp16, + /*force_cpu_fp32*/ force_cpu_fp32, + &mlas_backend_kernel_selector_config}; ORT_RETURN_IF_ERROR_SESSIONID_( apply_transformer_once(insert_cast_transformer, *session_logger_, graph, ((graph_optimizations_loop_level > 1) ? &is_graph_modified : nullptr))); diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index 608d30c12b587..52e5c9d0092f0 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -190,6 +190,35 @@ void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const float* A, const MlasGemm(CblasNoTrans, CblasNoTrans, M, N, K, 1.f, A, K, B, N, 0.f, C, N, threadpool, mlas_backend_kernel_selector_config); } +template <> +void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, ThreadPool* threadpool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config) { +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + MLAS_HALF_GEMM_DATA_PARAMS data; + data.A = A; + data.lda = static_cast(K); + data.B = B; + data.ldb = static_cast(N); + data.C = C; + data.ldc = static_cast(N); + MlasHalfGemmBatch(static_cast(M), static_cast(N), static_cast(K), 1, &data, threadpool, + mlas_backend_kernel_selector_config); +#else + ORT_UNUSED_PARAMETER(threadpool); + ORT_UNUSED_PARAMETER(mlas_backend_kernel_selector_config); +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + auto C_mat = EigenMatrixMap(reinterpret_cast(C), N, M); + C_mat.noalias() = ConstEigenMatrixMap(reinterpret_cast(B), N, K) * + ConstEigenMatrixMap(reinterpret_cast(A), K, M); +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif +#endif +} + #ifdef MLAS_SUPPORTS_GEMM_DOUBLE template <> void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const double* A, const double* B, double* C, ThreadPool* threadpool, diff --git a/onnxruntime/test/framework/insert_cast_transformer_test.cc b/onnxruntime/test/framework/insert_cast_transformer_test.cc index 1c42deff1a130..1f9e60c494e65 100644 --- a/onnxruntime/test/framework/insert_cast_transformer_test.cc +++ b/onnxruntime/test/framework/insert_cast_transformer_test.cc @@ -2,15 +2,20 @@ // Licensed under the MIT License. #include "core/framework/allocator.h" +#include "core/framework/tensorprotoutils.h" #include "core/optimizer/insert_cast_transformer.h" #include "core/graph/model.h" #include "core/graph/node_attr_utils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include #include "gtest/gtest.h" +#include "test/internal_testing_ep/internal_testing_execution_provider.h" #include "test/unittest_util/framework_test_utils.h" #include "test/test_environment.h" #include "test/util/include/default_providers.h" #include "test/util/include/inference_session_wrapper.h" #include "test/util/include/asserts.h" +#include using namespace ONNX_NAMESPACE; namespace onnxruntime { @@ -19,6 +24,20 @@ namespace test { #define MODEL_FOLDER ORT_TSTR("testdata/transform/") typedef std::vector ArgMap; + +static TypeProto MakeFp16TensorType(std::initializer_list shape = {}) { + TypeProto tensor_float_16; + tensor_float_16.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT16); + if (shape.size() > 0) { + auto* tensor_shape = tensor_float_16.mutable_tensor_type()->mutable_shape(); + for (const auto dim : shape) { + tensor_shape->add_dim()->set_dim_value(dim); + } + } + + return tensor_float_16; +} + TEST(TransformerTest, InsertCastGPUTest) { auto model = std::make_shared("test", false, DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& graph = model->MainGraph(); @@ -309,6 +328,679 @@ TEST(TransformerTest, Fp16NodeWithSubgraph) { EXPECT_EQ(new_subgraph_ops.find("Cast")->second, 3) << "'Add' node in subgraph should have had Casts added"; } +static std::shared_ptr MakeCpuFp16Model(const std::string& model_name, const std::string& op_type, + bool assign_cpu_ep) { + auto model = std::make_shared(model_name, false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model->MainGraph(); + + TypeProto tensor_float_16 = MakeFp16TensorType(); + + auto& a = graph.GetOrCreateNodeArg("A", &tensor_float_16); + auto& b = graph.GetOrCreateNodeArg("B", &tensor_float_16); + auto& y = graph.GetOrCreateNodeArg("Y", &tensor_float_16); + + Node& node = [&]() -> Node& { + if (op_type == "MatMul" || op_type == "Add") { + return graph.AddNode(op_type, op_type, "fp16 test op", ArgMap{&a, &b}, ArgMap{&y}); + } + + if (op_type == "Abs") { + return graph.AddNode(op_type, op_type, "fp16 test op", ArgMap{&a}, ArgMap{&y}); + } + + if (op_type == "Gemm") { + auto& c = graph.GetOrCreateNodeArg("C", &tensor_float_16); + graph.SetInputs({&a, &b, &c}); + return graph.AddNode(op_type, op_type, "fp16 test op", ArgMap{&a, &b, &c}, ArgMap{&y}); + } + + ORT_THROW("Unsupported op type for test: ", op_type); + }(); + + if (op_type == "Gemm") { + // Inputs already set when the node is created. + } else if (op_type == "Abs") { + graph.SetInputs({&a}); + } else { + graph.SetInputs({&a, &b}); + } + graph.SetOutputs({&y}); + + if (assign_cpu_ep) { + node.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); + } + + ORT_THROW_IF_ERROR(graph.Resolve()); + return model; +} + +static std::shared_ptr MakeCpuFp16MatMulModelWithShapes(const std::string& model_name, + std::initializer_list a_shape, + std::initializer_list b_shape, + std::initializer_list y_shape, + bool assign_cpu_ep, + bool constant_b = false) { + auto model = std::make_shared(model_name, false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model->MainGraph(); + + TypeProto a_type = MakeFp16TensorType(a_shape); + TypeProto b_type = MakeFp16TensorType(b_shape); + TypeProto y_type = MakeFp16TensorType(y_shape); + + auto& a = graph.GetOrCreateNodeArg("A", &a_type); + auto& b = graph.GetOrCreateNodeArg("B", &b_type); + auto& y = graph.GetOrCreateNodeArg("Y", &y_type); + + if (constant_b) { + TensorProto b_initializer; + b_initializer.set_name("B"); + b_initializer.set_data_type(TensorProto_DataType_FLOAT16); + for (const auto dim : b_shape) { + b_initializer.add_dims(dim); + } + + size_t element_count = 1; + for (const auto dim : b_shape) { + element_count *= static_cast(dim); + } + std::vector data(element_count, MLFloat16::Zero); + utils::SetRawDataInTensorProto(b_initializer, data.data(), data.size() * sizeof(MLFloat16)); + graph.AddInitializedTensor(b_initializer); + } + + auto& node = graph.AddNode("MatMul", "MatMul", "fp16 matmul shape heuristic test", ArgMap{&a, &b}, ArgMap{&y}); + if (assign_cpu_ep) { + node.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); + } + + if (constant_b) { + graph.SetInputs({&a}); + } else { + graph.SetInputs({&a, &b}); + } + graph.SetOutputs({&y}); + ORT_THROW_IF_ERROR(graph.Resolve()); + + return model; +} + +static void ExpectMlFloat16Output(const std::vector& output, + const std::vector& expected) { + ASSERT_EQ(output.size(), expected.size()); + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_EQ(output[i].val, expected[i].val); + } +} + +static bool IsNodeArgType(const NodeArg& node_arg, const MLDataType type) { + return node_arg.Type() != nullptr && + DataTypeImpl::TypeFromProto(*node_arg.TypeAsProto()) == type; +} + +static bool CanRunNativeCpuFp16GemmRuntime() { +#if defined(__aarch64__) || defined(_M_ARM64) + return MlasFp16AccelerationSupported(); +#else + return false; +#endif +} + +static const Node* FindNodeByOpType(const Graph& graph, const std::string& op_type) { + for (const auto& node : graph.Nodes()) { + if (node.OpType() == op_type) { + return &node; + } + } + + return nullptr; +} + +static std::vector RunCpuFp16Model(const std::string& model_name, + const std::string& op_type, + bool enable_cpu_fp16) { + const auto input_path = ToPathString(model_name + (enable_cpu_fp16 ? "_enabled_runtime.onnx" : "_disabled_runtime.onnx")); + ORT_THROW_IF_ERROR(Model::Save(*MakeCpuFp16Model(model_name, op_type, false), input_path)); + + SessionOptions so; + so.session_logid = model_name; + so.graph_optimization_level = TransformerLevel::Level4; + ORT_THROW_IF_ERROR(so.config_options.AddConfigEntry( + kOrtSessionOptionsEnableCpuFp16, enable_cpu_fp16 ? "1" : "0")); + ORT_THROW_IF_ERROR(so.config_options.AddConfigEntry( + kOrtSessionOptionsCpuFp16UseFp32FallbackHeuristic, enable_cpu_fp16 ? "0" : "1")); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ORT_THROW_IF_ERROR(session.Load(input_path)); + ORT_THROW_IF_ERROR(session.Initialize()); + + auto allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; + + NameMLValMap feeds; + OrtValue a; + CreateMLValue(allocator, {2, 2}, + std::vector{MLFloat16(1.0f), MLFloat16(2.0f), + MLFloat16(3.0f), MLFloat16(4.0f)}, + &a); + feeds.insert({"A", a}); + + OrtValue b; + CreateMLValue(allocator, {2, 2}, + std::vector{MLFloat16(5.0f), MLFloat16(6.0f), + MLFloat16(7.0f), MLFloat16(8.0f)}, + &b); + feeds.insert({"B", b}); + + if (op_type == "Gemm") { + OrtValue c; + CreateMLValue(allocator, {2, 2}, + std::vector{MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f)}, + &c); + feeds.insert({"C", c}); + } + + std::vector fetches; + RunOptions run_options; + ORT_THROW_IF_ERROR(session.Run(run_options, feeds, AsSpan({std::string("Y")}), &fetches)); + ORT_ENFORCE(fetches.size() == 1u, "Expected exactly one output for ", model_name); + + const auto& output = fetches[0].Get(); + std::vector result(output.Data(), + output.Data() + output.Shape().Size()); + + std::error_code ec; + std::filesystem::remove(std::filesystem::path{input_path}, ec); + + return result; +} + +TEST(TransformerTest, CpuFp16MatMulPreservesWhenEnabled) { + { + auto model = MakeCpuFp16Model("cpu_fp16_matmul_optin", "MatMul", true); + auto& graph = model->MainGraph(); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, false); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_TRUE(op_counts.find("Cast") == op_counts.cend()); + } +} + +TEST(TransformerTest, CpuFp16CpuAssignedMatMulHasNoCastsWhenEnabled) { + auto model = MakeCpuFp16Model("cpu_fp16_matmul_unassigned", "MatMul", false); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + node->SetExecutionProviderType(kCpuExecutionProvider); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, false); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_TRUE(op_counts.find("Cast") == op_counts.cend()); + + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16CpuAssignedGemmHasNoCastsWhenEnabled) { + auto model = MakeCpuFp16Model("cpu_fp16_gemm_unassigned", "Gemm", false); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + node->SetExecutionProviderType(kCpuExecutionProvider); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, false); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_TRUE(op_counts.find("Cast") == op_counts.cend()); + + EXPECT_EQ(node->OpType(), "Gemm"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16NonCpuAssignedMatMulIsUntouchedWhenEnabled) { + auto model = MakeCpuFp16Model("cpu_fp16_matmul_other_ep", "MatMul", false); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + node->SetExecutionProviderType("SomeOtherEP"); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, false); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_TRUE(op_counts.find("Cast") == op_counts.cend()); + + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), "SomeOtherEP"); +} + +TEST(TransformerTest, CpuFp16NonCpuAssignedGemmIsUntouchedWhenEnabled) { + auto model = MakeCpuFp16Model("cpu_fp16_gemm_other_ep", "Gemm", false); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + node->SetExecutionProviderType("SomeOtherEP"); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, false); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_TRUE(op_counts.find("Cast") == op_counts.cend()); + + EXPECT_EQ(node->OpType(), "Gemm"); + EXPECT_EQ(node->GetExecutionProviderType(), "SomeOtherEP"); +} + +TEST(TransformerTest, CpuFp16UnassignedMatMulKeepsFp16WhenEnabled) { + auto model = MakeCpuFp16Model("cpu_fp16_matmul_unassigned", "MatMul", false); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, false); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.count("Cast"), 0U); + + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulHeuristicForcesBertLikeShapeToFp32) { + auto model = MakeCpuFp16MatMulModelWithShapes("cpu_fp16_matmul_heuristic_bert_like", + {128, 512}, {512, 512}, {128, 512}, true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.at("Cast"), 3); + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulHeuristicKeepsLargeGemvNativeFp16) { + auto model = MakeCpuFp16MatMulModelWithShapes("cpu_fp16_matmul_heuristic_large_gemv", + {1, 4096}, {4096, 4096}, {1, 4096}, true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.count("Cast"), 0U); + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulHeuristicKeepsBatchedLargeGemvNativeFp16) { + auto model = MakeCpuFp16MatMulModelWithShapes("cpu_fp16_matmul_heuristic_batched_large_gemv", + {8, 1, 4096}, {8, 4096, 4096}, {8, 1, 4096}, true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.count("Cast"), 0U); + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulHeuristicForcesSingletonBroadcastRhsToFp32) { + auto model = MakeCpuFp16MatMulModelWithShapes("cpu_fp16_matmul_heuristic_singleton_broadcast_rhs", + {8, 1, 4096}, {1, 4096, 4096}, {8, 1, 4096}, true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.at("Cast"), 3); + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulHeuristicForcesFlattenedLargeGemvToFp32) { + auto model = MakeCpuFp16MatMulModelWithShapes("cpu_fp16_matmul_heuristic_flattened_large_gemv", + {8, 1, 4096}, {4096, 4096}, {8, 1, 4096}, true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.at("Cast"), 3); + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulHeuristicKeepsConstantRhsBertLikeShapeNativeFp16) { + if (MlasHalfGemmNativePackBSize(CblasNoTrans, CblasNoTrans, 1024, 512) == 0) { + GTEST_SKIP() << "No native packed-B fp16 MatMul backend is available."; + } + + auto model = MakeCpuFp16MatMulModelWithShapes("cpu_fp16_matmul_heuristic_constant_rhs_bert_like", + {128, 512}, {512, 1024}, {128, 1024}, true, true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.count("Cast"), 0U); + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulHeuristicForcesBatchedConstantRhsToFp32) { + if (MlasHalfGemmNativePackBSize(CblasNoTrans, CblasNoTrans, 1024, 512) == 0) { + GTEST_SKIP() << "No native packed-B fp16 MatMul backend is available."; + } + + auto model = MakeCpuFp16MatMulModelWithShapes("cpu_fp16_matmul_heuristic_batched_constant_rhs", + {2, 1, 512}, {2, 512, 1024}, {2, 1, 1024}, true, true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.at("Cast"), 3); + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulHeuristicForcesConstantRhsWhenNativePackedBUnavailable) { + auto model = MakeCpuFp16MatMulModelWithShapes("cpu_fp16_matmul_heuristic_constant_rhs_no_native_pack", + {128, 512}, {512, 1024}, {128, 1024}, true, true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config; + mlas_backend_kernel_selector_config.use_kleidiai = false; + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, true, &mlas_backend_kernel_selector_config); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.at("Cast"), 3); + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulSessionConfigUsesFallbackHeuristicKey) { + const auto input_path = ToPathString("cpu_fp16_matmul_heuristic_mlas_config_runtime.onnx"); + ORT_THROW_IF_ERROR(Model::Save(*MakeCpuFp16MatMulModelWithShapes( + "cpu_fp16_matmul_heuristic_mlas_config_runtime", + {128, 512}, {512, 1024}, {128, 1024}, false, true), + input_path)); + + SessionOptions so; + so.session_logid = "cpu_fp16_matmul_heuristic_mlas_config_runtime"; + so.graph_optimization_level = TransformerLevel::Level4; + ORT_THROW_IF_ERROR(so.config_options.AddConfigEntry(kOrtSessionOptionsEnableCpuFp16, "1")); + ORT_THROW_IF_ERROR(so.config_options.AddConfigEntry(kOrtSessionOptionsCpuFp16UseFp32FallbackHeuristic, "1")); + ORT_THROW_IF_ERROR(so.config_options.AddConfigEntry(kOrtSessionOptionsMlasDisableKleidiAi, "1")); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(input_path)); + ASSERT_STATUS_OK(session.Initialize()); + + const auto op_counts = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_counts.at("Cast"), 2); + + const Node* matmul_node = FindNodeByOpType(session.GetGraph(), "MatMul"); + ASSERT_NE(matmul_node, nullptr); + + ASSERT_EQ(matmul_node->InputDefs().size(), 2U); + ASSERT_EQ(matmul_node->OutputDefs().size(), 1U); + EXPECT_TRUE(IsNodeArgType(*matmul_node->InputDefs()[0], DataTypeImpl::GetTensorType())); + EXPECT_TRUE(IsNodeArgType(*matmul_node->InputDefs()[1], DataTypeImpl::GetTensorType())); + EXPECT_TRUE(IsNodeArgType(*matmul_node->OutputDefs()[0], DataTypeImpl::GetTensorType())); + + std::error_code ec; + std::filesystem::remove(std::filesystem::path{input_path}, ec); +} + +TEST(TransformerTest, CpuFp16UnsupportedOpStillGetsCastsWhenEnabled) { + auto model = MakeCpuFp16Model("cpu_fp16_abs_unassigned", "Abs", false); + auto& graph = model->MainGraph(); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, false); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_TRUE(op_counts.find("Cast") != op_counts.cend()); + EXPECT_EQ(op_counts.at("Cast"), 2); + + const Node* abs_node = nullptr; + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "Abs") { + abs_node = &node; + break; + } + } + ASSERT_NE(abs_node, nullptr); + EXPECT_EQ(abs_node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16SupportedCpuOpKeepsFp16WhenEnabled) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP() << "CPU fp16 kernels are not registered on this platform."; + } + + auto model = MakeCpuFp16Model("cpu_fp16_add_unassigned", "Add", false); + auto& graph = model->MainGraph(); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, false); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.count("Cast"), 0U); + + const Node* add_node = nullptr; + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "Add") { + add_node = &node; + break; + } + } + ASSERT_NE(add_node, nullptr); + EXPECT_EQ(add_node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16CpuAssignedNewOptInOpsUseFp32FallbackWhenDisabled) { + const std::vector> test_cases{ + {"MatMul", 3}, + {"Gemm", 4}, + }; + + for (const auto& [op_type, expected_cast_count] : test_cases) { + auto model = MakeCpuFp16Model("cpu_fp16_" + op_type + "_disabled_cpu_assigned", op_type, true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + false, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.at("Cast"), expected_cast_count); + EXPECT_EQ(node->OpType(), op_type); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); + } +} + +TEST(TransformerTest, CpuFp16CpuAssignedExistingFp16OpHasNoExtraCastsWhenDisabled) { + auto model = MakeCpuFp16Model("cpu_fp16_add_disabled_cpu_assigned", "Add", true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + false, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.count("Cast"), 0U); + EXPECT_EQ(node->OpType(), "Add"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulRuntimeRunsWhenEnabled) { + if (!CanRunNativeCpuFp16GemmRuntime()) { + GTEST_SKIP() << "Native CPU fp16 MatMul runtime is not available on this platform."; + } + + const auto output = RunCpuFp16Model("cpu_fp16_matmul_runtime", "MatMul", true); + const std::vector expected{ + MLFloat16(19.0f), MLFloat16(22.0f), + MLFloat16(43.0f), MLFloat16(50.0f)}; + ExpectMlFloat16Output(output, expected); +} + +TEST(TransformerTest, CpuFp16GemmRuntimeRunsWhenEnabled) { + if (!CanRunNativeCpuFp16GemmRuntime()) { + GTEST_SKIP() << "Native CPU fp16 Gemm runtime is not available on this platform."; + } + + const auto output = RunCpuFp16Model("cpu_fp16_gemm_runtime", "Gemm", true); + const std::vector expected{ + MLFloat16(20.0f), MLFloat16(23.0f), + MLFloat16(44.0f), MLFloat16(51.0f)}; + ExpectMlFloat16Output(output, expected); +} + +TEST(TransformerTest, CpuFp16MatMulRuntimeRunsWhenDisabled) { + const auto output = RunCpuFp16Model("cpu_fp16_matmul_runtime", "MatMul", false); + const std::vector expected{ + MLFloat16(19.0f), MLFloat16(22.0f), + MLFloat16(43.0f), MLFloat16(50.0f)}; + ExpectMlFloat16Output(output, expected); +} + +TEST(TransformerTest, CpuFp16GemmRuntimeRunsWhenDisabled) { + const auto output = RunCpuFp16Model("cpu_fp16_gemm_runtime", "Gemm", false); + const std::vector expected{ + MLFloat16(20.0f), MLFloat16(23.0f), + MLFloat16(44.0f), MLFloat16(51.0f)}; + ExpectMlFloat16Output(output, expected); +} + +TEST(TransformerTest, CpuFp16MixedEpMatMulStaysOnNonCpuEpWhenEnabled) { + const auto input_path = ToPathString("cpu_fp16_matmul_mixed_ep_runtime.onnx"); + ORT_THROW_IF_ERROR(Model::Save(*MakeCpuFp16Model("cpu_fp16_matmul_mixed_ep_runtime", "MatMul", false), input_path)); + + SessionOptions so; + so.session_logid = "cpu_fp16_matmul_mixed_ep_runtime"; + so.graph_optimization_level = TransformerLevel::Level4; + ORT_THROW_IF_ERROR(so.config_options.AddConfigEntry(kOrtSessionOptionsEnableCpuFp16, "1")); + ORT_THROW_IF_ERROR(so.config_options.AddConfigEntry(kOrtSessionOptionsCpuFp16UseFp32FallbackHeuristic, "0")); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.RegisterExecutionProvider( + std::make_unique(std::unordered_set{"MatMul"}))); + ASSERT_STATUS_OK(session.Load(input_path)); + ASSERT_STATUS_OK(session.Initialize()); + + const auto& graph = session.GetGraph(); + size_t internal_testing_nodes = 0; + size_t cpu_nodes = 0; + size_t cast_nodes = 0; + for (const auto& node : graph.Nodes()) { + if (node.GetExecutionProviderType() == internal_testing_ep::kInternalTestingExecutionProvider) { + ++internal_testing_nodes; + } + if (node.GetExecutionProviderType() == kCpuExecutionProvider) { + ++cpu_nodes; + } + if (node.OpType() == "Cast") { + ++cast_nodes; + } + } + + EXPECT_EQ(internal_testing_nodes, 1u); + EXPECT_EQ(cpu_nodes, 0u); + EXPECT_EQ(cast_nodes, 0u); + + std::error_code ec; + std::filesystem::remove(std::filesystem::path{input_path}, ec); +} + TEST(TransformerTest, IsIsolatedFp16NodeOnCpuTest) { auto model = std::make_shared("test", false, DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& graph = model->MainGraph(); diff --git a/onnxruntime/test/mlas/unittest/test_conv2d.cpp b/onnxruntime/test/mlas/unittest/test_conv2d.cpp index 091d4ee833f8f..98bf05c8bd476 100644 --- a/onnxruntime/test/mlas/unittest/test_conv2d.cpp +++ b/onnxruntime/test/mlas/unittest/test_conv2d.cpp @@ -4,6 +4,64 @@ #include "test_conv2d.h" #include "test_conv2d_fixture.h" +TEST(Conv2d_HalfConv, PrepareRespectsBackendSelectorConfig) { + const int64_t input_shape[] = {5, 5}; + const int64_t kernel_shape[] = {3, 3}; + const int64_t dilation_shape[] = {1, 1}; + const int64_t padding[] = {1, 1, 1, 1}; + const int64_t stride_shape[] = {1, 1}; + const int64_t output_shape[] = {5, 5}; + + MLAS_ACTIVATION activation; + activation.ActivationKind = MlasIdentityActivation; + + MLAS_CONV_PARAMETERS parameters{}; + size_t working_buffer_size = 0; + if (!MlasHalfConvPrepare(¶meters, + 2, + 1, + 1, + 4, + input_shape, + kernel_shape, + dilation_shape, + padding, + stride_shape, + output_shape, + 8, + &activation, + &working_buffer_size, + 0.0f, + false, + nullptr, + nullptr)) { + GTEST_SKIP() << "HalfConv prepare path unavailable"; + } + + MLAS_BACKEND_KERNEL_SELECTOR_CONFIG selector_config; + selector_config.use_kleidiai = false; + + MLAS_CONV_PARAMETERS disabled_parameters{}; + EXPECT_FALSE(MlasHalfConvPrepare(&disabled_parameters, + 2, + 1, + 1, + 4, + input_shape, + kernel_shape, + dilation_shape, + padding, + stride_shape, + output_shape, + 8, + &activation, + &working_buffer_size, + 0.0f, + false, + nullptr, + &selector_config)); +} + static size_t Conv2dRegistLongExecute() { size_t count = MlasLongExecuteTests>::RegisterLongExecute(); if (GetMlasThreadPool() != nullptr) { diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp index aafdcc14c0028..d2b686c29b6ca 100644 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp @@ -15,6 +15,264 @@ Module Name: --*/ #include "test_halfgemm.h" +#if defined(USE_KLEIDIAI) +#include "core/mlas/lib/mlasi.h" +#include "core/mlas/lib/kleidiai/mlasi_kleidiai.h" +#endif + +#include +#include +#include +#include + +#if defined(USE_KLEIDIAI) +namespace { + +bool g_test_halfgemm_override_called = false; + +bool MLASCALL +TestHalfGemmBatchOverride( + size_t M, + size_t N, + size_t, + size_t BatchN, + const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL*, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG*) { + g_test_halfgemm_override_called = true; + + for (size_t batch = 0; batch < BatchN; ++batch) { + auto* c = DataParams[batch].C; + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + c[m * DataParams[batch].ldc + n] = onnxruntime::MLFloat16(1.0f); + } + } + } + + return true; +} + +void ReferenceHalfGemm( + size_t M, + size_t N, + size_t K, + const MLFp16* A, + const MLFp16* B, + MLFp16* C) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float sum = 0.0f; + for (size_t k = 0; k < K; ++k) { + sum += float(A[m * K + k]) * float(B[k * N + n]); + } + C[m * N + n] = MLFp16(sum); + } + } +} + +struct HalfGemmOverrideGuard { + explicit HalfGemmOverrideGuard(MLAS_HALF_GEMM_BATCH_OVERRIDE* replacement) + : original_(GetMlasPlatform().MlasHalfGemmBatchOverride) { + GetMlasPlatform().MlasHalfGemmBatchOverride = replacement; + } + + ~HalfGemmOverrideGuard() { + GetMlasPlatform().MlasHalfGemmBatchOverride = original_; + } + + private: + MLAS_HALF_GEMM_BATCH_OVERRIDE* original_; +}; + +} // namespace + +TEST(HalfGemmKleidiAISelector, DisableKleidiAIBypassesOverride) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP() << "HalfGemm FP16 acceleration not available."; + } + + constexpr size_t M = 5; + constexpr size_t N = 7; + constexpr size_t K = 9; + + std::vector A(M * K); + std::vector B(K * N); + std::vector C(M * N, MLFp16(0.0f)); + std::vector CReference(M * N, MLFp16(0.0f)); + + SmallFloatFill(A.data(), A.size()); + SmallFloatFill(B.data(), B.size()); + + MLAS_HALF_GEMM_DATA_PARAMS data; + data.A = A.data(); + data.B = B.data(); + data.C = reinterpret_cast(C.data()); + data.lda = K; + data.ldb = N; + data.ldc = N; + + HalfGemmOverrideGuard guard(TestHalfGemmBatchOverride); + + g_test_halfgemm_override_called = false; + MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr); + ASSERT_TRUE(g_test_halfgemm_override_called); + for (const auto& value : C) { + ASSERT_EQ(float(value), 1.0f); + } + + std::fill(C.begin(), C.end(), MLFp16(0.0f)); + ReferenceHalfGemm(M, N, K, A.data(), B.data(), CReference.data()); + + MLAS_BACKEND_KERNEL_SELECTOR_CONFIG selector_config; + selector_config.use_kleidiai = false; + + g_test_halfgemm_override_called = false; + MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, &selector_config); + ASSERT_FALSE(g_test_halfgemm_override_called); + + for (size_t i = 0; i < C.size(); ++i) { + ASSERT_TRUE(CloseEnough(float(C[i]), float(CReference[i]))) << "index=" << i; + } +} +#endif + +namespace { + +struct HalfGemmCase { + const char* name; + size_t M; + size_t N; + size_t K; + size_t Batch; + bool has_bias; +}; + +template +void RunHalfGemmCases(const HalfGemmCase* test_cases, size_t num_cases, Runner run_case) { + for (size_t i = 0; i < num_cases; ++i) { + const auto& test_case = test_cases[i]; + SCOPED_TRACE(testing::Message() + << test_case.name + << " Batch=" << test_case.Batch + << " M=" << test_case.M + << " N=" << test_case.N + << " K=" << test_case.K + << " hasBias=" << test_case.has_bias); + run_case(test_case); + } +} + +void ReferenceHalfGemmPackedCompatibility( + size_t M, + size_t N, + size_t K, + const MLFp16* A, + const MLFp16* B, + MLFp16* C) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float sum = 0.0f; + for (size_t k = 0; k < K; ++k) { + MLFp16 down(float(A[m * K + k]) * float(B[k * N + n]) + sum); + sum = float(down); + } + C[m * N + n] = MLFp16(sum); + } + } +} + +constexpr HalfGemmCase kNativeFp16Cases[] = { + {"WideNBatch1WithBias", 43, 500, 401, 1, true}, + {"WideNBatch3NoBias", 43, 500, 401, 3, false}, + {"VectorLikeBatch3WithBias", 1, 32, 79, 3, true}, + {"RectangularBatch1WithBias", 64, 48, 80, 1, true}, + {"RectangularBatch1NoBias", 64, 48, 80, 1, false}, +}; + +template +void RunNativeFp16WithoutOutputProcessorCases(const HalfGemmCase* test_cases, size_t num_cases) { + RunHalfGemmCases(test_cases, num_cases, [](const HalfGemmCase& test_case) { + MlasHalfGemmTest test; + test.TestNativeFp16WithoutOutputProcessor( + test_case.M, test_case.N, test_case.K, test_case.Batch, test_case.has_bias); + }); +} + +constexpr HalfGemmCase kKleidiAIPathNonPackedCases[] = { + {"WideNBatch3WithBias", 43, 500, 401, 3, true}, + {"WideNBatch3NoBias", 43, 500, 401, 3, false}, + {"RectangularBatch1WithBias", 64, 48, 80, 1, true}, + {"RectangularBatch1NoBias", 64, 48, 80, 1, false}, +}; + +constexpr HalfGemmCase kKleidiAIPathPackedCases[] = { + {"WideNBatch1WithBias", 43, 500, 401, 1, true}, + {"WideNBatch1NoBias", 43, 500, 401, 1, false}, + {"RectangularBatch1WithBias", 64, 48, 80, 1, true}, + {"RectangularBatch1NoBias", 64, 48, 80, 1, false}, +}; + +template +void RunKleidiAIWithoutOutputProcessorCases(const HalfGemmCase* test_cases, size_t num_cases) { + RunHalfGemmCases(test_cases, num_cases, [](const HalfGemmCase& test_case) { + MlasHalfGemmTest test; + test.TestKleidiAIWithoutOutputProcessor( + test_case.M, test_case.N, test_case.K, test_case.Batch, test_case.has_bias); + }); +} + +} // namespace + +TEST(HalfGemm, ZeroKInitializesBiasAndRunsOutputProcessor) { + constexpr size_t M = 3; + constexpr size_t N = 4; + constexpr size_t K = 0; + + std::vector Bias{MLFp16(1.0f), MLFp16(-2.0f), MLFp16(3.5f), MLFp16(0.25f)}; + std::vector C(M * N, MLFp16(-9.0f)); + std::vector CFloat(M * N, -9.0f); + + MLAS_ACTIVATION act; + act.ActivationKind = MlasIdentityActivation; + MLAS_HALF_GEMM_2FLOAT_PROCESSOR output_processor(act, CFloat.data(), N); + + MLAS_HALF_GEMM_DATA_PARAMS data{}; + data.Bias = reinterpret_cast(Bias.data()); + data.C = reinterpret_cast(C.data()); + data.ldc = N; + data.AIsfp32 = true; + data.OutputProcessor = &output_processor; + + MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr); + + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + const size_t index = m * N + n; + ASSERT_EQ(float(C[index]), float(Bias[n])) << "index=" << index; + ASSERT_EQ(CFloat[index], float(Bias[n])) << "index=" << index; + } + } +} + +TEST(HalfGemm, ZeroKInitializesZeroWithoutBias) { + constexpr size_t M = 3; + constexpr size_t N = 4; + constexpr size_t K = 0; + + std::vector C(M * N, MLFp16(-9.0f)); + + MLAS_HALF_GEMM_DATA_PARAMS data{}; + data.C = reinterpret_cast(C.data()); + data.ldc = N; + data.BIsfp32 = true; + + MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr); + + for (const auto& value : C) { + ASSERT_EQ(float(value), 0.0f); + } +} // // Short Execute() test helper to register each test separately by all parameters. @@ -166,3 +424,365 @@ static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_exe } return HalfGemmRegistLongExecute() > 0; }); + +TEST(HalfGemmPackB, ReturnsZeroOnOverflow) { + const size_t max = (std::numeric_limits::max)(); + EXPECT_EQ(MlasHalfGemmPackBSize(1, max, true), size_t{0}); + EXPECT_EQ(MlasHalfGemmPackBSize(max, 2, true), size_t{0}); +} + +TEST(HalfGemmPackB, GenericPackedBFlagRunsOnFallback) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + constexpr size_t M = 5; + constexpr size_t N = 7; + constexpr size_t K = 9; + + std::vector A(M * K); + std::vector B(K * N); + std::vector C(M * N, MLFp16(0.0f)); + std::vector CReference(M * N, MLFp16(0.0f)); + + SmallFloatFill(A.data(), A.size()); + SmallFloatFill(B.data(), B.size()); + + const size_t packed_b_size = MlasHalfGemmPackBSize(N, K, false); + if (packed_b_size == 0) { + GTEST_SKIP(); + } + + std::vector packed_b(packed_b_size); + MlasHalfGemmPackB(N, K, reinterpret_cast(B.data()), N, packed_b.data()); + + MLAS_HALF_GEMM_DATA_PARAMS data{}; + data.A = A.data(); + data.B = packed_b.data(); + data.C = reinterpret_cast(C.data()); + data.lda = K; + data.ldb = 0; + data.ldc = N; + data.BIsPacked = true; + + MLAS_BACKEND_KERNEL_SELECTOR_CONFIG selector_config{}; + selector_config.use_kleidiai = false; + + MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, &selector_config); + ReferenceHalfGemmPackedCompatibility(M, N, K, A.data(), B.data(), CReference.data()); + + for (size_t i = 0; i < C.size(); ++i) { + ASSERT_TRUE(CloseEnough(float(C[i]), float(CReference[i]))) << "index=" << i; + } +} + +TEST(HalfGemmKleidiAINativeFp16, NoPackSingleThreadWithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + MlasHalfGemmTest test; + test.TestNativeFp16WithoutOutputProcessor(43, 500, 401, 1, true); +} + +TEST(HalfGemmKleidiAINativeFp16, NoPackSingleThreadWithoutOutputProcessorBatch3) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + MlasHalfGemmTest test; + test.TestNativeFp16WithoutOutputProcessor(43, 500, 401, 3, true); +} + +TEST(HalfGemmKleidiAINativeFp16, NoPackThreadedWithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + if (GetMlasThreadPool() == nullptr) { + GTEST_SKIP(); + } + + MlasHalfGemmTest test; + test.TestNativeFp16WithoutOutputProcessor(43, 500, 401, 1, true); +} + +TEST(HalfGemmKleidiAINativeFp16, NoPackThreadedWithoutOutputProcessorBatch3) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + if (GetMlasThreadPool() == nullptr) { + GTEST_SKIP(); + } + + MlasHalfGemmTest test; + test.TestNativeFp16WithoutOutputProcessor(43, 500, 401, 3, true); +} + +TEST(HalfGemmKleidiAINativeFp16, NoPackSingleThreadWithoutOutputProcessorVariedShapesAndBias) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + RunNativeFp16WithoutOutputProcessorCases(kNativeFp16Cases, std::size(kNativeFp16Cases)); +} + +TEST(HalfGemmKleidiAINativeFp16, NoPackThreadedWithoutOutputProcessorVariedShapesAndBias) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + if (GetMlasThreadPool() == nullptr) { + GTEST_SKIP(); + } + + RunNativeFp16WithoutOutputProcessorCases(kNativeFp16Cases, std::size(kNativeFp16Cases)); +} + +TEST(HalfGemmKleidiAIPath, Fp32AConversionSingleThreadBatch3WithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + MlasHalfGemmTest test; + test.TestKleidiAIWithoutOutputProcessor(43, 500, 401, 3, true); +} + +TEST(HalfGemmKleidiAIPath, Fp32AConversionSingleThreadVariedShapesAndBiasWithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + RunKleidiAIWithoutOutputProcessorCases( + kKleidiAIPathNonPackedCases, std::size(kKleidiAIPathNonPackedCases)); +} + +TEST(HalfGemmKleidiAIPath, Fp32BConversionSingleThreadBatch3WithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + MlasHalfGemmTest test; + test.TestKleidiAIWithoutOutputProcessor(43, 500, 401, 3, true); +} + +TEST(HalfGemmKleidiAIPath, Fp32BConversionSingleThreadVariedShapesAndBiasWithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + RunKleidiAIWithoutOutputProcessorCases( + kKleidiAIPathNonPackedCases, std::size(kKleidiAIPathNonPackedCases)); +} + +TEST(HalfGemmKleidiAIPath, Fp32ABConversionSingleThreadBatch3WithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + MlasHalfGemmTest test; + test.TestKleidiAIWithoutOutputProcessor(43, 500, 401, 3, true); +} + +TEST(HalfGemmKleidiAIPath, Fp32ABConversionSingleThreadVariedShapesAndBiasWithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + RunKleidiAIWithoutOutputProcessorCases( + kKleidiAIPathNonPackedCases, std::size(kKleidiAIPathNonPackedCases)); +} + +TEST(HalfGemmKleidiAIPath, PackedBFp16SingleThreadWithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + if (MlasHalfGemmPackBSize(128, 128, false) == 0) { + GTEST_SKIP(); + } + + MlasHalfGemmTest test; + test.TestKleidiAIWithoutOutputProcessor(43, 500, 401, 1, true); +} + +TEST(HalfGemmKleidiAIPath, PackedBFp16SingleThreadVariedShapesAndBiasWithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + if (MlasHalfGemmPackBSize(128, 128, false) == 0) { + GTEST_SKIP(); + } + + RunKleidiAIWithoutOutputProcessorCases( + kKleidiAIPathPackedCases, std::size(kKleidiAIPathPackedCases)); +} + +TEST(HalfGemmKleidiAIPath, PackedBFloatSingleThreadWithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + if (MlasHalfGemmPackBSize(128, 128, true) == 0) { + GTEST_SKIP(); + } + + MlasHalfGemmTest test; + test.TestKleidiAIWithoutOutputProcessor(43, 500, 401, 1, true); +} + +TEST(HalfGemmKleidiAIPath, PackedBFloatSingleThreadVariedShapesAndBiasWithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + if (MlasHalfGemmPackBSize(128, 128, true) == 0) { + GTEST_SKIP(); + } + + RunKleidiAIWithoutOutputProcessorCases( + kKleidiAIPathPackedCases, std::size(kKleidiAIPathPackedCases)); +} + +#if defined(USE_KLEIDIAI) +// KleidiAI-specific packed-B uses a separate direct-consumption contract from +// generic halfgemm PackB. Unsupported combinations fail at the public API +// boundary because generic MLAS cannot consume this backend-native layout. +TEST(HalfGemmKleidiAIPath, KleidiAIPackedBWithBiasThrows) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + if (GetMlasPlatform().MlasHalfGemmBatchOverride == nullptr) { + GTEST_SKIP() << "KleidiAI halfgemm override unavailable"; + } + + constexpr size_t M = 5; + constexpr size_t N = 7; + constexpr size_t K = 9; + + std::vector A(M * K); + std::vector B(K * N); + std::vector Bias(N); + std::vector C(M * N, MLFp16(0.0f)); + + SmallFloatFill(A.data(), A.size()); + SmallFloatFill(B.data(), B.size()); + SmallFloatFill(Bias.data(), Bias.size()); + + const size_t packed_b_size = ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize(CblasNoTrans, CblasNoTrans, N, K); + ASSERT_NE(packed_b_size, size_t{0}); + + std::vector packed_b(packed_b_size); + ASSERT_TRUE(ArmKleidiAI::MlasHalfGemmKleidiAIPackB( + CblasNoTrans, CblasNoTrans, N, K, reinterpret_cast(B.data()), N, packed_b.data())); + + MLAS_HALF_GEMM_DATA_PARAMS data{}; + data.A = A.data(); + data.B = packed_b.data(); + data.Bias = reinterpret_cast(Bias.data()); + data.C = reinterpret_cast(C.data()); + data.lda = K; + data.ldb = 0; + data.ldc = N; + data.BIsBackendNativePacked = true; + + ASSERT_FALSE(ArmKleidiAI::MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr)); + EXPECT_THROW(MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr), std::runtime_error); +} + +TEST(HalfGemmKleidiAIPath, KleidiAIPackedBWithOutputProcessorThrows) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + if (GetMlasPlatform().MlasHalfGemmBatchOverride == nullptr) { + GTEST_SKIP() << "KleidiAI halfgemm override unavailable"; + } + + constexpr size_t M = 5; + constexpr size_t N = 7; + constexpr size_t K = 9; + + std::vector A(M * K); + std::vector B(K * N); + std::vector C(M * N, MLFp16(0.0f)); + std::vector CFloat(M * N, 0.0f); + + SmallFloatFill(A.data(), A.size()); + SmallFloatFill(B.data(), B.size()); + + const size_t packed_b_size = ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize(CblasNoTrans, CblasNoTrans, N, K); + ASSERT_NE(packed_b_size, size_t{0}); + + std::vector packed_b(packed_b_size); + ASSERT_TRUE(ArmKleidiAI::MlasHalfGemmKleidiAIPackB( + CblasNoTrans, CblasNoTrans, N, K, reinterpret_cast(B.data()), N, packed_b.data())); + + MLAS_ACTIVATION act; + act.ActivationKind = MlasIdentityActivation; + MLAS_HALF_GEMM_2FLOAT_PROCESSOR output_processor(act, CFloat.data(), N); + + MLAS_HALF_GEMM_DATA_PARAMS data{}; + data.A = A.data(); + data.B = packed_b.data(); + data.C = reinterpret_cast(C.data()); + data.lda = K; + data.ldb = 0; + data.ldc = N; + data.BIsBackendNativePacked = true; + data.OutputProcessor = &output_processor; + + ASSERT_FALSE(ArmKleidiAI::MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr)); + EXPECT_THROW(MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr), std::runtime_error); +} + +TEST(HalfGemmKleidiAIPath, ZeroKFallsBack) { + if (GetMlasPlatform().MlasHalfGemmBatchOverride == nullptr) { + GTEST_SKIP() << "KleidiAI halfgemm override unavailable"; + } + + constexpr size_t M = 5; + constexpr size_t N = 7; + constexpr size_t K = 0; + + std::vector Bias(N); + std::vector C(M * N, MLFp16(1.0f)); + + SmallFloatFill(Bias.data(), Bias.size()); + + MLAS_HALF_GEMM_DATA_PARAMS data{}; + data.Bias = reinterpret_cast(Bias.data()); + data.C = reinterpret_cast(C.data()); + data.ldc = N; + + const bool handled = ArmKleidiAI::MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr); + ASSERT_FALSE(handled); +} + +TEST(HalfGemmKleidiAIPath, KleidiAIPackedBSizeRejectsUnsupportedTranspose) { + if (GetMlasPlatform().MlasHalfGemmBatchOverride == nullptr) { + GTEST_SKIP() << "KleidiAI halfgemm override unavailable"; + } + + constexpr size_t N = 7; + constexpr size_t K = 9; + + EXPECT_EQ(ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize(CblasTrans, CblasNoTrans, N, K), size_t{0}); + EXPECT_EQ(ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize(CblasNoTrans, CblasTrans, N, K), size_t{0}); +} + +TEST(HalfGemmKleidiAIPath, KleidiAIPackedBRejectsUnsupportedTranspose) { + if (GetMlasPlatform().MlasHalfGemmBatchOverride == nullptr) { + GTEST_SKIP() << "KleidiAI halfgemm override unavailable"; + } + + constexpr size_t N = 7; + constexpr size_t K = 9; + + std::vector B(K * N); + SmallFloatFill(B.data(), B.size()); + + const size_t packed_b_size = ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize(CblasNoTrans, CblasNoTrans, N, K); + ASSERT_NE(packed_b_size, size_t{0}); + + std::vector packed_b(packed_b_size); + EXPECT_FALSE(ArmKleidiAI::MlasHalfGemmKleidiAIPackB( + CblasTrans, CblasNoTrans, N, K, reinterpret_cast(B.data()), N, packed_b.data())); + EXPECT_FALSE(ArmKleidiAI::MlasHalfGemmKleidiAIPackB( + CblasNoTrans, CblasTrans, N, K, reinterpret_cast(B.data()), N, packed_b.data())); +} +#endif diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.h b/onnxruntime/test/mlas/unittest/test_halfgemm.h index 4db5c2bebca40..c68697aa74b3c 100644 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.h +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.h @@ -17,6 +17,9 @@ Module Name: #pragma once #include "test_fp16.h" +#include "core/mlas/lib/mlasi.h" +#include +#include /** * @brief Test class for half precision GEMM @@ -26,6 +29,21 @@ Module Name: template class MlasHalfGemmTest : public MlasTestBase { private: + // Native FP16 is validated against the FP32 reference path + // rather than the existing stepwise-FP16 oracle, so these backend-specific + // tests use a separate tolerance. + static bool CloseEnoughNativeFp16(float got, float ref) { + constexpr float abs_tol = 0.03125f; + constexpr float rel_tol = 0.005f; + + const float diff = std::fabs(got - ref); + if (diff <= abs_tol) { + return true; + } + + return diff <= rel_tol * std::max(std::fabs(got), std::fabs(ref)); + } + MatrixGuardBuffer BufferBPacked; MatrixGuardBuffer BufferA; MatrixGuardBuffer BufferB; @@ -60,7 +78,9 @@ class MlasHalfGemmTest : public MlasTestBase { const MLFp16* Bias, MLFp16* C, size_t ldc, - float* Cfloat) { + float* Cfloat, + bool use_output_processor = true, + bool enforce_kleidiai_override = false) { MLAS_ACTIVATION act; act.ActivationKind = MlasIdentityActivation; std::vector Converters; @@ -84,17 +104,29 @@ class MlasHalfGemmTest : public MlasTestBase { ASSERT_EQ(BatchSize, size_t(1)) << "Packing B not supported in batching yet!"; params.B = PackB(N, K, B, ldb); params.ldb = 0; + params.BIsPacked = true; } else { params.B = B + (K * N * i); params.ldb = ldb; } params.AIsfp32 = std::is_same::value; params.BIsfp32 = std::is_same::value; - Converters.emplace_back(act, Cfloat + (M * N * i), N); - params.OutputProcessor = &(Converters[i]); + if (use_output_processor) { + Converters.emplace_back(act, Cfloat + (M * N * i), N); + params.OutputProcessor = &(Converters[i]); + } else { + params.OutputProcessor = nullptr; + } } - MlasHalfGemmBatch(M, N, K, BatchSize, GemmParameters.data(), threadpool_); + if (enforce_kleidiai_override) { + ASSERT_NE(GetMlasPlatform().MlasHalfGemmBatchOverride, nullptr); + const bool handled = GetMlasPlatform().MlasHalfGemmBatchOverride( + M, N, K, BatchSize, GemmParameters.data(), threadpool_, nullptr); + ASSERT_TRUE(handled); + } else { + MlasHalfGemmBatch(M, N, K, BatchSize, GemmParameters.data(), threadpool_, nullptr); + } } void ReferenceQgemm(size_t M, @@ -153,6 +185,60 @@ class MlasHalfGemmTest : public MlasTestBase { } } + void ReferenceMlasGemmFp32(size_t M, + size_t N, + size_t K, + size_t BatchSize, + const AType* A, + const BType* B, + const MLFp16* Bias, + float* C) { + MatrixGuardBuffer buffer_a_fp32{}; + MatrixGuardBuffer buffer_b_fp32{}; + MatrixGuardBuffer buffer_c_fp32{}; + + float* AFloat = buffer_a_fp32.GetBuffer(M * K * BatchSize); + float* BFloat = buffer_b_fp32.GetBuffer(K * N * BatchSize); + float* CFloat = buffer_c_fp32.GetBuffer(M * N * BatchSize, true); + + for (size_t i = 0; i < M * K * BatchSize; ++i) { + AFloat[i] = float(A[i]); + } + + for (size_t i = 0; i < K * N * BatchSize; ++i) { + BFloat[i] = float(B[i]); + } + + for (size_t batch = 0; batch < BatchSize; ++batch) { + MlasGemm( + CblasNoTrans, CblasNoTrans, M, N, K, + 1.0f, + AFloat + batch * (M * K), K, + BFloat + batch * (K * N), N, + 0.0f, + CFloat + batch * (M * N), N, + threadpool_, + nullptr); + } + + for (size_t batch = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const size_t idx = (M * N * batch) + (m * N) + n; + float sum = CFloat[idx]; + if (Bias != nullptr) { + sum += float(Bias[n]); + } + C[idx] = float(MLFp16(sum)); + } + } + + if (Bias) { + Bias += N; + } + } + } + public: MlasHalfGemmTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} @@ -200,6 +286,56 @@ class MlasHalfGemmTest : public MlasTestBase { } } + void TestKleidiAIWithoutOutputProcessor(size_t M, size_t N, size_t K, size_t BatchSize, bool withBias) { + if (GetMlasPlatform().MlasHalfGemmBatchOverride == nullptr) { + GTEST_SKIP() << "KleidiAI halfgemm override unavailable"; + } + + const AType* A = BufferA.GetFilledBuffer(K * M * BatchSize + 16, SmallFloatFill); + const BType* B = BufferB.GetFilledBuffer(N * K * BatchSize + 16, SmallFloatFill); + + const MLFp16* Bias = nullptr; + if (withBias) { + Bias = BufferBias.GetFilledBuffer(N * BatchSize + 16, SmallFloatFill); + } + + MLFp16* C = BufferC.GetFilledBuffer(N * M * BatchSize, SmallFloatFill); + float* Cfloat = BufferFloatC.GetBuffer(N * M * BatchSize, true); + float* CReference = BufferCReference.GetFilledBuffer( + N * M * BatchSize, + [](float* start, size_t size) { + std::fill_n(start, size, -1.0f); + }); + MatrixGuardBuffer buffer_fp32_reference{}; + float* CFp32Reference = buffer_fp32_reference.GetFilledBuffer( + N * M * BatchSize, + [](float* start, size_t size) { + std::fill_n(start, size, -1.0f); + }); + + this->CallGemm(M, N, K, BatchSize, A, K, B, N, Bias, C, N, Cfloat, false, true); + ReferenceQgemm(M, N, K, BatchSize, A, B, Bias, CReference); + ReferenceMlasGemmFp32(M, N, K, BatchSize, A, B, Bias, CFp32Reference); + + for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++, f++) { + ASSERT_TRUE(CloseEnoughNativeFp16(float(C[f]), CFp32Reference[f])) << "@[" << batch << "x" << m << "x" << n << "], " + << "Batch=" << BatchSize << "M=" << M << ", N=" << N << ", K=" << K + << " got=" << float(C[f]) + << " fp32=" << CFp32Reference[f] + << " stepwise=" << CReference[f]; + } + } + } + } + + void TestNativeFp16WithoutOutputProcessor(size_t M, size_t N, size_t K, size_t BatchSize, bool withBias) { + static_assert(std::is_same_v); + static_assert(std::is_same_v); + TestKleidiAIWithoutOutputProcessor(M, N, K, BatchSize, withBias); + } + private: public: static const char* GetTestSuiteName() { diff --git a/onnxruntime/test/mlas/unittest/test_util.h b/onnxruntime/test/mlas/unittest/test_util.h index a000e353f370d..8c31ad943e319 100644 --- a/onnxruntime/test/mlas/unittest/test_util.h +++ b/onnxruntime/test/mlas/unittest/test_util.h @@ -44,6 +44,7 @@ class MatrixGuardBuffer { _BaseBuffer = nullptr; _BaseBufferSize = 0; _ElementsAllocated = 0; + _GuardAddress = nullptr; } ~MatrixGuardBuffer(void) { @@ -150,6 +151,7 @@ class MatrixGuardBuffer { } _ElementsAllocated = 0; + _GuardAddress = nullptr; } private: diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 9effdd7e5fb6e..05f9ac2b18224 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -118,6 +118,52 @@ TEST(GemmOpTest, GemmNoTrans_f16) { ConvertFloatToMLFloat16(A.data(), f_A.data(), 8); ConvertFloatToMLFloat16(B.data(), f_B.data(), 12); + { + // Missing C uses effective beta == 0. + std::vector f_Y(6); + std::vector Y{19.3f, -1.4f, -26.9f, + -19.3f, 1.4f, 26.9f}; + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); + + OpTester test("Gemm", 13); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 0.0f); + test.AddInput("A", {2, 4}, f_A); + test.AddInput("B", {4, 3}, f_B, true); + test.AddOutput("Y", {2, 3}, f_Y); + test.SetOutputTolerance(0.005f); + test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported + .Config(run_with_tunable_op) + .RunWithConfig(); + } + { + // beta == 0 ignores C, even when C has the full output shape. + std::vector f_Y(6); + std::vector Y{19.3f, -1.4f, -26.9f, + -19.3f, 1.4f, 26.9f}; + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); + + std::vector f_C(6); + ConvertFloatToMLFloat16(C.data(), f_C.data(), 6); + + OpTester test("Gemm", 13); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 0.0f); + test.AddInput("A", {2, 4}, f_A); + test.AddInput("B", {4, 3}, f_B); + test.AddInput("C", {2, 3}, f_C); + test.AddOutput("Y", {2, 3}, f_Y); + test.SetOutputTolerance(0.005f); + test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported + .Config(run_with_tunable_op) + .RunWithConfig(); + } { // bias has same shape as output std::vector f_Y(6); diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index f624ecf57d05e..e4d80b9212ba9 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -3,6 +3,7 @@ #include "gtest/gtest.h" +#include "core/mlas/inc/mlas.h" #include "test/providers/provider_test_utils.h" #include "test/common/dnnl_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" @@ -687,6 +688,68 @@ TEST(MathOpTest, MatMulBatchedSplitK) { .RunWithConfig(); } +TEST(MathOpTest, MatMulFloat16NativePrepackedWeightsAreNotShared) { + if (MlasHalfGemmNativePackBSize(CblasNoTrans, CblasNoTrans, 3, 4) == 0) { + GTEST_SKIP() << "Native fp16 MatMul prepack unavailable"; + } + + std::vector a_values{1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}; + std::vector b_values(12, 1.0f); + std::vector y_values{10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f}; + + std::vector a_fp16(8); + std::vector b_fp16(12); + std::vector y_fp16(6); + ConvertFloatToMLFloat16(a_values.data(), a_fp16.data(), a_fp16.size()); + ConvertFloatToMLFloat16(b_values.data(), b_fp16.data(), b_fp16.size()); + ConvertFloatToMLFloat16(y_values.data(), y_fp16.data(), y_fp16.size()); + + OpTester test("MatMul", 14); + test.AddInput("A", {2, 4}, a_fp16); + test.AddInput("B", {4, 3}, b_fp16, true); + test.AddOutput("Y", {2, 3}, y_fp16); + + OrtValue b; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({4, 3}), + b_fp16.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b); + + SessionOptions so; + ASSERT_EQ(so.AddInitializer("B", &b), Status::OK()); + ASSERT_EQ(so.config_options.AddConfigEntry("session.enable_cpu_fp16", "1"), Status::OK()); + ASSERT_EQ(so.config_options.AddConfigEntry("session.cpu_fp16_use_fp32_fallback_heuristic", "0"), Status::OK()); + + test.EnableSharingOfPrePackedWeightsAcrossSessions(); + + auto cpu_ep = []() -> std::vector> { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + return execution_providers; + }; + + size_t first_session_prepacked_weights = 0; + size_t shared_prepacked_weights = 0; + + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&first_session_prepacked_weights, &shared_prepacked_weights); + + ASSERT_EQ(shared_prepacked_weights, static_cast(0)); + ASSERT_GT(first_session_prepacked_weights, static_cast(0)); + ASSERT_EQ(test.GetNumPrePackedWeightsShared(), static_cast(0)); + + size_t second_session_prepacked_weights = 0; + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&second_session_prepacked_weights, &shared_prepacked_weights); + + ASSERT_EQ(second_session_prepacked_weights, first_session_prepacked_weights); + ASSERT_EQ(shared_prepacked_weights, static_cast(0)); +} + #endif } // namespace test diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 843d925ed6638..eb4ed757d1a94 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -9,7 +9,9 @@ #include "test/common/cuda_op_test_utils.h" #include "test/common/random_generator.h" #include "test/providers/provider_test_utils.h" +#include "test/util/include/scoped_env_vars.h" #include "default_providers.h" +#include "core/session/onnxruntime_session_options_config_keys.h" using namespace std; namespace onnxruntime { @@ -47,7 +49,8 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", int opset = 11, - float rel_error = 0.002f) { + float rel_error = 0.002f, + bool disable_kleidiai = false) { std::unique_ptr tester; if (!attributes.activation.empty()) { tester = std::make_unique("NhwcFusedConv", 1, onnxruntime::kMSDomain); @@ -89,6 +92,12 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, tester->AddOutput("Y", expected_output_shape, expected_output, /*no sort*/ false, rel_error, 0.0f); + if (disable_kleidiai) { + SessionOptions session_options; + ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsMlasDisableKleidiAi, "1")); + tester->Config(session_options); + } + std::unordered_set excluded_providers(attributes.excluded_providers); // Disable TensorRT because weight as input is not supported excluded_providers.insert(kTensorrtExecutionProvider); @@ -369,6 +378,150 @@ TEST(ConvFp16Test, Conv2D_1) { TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } +TEST(ConvFp16Test, Conv2D_KleidiAiImatmulEligibleNoBias) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{3, 3}, // kernel_shape + vector{1, 1, 1, 1}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), MLFloat16(4.0f), + MLFloat16(5.0f), MLFloat16(6.0f), MLFloat16(7.0f), MLFloat16(8.0f), + MLFloat16(9.0f), MLFloat16(10.0f), MLFloat16(11.0f), MLFloat16(12.0f), + MLFloat16(13.0f), MLFloat16(14.0f), MLFloat16(15.0f), MLFloat16(16.0f)}; + vector X_shape = {1, 1, 4, 4}; + vector W = { + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f)}; + vector W_shape = {2, 1, 3, 3}; + vector Y_shape = {1, 2, 4, 4}; + auto expected_vals = { + MLFloat16(14.0f), MLFloat16(24.0f), MLFloat16(30.0f), MLFloat16(22.0f), + MLFloat16(33.0f), MLFloat16(54.0f), MLFloat16(63.0f), MLFloat16(45.0f), + MLFloat16(57.0f), MLFloat16(90.0f), MLFloat16(99.0f), MLFloat16(69.0f), + MLFloat16(46.0f), MLFloat16(72.0f), MLFloat16(78.0f), MLFloat16(54.0f), + MLFloat16(7.0f), MLFloat16(12.0f), MLFloat16(15.0f), MLFloat16(11.0f), + MLFloat16(16.5f), MLFloat16(27.0f), MLFloat16(31.5f), MLFloat16(22.5f), + MLFloat16(28.5f), MLFloat16(45.0f), MLFloat16(49.5f), MLFloat16(34.5f), + MLFloat16(23.0f), MLFloat16(36.0f), MLFloat16(39.0f), MLFloat16(27.0f)}; + + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvFp16Test, Conv2D_KleidiAiImatmulEligibleBiasAndDisabledFallback) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{3, 3}, // kernel_shape + vector{1, 1, 1, 1}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), MLFloat16(4.0f), + MLFloat16(5.0f), MLFloat16(6.0f), MLFloat16(7.0f), MLFloat16(8.0f), + MLFloat16(9.0f), MLFloat16(10.0f), MLFloat16(11.0f), MLFloat16(12.0f), + MLFloat16(13.0f), MLFloat16(14.0f), MLFloat16(15.0f), MLFloat16(16.0f)}; + vector X_shape = {1, 1, 4, 4}; + vector W = { + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f)}; + vector W_shape = {2, 1, 3, 3}; + vector B = {MLFloat16(1.0f), MLFloat16(-2.0f)}; + vector B_shape = {2}; + vector Y_shape = {1, 2, 4, 4}; + auto expected_vals = { + MLFloat16(15.0f), MLFloat16(25.0f), MLFloat16(31.0f), MLFloat16(23.0f), + MLFloat16(34.0f), MLFloat16(55.0f), MLFloat16(64.0f), MLFloat16(46.0f), + MLFloat16(58.0f), MLFloat16(91.0f), MLFloat16(100.0f), MLFloat16(70.0f), + MLFloat16(47.0f), MLFloat16(73.0f), MLFloat16(79.0f), MLFloat16(55.0f), + MLFloat16(5.0f), MLFloat16(10.0f), MLFloat16(13.0f), MLFloat16(9.0f), + MLFloat16(14.5f), MLFloat16(25.0f), MLFloat16(29.5f), MLFloat16(20.5f), + MLFloat16(26.5f), MLFloat16(43.0f), MLFloat16(47.5f), MLFloat16(32.5f), + MLFloat16(21.0f), MLFloat16(34.0f), MLFloat16(37.0f), MLFloat16(25.0f)}; + + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + + constexpr bool weight_is_initializer = true; + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, weight_is_initializer); + + constexpr bool disable_kleidiai = true; + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, weight_is_initializer, + OpTester::ExpectResult::kExpectSuccess, "", 11, 0.002f, disable_kleidiai); +} + +TEST(ConvFp16Test, NhwcFusedConv2D_KleidiAiImatmulEligibleBiasAndDisabledFallback) { + auto run_test = [](bool disable_kleidiai) { + OpTester test("NhwcFusedConv", 1, onnxruntime::kMSDomain); + test.AddAttribute("group", static_cast(1)); + test.AddAttribute("kernel_shape", vector{3, 3}); + test.AddAttribute("pads", vector{1, 1, 1, 1}); + test.AddAttribute("strides", vector{1, 1}); + test.AddAttribute("dilations", vector{1, 1}); + + vector X = { + MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), MLFloat16(4.0f), + MLFloat16(5.0f), MLFloat16(6.0f), MLFloat16(7.0f), MLFloat16(8.0f), + MLFloat16(9.0f), MLFloat16(10.0f), MLFloat16(11.0f), MLFloat16(12.0f), + MLFloat16(13.0f), MLFloat16(14.0f), MLFloat16(15.0f), MLFloat16(16.0f)}; + vector X_shape = {1, 4, 4, 1}; + vector W = { + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f)}; + vector W_shape = {2, 1, 3, 3}; + vector B = {MLFloat16(1.0f), MLFloat16(-2.0f)}; + vector B_shape = {2}; + vector Y_shape = {1, 4, 4, 2}; + auto expected_vals = { + MLFloat16(15.0f), MLFloat16(5.0f), MLFloat16(25.0f), MLFloat16(10.0f), + MLFloat16(31.0f), MLFloat16(13.0f), MLFloat16(23.0f), MLFloat16(9.0f), + MLFloat16(34.0f), MLFloat16(14.5f), MLFloat16(55.0f), MLFloat16(25.0f), + MLFloat16(64.0f), MLFloat16(29.5f), MLFloat16(46.0f), MLFloat16(20.5f), + MLFloat16(58.0f), MLFloat16(26.5f), MLFloat16(91.0f), MLFloat16(43.0f), + MLFloat16(100.0f), MLFloat16(47.5f), MLFloat16(70.0f), MLFloat16(32.5f), + MLFloat16(47.0f), MLFloat16(21.0f), MLFloat16(73.0f), MLFloat16(34.0f), + MLFloat16(79.0f), MLFloat16(37.0f), MLFloat16(55.0f), MLFloat16(25.0f)}; + + test.AddInput("X", X_shape, X); + test.AddInput("W", W_shape, W, true); + test.AddInput("B", B_shape, B, true); + test.AddOutput("Y", Y_shape, expected_vals, /*no sort*/ false, 0.002f, 0.0f); + + if (disable_kleidiai) { + SessionOptions session_options; + ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsMlasDisableKleidiAi, "1")); + test.Config(session_options); + } + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.ConfigEps(std::move(execution_providers)).RunWithConfig(); + }; + + run_test(false); + run_test(true); +} + TEST(ConvFp16Test, Conv2D_2) { ConvOpAndTestAttributes attrs = { "", // auto_pad @@ -1496,6 +1649,165 @@ TEST(ConvFp16Test, SharedPrepackedWeights) { } } +TEST(ConvFp16Test, SharedPrepackedWeights_HalfConvEligible_NoBias) { + OpTester test("Conv", 11); + test.AddAttribute("group", static_cast(1)); + test.AddAttribute("kernel_shape", vector{3, 3}); + test.AddAttribute("pads", vector{1, 1, 1, 1}); + test.AddAttribute("strides", vector{1, 1}); + test.AddAttribute("dilations", vector{1, 1}); + + vector X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), MLFloat16(4.0f), + MLFloat16(5.0f), MLFloat16(6.0f), MLFloat16(7.0f), MLFloat16(8.0f), + MLFloat16(9.0f), MLFloat16(10.0f), MLFloat16(11.0f), MLFloat16(12.0f), + MLFloat16(13.0f), MLFloat16(14.0f), MLFloat16(15.0f), MLFloat16(16.0f)}; + vector X_shape = {1, 1, 4, 4}; + vector W = {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f)}; + vector W_shape = {2, 1, 3, 3}; + vector Y_shape = {1, 2, 4, 4}; + auto expected_vals = { + MLFloat16(14.0f), MLFloat16(24.0f), MLFloat16(30.0f), MLFloat16(22.0f), MLFloat16(33.0f), MLFloat16(54.0f), + MLFloat16(63.0f), MLFloat16(45.0f), MLFloat16(57.0f), MLFloat16(90.0f), MLFloat16(99.0f), MLFloat16(69.0f), + MLFloat16(46.0f), MLFloat16(72.0f), MLFloat16(78.0f), MLFloat16(54.0f), MLFloat16(7.0f), MLFloat16(12.0f), + MLFloat16(15.0f), MLFloat16(11.0f), MLFloat16(16.5f), MLFloat16(27.0f), MLFloat16(31.5f), MLFloat16(22.5f), + MLFloat16(28.5f), MLFloat16(45.0f), MLFloat16(49.5f), MLFloat16(34.5f), MLFloat16(23.0f), MLFloat16(36.0f), + MLFloat16(39.0f), MLFloat16(27.0f)}; + + test.AddInput("X", X_shape, X); + test.AddInput("W", W_shape, W, true); + test.AddOutput("Y", Y_shape, expected_vals, /*no sort*/ false, 0.002f, 0.0f); + + OrtValue w; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape(W_shape), + W.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), w); + + SessionOptions so; + ASSERT_EQ(so.AddInitializer("W", &w), Status::OK()); + + test.EnableSharingOfPrePackedWeightsAcrossSessions(); + + auto cpu_ep = []() -> std::vector> { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + return execution_providers; + }; + + size_t number_of_pre_packed_weights_counter_session_1 = 0; + size_t number_of_shared_pre_packed_weights_counter = 0; + + { + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&number_of_pre_packed_weights_counter_session_1, &number_of_shared_pre_packed_weights_counter); + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(0)); + } + + const auto number_of_elements_in_shared_prepacked_buffers_container = test.GetNumPrePackedWeightsShared(); + + if (number_of_pre_packed_weights_counter_session_1 == 0) { + GTEST_SKIP() << "No pre-packed weights were produced."; + } + + ASSERT_EQ(number_of_elements_in_shared_prepacked_buffers_container, static_cast(1)); + + { + size_t number_of_pre_packed_weights_counter_session_2 = 0; + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&number_of_pre_packed_weights_counter_session_2, &number_of_shared_pre_packed_weights_counter); + + ASSERT_GE(number_of_pre_packed_weights_counter_session_1, number_of_shared_pre_packed_weights_counter); + ASSERT_GE(number_of_pre_packed_weights_counter_session_2, number_of_shared_pre_packed_weights_counter); + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(1)); + } +} + +TEST(ConvFp16Test, SharedPrepackedWeights_HalfConvEligible_BiasNotShared) { + OpTester test("Conv", 11); + test.AddAttribute("group", static_cast(1)); + test.AddAttribute("kernel_shape", vector{3, 3}); + test.AddAttribute("pads", vector{1, 1, 1, 1}); + test.AddAttribute("strides", vector{1, 1}); + test.AddAttribute("dilations", vector{1, 1}); + + vector X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), MLFloat16(4.0f), + MLFloat16(5.0f), MLFloat16(6.0f), MLFloat16(7.0f), MLFloat16(8.0f), + MLFloat16(9.0f), MLFloat16(10.0f), MLFloat16(11.0f), MLFloat16(12.0f), + MLFloat16(13.0f), MLFloat16(14.0f), MLFloat16(15.0f), MLFloat16(16.0f)}; + vector X_shape = {1, 1, 4, 4}; + vector W = {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f)}; + vector W_shape = {2, 1, 3, 3}; + vector B = {MLFloat16(1.0f), MLFloat16(-2.0f)}; + vector B_shape = {2}; + vector Y_shape = {1, 2, 4, 4}; + auto expected_vals = { + MLFloat16(15.0f), MLFloat16(25.0f), MLFloat16(31.0f), MLFloat16(23.0f), MLFloat16(34.0f), MLFloat16(55.0f), + MLFloat16(64.0f), MLFloat16(46.0f), MLFloat16(58.0f), MLFloat16(91.0f), MLFloat16(100.0f), MLFloat16(70.0f), + MLFloat16(47.0f), MLFloat16(73.0f), MLFloat16(79.0f), MLFloat16(55.0f), MLFloat16(5.0f), MLFloat16(10.0f), + MLFloat16(13.0f), MLFloat16(9.0f), MLFloat16(14.5f), MLFloat16(25.0f), MLFloat16(29.5f), MLFloat16(20.5f), + MLFloat16(26.5f), MLFloat16(43.0f), MLFloat16(47.5f), MLFloat16(32.5f), MLFloat16(21.0f), MLFloat16(34.0f), + MLFloat16(37.0f), MLFloat16(25.0f)}; + + test.AddInput("X", X_shape, X); + test.AddInput("W", W_shape, W, true); + test.AddInput("B", B_shape, B, true); + test.AddOutput("Y", Y_shape, expected_vals, /*no sort*/ false, 0.002f, 0.0f); + + OrtValue w; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape(W_shape), + W.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), w); + + SessionOptions so; + ASSERT_EQ(so.AddInitializer("W", &w), Status::OK()); + + test.EnableSharingOfPrePackedWeightsAcrossSessions(); + + auto cpu_ep = []() -> std::vector> { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + return execution_providers; + }; + + size_t number_of_pre_packed_weights_counter_session_1 = 0; + size_t number_of_shared_pre_packed_weights_counter = 0; + + { + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&number_of_pre_packed_weights_counter_session_1, &number_of_shared_pre_packed_weights_counter); + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(0)); + } + + const auto number_of_elements_in_shared_prepacked_buffers_container = test.GetNumPrePackedWeightsShared(); + + if (number_of_pre_packed_weights_counter_session_1 == 0) { + GTEST_SKIP() << "No pre-packed weights were produced."; + } + + ASSERT_EQ(number_of_elements_in_shared_prepacked_buffers_container, static_cast(1)); + + { + size_t number_of_pre_packed_weights_counter_session_2 = 0; + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&number_of_pre_packed_weights_counter_session_2, &number_of_shared_pre_packed_weights_counter); + + ASSERT_GE(number_of_pre_packed_weights_counter_session_1, number_of_shared_pre_packed_weights_counter); + ASSERT_GE(number_of_pre_packed_weights_counter_session_2, number_of_shared_pre_packed_weights_counter); + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(1)); + } +} + #endif } // namespace test From 6b7463e14a0c5c18e59ce051058a5d2b0b85a5fa Mon Sep 17 00:00:00 2001 From: Cathal Lawlor Date: Wed, 13 May 2026 17:05:55 +0100 Subject: [PATCH 2/8] Add CPU FP16 fallback guards Ensure CPU-assigned fp16 nodes without a matching CPU fp16 kernel are routed through the existing fp32 cast fallback when a float CPU kernel exists. This preserves main-compatible behavior for fused nodes such as BiasGelu. Also split the graph-output cast coverage so the fallback behavior is easier to review. Signed-off-by: Cathal Lawlor --- .../mlas/lib/kleidiai/halfconv_kleidiai.cpp | 7 +- .../core/optimizer/insert_cast_transformer.cc | 15 ++++ .../framework/insert_cast_transformer_test.cc | 74 ++++++++++++++++--- 3 files changed, 83 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/mlas/lib/kleidiai/halfconv_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/halfconv_kleidiai.cpp index 2a0efe3aab32f..f1261b399e380 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/halfconv_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/halfconv_kleidiai.cpp @@ -65,8 +65,11 @@ TryComputeConvOutSize( size_t total_padding = 0; size_t padded_input = 0; if (mul_overflow_size_t_builtin(padding, 2, &total_padding) || - !TryAddSize(input, total_padding, padded_input) || - padded_input < kernel) { + !TryAddSize(input, total_padding, padded_input)) { + return false; + } + + if (padded_input < kernel) { return true; } diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 4ccc3caaa08e1..7eec669722728 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -421,6 +421,17 @@ static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node, return false; } +static bool CpuAssignedFp16NodeNeedsFallbackCast( + const onnxruntime::Node& node, + const InlinedVector>& cpu_kernel_registries, + const logging::Logger& logger) { + return node.GetExecutionProviderType() == kCpuExecutionProvider && + !node.ContainsSubgraph() && + HasFp16IO(node) && + !HasCpuKernelForCurrentTypes(node, cpu_kernel_registries, logger) && + HasCpuFloat32FallbackKernel(node, cpu_kernel_registries, logger); +} + // Detect an isolated node that is able to process fp16 data but is between other nodes that have fp16 inputs // but will need a Cast inserted to enable them to run. // @@ -889,6 +900,10 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie if (!node) return Status(ONNXRUNTIME, INVALID_ARGUMENT); + if (CpuAssignedFp16NodeNeedsFallbackCast(*node, cpu_kernel_registries_, logger)) { + node->SetExecutionProviderType(""); + } + if (!enable_cpu_fp16_ && node->GetExecutionProviderType() == kCpuExecutionProvider && IsCpuFp16OptInPolicyOp(*node) && diff --git a/onnxruntime/test/framework/insert_cast_transformer_test.cc b/onnxruntime/test/framework/insert_cast_transformer_test.cc index 1f9e60c494e65..85331f8646a9d 100644 --- a/onnxruntime/test/framework/insert_cast_transformer_test.cc +++ b/onnxruntime/test/framework/insert_cast_transformer_test.cc @@ -649,6 +649,10 @@ TEST(TransformerTest, CpuFp16MatMulHeuristicForcesBertLikeShapeToFp32) { } TEST(TransformerTest, CpuFp16MatMulHeuristicKeepsLargeGemvNativeFp16) { + if (!CanRunNativeCpuFp16GemmRuntime()) { + GTEST_SKIP() << "Native CPU fp16 Gemm/MatMul runtime support is unavailable."; + } + auto model = MakeCpuFp16MatMulModelWithShapes("cpu_fp16_matmul_heuristic_large_gemv", {1, 4096}, {4096, 4096}, {1, 4096}, true); auto& graph = model->MainGraph(); @@ -668,6 +672,10 @@ TEST(TransformerTest, CpuFp16MatMulHeuristicKeepsLargeGemvNativeFp16) { } TEST(TransformerTest, CpuFp16MatMulHeuristicKeepsBatchedLargeGemvNativeFp16) { + if (!CanRunNativeCpuFp16GemmRuntime()) { + GTEST_SKIP() << "Native CPU fp16 Gemm/MatMul runtime support is unavailable."; + } + auto model = MakeCpuFp16MatMulModelWithShapes("cpu_fp16_matmul_heuristic_batched_large_gemv", {8, 1, 4096}, {8, 4096, 4096}, {8, 1, 4096}, true); auto& graph = model->MainGraph(); @@ -850,6 +858,28 @@ TEST(TransformerTest, CpuFp16UnsupportedOpStillGetsCastsWhenEnabled) { EXPECT_EQ(abs_node->GetExecutionProviderType(), kCpuExecutionProvider); } +TEST(TransformerTest, CpuFp16UnsupportedCpuAssignedOpStillGetsCasts) { + auto model = MakeCpuFp16Model("cpu_fp16_abs_cpu_assigned", "Abs", true); + auto& graph = model->MainGraph(); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + false, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.at("Cast"), 2); + + const Node* abs_node = FindNodeByOpType(graph, "Abs"); + ASSERT_NE(abs_node, nullptr); + EXPECT_EQ(abs_node->GetExecutionProviderType(), kCpuExecutionProvider); + ASSERT_EQ(abs_node->InputDefs().size(), 1U); + ASSERT_EQ(abs_node->OutputDefs().size(), 1U); + EXPECT_TRUE(IsNodeArgType(*abs_node->InputDefs()[0], DataTypeImpl::GetTensorType())); + EXPECT_TRUE(IsNodeArgType(*abs_node->OutputDefs()[0], DataTypeImpl::GetTensorType())); +} + TEST(TransformerTest, CpuFp16SupportedCpuOpKeepsFp16WhenEnabled) { if (!MlasFp16AccelerationSupported()) { GTEST_SKIP() << "CPU fp16 kernels are not registered on this platform."; @@ -1015,21 +1045,22 @@ TEST(TransformerTest, IsIsolatedFp16NodeOnCpuTest) { o4_def("O4", &tensor_float_16), o5_def("O5", &tensor_float_16); - // for the sake of this example, pretend Clip has no fp16 kernel but Abs does - // -> Clip -> Abs -> Clip -> Abs -> Clip -> + // For the sake of this example, Clip requires fp32 fallback while Identity + // can run fp16 on CPU. + // -> Clip -> Identity -> Clip -> Identity -> Clip -> // | | // - O4 - O5 auto& node1 = graph.AddNode("node1", "Clip", "no fp16", {&i1_def}, {&o1_def}); - auto& node2 = graph.AddNode("node2", "Abs", "fp16", {&o1_def}, {&o2_def}); + auto& node2 = graph.AddNode("node2", "Identity", "fp16", {&o1_def}, {&o2_def}); auto& node3 = graph.AddNode("node3", "Clip", "no fp16", {&o2_def}, {&o3_def}); - auto& node4 = graph.AddNode("node4", "Abs", "fp16 producing graph output", {&o3_def}, {&o4_def}); + auto& node4 = graph.AddNode("node4", "Identity", "fp16 producing graph output", {&o3_def}, {&o4_def}); auto& node5 = graph.AddNode("node5", "Clip", "no fp16", {&o4_def}, {&o5_def}); // manually set outputs as we want O4 and well as O5 to be graph outputs. // AddNode creates a NodeArg instance in Graph so need to get address from the node graph.SetOutputs({node4.OutputDefs()[0], node5.OutputDefs()[0]}); - // node2 and node4 have a kernel + // node2 and node4 are pre-assigned to CPU. node2.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); node4.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); @@ -1047,12 +1078,8 @@ TEST(TransformerTest, IsIsolatedFp16NodeOnCpuTest) { }; // we expect: - // node2 Abs to get forced to fp32 as it's isolated between node1 and node3 which need Casts - // node4 Abs should not get forced to fp32 as it produces a graph output - // - // -> CastFp32 -> Clip -> Abs -> Clip -> CastFp16 -> Abs -> CastFp32 -> Clip -> CastFp16 - // | | - // - O4 - O5 + // node2 Identity to get forced to fp32 as it's isolated between node1 and node3 which need Casts. + // node4 Identity to stay fp16 as it produces graph output O4. EXPECT_TRUE(is_type(*node1.InputDefs()[0], DataTypeImpl::GetTensorType())); EXPECT_TRUE(is_type(*node2.InputDefs()[0], DataTypeImpl::GetTensorType())); EXPECT_TRUE(is_type(*node3.InputDefs()[0], DataTypeImpl::GetTensorType())); @@ -1063,6 +1090,31 @@ TEST(TransformerTest, IsIsolatedFp16NodeOnCpuTest) { EXPECT_EQ(ops["Cast"], 4); } +TEST(TransformerTest, CpuAssignedFp16FallbackPreservesGraphOutputType) { + auto model = MakeCpuFp16Model("cpu_fp16_abs_cpu_assigned_graph_output", "Abs", true); + auto& graph = model->MainGraph(); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); + + bool modified = true; + EXPECT_TRUE(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()).IsOK()); + + auto is_type = [](const NodeArg& node_arg, const MLDataType type) { + return node_arg.Type() != nullptr && + DataTypeImpl::TypeFromProto(*node_arg.TypeAsProto()) == type; + }; + + const Node* abs_node = FindNodeByOpType(graph, "Abs"); + ASSERT_NE(abs_node, nullptr); + EXPECT_TRUE(is_type(*abs_node->InputDefs()[0], DataTypeImpl::GetTensorType())); + EXPECT_TRUE(is_type(*abs_node->OutputDefs()[0], DataTypeImpl::GetTensorType())); + ASSERT_EQ(graph.GetOutputs().size(), 1U); + EXPECT_TRUE(is_type(*graph.GetOutputs()[0], DataTypeImpl::GetTensorType())); + + auto ops = CountOpsInGraph(graph); + EXPECT_EQ(ops["Cast"], 2); +} + // Verify that RemoveDuplicateCastTransformer does not fuse Cast(float->int32)->Cast(int32->bool) // because the intermediate int32 truncation changes semantics (e.g. -0.1 -> 0 -> false vs -0.1 -> true). // Regression test for https://github.com/microsoft/onnxruntime/issues/28089 From 30911477f096a581d9b4130852e708e13fdab82b Mon Sep 17 00:00:00 2001 From: Cathal Lawlor Date: Wed, 13 May 2026 18:06:00 +0100 Subject: [PATCH 3/8] Skip CPU fp16 tests if acceleration is not supported Signed-off-by: Cathal Lawlor --- onnxruntime/test/framework/insert_cast_transformer_test.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnxruntime/test/framework/insert_cast_transformer_test.cc b/onnxruntime/test/framework/insert_cast_transformer_test.cc index 85331f8646a9d..5a61ed94beaaa 100644 --- a/onnxruntime/test/framework/insert_cast_transformer_test.cc +++ b/onnxruntime/test/framework/insert_cast_transformer_test.cc @@ -934,6 +934,10 @@ TEST(TransformerTest, CpuFp16CpuAssignedNewOptInOpsUseFp32FallbackWhenDisabled) } TEST(TransformerTest, CpuFp16CpuAssignedExistingFp16OpHasNoExtraCastsWhenDisabled) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP() << "CPU fp16 kernels are not registered on this platform."; + } + auto model = MakeCpuFp16Model("cpu_fp16_add_disabled_cpu_assigned", "Add", true); auto& graph = model->MainGraph(); auto* node = graph.Nodes().begin().operator->(); From a504a0c3c2e11f45764a002e8d60785472b92b92 Mon Sep 17 00:00:00 2001 From: Cathal Lawlor Date: Mon, 18 May 2026 14:48:01 +0100 Subject: [PATCH 4/8] Fix CPU fp16 portability and packing issues Signed-off-by: Cathal Lawlor --- docs/OperatorKernels.md | 14 +++---- onnxruntime/core/mlas/lib/halfgemm.h | 38 +++++++++++++++---- .../mlas/lib/kleidiai/halfgemm_kleidiai.cpp | 11 ++++-- .../core/optimizer/insert_cast_transformer.cc | 13 +++++-- .../core/optimizer/insert_cast_transformer.h | 3 +- onnxruntime/core/providers/cpu/math/matmul.cc | 4 +- .../test/contrib_ops/nhwc_pool_in_op_test.cc | 19 ++++++++-- .../test/mlas/unittest/test_halfgemm.cpp | 38 +++++++++++++++++++ 8 files changed, 113 insertions(+), 27 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d0d8e750285d4..a0c8ae3dfe4ca 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -174,10 +174,10 @@ The **OpSet Version** column uses the following notation: |||12|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)| |||11|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)| |Gelu|*in* X:**T**
*out* Y:**T**|20+|**T** = tensor(float)| -|Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| -|||[11, 12]|**T** = tensor(double), tensor(float)| -|||[9, 10]|**T** = tensor(double), tensor(float)| -|||[7, 8]|**T** = tensor(double), tensor(float)| +|Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[9, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| |GlobalAveragePool|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(float)| |||[1, 21]|**T** = tensor(float)| |GlobalLpPool|*in* X:**T**
*out* Y:**T**|2+|**T** = tensor(float)| @@ -258,9 +258,9 @@ The **OpSet Version** column uses the following notation: |||[18, 21]|**T** = tensor(float)| |||[11, 17]|**T** = tensor(float)| |||[2, 10]|**T** = tensor(float)| -|MatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||[1, 8]|**T** = tensor(double), tensor(float)| +|MatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||[1, 8]|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulInteger|*in* A:**T1**
*in* B:**T2**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*out* Y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int32)| |Max|*in* data_0:**T**
*out* max:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| |||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/mlas/lib/halfgemm.h b/onnxruntime/core/mlas/lib/halfgemm.h index 1eaedbb1ccef0..4d0a9b9888b3a 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.h +++ b/onnxruntime/core/mlas/lib/halfgemm.h @@ -32,9 +32,11 @@ Module Name: #pragma once -#include #include +#include +#include #include +#include #include #include "mlasi.h" @@ -112,26 +114,46 @@ MlasHalfGemmCopyPackB( size_t CountK ) { + size_t aligned_count_k_input = 0; + if (MlasTryAddSizeT(CountK, KernelType::PackedK - 1, &aligned_count_k_input)) { + MLAS_THROW_EX(std::runtime_error, "MlasHalfGemmCopyPackB aligned K overflow"); + } + const size_t AlignedCountK = aligned_count_k_input & ~(KernelType::PackedK - 1); + size_t PaddingCountK = AlignedCountK - CountK; + if (ldb == CountN) { size_t bytes_to_copy = 0; - ORT_ENFORCE( - !MlasTryMultiplySizeT(CountK, CountN, &bytes_to_copy) && - !MlasTryMultiplySizeT(bytes_to_copy, sizeof(_mlas_fp16_), &bytes_to_copy), - "MlasHalfGemmCopyPackB size overflow"); + if (MlasTryMultiplySizeT(CountK, CountN, &bytes_to_copy) || + MlasTryMultiplySizeT(bytes_to_copy, sizeof(_mlas_fp16_), &bytes_to_copy)) { + MLAS_THROW_EX(std::runtime_error, "MlasHalfGemmCopyPackB size overflow"); + } std::memcpy(D, B, bytes_to_copy); + if (PaddingCountK > 0) { + size_t padding_bytes = 0; + if (MlasTryMultiplySizeT(PaddingCountK, CountN, &padding_bytes) || + MlasTryMultiplySizeT(padding_bytes, sizeof(_mlas_fp16_), &padding_bytes)) { + MLAS_THROW_EX(std::runtime_error, "MlasHalfGemmCopyPackB padding size overflow"); + } + std::memset(D + CountK * CountN, 0, padding_bytes); + } return; } size_t row_bytes = 0; - ORT_ENFORCE( - !MlasTryMultiplySizeT(CountN, sizeof(_mlas_fp16_), &row_bytes), - "MlasHalfGemmCopyPackB row size overflow"); + if (MlasTryMultiplySizeT(CountN, sizeof(_mlas_fp16_), &row_bytes)) { + MLAS_THROW_EX(std::runtime_error, "MlasHalfGemmCopyPackB row size overflow"); + } while (CountK > 0) { std::memcpy(D, B, row_bytes); B += ldb; D += CountN; CountK--; } + while (PaddingCountK > 0) { + std::memset(D, 0, row_bytes); + D += CountN; + PaddingCountK--; + } } /** diff --git a/onnxruntime/core/mlas/lib/kleidiai/halfgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/halfgemm_kleidiai.cpp index e93f1dbc485b1..bc10db249466d 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/halfgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/halfgemm_kleidiai.cpp @@ -4,9 +4,10 @@ // SPDX-License-Identifier: MIT // -#include -#include +#include #include +#include +#include #include "mlas.h" #include "mlasi_kleidiai.h" @@ -130,6 +131,7 @@ ArmKleidiAI::MlasHalfGemmBatch( MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig); // Validate all batch entries up front so we never partially execute and then // fall back (which would corrupt results for the already-written outputs). + bool needs_rhs_packing = false; for (size_t b = 0; b < BatchN; ++b) { const auto& data = DataParams[b]; if (data.OutputProcessor != nullptr) { @@ -141,6 +143,9 @@ ArmKleidiAI::MlasHalfGemmBatch( if (data.BIsBackendNativePacked && data.ldb != 0) { return false; } + // Native-packed RHS is consumed directly below. Only allocate the + // runtime RHS packing scratch when at least one batch entry needs it. + needs_rhs_packing = needs_rhs_packing || !data.BIsBackendNativePacked; } const auto& hgemm = GetKleidiAIHgemmUKernel(); @@ -158,7 +163,7 @@ ArmKleidiAI::MlasHalfGemmBatch( const float clamp_min = -std::numeric_limits::infinity(); const float clamp_max = std::numeric_limits::infinity(); - if (!TryResizeVector(g_kai_half_tls.rhs_packed, packed_rhs_size)) { + if (needs_rhs_packing && !TryResizeVector(g_kai_half_tls.rhs_packed, packed_rhs_size)) { return false; } diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 7eec669722728..dcf49882ecc73 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -33,7 +33,13 @@ static bool IsMLFloat16Tensor(const NodeArg& node_arg) { return IsTensorOfType(node_arg); } -bool InsertCastTransformer::NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const { +static bool HasCpuFloat32FallbackKernel( + const onnxruntime::Node& node, + const InlinedVector>& cpu_kernel_registries, + const logging::Logger& logger); + +bool InsertCastTransformer::NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input, + const logging::Logger& logger) const { // Returns true when this input is an fp16 input to an unassigned node that is eligible // for the cast-to-fp32 fallback path. // @@ -41,7 +47,8 @@ bool InsertCastTransformer::NeedInsertCast(const onnxruntime::Node* node, const // inputs safely requires additional checks of the subgraph boundaries and contents. return node->GetExecutionProviderType().empty() && !node->ContainsSubgraph() && - IsMLFloat16Tensor(*input); + IsMLFloat16Tensor(*input) && + HasCpuFloat32FallbackKernel(*node, cpu_kernel_registries_, logger); } static bool HasFp16IO(const onnxruntime::Node& node) { @@ -945,7 +952,7 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie std::map replacement_defs; bool casted = false; for (auto input : inputs) { - if (NeedInsertCast(node, input)) { + if (NeedInsertCast(node, input, logger)) { auto src_arg = input; if (input_def_updates.count(src_arg)) { replacement_defs[src_arg] = input_def_updates[src_arg]; diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.h b/onnxruntime/core/optimizer/insert_cast_transformer.h index abe2a40a687be..d3a86384e60dd 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.h +++ b/onnxruntime/core/optimizer/insert_cast_transformer.h @@ -59,7 +59,8 @@ class InsertCastTransformer : public onnxruntime::GraphTransformer { private: Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; - bool NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const; + bool NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input, + const logging::Logger& logger) const; const InlinedVector> cpu_kernel_registries_; diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 594b732683000..6181045f48102 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -130,7 +130,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { if (helper.K() == 0) { // When we have (M, 0, N) then the inputs are empty, but the output should // be filled out with zeros. - auto output_span = gsl::make_span(y->MutableData(), y->Shape().Size()); + auto output_span = gsl::make_span(y->MutableData(), narrow(y->Shape().Size())); std::fill(output_span.begin(), output_span.end(), T{}); return Status::OK(); } @@ -393,7 +393,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { } if (helper.K() == 0) { - auto output_span = gsl::make_span(y->MutableData(), y->Shape().Size()); + auto output_span = gsl::make_span(y->MutableData(), narrow(y->Shape().Size())); std::fill(output_span.begin(), output_span.end(), MLFloat16{}); return Status::OK(); } diff --git a/onnxruntime/test/contrib_ops/nhwc_pool_in_op_test.cc b/onnxruntime/test/contrib_ops/nhwc_pool_in_op_test.cc index e40e635c79a26..c58d813e1b061 100644 --- a/onnxruntime/test/contrib_ops/nhwc_pool_in_op_test.cc +++ b/onnxruntime/test/contrib_ops/nhwc_pool_in_op_test.cc @@ -7,6 +7,7 @@ #include #include +#include #include "core/util/math.h" #include "core/mlas/inc/mlas.h" @@ -17,6 +18,16 @@ namespace onnxruntime { namespace test { #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +namespace { + +// These tests target ORT's CPU implementation of the MS-internal NHWC fp16 pool +// ops. Do not offer the models to EPs that may be registered by the test +// harness but do not own this internal-domain CPU/MLAS coverage. +const std::unordered_set kNhwcFp16PoolExcludedProviders{ + kCoreMLExecutionProvider, + kTensorrtExecutionProvider, +}; + class NhwcFp16PoolOpTester { private: bool is_max_pool_; // max or average pool @@ -181,10 +192,12 @@ class NhwcFp16PoolOpTester { if (!dilations_.empty()) { test.AddAttribute("dilations", dilations_); } - test.Run(OpTester::ExpectResult::kExpectSuccess, ""); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", kNhwcFp16PoolExcludedProviders); } }; +} // namespace + TEST(NhwcFp16PoolOpTest, MaxPool1D) { for (int64_t channels = 1; channels < 94; channels++) { NhwcFp16PoolOpTester test(true); @@ -303,7 +316,7 @@ TEST(NhwcFp16PoolOpTest, AvgPoolIncludePadPixel) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", kNhwcFp16PoolExcludedProviders); } TEST(NhwcFp16PoolOpTest, GlobalAveragePool) { @@ -508,7 +521,7 @@ TEST(NhwcFp16PoolOpTest, GlobalAveragePool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", kNhwcFp16PoolExcludedProviders); } #endif diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp index d2b686c29b6ca..9639c24ed90ae 100644 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp @@ -15,6 +15,7 @@ Module Name: --*/ #include "test_halfgemm.h" +#include "core/mlas/lib/halfgemm.h" #if defined(USE_KLEIDIAI) #include "core/mlas/lib/mlasi.h" #include "core/mlas/lib/kleidiai/mlasi_kleidiai.h" @@ -25,6 +26,14 @@ Module Name: #include #include +namespace { + +struct HalfGemmPackBPaddingKernel { + static constexpr size_t PackedK = 4; +}; + +} // namespace + #if defined(USE_KLEIDIAI) namespace { @@ -431,6 +440,31 @@ TEST(HalfGemmPackB, ReturnsZeroOnOverflow) { EXPECT_EQ(MlasHalfGemmPackBSize(max, 2, true), size_t{0}); } +TEST(HalfGemmPackB, CopyPackBZeroPadsAlignedKTail) { + constexpr size_t N = 5; + constexpr size_t K = 3; + constexpr size_t AlignedK = 4; + + std::vector<_mlas_fp16_> b(K * N); + for (size_t i = 0; i < b.size(); ++i) { + b[i] = static_cast<_mlas_fp16_>(i + 1); + } + + constexpr _mlas_fp16_ stale_tail_value = 0xFFFF; + std::vector<_mlas_fp16_> packed(AlignedK * N, stale_tail_value); + + MlasHalfGemmCopyPackB( + packed.data(), + b.data(), + N, + N, + K); + + for (size_t n = 0; n < N; ++n) { + EXPECT_EQ(packed[K * N + n], 0) << "n=" << n; + } +} + TEST(HalfGemmPackB, GenericPackedBFlagRunsOnFallback) { if (!MlasFp16AccelerationSupported()) { GTEST_SKIP(); @@ -682,7 +716,9 @@ TEST(HalfGemmKleidiAIPath, KleidiAIPackedBWithBiasThrows) { data.BIsBackendNativePacked = true; ASSERT_FALSE(ArmKleidiAI::MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr)); +#if !defined(ORT_NO_EXCEPTIONS) EXPECT_THROW(MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr), std::runtime_error); +#endif } TEST(HalfGemmKleidiAIPath, KleidiAIPackedBWithOutputProcessorThrows) { @@ -727,7 +763,9 @@ TEST(HalfGemmKleidiAIPath, KleidiAIPackedBWithOutputProcessorThrows) { data.OutputProcessor = &output_processor; ASSERT_FALSE(ArmKleidiAI::MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr)); +#if !defined(ORT_NO_EXCEPTIONS) EXPECT_THROW(MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr), std::runtime_error); +#endif } TEST(HalfGemmKleidiAIPath, ZeroKFallsBack) { From 248399052eda5d3798b48cc0ea4d13a3a2aa9a75 Mon Sep 17 00:00:00 2001 From: Cathal Lawlor Date: Tue, 19 May 2026 17:20:18 +0100 Subject: [PATCH 5/8] Fix CPU fp16 test expectations for native ARM64 paths Native ARM64 fp16 execution can now exercise provider paths that were previously hidden by fp32 fallback behaviour. Update the affected tests to reflect supported provider coverage and the expected numerical drift from HQNBIT_CompFp16 fp16 accumulation on large-K MatMulNBits cases. Signed-off-by: Cathal Lawlor cathal.lawlor@arm.com --- onnxruntime/test/contrib_ops/attention_op_test.cc | 6 +++++- onnxruntime/test/contrib_ops/matmul_4bits_test.cc | 2 +- onnxruntime/test/optimizer/nhwc_transformer_test.cc | 5 ++++- onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc | 8 ++++++++ onnxruntime/test/providers/cpu/tensor/resize_op_test.cc | 9 +++++++++ 5 files changed, 27 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 6268e425ebb61..ad4c8941da782 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -521,8 +521,12 @@ TEST(ContribOpAttentionTest, AttentionBatch1_Float16) { 3.154296875, 0.1082763671875, 4.25, 5.6484375, 3.970703125, 0.072998046875, 4.25, 5.6484375}; + // WebGPU Attention does not support mask_index input. + constexpr bool disable_webgpu = true; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, - batch_size, sequence_length, hidden_size, number_of_heads, true); + batch_size, sequence_length, hidden_size, number_of_heads, true /*use_float16*/, + false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, 0, + false /*disable_cpu*/, false /*disable_cuda*/, false /*disable_dml*/, disable_webgpu); } TEST(ContribOpAttentionTest, AttentionBatch2) { diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 29bf994ff3d34..cf18dc8509dd7 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -292,7 +292,7 @@ void TestMatMulNBitsTyped(std::optional abs_error = std::nullopt, } else if (base_opts.accuracy_level == 4) { base_opts.output_abs_error = 0.1f; } else if constexpr (std::is_same::value) { - base_opts.output_abs_error = 0.055f; + base_opts.output_abs_error = 0.065f; } else { base_opts.output_abs_error = 0.05f; } diff --git a/onnxruntime/test/optimizer/nhwc_transformer_test.cc b/onnxruntime/test/optimizer/nhwc_transformer_test.cc index b73929efab8a6..438823aaa6cec 100644 --- a/onnxruntime/test/optimizer/nhwc_transformer_test.cc +++ b/onnxruntime/test/optimizer/nhwc_transformer_test.cc @@ -997,7 +997,10 @@ TEST_F(NhwcTransformerTestsFp16, FusedConvWithSumFp16) { TransformerTester(build_test_case, check_nhwc_graph, TransformerLevel::Level2, - TransformerLevel::Level3); + TransformerLevel::Level3, + /*opset_version*/ 12, + /*per_sample_tolerance*/ 0.02, + /*relative_per_sample_tolerance*/ 0.02); } TEST_F(NhwcTransformerTestsFp16, ConvMaxPoolFp16) { diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index eb4ed757d1a94..891a0e3217b0a 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -467,6 +467,14 @@ TEST(ConvFp16Test, Conv2D_KleidiAiImatmulEligibleBiasAndDisabledFallback) { } TEST(ConvFp16Test, NhwcFusedConv2D_KleidiAiImatmulEligibleBiasAndDisabledFallback) { +#if !defined(__aarch64__) && !defined(_M_ARM64) + GTEST_SKIP() << "Native CPU fp16 Conv runtime support is only tested on Arm64."; +#else + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP() << "Native CPU fp16 Conv runtime support is unavailable."; + } +#endif + auto run_test = [](bool disable_kleidiai) { OpTester test("NhwcFusedConv", 1, onnxruntime::kMSDomain); test.AddAttribute("group", static_cast(1)); diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 4b4d9e9ddd1ba..4714f2c7fb0f9 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" @@ -253,6 +254,14 @@ TYPED_TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear) { // QNN: result diff // TRT: Segmentation fault in A100 std::unordered_set excluded_providers({kQnnExecutionProvider}); + if constexpr (std::is_same_v) { + // These EPs do not provide this Resize(13) MLFloat16 kernel. + excluded_providers.insert(kNnapiExecutionProvider); + excluded_providers.insert(kXnnpackExecutionProvider); + excluded_providers.insert(kCoreMLExecutionProvider); + excluded_providers.insert(kTensorrtExecutionProvider); + excluded_providers.insert("example_ep"); + } test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100(excluded_providers)); }; From 152ef310c4db6a9a632e9436f2c2d4861293e2e6 Mon Sep 17 00:00:00 2001 From: Cathal Lawlor Date: Thu, 21 May 2026 12:34:06 +0100 Subject: [PATCH 6/8] Initialize halfgemm params and include cstring Value-initialize MLAS_HALF_GEMM_DATA_PARAMS at call sites so optional flags and pointers default to null/false unless explicitly set. Include directly in matmul.cc for the native fp16 prepack memset calls. Signed-off-by: Cathal Lawlor --- onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc | 2 +- onnxruntime/core/providers/cpu/fp16/fp16_conv.cc | 2 +- onnxruntime/core/providers/cpu/math/gemm.cc | 2 +- onnxruntime/core/providers/cpu/math/matmul.cc | 1 + onnxruntime/core/util/math_cpu.cc | 2 +- onnxruntime/test/mlas/unittest/test_halfgemm.cpp | 2 +- 6 files changed, 6 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc index 08fab6d2c110a..3cc1bdca98490 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc @@ -536,7 +536,7 @@ Status MoE::ComputeGEMM(const float* A, const float* B, float* C, template <> Status MoE::ComputeGEMM(const MLFloat16* A, const MLFloat16* B, MLFloat16* C, int64_t M, int64_t K, int64_t N, bool transpose_B) const { - MLAS_HALF_GEMM_DATA_PARAMS params; + MLAS_HALF_GEMM_DATA_PARAMS params{}; params.A = A; params.lda = static_cast(K); params.C = C; diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc index 95445c73a39ce..222dcc4b3c6dc 100644 --- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc +++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc @@ -663,7 +663,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { const auto* gemm_add = add_src == nullptr ? nullptr : worker_addsrc + group_id * group_output_channels; MLAS_HALF_GEMM_ACTIVATION_PROCESSOR act(activation_, gemm_add); - MLAS_HALF_GEMM_DATA_PARAMS gemm_params; + MLAS_HALF_GEMM_DATA_PARAMS gemm_params{}; gemm_params.A = AData; gemm_params.lda = lda; if (packed_W_buffer_) { diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index 6fd633619b956..a86e5a36aadf9 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -219,7 +219,7 @@ void Gemm_MLFloat16(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, const bool use_mlas_bias = beta == onnxruntime::MLFloat16::One && support_mlas_bias; if (trans_a == CblasNoTrans && trans_b == CblasNoTrans && alpha == onnxruntime::MLFloat16::One && (use_mlas_no_bias || use_mlas_bias)) { - MLAS_HALF_GEMM_DATA_PARAMS data; + MLAS_HALF_GEMM_DATA_PARAMS data{}; data.A = a_data; data.lda = K; data.B = b_data; diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 6181045f48102..e19f8ec370d26 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -8,6 +8,7 @@ #include "core/util/math_cpuonly.h" #include #include +#include namespace onnxruntime { diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index 52e5c9d0092f0..3eeb0c5df8f1f 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -194,7 +194,7 @@ template <> void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, ThreadPool* threadpool, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config) { #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED - MLAS_HALF_GEMM_DATA_PARAMS data; + MLAS_HALF_GEMM_DATA_PARAMS data{}; data.A = A; data.lda = static_cast(K); data.B = B; diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp index 9639c24ed90ae..d350a76f4a399 100644 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp @@ -113,7 +113,7 @@ TEST(HalfGemmKleidiAISelector, DisableKleidiAIBypassesOverride) { SmallFloatFill(A.data(), A.size()); SmallFloatFill(B.data(), B.size()); - MLAS_HALF_GEMM_DATA_PARAMS data; + MLAS_HALF_GEMM_DATA_PARAMS data{}; data.A = A.data(); data.B = B.data(); data.C = reinterpret_cast(C.data()); From e08e9a9ed53ad36af25efb6cf87913624886d09b Mon Sep 17 00:00:00 2001 From: Cathal Lawlor Date: Mon, 25 May 2026 15:21:48 +0100 Subject: [PATCH 7/8] Refactor insert_cast_transformer and related tests for improved type handling and compatibility checks; add CPU FP16 resize model test Signed-off-by: Cathal Lawlor --- .../core/optimizer/insert_cast_transformer.cc | 196 ++++++++++-------- .../test/contrib_ops/matmul_4bits_test.cc | 4 + .../framework/insert_cast_transformer_test.cc | 71 ++++++- .../test/providers/cpu/math/matmul_test.cc | 5 +- .../test/providers/cpu/nn/conv_fp16_test.cc | 1 - 5 files changed, 183 insertions(+), 94 deletions(-) diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 2dcd2fa7a0216..38f045f913124 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -287,125 +287,138 @@ static bool ShouldKeepNativeCpuFp16ForMatMulOrGemm( return nk >= kMinNativeFp16NK; } -static const IKernelTypeStrResolver& GetInsertCastKernelTypeStrResolver() { +static bool KernelDomainMatchesNode(const KernelDef& kernel_def, const onnxruntime::Node& node) { + const auto& kernel_domain = kernel_def.Domain(); + const auto& node_domain = node.Domain(); + return kernel_domain == node_domain || + (kernel_domain == kOnnxDomainAlias && node_domain.empty()); +} + +static bool KernelVersionMatchesNode(const KernelDef& kernel_def, const onnxruntime::Node& node) { + const auto [kernel_start_version, kernel_end_version] = kernel_def.SinceVersion(); + return kernel_start_version <= node.SinceVersion() && + kernel_end_version >= node.SinceVersion(); +} + +static bool IsKernelTypeCompatible(gsl::span enabled_types, + const ONNX_NAMESPACE::TypeProto& actual_type, + bool replace_fp16_with_float) { + const auto* type_to_check = &actual_type; + TypeProto float_tensor_type; + if (replace_fp16_with_float && + DataTypeImpl::TypeFromProto(actual_type) == DataTypeImpl::GetTensorType()) { + float_tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + type_to_check = &float_tensor_type; + } + + return std::any_of(enabled_types.begin(), enabled_types.end(), + [type_to_check](const DataTypeImpl* enabled_type) { + return enabled_type->IsCompatible(*type_to_check); + }); +} + +static bool KernelTypeConstraintsMatchAllNodeArgs(const onnxruntime::Node& node, + const KernelDef& kernel_def, + bool replace_fp16_with_float) { #if !defined(ORT_MINIMAL_BUILD) static const OpSchemaKernelTypeStrResolver resolver; #else static const KernelTypeStrResolver resolver; #endif - return resolver; -} -static bool BuildTypeConstraintMapForNode(const onnxruntime::Node& node, - bool replace_fp16_with_float, - InlinedHashMap& type_constraint_map) { - // Build the type-constraint map that kernel lookup uses for this node. - // - // ONNX kernel lookup is based on the operator schema's type variables (e.g. T, T1, T2) - // rather than directly on individual NodeArg names. For example, a schema may say that - // both inputs and the output are of type "T". To ask "does CPU have a kernel for this - // node as currently typed?" or "does CPU have a float32 fallback for this fp16 node?", - // we first need to resolve those schema type variables to concrete MLDataType values. - // - // When replace_fp16_with_float is false we record the node's current types as-is. - // When it is true we rewrite any float16 tensors to float in the constructed map so - // we can ask whether a valid float32 fallback kernel exists for the same operator. - const auto* schema = node.Op(); - if (!schema) { - return false; + const auto actual_inputs = node.InputDefs(); + const auto actual_outputs = node.OutputDefs(); + const auto& actual_input_arg_counts = node.InputArgCount(); + InlinedVector actual_input_arg_offsets; + actual_input_arg_offsets.reserve(actual_input_arg_counts.size()); + int current_offset = 0; + for (const auto arg_count : actual_input_arg_counts) { + actual_input_arg_offsets.push_back(current_offset); + current_offset += arg_count; } - const TypeConstraintMap& type_schema = schema->typeConstraintMap(); - type_constraint_map.reserve(type_schema.size()); - - const auto SetTypeConstraint = [&](const std::string& type_str, const NodeArg* def) { - if (!def || !def->Exists()) { - return; + const auto CheckArg = [&](const NodeArg* arg, + gsl::span enabled_types) { + if (!arg || !arg->Exists()) { + return true; } - TypeConstraintMap::const_iterator it = type_schema.find(type_str); - if (it == type_schema.end()) { - return; + const auto* type_proto = arg->TypeAsProto(); + if (type_proto == nullptr) { + return false; } - auto type = DataTypeImpl::TypeFromProto(*(def->TypeAsProto())); - if (replace_fp16_with_float && type == DataTypeImpl::GetTensorType()) { - type = DataTypeImpl::GetTensorType(); + return IsKernelTypeCompatible(enabled_types, *type_proto, replace_fp16_with_float); + }; + + for (const auto& [kernel_type_str, enabled_types] : kernel_def.TypeConstraints()) { + gsl::span constraint_args; + if (!resolver.ResolveKernelTypeStr(node, kernel_type_str, constraint_args).IsOK()) { + return false; } - type_constraint_map[type_str] = type; - }; + for (const auto& [arg_type, formal_arg_idx] : constraint_args) { + if (arg_type == ArgType::kInput) { + if (formal_arg_idx >= actual_input_arg_counts.size() || + actual_input_arg_counts[formal_arg_idx] == 0) { + continue; + } - const auto& input_arg_counts = node.InputArgCount(); - const auto& input_defs = node.InputDefs(); - const auto& formal_inputs = schema->inputs(); - const size_t num_inputs = std::min(formal_inputs.size(), input_arg_counts.size()); - int input_idx_start = 0; - for (size_t formal_idx = 0; - formal_idx < num_inputs; - input_idx_start += input_arg_counts[formal_idx], formal_idx++) { - const auto& type_str = formal_inputs[formal_idx].GetTypeStr(); - // Variadic formal parameters can map to multiple actual inputs. For current CPU fp16 - // preservation/fallback decisions we only need one concrete binding for the schema type - // variable, so we take the first existing actual input for that formal parameter. - for (int input_idx = 0; input_idx < input_arg_counts[formal_idx]; input_idx++) { - const size_t idx = static_cast(input_idx_start) + static_cast(input_idx); - ORT_ENFORCE(idx < input_defs.size()); - const NodeArg* input_def = input_defs[idx]; - if (!input_def || !input_def->Exists()) { - continue; + const auto first_arg_idx = actual_input_arg_offsets[formal_arg_idx]; + for (int arg_idx = 0; arg_idx < actual_input_arg_counts[formal_arg_idx]; arg_idx++) { + const auto actual_arg_idx = static_cast(first_arg_idx + arg_idx); + ORT_ENFORCE(actual_arg_idx < actual_inputs.size()); + if (!CheckArg(actual_inputs[actual_arg_idx], enabled_types)) { + return false; + } + } + } else { + if (formal_arg_idx < actual_outputs.size() && + !CheckArg(actual_outputs[formal_arg_idx], enabled_types)) { + return false; + } } - - SetTypeConstraint(type_str, input_def); - break; } } - const auto& output_defs = node.OutputDefs(); - const auto& formal_outputs = schema->outputs(); - const size_t num_outputs = std::min(formal_outputs.size(), output_defs.size()); - for (size_t idx = 0; idx < num_outputs; idx++) { - const auto& type_str = formal_outputs[idx].GetTypeStr(); - SetTypeConstraint(type_str, output_defs[idx]); - } - return true; } -static bool HasCpuKernelForCurrentTypes( +static bool HasCpuKernelWithTypeSupport( const onnxruntime::Node& node, const InlinedVector>& cpu_kernel_registries, - const logging::Logger& logger) { - const auto& resolver = GetInsertCastKernelTypeStrResolver(); + bool replace_fp16_with_float) { for (const KernelRegistry* cpu_kernel_registry : cpu_kernel_registries) { - if (KernelRegistry::HasImplementationOf(*cpu_kernel_registry, node, kCpuExecutionProvider, resolver, logger)) { - return true; + for (const auto& [_, kernel_create_info] : cpu_kernel_registry->GetKernelCreateMap()) { + const auto* kernel_def = kernel_create_info.kernel_def.get(); + if (kernel_def != nullptr && + kernel_def->Provider() == kCpuExecutionProvider && + kernel_def->OpName() == node.OpType() && + KernelDomainMatchesNode(*kernel_def, node) && + KernelVersionMatchesNode(*kernel_def, node) && + KernelTypeConstraintsMatchAllNodeArgs(node, *kernel_def, replace_fp16_with_float)) { + return true; + } } } return false; } -static bool HasCpuFloat32FallbackKernel( +static bool HasCpuKernelForCurrentTypes( const onnxruntime::Node& node, const InlinedVector>& cpu_kernel_registries, const logging::Logger& logger) { - InlinedHashMap type_constraint_map; - if (!BuildTypeConstraintMapForNode(node, true, type_constraint_map)) { - return false; - } - - for (const KernelRegistry* cpu_kernel_registry : cpu_kernel_registries) { - const KernelCreateInfo* kernel_create_info{}; - const auto lookup_status = cpu_kernel_registry->TryFindKernel( - kCpuExecutionProvider, node.OpType(), node.Domain(), - node.SinceVersion(), type_constraint_map, logger, &kernel_create_info); - if (lookup_status.IsOK() && kernel_create_info != nullptr) { - return true; - } - } + ORT_UNUSED_PARAMETER(logger); + return HasCpuKernelWithTypeSupport(node, cpu_kernel_registries, false); +} - return false; +static bool HasCpuFloat32FallbackKernel( + const onnxruntime::Node& node, + const InlinedVector>& cpu_kernel_registries, + const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + return HasCpuKernelWithTypeSupport(node, cpu_kernel_registries, true); } onnxruntime::NodeArg* AddCastNode(onnxruntime::Graph& graph, @@ -539,7 +552,13 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: type_constraint_map[type_str] = DataTypeImpl::GetTensorType(); break; // we don't have multiple tensors feeding into one input } - type_constraint_map[type_str] = DataTypeImpl::TypeFromProto(*(input_def->TypeAsProto())); + + const auto* type_proto = input_def->TypeAsProto(); + if (type_proto == nullptr) { + return false; + } + + type_constraint_map[type_str] = DataTypeImpl::TypeFromProto(*type_proto); break; // we don't have multiple tensors feeding into one input } } @@ -581,7 +600,12 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: fp16_args.emplace((int)idx); type_constraint_map[type_str] = DataTypeImpl::GetTensorType(); } else { - type_constraint_map[type_str] = DataTypeImpl::TypeFromProto(*(output_def->TypeAsProto())); + const auto* type_proto = output_def->TypeAsProto(); + if (type_proto == nullptr) { + return false; + } + + type_constraint_map[type_str] = DataTypeImpl::TypeFromProto(*type_proto); } } diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index f61dcad200684..79a520fe49d44 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -292,7 +292,11 @@ void TestMatMulNBitsTyped(std::optional abs_error = std::nullopt, } else if (base_opts.accuracy_level == 4) { base_opts.output_abs_error = 0.1f; } else if constexpr (std::is_same::value) { +#if defined(USE_WEBGPU) + base_opts.output_abs_error = 0.1f; +#else base_opts.output_abs_error = 0.065f; +#endif } else { base_opts.output_abs_error = 0.05f; } diff --git a/onnxruntime/test/framework/insert_cast_transformer_test.cc b/onnxruntime/test/framework/insert_cast_transformer_test.cc index 6f04f925b4095..4429ac76e2d3c 100644 --- a/onnxruntime/test/framework/insert_cast_transformer_test.cc +++ b/onnxruntime/test/framework/insert_cast_transformer_test.cc @@ -25,17 +25,21 @@ namespace test { typedef std::vector ArgMap; -static TypeProto MakeFp16TensorType(std::initializer_list shape = {}) { - TypeProto tensor_float_16; - tensor_float_16.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT16); +static TypeProto MakeTensorType(TensorProto_DataType elem_type, std::initializer_list shape = {}) { + TypeProto tensor_type; + tensor_type.mutable_tensor_type()->set_elem_type(elem_type); if (shape.size() > 0) { - auto* tensor_shape = tensor_float_16.mutable_tensor_type()->mutable_shape(); + auto* tensor_shape = tensor_type.mutable_tensor_type()->mutable_shape(); for (const auto dim : shape) { tensor_shape->add_dim()->set_dim_value(dim); } } - return tensor_float_16; + return tensor_type; +} + +static TypeProto MakeFp16TensorType(std::initializer_list shape = {}) { + return MakeTensorType(TensorProto_DataType_FLOAT16, shape); } TEST(TransformerTest, InsertCastGPUTest) { @@ -424,6 +428,38 @@ static std::shared_ptr MakeCpuFp16MatMulModelWithShapes(const std::string return model; } +static std::shared_ptr MakeCpuFp16ResizeModel(const std::string& model_name, bool assign_cpu_ep) { + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 13; + auto model = std::make_shared(model_name, false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, + std::vector(), + DefaultLoggingManager().DefaultLogger()); + auto& graph = model->MainGraph(); + + TypeProto x_type = MakeFp16TensorType({1, 1, 2, 2}); + TypeProto roi_type = MakeTensorType(TensorProto_DataType_FLOAT, {0}); + TypeProto scales_type = MakeTensorType(TensorProto_DataType_FLOAT, {4}); + TypeProto y_type = MakeFp16TensorType({1, 1, 4, 4}); + + auto& x = graph.GetOrCreateNodeArg("X", &x_type); + auto& roi = graph.GetOrCreateNodeArg("roi", &roi_type); + auto& scales = graph.GetOrCreateNodeArg("scales", &scales_type); + auto& y = graph.GetOrCreateNodeArg("Y", &y_type); + + auto& node = graph.AddNode("Resize", "Resize", "fp16 resize fallback test", + ArgMap{&x, &roi, &scales}, ArgMap{&y}); + if (assign_cpu_ep) { + node.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); + } + + graph.SetInputs({&x, &roi, &scales}); + graph.SetOutputs({&y}); + ORT_THROW_IF_ERROR(graph.Resolve()); + + return model; +} + static void ExpectMlFloat16Output(const std::vector& output, const std::vector& expected) { ASSERT_EQ(output.size(), expected.size()); @@ -880,6 +916,31 @@ TEST(TransformerTest, CpuFp16UnsupportedCpuAssignedOpStillGetsCasts) { EXPECT_TRUE(IsNodeArgType(*abs_node->OutputDefs()[0], DataTypeImpl::GetTensorType())); } +TEST(TransformerTest, CpuFp16CpuAssignedResizeWithoutFp16KernelUsesFp32Fallback) { + auto model = MakeCpuFp16ResizeModel("cpu_fp16_resize_cpu_assigned", true); + auto& graph = model->MainGraph(); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, false); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.at("Cast"), 2); + + const Node* resize_node = FindNodeByOpType(graph, "Resize"); + ASSERT_NE(resize_node, nullptr); + EXPECT_EQ(resize_node->SinceVersion(), 13); + EXPECT_EQ(resize_node->GetExecutionProviderType(), kCpuExecutionProvider); + ASSERT_EQ(resize_node->InputDefs().size(), 3U); + ASSERT_EQ(resize_node->OutputDefs().size(), 1U); + EXPECT_TRUE(IsNodeArgType(*resize_node->InputDefs()[0], DataTypeImpl::GetTensorType())); + EXPECT_TRUE(IsNodeArgType(*resize_node->InputDefs()[1], DataTypeImpl::GetTensorType())); + EXPECT_TRUE(IsNodeArgType(*resize_node->InputDefs()[2], DataTypeImpl::GetTensorType())); + EXPECT_TRUE(IsNodeArgType(*resize_node->OutputDefs()[0], DataTypeImpl::GetTensorType())); +} + TEST(TransformerTest, CpuFp16SupportedCpuOpKeepsFp16WhenEnabled) { if (!MlasFp16AccelerationSupported()) { GTEST_SKIP() << "CPU fp16 kernels are not registered on this platform."; diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index e4d80b9212ba9..55eef75ffa569 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -4,6 +4,7 @@ #include "gtest/gtest.h" #include "core/mlas/inc/mlas.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "test/providers/provider_test_utils.h" #include "test/common/dnnl_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" @@ -717,8 +718,8 @@ TEST(MathOpTest, MatMulFloat16NativePrepackedWeightsAreNotShared) { SessionOptions so; ASSERT_EQ(so.AddInitializer("B", &b), Status::OK()); - ASSERT_EQ(so.config_options.AddConfigEntry("session.enable_cpu_fp16", "1"), Status::OK()); - ASSERT_EQ(so.config_options.AddConfigEntry("session.cpu_fp16_use_fp32_fallback_heuristic", "0"), Status::OK()); + ASSERT_EQ(so.config_options.AddConfigEntry(kOrtSessionOptionsEnableCpuFp16, "1"), Status::OK()); + ASSERT_EQ(so.config_options.AddConfigEntry(kOrtSessionOptionsCpuFp16UseFp32FallbackHeuristic, "0"), Status::OK()); test.EnableSharingOfPrePackedWeightsAcrossSessions(); diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 891a0e3217b0a..03b1eac1e91bf 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -9,7 +9,6 @@ #include "test/common/cuda_op_test_utils.h" #include "test/common/random_generator.h" #include "test/providers/provider_test_utils.h" -#include "test/util/include/scoped_env_vars.h" #include "default_providers.h" #include "core/session/onnxruntime_session_options_config_keys.h" From 93bba78c50ac21c64756f2ddb34edddab3682169 Mon Sep 17 00:00:00 2001 From: Cathal Lawlor Date: Mon, 25 May 2026 16:18:25 +0100 Subject: [PATCH 8/8] Add comment clarifying FP16 MatMulNBits tolerance for macOS arm64 WebGPU CI Signed-off-by: Cathal Lawlor --- onnxruntime/test/contrib_ops/matmul_4bits_test.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 79a520fe49d44..5740eacceee82 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -293,6 +293,7 @@ void TestMatMulNBitsTyped(std::optional abs_error = std::nullopt, base_opts.output_abs_error = 0.1f; } else if constexpr (std::is_same::value) { #if defined(USE_WEBGPU) + // Match existing fp16 MatMulNBits tolerance for WebGPU builds while keeping CPU stricter. base_opts.output_abs_error = 0.1f; #else base_opts.output_abs_error = 0.065f;