diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 7ae18db235ccb..7b623fe4cbaad 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -23,6 +23,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/threading.cpp ${MLAS_SRC_DIR}/sgemm.cpp ${MLAS_SRC_DIR}/halfgemm.cpp + ${MLAS_SRC_DIR}/halfconv.cpp ${MLAS_SRC_DIR}/qgemm.cpp ${MLAS_SRC_DIR}/qdwconv.cpp ${MLAS_SRC_DIR}/convolve.cpp @@ -309,6 +310,8 @@ function(setup_kleidiai) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/kai_ukernel_interface.cpp ${MLAS_SRC_DIR}/kleidiai/sgemm_kleidiai.cpp + ${MLAS_SRC_DIR}/kleidiai/halfgemm_kleidiai.cpp + ${MLAS_SRC_DIR}/kleidiai/halfconv_kleidiai.cpp ${MLAS_SRC_DIR}/kleidiai/sbgemm_kleidiai.cpp ${MLAS_SRC_DIR}/kleidiai/convolve_kleidiai.cpp ${MLAS_SRC_DIR}/kleidiai/qgemm_kleidiai.cpp @@ -363,7 +366,7 @@ function (setup_arm_neon_nchwc) ${MLAS_SRC_DIR}/aarch64/SconvNchwcKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeon.S - ) + ) endif() list(APPEND mlas_private_compile_definitions MLAS_USE_ARM_NEON_NCHWC) set(mlas_private_compile_definitions ${mlas_private_compile_definitions} PARENT_SCOPE) @@ -546,8 +549,7 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") - if (NOT APPLE) - set(mlas_platform_srcs + set(mlas_platform_srcs ${mlas_platform_srcs} ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S @@ -573,37 +575,36 @@ else() ${MLAS_SRC_DIR}/gelu_neon_fp16.h ${MLAS_SRC_DIR}/gelu_neon_fp16.cpp ) - if (onnxruntime_USE_ARM_NEON_NCHWC) - list(APPEND mlas_platform_srcs - ${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeonBf16.S - ${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S - ${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S - ) - endif() - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16_8bit.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - if (onnxruntime_USE_ARM_NEON_NCHWC) - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - endif() - set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/gelu_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + if (onnxruntime_USE_ARM_NEON_NCHWC AND NOT APPLE) + list(APPEND mlas_platform_srcs + ${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeonBf16.S + ${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S + ${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S + ) + endif() + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16_8bit.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + if (onnxruntime_USE_ARM_NEON_NCHWC AND NOT APPLE) + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeonBf16.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") endif() + set_source_files_properties(${MLAS_SRC_DIR}/erf_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/gelu_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") if(ONNXRUNTIME_MLAS_MULTI_ARCH) onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs}) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d0d8e750285d4..a0c8ae3dfe4ca 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -174,10 +174,10 @@ The **OpSet Version** column uses the following notation: |||12|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)| |||11|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)| |Gelu|*in* X:**T**
*out* Y:**T**|20+|**T** = tensor(float)| -|Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| -|||[11, 12]|**T** = tensor(double), tensor(float)| -|||[9, 10]|**T** = tensor(double), tensor(float)| -|||[7, 8]|**T** = tensor(double), tensor(float)| +|Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[9, 10]|**T** = tensor(double), tensor(float), tensor(float16)| +|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| |GlobalAveragePool|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(float)| |||[1, 21]|**T** = tensor(float)| |GlobalLpPool|*in* X:**T**
*out* Y:**T**|2+|**T** = tensor(float)| @@ -258,9 +258,9 @@ The **OpSet Version** column uses the following notation: |||[18, 21]|**T** = tensor(float)| |||[11, 17]|**T** = tensor(float)| |||[2, 10]|**T** = tensor(float)| -|MatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||[1, 8]|**T** = tensor(double), tensor(float)| +|MatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|||[1, 8]|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulInteger|*in* A:**T1**
*in* B:**T2**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*out* Y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int32)| |Max|*in* data_0:**T**
*out* max:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| |||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 9d61165927d8c..2b1d4d8701563 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -462,6 +462,24 @@ static const char* const kOrtSessionOptionsEpContextModelExternalInitializersFil // - "1": Gemm FastMath mode is enabled. static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; +// Enables opt-in CPU EP float16 execution for supported ops. +// This allows InsertCastTransformer to preserve eligible fp16 CPU nodes instead of inserting fp16<->fp32 casts. +// The fp32 fallback heuristic below is still enabled by default and may keep selected shapes on the fp32 path when +// native fp16 is not expected to be profitable for the active MLAS backend. +// Option values: +// - "0": CPU fp16 is not explicitly enabled, and default cast/fallback behavior is used. [DEFAULT] +// - "1": Enable CPU fp16 preservation for the current opt-in scope. +static const char* const kOrtSessionOptionsEnableCpuFp16 = "session.enable_cpu_fp16"; + +// Controls the CPU fp16 -> fp32 fallback heuristic used when session.enable_cpu_fp16 is "1". +// The heuristic keeps native CPU fp16 only for cases expected to be profitable, such as supported GEMV-like shapes or +// constant-RHS MatMul when the active MLAS backend reports a native packed-B path. Other eligible nodes are assigned to +// CPU through fp32 casts when a valid CPU fp32 kernel exists. This is the recommended/default CPU fp16 mode. +// Option values: +// - "0": Do not use the heuristic; preserve native CPU fp16 more broadly for eligible CPU fp16 kernels. +// - "1": Use the fp32 fallback heuristic. [DEFAULT] +static const char* const kOrtSessionOptionsCpuFp16UseFp32FallbackHeuristic = "session.cpu_fp16_use_fp32_fallback_heuristic"; + // Use LUT (Lookup Table) based GEMM for quantized models when available. // Option values: // - "0": Do not use LUT based GEMM. [DEFAULT] diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc index c3194f1c57ba7..3cc1bdca98490 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc @@ -536,7 +536,7 @@ Status MoE::ComputeGEMM(const float* A, const float* B, float* C, template <> Status MoE::ComputeGEMM(const MLFloat16* A, const MLFloat16* B, MLFloat16* C, int64_t M, int64_t K, int64_t N, bool transpose_B) const { - MLAS_HALF_GEMM_DATA_PARAMS params; + MLAS_HALF_GEMM_DATA_PARAMS params{}; params.A = A; params.lda = static_cast(K); params.C = C; @@ -551,7 +551,8 @@ Status MoE::ComputeGEMM(const MLFloat16* A, const MLFloat16* B, MLFlo params.ldb = static_cast(N); } - MlasHalfGemmBatch(static_cast(M), static_cast(N), static_cast(K), 1, ¶ms, nullptr); + MlasHalfGemmBatch(static_cast(M), static_cast(N), static_cast(K), 1, ¶ms, nullptr, + &mlas_backend_kernel_selector_config_); return Status::OK(); } diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 6ef2319c1d3f4..69cfe16c20ff5 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -517,13 +517,10 @@ Status SessionState::PrepackConstantInitializedTensors( is_packed, &weights_to_be_filled_in)); - if (is_packed) { - // BUG CHECK: Ensure that the kernel has filled in the pre-packed weight - // to be cached if the weight was pre-packed - ORT_ENFORCE(weights_to_be_filled_in.buffers_.size() > 0, - "The kernel corresponding to the node ", node.Name(), - " doesn't have an implementation that can cache computed pre-packed weights"); - + // Some kernels pre-pack for their own session but intentionally do not + // expose those buffers for sharing, for example when the packed layout is + // backend-specific and the shared container has no layout metadata. + if (is_packed && !weights_to_be_filled_in.buffers_.empty()) { const auto& op_type = node.OpType(); // Sanity check diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index ddb9daa5e244b..1552c15db2e6d 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -92,15 +92,7 @@ Module Name: #if (!defined(_MSC_VER)) || (_MSC_VER >= 1930) #if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) -#if !defined(__APPLE__) -// Had to temporary disable fp16 under APPLE ARM64, as compiling -// the source files require a hardware specific compilation flag. -// When building an universial binary for APPLE, this flag would -// cause trouble for x64 target. - #define MLAS_F16VEC_INTRINSICS_SUPPORTED - -#endif // #endif // ARM64 #endif // Visual Studio 16 or earlier does not support fp16 intrinsic @@ -906,6 +898,7 @@ struct MLAS_CONV_PARAMETERS { float Beta; MLAS_CONV_ALGORITHM Algorithm; ptrdiff_t ThreadCount; + bool InputOutputChannelsLast; const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr; const void* PackedFilter = nullptr; size_t PackedFilterGroupStride = 0; @@ -1898,11 +1891,20 @@ struct MLAS_HALF_GEMM_DATA_PARAMS { const MLAS_HALF_GEMM_POSTPROCESSOR* OutputProcessor = nullptr; bool AIsfp32 = false; /**< matrix A is fp32, needs to be casted into fp16*/ bool BIsfp32 = false; /**< matrix B is fp32, needs to be casted into fp16*/ + bool BIsPacked = false; /**< matrix B is pre-packed by MlasHalfGemmPackB/MlasHalfGemmConvertPackB */ + /** + * Matrix B uses a backend-specific direct-consumption packed layout. + * When true, B must be produced by MlasHalfGemmNativePackB, ldb must be 0, + * Bias must be nullptr, and OutputProcessor must be nullptr. + */ + bool BIsBackendNativePacked = false; }; /** * @brief Half precision Batched GEMM: C = A * B + Bias - * Either A or B can be fp32 or fp16 + * Either A or B can be fp32 or fp16. + * Backend-native packed B is a constrained direct-consumption layout + * and does not support runtime Bias or OutputProcessor. * * Note: We only support uniform batching, so shapes and types of the * input must be same across all parameter blocks. @@ -1913,6 +1915,7 @@ struct MLAS_HALF_GEMM_DATA_PARAMS { * @param[in] BatchN number of batches * @param[inout] DataParams An array (size BatchN) of parameter blocks * @param[in] ThreadPool + * @param[in] BackendKernelSelectorConfig Optional backend selector config * @return */ void @@ -1923,7 +1926,8 @@ MlasHalfGemmBatch( const size_t K, const size_t BatchN, const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr ); /** @@ -1964,6 +1968,106 @@ MlasHalfGemmPackB( void* PackedB ); +/** + * @brief For half precision GEMM, returns the size of a backend-native + * direct-consumption packing buffer for right hand side B. + * + * The returned layout is only valid when MlasHalfGemmBatch consumes + * MLAS_HALF_GEMM_DATA_PARAMS with BIsBackendNativePacked set to true, ldb set + * to 0, Bias set to nullptr, and OutputProcessor set to nullptr. Returns 0 + * when no backend-native format is available for the given transpose/shape/config. + */ +size_t +MLASCALL +MlasHalfGemmNativePackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr + ); + +/** + * @brief For half precision GEMM, pack right hand side B into the + * backend-native direct-consumption format. + * + * Returns false when the backend-native format is unavailable for the given + * transpose/shape/config. The resulting buffer must be passed to + * MlasHalfGemmBatch with ldb set to 0, BIsBackendNativePacked set to true, + * Bias set to nullptr, and OutputProcessor set to nullptr. + */ +bool +MLASCALL +MlasHalfGemmNativePackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const MLAS_FP16* B, + size_t ldb, + void* PackedB, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr + ); + +bool +MLASCALL +MlasHalfConvPrepare( + MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + bool InputOutputChannelsLast, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr); + +bool +MLASCALL +MlasHalfConv( + const MLAS_CONV_PARAMETERS* Parameters, + const MLAS_FP16* Input, + const MLAS_FP16* Filter, + // When true, Filter must point to a packed weights+bias buffer produced by + // MlasHalfConvPackWeightsAndBias and Bias must be nullptr. + bool FilterAndBiasArePacked, + const MLAS_FP16* Bias, + MLAS_FP16* WorkingBuffer, + MLAS_FP16* Output, + MLAS_THREADPOOL* ThreadPool); + +size_t +MLASCALL +MlasHalfConvPackWeightsAndBiasSize( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr); + +bool +MLASCALL +MlasHalfConvPackWeightsAndBias( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape, + const MLAS_FP16* Filter, + // Optional bias to bake into the packed direct-consumption buffer. + const MLAS_FP16* Bias, + void* PackedWeightsAndBias, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr); + /** * @brief For half precision GEMM, convert the float matrix B * to half precision and pack it into a packing buffer diff --git a/onnxruntime/core/mlas/lib/halfconv.cpp b/onnxruntime/core/mlas/lib/halfconv.cpp new file mode 100644 index 0000000000000..22fa48dbc2874 --- /dev/null +++ b/onnxruntime/core/mlas/lib/halfconv.cpp @@ -0,0 +1,158 @@ +// +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +/*++ + +Module Name: + + halfconv.cpp + +Abstract: + + This module implements public dispatch wrappers for optional half precision + convolution backends. + +--*/ + +#include "mlasi.h" + +bool +MLASCALL +MlasHalfConvPrepare( + MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + bool InputOutputChannelsLast, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig + ) +{ + if (GetMlasPlatform().MlasHalfConvPrepareOverride == nullptr) { + return false; + } + + return GetMlasPlatform().MlasHalfConvPrepareOverride( + Parameters, + Dimensions, + BatchCount, + GroupCount, + InputChannels, + InputShape, + KernelShape, + DilationShape, + Padding, + StrideShape, + OutputShape, + FilterCount, + Activation, + WorkingBufferSize, + Beta, + InputOutputChannelsLast, + ThreadPool, + BackendKernelSelectorConfig); +} + +bool +MLASCALL +MlasHalfConv( + const MLAS_CONV_PARAMETERS* Parameters, + const MLAS_FP16* Input, + const MLAS_FP16* Filter, + bool FilterAndBiasArePacked, + const MLAS_FP16* Bias, + MLAS_FP16* WorkingBuffer, + MLAS_FP16* Output, + MLAS_THREADPOOL* ThreadPool + ) +{ + if (GetMlasPlatform().MlasHalfConvOverride == nullptr) { + return false; + } + + if (FilterAndBiasArePacked && Bias != nullptr) { + return false; + } + + return GetMlasPlatform().MlasHalfConvOverride( + Parameters, + Input, + Filter, + FilterAndBiasArePacked, + Bias, + WorkingBuffer, + Output, + ThreadPool); +} + +size_t +MLASCALL +MlasHalfConvPackWeightsAndBiasSize( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig + ) +{ + if (BackendKernelSelectorConfig != nullptr && !BackendKernelSelectorConfig->use_kleidiai) { + return 0; + } + + if (GetMlasPlatform().MlasHalfConvPackWeightsAndBiasSizeOverride == nullptr) { + return 0; + } + + return GetMlasPlatform().MlasHalfConvPackWeightsAndBiasSizeOverride( + FilterCount, + InputChannels, + KernelShape, + DilationShape); +} + +bool +MLASCALL +MlasHalfConvPackWeightsAndBias( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape, + const MLAS_FP16* Filter, + const MLAS_FP16* Bias, + void* PackedWeightsAndBias, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig + ) +{ + if (BackendKernelSelectorConfig != nullptr && !BackendKernelSelectorConfig->use_kleidiai) { + return false; + } + + if (GetMlasPlatform().MlasHalfConvPackWeightsAndBiasOverride == nullptr) { + return false; + } + + return GetMlasPlatform().MlasHalfConvPackWeightsAndBiasOverride( + FilterCount, + InputChannels, + KernelShape, + DilationShape, + Filter, + Bias, + PackedWeightsAndBias, + ThreadPool); +} diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp index 66a335665d024..e7fe787b7f7dd 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -19,9 +19,38 @@ Module Name: #include "mlas_float16.h" #include "halfgemm.h" +#if defined(USE_KLEIDIAI) +#include "kleidiai/mlasi_kleidiai.h" +#endif #include +static void +MlasHalfGemmZeroKBatch( + const size_t M, + const size_t N, + const size_t BatchN, + const MLAS_HALF_GEMM_DATA_PARAMS* DataParams + ) +{ + for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { + const auto* Data = &DataParams[gemm_i]; + auto* C = reinterpret_cast(Data->C); + const auto* Bias = reinterpret_cast(Data->Bias); + const size_t ldc = Data->ldc; + + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + C[m * ldc + n] = Bias == nullptr ? MLAS_FP16::FromBits(0) : Bias[n]; + } + } + + if (Data->OutputProcessor != nullptr) { + Data->OutputProcessor->Process(Data->C, 0, 0, M, N, ldc); + } + } +} + bool MLASCALL MlasFp16AccelerationSupported() { @@ -41,9 +70,41 @@ MlasHalfGemmBatch( const size_t K, const size_t BatchN, const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig ) { + MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig); + + if (BatchN == 0 || M == 0 || N == 0) { + return; + } + + if (K == 0) { + MlasHalfGemmZeroKBatch(M, N, BatchN, DataParams); + return; + } + +#if defined(USE_KLEIDIAI) + if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) && + GetMlasPlatform().MlasHalfGemmBatchOverride != nullptr && + GetMlasPlatform().MlasHalfGemmBatchOverride( + M, N, K, BatchN, DataParams, ThreadPool, BackendKernelSelectorConfig)) { + return; + } +#endif + + // BIsPacked denotes the generic MLAS halfgemm packed-B layout and can be + // consumed here. Backend-native packed layouts are separate and must not + // silently fall through to the generic kernels. + for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { + if (DataParams[gemm_i].BIsBackendNativePacked) { + MLAS_THROW_EX( + std::runtime_error, + "backend-native halfgemm packed B is not supported by generic MLAS halfgemm"); + } + } + const MLAS_HALFGEMM_DISPATCH* dispatch = MlasHalfGemmGetDispatch(); MLAS_HALFGEMM_OPERATION* operation = dispatch->Operation; @@ -128,11 +189,24 @@ MlasHalfGemmPackBSize( // No packing routine provided return 0; } - const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); - const size_t BytesRequired = N * AlignedK * FP16_SIZE + padding; + size_t aligned_k_input = 0; + if (MlasTryAddSizeT(K, PackedK - 1, &aligned_k_input)) { + return 0; + } + const size_t AlignedK = aligned_k_input & ~(PackedK - 1); + + size_t BytesRequired = 0; + if (MlasTryMultiplySizeT(N, AlignedK, &BytesRequired) || + MlasTryMultiplySizeT(BytesRequired, FP16_SIZE, &BytesRequired) || + MlasTryAddSizeT(BytesRequired, padding, &BytesRequired)) { + return 0; + } const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); - const size_t AlignedBytesRequired = - (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); + size_t aligned_bytes_input = 0; + if (MlasTryAddSizeT(BytesRequired, BufferAlignment - 1, &aligned_bytes_input)) { + return 0; + } + const size_t AlignedBytesRequired = aligned_bytes_input & ~(BufferAlignment - 1); return AlignedBytesRequired; } @@ -150,6 +224,62 @@ MlasHalfGemmPackB( dispatch->CopyPackBRoutine((_mlas_fp16_*)PackedB, (const _mlas_fp16_*)B, ldb, N, K); } +size_t +MLASCALL +MlasHalfGemmNativePackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig + ) +{ +#if defined(USE_KLEIDIAI) + if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) && + GetMlasPlatform().MlasHalfGemmPackBSizeOverride != nullptr) { + return GetMlasPlatform().MlasHalfGemmPackBSizeOverride(TransA, TransB, N, K); + } +#else + MLAS_UNREFERENCED_PARAMETER(TransA); + MLAS_UNREFERENCED_PARAMETER(TransB); + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig); +#endif + return 0; +} + +bool +MLASCALL +MlasHalfGemmNativePackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const MLAS_FP16* B, + size_t ldb, + void* PackedB, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig + ) +{ +#if defined(USE_KLEIDIAI) + if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) && + GetMlasPlatform().MlasHalfGemmPackBOverride != nullptr) { + return GetMlasPlatform().MlasHalfGemmPackBOverride(TransA, TransB, N, K, B, ldb, PackedB); + } +#else + MLAS_UNREFERENCED_PARAMETER(TransA); + MLAS_UNREFERENCED_PARAMETER(TransB); + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(B); + MLAS_UNREFERENCED_PARAMETER(ldb); + MLAS_UNREFERENCED_PARAMETER(PackedB); + MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig); +#endif + return false; +} + void MLASCALL MlasHalfGemmConvertPackB( @@ -574,7 +704,7 @@ MlasGemmBatch( const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault = { MlasHalfGemmOperation, - nullptr, + MlasHalfGemmCopyPackB, MlasHalfGemmConvertPackB, MLAS_HALF_GEMM_KERNEL_DEFAULT::PackedK, MLAS_HALF_GEMM_KERNEL_DEFAULT::KernelMaxM, diff --git a/onnxruntime/core/mlas/lib/halfgemm.h b/onnxruntime/core/mlas/lib/halfgemm.h index 529db48f58e6f..4d0a9b9888b3a 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.h +++ b/onnxruntime/core/mlas/lib/halfgemm.h @@ -32,14 +32,57 @@ Module Name: #pragma once -#include #include +#include +#include +#include +#include #include #include "mlasi.h" #include "mlas_float16.h" +MLAS_FORCEINLINE +bool +MlasTryMultiplySizeT( + size_t a, + size_t b, + size_t* out + ) +{ +#if defined(__has_builtin) +#if __has_builtin(__builtin_mul_overflow) + return __builtin_mul_overflow(a, b, out); +#endif +#endif + if (b != 0 && a > (std::numeric_limits::max)() / b) { + return true; + } + *out = a * b; + return false; +} + +MLAS_FORCEINLINE +bool +MlasTryAddSizeT( + size_t a, + size_t b, + size_t* out + ) +{ +#if defined(__has_builtin) +#if __has_builtin(__builtin_add_overflow) + return __builtin_add_overflow(a, b, out); +#endif +#endif + if (a > (std::numeric_limits::max)() - b) { + return true; + } + *out = a + b; + return false; +} + /** * @brief Define the default striding parameters for * the half precision gemm operation @@ -71,12 +114,46 @@ MlasHalfGemmCopyPackB( size_t CountK ) { - MLAS_UNREFERENCED_PARAMETER(D); - MLAS_UNREFERENCED_PARAMETER(B); - MLAS_UNREFERENCED_PARAMETER(ldb); - MLAS_UNREFERENCED_PARAMETER(CountN); - MLAS_UNREFERENCED_PARAMETER(CountK); - // No packing needed by default + size_t aligned_count_k_input = 0; + if (MlasTryAddSizeT(CountK, KernelType::PackedK - 1, &aligned_count_k_input)) { + MLAS_THROW_EX(std::runtime_error, "MlasHalfGemmCopyPackB aligned K overflow"); + } + const size_t AlignedCountK = aligned_count_k_input & ~(KernelType::PackedK - 1); + size_t PaddingCountK = AlignedCountK - CountK; + + if (ldb == CountN) { + size_t bytes_to_copy = 0; + if (MlasTryMultiplySizeT(CountK, CountN, &bytes_to_copy) || + MlasTryMultiplySizeT(bytes_to_copy, sizeof(_mlas_fp16_), &bytes_to_copy)) { + MLAS_THROW_EX(std::runtime_error, "MlasHalfGemmCopyPackB size overflow"); + } + std::memcpy(D, B, bytes_to_copy); + if (PaddingCountK > 0) { + size_t padding_bytes = 0; + if (MlasTryMultiplySizeT(PaddingCountK, CountN, &padding_bytes) || + MlasTryMultiplySizeT(padding_bytes, sizeof(_mlas_fp16_), &padding_bytes)) { + MLAS_THROW_EX(std::runtime_error, "MlasHalfGemmCopyPackB padding size overflow"); + } + std::memset(D + CountK * CountN, 0, padding_bytes); + } + return; + } + + size_t row_bytes = 0; + if (MlasTryMultiplySizeT(CountN, sizeof(_mlas_fp16_), &row_bytes)) { + MLAS_THROW_EX(std::runtime_error, "MlasHalfGemmCopyPackB row size overflow"); + } + while (CountK > 0) { + std::memcpy(D, B, row_bytes); + B += ldb; + D += CountN; + CountK--; + } + while (PaddingCountK > 0) { + std::memset(D, 0, row_bytes); + D += CountN; + PaddingCountK--; + } } /** diff --git a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp index d7f5a90b00589..f9bb3da1c5b57 100644 --- a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp @@ -179,7 +179,7 @@ MlasHalfGemmKernel( const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchNeon = { MlasHalfGemmOperation, - nullptr, + MlasHalfGemmCopyPackB, MlasHalfGemmConvertPackB, MLAS_HALF_GEMM_KERNEL_NEON::PackedK, MLAS_HALF_GEMM_KERNEL_NEON::KernelMaxM, diff --git a/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp b/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp index 6ee80594c6b49..a352fc85abbd1 100644 --- a/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp +++ b/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp @@ -28,6 +28,7 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h" // IMATMUL #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h" +#include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa.h" // SME2 kernels // GEMM/QGEMM/SBGEMM @@ -40,6 +41,11 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h" // IMATMUL #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h" +#include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h" + +// FP16 HGEMM kernels +#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot.h" #if defined(ENABLE_QMX_KERNELS) // QMX kernels (optional) @@ -229,6 +235,12 @@ const KaiF32IMatmulKernel imatmul_conv_sme = const KaiF32IMatmulKernel imatmul_conv_sme2 = KAI_WRAP_UKERNEL_RUN_IMATMUL_PACKED_7(imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa); +const KaiF16IMatmulKernel imatmul_f16_conv_sme = + KAI_WRAP_UKERNEL_RUN_IMATMUL_PACKED_7(imatmul_clamp_f16_f16p2vlx2_f16p2vlx2b_2vlx2vl_sme_mopa); + +const KaiF16IMatmulKernel imatmul_f16_conv_sme2 = + KAI_WRAP_UKERNEL_RUN_IMATMUL_PACKED_7(imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa); + const KaiBF16SBgemmKernel sbgemm_gemm_sme2 = KAI_WRAP_UKERNEL_RUN_MATMUL_11(matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa); @@ -246,6 +258,12 @@ const KaiDynamicQGemmKernel qgemm_gemm_sme = const KaiDynamicQGemmKernel qgemm_gemm_sme2 = KAI_WRAP_UKERNEL_RUN_MATMUL_11(matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa); +const KaiF16HgemmKernel hgemm_sme = + KAI_WRAP_UKERNEL_RUN_MATMUL_10_LHS_OFFSET(matmul_clamp_f16_f16_f16p2vlx2b_1x8vl_sme_mla); + +const KaiF16HgemmKernel hgemm_sme2 = + KAI_WRAP_UKERNEL_RUN_MATMUL_10_LHS_PACKED_OFFSET(matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot); + #if defined(ENABLE_QMX_KERNELS) @@ -350,6 +368,14 @@ const KaiF32IMatmulKernel& GetKleidiAIF32IMatmulUKernel() { } } +const KaiF16IMatmulKernel& GetKleidiAIF16IMatmulUKernel() { + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) { + return imatmul_f16_conv_sme2; + } else { + return imatmul_f16_conv_sme; + } +} + const KaiDynamicQGemmKernel& GetKleidiAIQGemmUKernel() { if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) { return qgemm_gemm_sme2; @@ -372,3 +398,11 @@ const KaiBF16SBgemmKernel& GetKleidiAISBGemmUKernel() { // Currently only SME2 variant exists for bfloat16/SBGEMM kernel return sbgemm_gemm_sme2; } + +const KaiF16HgemmKernel& GetKleidiAIHgemmUKernel() { + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) { + return hgemm_sme2; + } else { + return hgemm_sme; + } +} diff --git a/onnxruntime/core/mlas/lib/kai_ukernel_interface.h b/onnxruntime/core/mlas/lib/kai_ukernel_interface.h index 155ecf1762b3b..9b42e30a969d0 100644 --- a/onnxruntime/core/mlas/lib/kai_ukernel_interface.h +++ b/onnxruntime/core/mlas/lib/kai_ukernel_interface.h @@ -18,6 +18,10 @@ #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h" +#include "kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/kai_imatmul_clamp_f16_f16p_f16p_interface.h" + +#include "kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p_interface.h" + // Wrapper type that carries a stable "name" alongside the KAI ukernel interface. // This avoids needing to infer which underlying microkernel was selected from a function pointer. template @@ -41,8 +45,14 @@ using KaiDynamicQGemmKernel = KaiMatmulKernel; +// Wrapper for FP16 IMATMUL kernels used by the KleidiAI convolution implementation. +using KaiF16IMatmulKernel = KaiMatmulKernel; + using KaiBF16SBgemmKernel = KaiMatmulKernel; +// Wrapper for FP16 HGEMM kernels producing FP16 output. +using KaiF16HgemmKernel = KaiMatmulKernel; + // Returns the selected Qnbit GEMM ukernel based on runtime CPU capabilities. const KaiQnbitGemmKernel& GetKleidiAIGemmUKernel(); @@ -61,5 +71,11 @@ const KaiF32SgemvKernel& GetKleidiAISGemvUKernel(); // Returns the selected FP32 IMATMUL ukernel used by the KleidiAI convolution implementation. const KaiF32IMatmulKernel& GetKleidiAIF32IMatmulUKernel(); +// Returns the selected FP16 IMATMUL ukernel used by the KleidiAI convolution implementation. +const KaiF16IMatmulKernel& GetKleidiAIF16IMatmulUKernel(); + // Returns the selected BF16 SBGEMM ukernel used by the KleidiAI based on runtime CPU capabilities. -const KaiBF16SBgemmKernel& GetKleidiAISBGemmUKernel(); \ No newline at end of file +const KaiBF16SBgemmKernel& GetKleidiAISBGemmUKernel(); + +// Returns the selected FP16 HGEMM ukernel based on runtime CPU capabilities. +const KaiF16HgemmKernel& GetKleidiAIHgemmUKernel(); diff --git a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp index cca4f5a19c417..1b8b80ed496d9 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp @@ -480,14 +480,12 @@ static std::unique_ptr LhsPackImageDataSme(const size_t ci, const s 1, 1 }; - auto& lhs_ptrs_cache = lhs_ptrs_cache_by_pad[cur_pad_ptr]; - std::shared_ptr lhs_ptrs; - if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) { + if (auto found = lhs_ptrs_cache_by_pad[cur_pad_ptr].find(key); found != lhs_ptrs_cache_by_pad[cur_pad_ptr].end()) { lhs_ptrs = found->second; } else { lhs_ptrs = LhsPtrFill(ci, ih, iw, kh, kw, sh, sw, padding, &pad_ptr[0]); - lhs_ptrs_cache[key] = lhs_ptrs; + lhs_ptrs_cache_by_pad[cur_pad_ptr][key] = lhs_ptrs; } MultiThreadedLHSPackSme(ThreadPool, ci, m, kh, kw, &lhs_ptrs[0], &lhs[0], activation_src, &pad_ptr[0]); diff --git a/onnxruntime/core/mlas/lib/kleidiai/halfconv_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/halfconv_kleidiai.cpp new file mode 100644 index 0000000000000..f1261b399e380 --- /dev/null +++ b/onnxruntime/core/mlas/lib/kleidiai/halfconv_kleidiai.cpp @@ -0,0 +1,1002 @@ +// +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include +#include +#include +#include +#include +#include + +#include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h" +#include "kai_ukernel_interface.h" +#include "mlasi_kleidiai.h" + +namespace +{ + +// Cache-oriented heuristics for the chunked IMATMUL path. These bound the +// packed LHS working set so we only split when the single-pass pack/compute +// footprint is large enough that chunking can improve locality. +constexpr size_t AutomaticMaximumLhsChunkBytes = 2 * 1024 * 1024; +constexpr size_t MinimumLhsPackedBytesForAutomaticChunking = 8 * 1024 * 1024; +constexpr size_t MinimumEffectiveKForAutomaticChunking = 1024; +// Small filter counts typically do not amortize the extra chunking overhead. +constexpr size_t MaximumFilterCountForAutomaticChunking = 32; + +bool +TryComputeKernelSize( + size_t dilation, + size_t kernel, + size_t& dilated_kernel +) +{ + if (dilation == 0 || kernel == 0) { + return false; + } + + size_t scaled_kernel = 0; + if (mul_overflow_size_t_builtin(dilation, kernel, &scaled_kernel) || scaled_kernel < (dilation - 1)) { + return false; + } + + dilated_kernel = scaled_kernel - (dilation - 1); + return true; +} + +bool +TryComputeConvOutSize( + size_t input, + size_t kernel, + size_t padding, + size_t stride, + size_t& output +) +{ + output = 0; + if (stride == 0) { + return false; + } + + size_t total_padding = 0; + size_t padded_input = 0; + if (mul_overflow_size_t_builtin(padding, 2, &total_padding) || + !TryAddSize(input, total_padding, padded_input)) { + return false; + } + + if (padded_input < kernel) { + return true; + } + + output = ((padded_input - kernel) / stride) + 1; + return true; +} + +bool +TryComputeOutputSize( + size_t input_height, + size_t input_width, + size_t kernel_height, + size_t kernel_width, + size_t padding_height, + size_t padding_width, + size_t stride_height, + size_t stride_width, + size_t& output_height, + size_t& output_width, + size_t& output_size +) +{ + if (!TryComputeConvOutSize(input_height, kernel_height, padding_height, stride_height, output_height) || + !TryComputeConvOutSize(input_width, kernel_width, padding_width, stride_width, output_width) || + mul_overflow_size_t_builtin(output_height, output_width, &output_size)) { + return false; + } + + return true; +} + +bool +TryComputeOutputSize( + size_t input_height, + size_t input_width, + size_t kernel_height, + size_t kernel_width, + size_t padding_height, + size_t padding_width, + size_t stride_height, + size_t stride_width, + size_t& output_size +) +{ + size_t output_height = 0; + size_t output_width = 0; + return TryComputeOutputSize( + input_height, + input_width, + kernel_height, + kernel_width, + padding_height, + padding_width, + stride_height, + stride_width, + output_height, + output_width, + output_size); +} + +size_t +SelectMaximumLhsChunkBytes( + size_t full_lhs_size, + size_t filter_count, + size_t effective_k +) +{ + // The bounded-LHS path is most useful when LHS packing is cache-hostile and + // there are few output channels, so full-LHS reuse across N tiles is limited. + if (full_lhs_size >= MinimumLhsPackedBytesForAutomaticChunking && + filter_count <= MaximumFilterCountForAutomaticChunking && + effective_k >= MinimumEffectiveKForAutomaticChunking) { + return AutomaticMaximumLhsChunkBytes; + } + + return 0; +} + +bool +IsPaddingSymmetric2D(const MLAS_CONV_PARAMETERS* parameters) +{ + return parameters->Padding[0] == parameters->Padding[1] && + parameters->Padding[0] == parameters->Padding[2] && + parameters->Padding[0] == parameters->Padding[3]; +} + +bool +CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* parameters) +{ + if (parameters == nullptr) { + return false; + } + + if (parameters->BackendKernelSelectorConfig != nullptr && + !parameters->BackendKernelSelectorConfig->use_kleidiai) { + KLEIDIAI_DEBUG_LOG("User explicitly disabled KleidiAI, returning false from MlasHalfConv."); + return false; + } + + if ((parameters->Dimensions != 2) || + (parameters->BatchCount != 1) || + (parameters->GroupCount != 1) || + (parameters->Beta != 0.0f) || + !(parameters->Activation == nullptr || + parameters->Activation->ActivationKind == MlasIdentityActivation) || + !IsPaddingSymmetric2D(parameters)) { + KLEIDIAI_DEBUG_LOG("MlasHalfConv capability check failed: unsupported configuration."); + return false; + } + + size_t d_kh = 0; + size_t d_kw = 0; + size_t output_size = 0; + if (!TryComputeKernelSize(parameters->DilationShape[0], parameters->KernelShape[0], d_kh) || + !TryComputeKernelSize(parameters->DilationShape[1], parameters->KernelShape[1], d_kw) || + !TryComputeOutputSize( + parameters->InputShape[0], + parameters->InputShape[1], + d_kh, + d_kw, + parameters->Padding[0], + parameters->Padding[1], + parameters->StrideShape[0], + parameters->StrideShape[1], + output_size)) { + return false; + } + + if (output_size == 0 || + output_size != parameters->OutputSize || + parameters->InputChannels == 0 || + parameters->FilterCount == 0 || + parameters->FilterCount == 1 || + parameters->KernelShape[0] < 3 || + parameters->KernelShape[1] < 3) { + KLEIDIAI_DEBUG_LOG("MlasHalfConv capability check failed: shape/heuristic gating."); + return false; + } + + return true; +} + +bool +GetPackedFilterSize( + size_t filter_count, + size_t input_channels, + const int64_t* kernel_shape, + const int64_t* dilation_shape, + size_t* packed_size +) +{ + if (packed_size == nullptr || + kernel_shape == nullptr || + dilation_shape == nullptr || + filter_count <= 1 || + input_channels == 0 || + kernel_shape[0] < 3 || + kernel_shape[1] < 3 || + dilation_shape[0] <= 0 || + dilation_shape[1] <= 0) { + return false; + } + + size_t d_kh = 0; + size_t d_kw = 0; + size_t k_chunk_count = 0; + if (!TryComputeKernelSize(static_cast(dilation_shape[0]), static_cast(kernel_shape[0]), d_kh) || + !TryComputeKernelSize(static_cast(dilation_shape[1]), static_cast(kernel_shape[1]), d_kw) || + mul_overflow_size_t_builtin(d_kh, d_kw, &k_chunk_count)) { + return false; + } + + *packed_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + filter_count, k_chunk_count, input_channels + ); + return *packed_size != 0; +} + +bool +NchwToNhwc( + const MLAS_FP16* input, + size_t channels, + size_t height, + size_t width, + std::vector& output +) +{ + size_t element_count = 0; + if (mul_overflow_size_t_builtin(channels, height, &element_count) || + mul_overflow_size_t_builtin(element_count, width, &element_count)) { + return false; + } + output.resize(element_count); + + if (input == nullptr) { + return false; + } + + for (size_t c = 0; c < channels; ++c) { + for (size_t h = 0; h < height; ++h) { + for (size_t w = 0; w < width; ++w) { + output[(h * width + w) * channels + c] = input[(c * height + h) * width + w]; + } + } + } + + return true; +} + +bool +FillIndirectionTable( + const MLAS_FP16* input_nhwc, + const MLAS_FP16* pad, + size_t input_channels, + size_t input_height, + size_t input_width, + size_t kernel_height, + size_t kernel_width, + size_t stride_height, + size_t stride_width, + size_t padding, + size_t output_size, + std::vector& indirection +) +{ + const auto& imatmul = GetKleidiAIF16IMatmulUKernel(); + const size_t m_step = imatmul.ukernel.get_m_step(); + size_t lhs_ptrs_k = 0; + if (mul_overflow_size_t_builtin(kernel_height, kernel_width, &lhs_ptrs_k)) { + return false; + } + + size_t lhs_ptrs_m = 0; + if (mul_overflow_size_t_builtin(m_step, MlasDivRoundup(output_size, m_step), &lhs_ptrs_m)) { + return false; + } + + size_t table_size = 0; + if (mul_overflow_size_t_builtin(lhs_ptrs_k, lhs_ptrs_m, &table_size)) { + return false; + } + indirection.resize(table_size); + + std::fill(indirection.begin(), indirection.end(), pad); + + auto ptr_offset = [lhs_ptrs_k, m_step](size_t k, size_t m) { + return ((m / m_step) * lhs_ptrs_k * m_step) + (k * m_step) + (m % m_step); + }; + + auto pixel_ptr = [=](size_t h, size_t w) -> const void* { + if (h < padding || w < padding) { + return pad; + } + + h -= padding; + w -= padding; + if (h >= input_height || w >= input_width) { + return pad; + } + + return input_nhwc + (h * input_width + w) * input_channels; + }; + + size_t output_height = 0; + size_t output_width = 0; + size_t computed_output_size = 0; + if (!TryComputeOutputSize( + input_height, + input_width, + kernel_height, + kernel_width, + padding, + padding, + stride_height, + stride_width, + output_height, + output_width, + computed_output_size) || + computed_output_size != output_size) { + return false; + } + + size_t m = 0; + for (size_t oh = 0; oh < output_height; ++oh) { + for (size_t ow = 0; ow < output_width; ++ow, ++m) { + size_t k = 0; + const size_t input_base_h = oh * stride_height; + const size_t input_base_w = ow * stride_width; + for (size_t kh = 0; kh < kernel_height; ++kh) { + for (size_t kw = 0; kw < kernel_width; ++kw, ++k) { + indirection[ptr_offset(k, m)] = pixel_ptr(input_base_h + kh, input_base_w + kw); + } + } + } + } + + return m == output_size; +} + +bool +PrepareLhsInput( + const MLAS_CONV_PARAMETERS* parameters, + const MLAS_FP16* input, + size_t output_size, + std::vector& input_nhwc, + std::vector& pad, + std::vector& indirection +) +{ + size_t d_kh = 0; + size_t d_kw = 0; + if (!TryComputeKernelSize(parameters->DilationShape[0], parameters->KernelShape[0], d_kh) || + !TryComputeKernelSize(parameters->DilationShape[1], parameters->KernelShape[1], d_kw)) { + return false; + } + + const size_t input_channels = parameters->InputChannels; + + const MLAS_FP16* input_nhwc_data = input; + if (!parameters->InputOutputChannelsLast) { + if (!NchwToNhwc( + input, + input_channels, + parameters->InputShape[0], + parameters->InputShape[1], + input_nhwc + )) { + return false; + } + input_nhwc_data = input_nhwc.data(); + } + + pad.resize(input_channels); + std::fill(pad.begin(), pad.end(), MLAS_FP16::FromBits(0)); + + return FillIndirectionTable( + input_nhwc_data, + pad.data(), + input_channels, + parameters->InputShape[0], + parameters->InputShape[1], + d_kh, + d_kw, + parameters->StrideShape[0], + parameters->StrideShape[1], + parameters->Padding[0], + output_size, + indirection + ); +} + +bool +PackFilter( + size_t filter_count, + size_t input_channels, + const int64_t* kernel_shape, + const int64_t* dilation_shape, + const MLAS_FP16* filter, + const MLAS_FP16* bias, + void* packed_filter +) +{ + if (filter == nullptr || packed_filter == nullptr || + kernel_shape == nullptr || dilation_shape == nullptr || + filter_count <= 1 || input_channels == 0 || + kernel_shape[0] < 3 || kernel_shape[1] < 3 || + dilation_shape[0] <= 0 || dilation_shape[1] <= 0) { + return false; + } + + const size_t kernel_height = static_cast(kernel_shape[0]); + const size_t kernel_width = static_cast(kernel_shape[1]); + const size_t dilation_height = static_cast(dilation_shape[0]); + const size_t dilation_width = static_cast(dilation_shape[1]); + size_t d_kh = 0; + size_t d_kw = 0; + size_t k_chunk_count = 0; + if (!TryComputeKernelSize(dilation_height, kernel_height, d_kh) || + !TryComputeKernelSize(dilation_width, kernel_width, d_kw) || + mul_overflow_size_t_builtin(d_kh, d_kw, &k_chunk_count)) { + return false; + } + + size_t reordered_size = 0; + if (mul_overflow_size_t_builtin(k_chunk_count, input_channels, &reordered_size) || + mul_overflow_size_t_builtin(reordered_size, filter_count, &reordered_size)) { + return false; + } + + std::vector reordered_filter; + reordered_filter.resize(reordered_size); + std::fill(reordered_filter.begin(), reordered_filter.end(), MLAS_FP16::FromBits(0)); + + for (size_t oc = 0; oc < filter_count; ++oc) { + for (size_t ic = 0; ic < input_channels; ++ic) { + for (size_t kh = 0; kh < kernel_height; ++kh) { + for (size_t kw = 0; kw < kernel_width; ++kw) { + const size_t src = ((oc * input_channels + ic) * kernel_height + kh) * kernel_width + kw; + const size_t dk = ((kh * dilation_height) * d_kw + (kw * dilation_width)) * input_channels + ic; + reordered_filter[dk * filter_count + oc] = filter[src]; + } + } + } + } + + std::vector zero_bias; + const MLAS_FP16* bias_data = bias; + if (bias_data == nullptr) { + zero_bias.resize(filter_count); + std::fill(zero_bias.begin(), zero_bias.end(), MLAS_FP16::FromBits(0)); + bias_data = zero_bias.data(); + } + + KLEIDIAI_KERNEL_LOG("kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme" << " N=" << filter_count << " k_chunk_count=" << k_chunk_count << " k_chunk_length=" << input_channels << " rhs_stride_row=" << (filter_count * sizeof(MLAS_FP16))); + kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme( + filter_count, + k_chunk_count, + input_channels, + filter_count * sizeof(MLAS_FP16), + reordered_filter.data(), + bias_data, + packed_filter + ); + + return true; +} + +bool +ConvolveSme( + const MLAS_CONV_PARAMETERS* parameters, + const MLAS_FP16* input, + const MLAS_FP16* filter, + bool filter_and_bias_are_packed, + const MLAS_FP16* bias, + MLAS_FP16* working_buffer, + MLAS_FP16* output, + MLAS_THREADPOOL* thread_pool +) +{ + if (input == nullptr || filter == nullptr || output == nullptr || + (!parameters->InputOutputChannelsLast && working_buffer == nullptr)) { + return false; + } + + size_t d_kh = 0; + size_t d_kw = 0; + size_t output_size = 0; + if (!TryComputeKernelSize(parameters->DilationShape[0], parameters->KernelShape[0], d_kh) || + !TryComputeKernelSize(parameters->DilationShape[1], parameters->KernelShape[1], d_kw) || + !TryComputeOutputSize( + parameters->InputShape[0], + parameters->InputShape[1], + d_kh, + d_kw, + parameters->Padding[0], + parameters->Padding[1], + parameters->StrideShape[0], + parameters->StrideShape[1], + output_size)) { + return false; + } + + std::vector input_nhwc; + std::vector pad; + std::vector indirection; + if (!PrepareLhsInput(parameters, input, output_size, input_nhwc, pad, indirection)) { + return false; + } + + std::vector packed_filter_buffer; + const std::byte* packed_filter = reinterpret_cast(filter); + if (!filter_and_bias_are_packed) { + const std::array kernel_shape{ + static_cast(parameters->KernelShape[0]), + static_cast(parameters->KernelShape[1]) + }; + const std::array dilation_shape{ + static_cast(parameters->DilationShape[0]), + static_cast(parameters->DilationShape[1]) + }; + size_t packed_filter_size = 0; + if (!GetPackedFilterSize( + parameters->FilterCount, + parameters->InputChannels, + kernel_shape.data(), + dilation_shape.data(), + &packed_filter_size + )) { + return false; + } + packed_filter_buffer.resize(packed_filter_size); + if (!PackFilter( + parameters->FilterCount, + parameters->InputChannels, + kernel_shape.data(), + dilation_shape.data(), + filter, + bias, + packed_filter_buffer.data() + )) { + return false; + } + packed_filter = packed_filter_buffer.data(); + } + + const auto& imatmul = GetKleidiAIF16IMatmulUKernel(); + const size_t base_n_step = imatmul.ukernel.get_n_step(); + const size_t base_m_step = imatmul.ukernel.get_m_step(); + size_t n_step = base_n_step; + size_t m_step = base_m_step; + const size_t filter_count = parameters->FilterCount; + const size_t input_channels = parameters->InputChannels; + + std::array dim{ + MlasDivRoundup(output_size, m_step), + MlasDivRoundup(filter_count, n_step) + }; + + size_t tile_count = 0; + if (mul_overflow_size_t_builtin(dim[0], dim[1], &tile_count)) { + return false; + } + + const size_t required_tiles = std::min( + static_cast(MlasGetMaximumThreadCount(thread_pool)), + tile_count + ); + + if (required_tiles == 0) { + return false; + } + + const size_t original_dim0 = dim[0]; + const size_t original_dim1 = dim[1]; + size_t scaled_dim0 = 0; + size_t scaled_dim1 = 0; + if (mul_overflow_size_t_builtin(required_tiles, original_dim0, &scaled_dim0) || + mul_overflow_size_t_builtin(required_tiles, original_dim1, &scaled_dim1)) { + return false; + } + + dim[0] = MlasDivRoundup(scaled_dim0, tile_count); + dim[1] = MlasDivRoundup(scaled_dim1, tile_count); + + size_t new_m_step = 0; + size_t new_n_step = 0; + if (mul_overflow_size_t_builtin(m_step, MlasDivRoundup(MlasDivRoundup(output_size, dim[0]), m_step), &new_m_step) || + mul_overflow_size_t_builtin(n_step, MlasDivRoundup(MlasDivRoundup(filter_count, dim[1]), n_step), &new_n_step)) { + return false; + } + m_step = new_m_step; + n_step = new_n_step; + + dim[0] = MlasDivRoundup(output_size, m_step); + dim[1] = MlasDivRoundup(filter_count, n_step); + size_t finalized_tile_count = 0; + if (mul_overflow_size_t_builtin(dim[0], dim[1], &finalized_tile_count)) { + return false; + } + + const float clamp_min = -std::numeric_limits::infinity(); + const float clamp_max = std::numeric_limits::infinity(); + size_t dst_stride = 0; + if (mul_overflow_size_t_builtin(filter_count, sizeof(MLAS_FP16), &dst_stride)) { + return false; + } + + MLAS_FP16* destination = parameters->InputOutputChannelsLast ? output : working_buffer; + + size_t kernel_chunk_count = 0; + size_t effective_k = 0; + if (mul_overflow_size_t_builtin(d_kh, d_kw, &kernel_chunk_count) || + mul_overflow_size_t_builtin(kernel_chunk_count, input_channels, &effective_k)) { + return false; + } + + const size_t full_lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + output_size, kernel_chunk_count, input_channels + ); + if (full_lhs_size == 0) { + return false; + } + + const size_t maximum_lhs_chunk_bytes = SelectMaximumLhsChunkBytes( + full_lhs_size, + filter_count, + effective_k + ); + + if (maximum_lhs_chunk_bytes == 0 || full_lhs_size <= maximum_lhs_chunk_bytes) { + std::vector packed_lhs; + packed_lhs.resize(full_lhs_size); + + KLEIDIAI_KERNEL_LOG("kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme" + << " M=" << output_size + << " k_chunk_count=" << kernel_chunk_count + << " k_chunk_length=" << input_channels); + kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + output_size, + kernel_chunk_count, + input_channels, + indirection.data(), + 0, + pad.data(), + packed_lhs.data() + ); + + std::atomic ok{true}; + MlasTrySimpleParallel(thread_pool, static_cast(finalized_tile_count), [&](ptrdiff_t tid) { + if (!ok.load(std::memory_order_relaxed)) { + return; + } + + const size_t m_idx = (static_cast(tid) / dim[1]) * m_step; + const size_t n_idx = (static_cast(tid) % dim[1]) * n_step; + const size_t tile_m = std::min(m_step, output_size - m_idx); + const size_t tile_n = std::min(n_step, filter_count - n_idx); + + const std::byte* lhs_tile = + packed_lhs.data() + imatmul.ukernel.get_lhs_packed_offset(m_idx, kernel_chunk_count, input_channels); + const std::byte* rhs_tile = + packed_filter + imatmul.ukernel.get_rhs_packed_offset(n_idx, kernel_chunk_count, input_channels); + + size_t dst_elements = 0; + size_t dst_bytes = 0; + if (mul_overflow_size_t_builtin(m_idx, filter_count, &dst_elements) || + !TryAddSize(dst_elements, n_idx, dst_elements) || + mul_overflow_size_t_builtin(dst_elements, sizeof(MLAS_FP16), &dst_bytes)) { + ok.store(false, std::memory_order_relaxed); + return; + } + + std::byte* dst_tile = reinterpret_cast(destination) + dst_bytes; + + KLEIDIAI_KERNEL_LOG(imatmul.name << " M=" << tile_m + << " N=" << tile_n + << " k_chunk_count=" << kernel_chunk_count + << " k_chunk_length=" << input_channels); + imatmul.ukernel.run_imatmul( + tile_m, + tile_n, + kernel_chunk_count, + input_channels, + lhs_tile, + rhs_tile, + dst_tile, + dst_stride, + clamp_min, + clamp_max + ); + }); + + if (!ok.load(std::memory_order_relaxed)) { + return false; + } + } else { + const size_t bytes_per_m_step = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + base_m_step, kernel_chunk_count, input_channels + ); + if (bytes_per_m_step == 0) { + return false; + } + + const size_t max_m_steps_per_chunk = std::min( + MlasDivRoundup(output_size, base_m_step), + std::max(1, maximum_lhs_chunk_bytes / bytes_per_m_step) + ); + size_t lhs_chunk_m = 0; + if (mul_overflow_size_t_builtin(base_m_step, max_m_steps_per_chunk, &lhs_chunk_m)) { + return false; + } + + const size_t lhs_chunk_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + lhs_chunk_m, kernel_chunk_count, input_channels + ); + if (lhs_chunk_size == 0 || lhs_chunk_size > std::vector().max_size()) { + return false; + } + + const size_t m_chunk_count = MlasDivRoundup(output_size, lhs_chunk_m); + std::atomic ok{true}; + MlasTrySimpleParallel(thread_pool, static_cast(m_chunk_count), [&](ptrdiff_t tid) { + if (!ok.load(std::memory_order_relaxed)) { + return; + } + + const size_t global_m_idx = static_cast(tid) * lhs_chunk_m; + const size_t chunk_m = std::min(lhs_chunk_m, output_size - global_m_idx); + size_t indirection_offset = 0; + size_t indirection_tiles = 0; + if (mul_overflow_size_t_builtin(global_m_idx / base_m_step, kernel_chunk_count, &indirection_tiles) || + mul_overflow_size_t_builtin(indirection_tiles, base_m_step, &indirection_offset)) { + ok.store(false, std::memory_order_relaxed); + return; + } + + std::vector packed_lhs; + packed_lhs.resize(lhs_chunk_size); + + KLEIDIAI_KERNEL_LOG("kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme" + << " M=" << chunk_m + << " k_chunk_count=" << kernel_chunk_count + << " k_chunk_length=" << input_channels); + kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme( + chunk_m, + kernel_chunk_count, + input_channels, + indirection.data() + indirection_offset, + 0, + pad.data(), + packed_lhs.data() + ); + + for (size_t n_tile_idx = 0; n_tile_idx < dim[1]; ++n_tile_idx) { + const size_t n_idx = n_tile_idx * n_step; + const size_t tile_n = std::min(n_step, filter_count - n_idx); + const std::byte* rhs_tile = + packed_filter + imatmul.ukernel.get_rhs_packed_offset(n_idx, kernel_chunk_count, input_channels); + + size_t dst_elements = 0; + size_t dst_bytes = 0; + if (mul_overflow_size_t_builtin(global_m_idx, filter_count, &dst_elements) || + !TryAddSize(dst_elements, n_idx, dst_elements) || + mul_overflow_size_t_builtin(dst_elements, sizeof(MLAS_FP16), &dst_bytes)) { + ok.store(false, std::memory_order_relaxed); + return; + } + std::byte* dst_tile = reinterpret_cast(destination) + dst_bytes; + + KLEIDIAI_KERNEL_LOG(imatmul.name << " M=" << chunk_m + << " N=" << tile_n + << " k_chunk_count=" << kernel_chunk_count + << " k_chunk_length=" << input_channels); + imatmul.ukernel.run_imatmul( + chunk_m, + tile_n, + kernel_chunk_count, + input_channels, + packed_lhs.data(), + rhs_tile, + dst_tile, + dst_stride, + clamp_min, + clamp_max + ); + } + }); + + if (!ok.load(std::memory_order_relaxed)) { + return false; + } + } + + if (parameters->InputOutputChannelsLast) { + return true; + } + + MlasTranspose(working_buffer, output, output_size, filter_count, thread_pool); + return true; +} + +} // namespace + +bool + MLASCALL + ArmKleidiAI::MlasHalfConvPrepare( + MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + bool InputOutputChannelsLast, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig + ) +{ + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + if (Parameters == nullptr || + InputShape == nullptr || + KernelShape == nullptr || + DilationShape == nullptr || + Padding == nullptr || + StrideShape == nullptr || + OutputShape == nullptr || + WorkingBufferSize == nullptr || + Dimensions != 2) { + return false; + } + + if (BackendKernelSelectorConfig != nullptr && + !BackendKernelSelectorConfig->use_kleidiai) { + KLEIDIAI_DEBUG_LOG("User explicitly disabled KleidiAI, returning false from MlasHalfConvPrepare."); + return false; + } + + Parameters->BackendKernelSelectorConfig = BackendKernelSelectorConfig; + Parameters->Activation = Activation; + Parameters->Dimensions = Dimensions; + Parameters->BatchCount = BatchCount; + Parameters->GroupCount = GroupCount; + Parameters->InputChannels = InputChannels; + Parameters->FilterCount = FilterCount; + Parameters->Beta = Beta; + Parameters->InputOutputChannelsLast = InputOutputChannelsLast; + + size_t input_size = 1; + size_t output_size = 1; + size_t k = InputChannels; + size_t dilated_kernel_size = InputChannels; + for (size_t dim = 0; dim < Dimensions; ++dim) { + if (InputShape[dim] <= 0 || + OutputShape[dim] <= 0 || + KernelShape[dim] <= 0 || + DilationShape[dim] <= 0 || + StrideShape[dim] <= 0 || + Padding[dim] < 0 || + Padding[dim + Dimensions] < 0) { + return false; + } + + Parameters->InputShape[dim] = static_cast(InputShape[dim]); + Parameters->OutputShape[dim] = static_cast(OutputShape[dim]); + Parameters->KernelShape[dim] = static_cast(KernelShape[dim]); + Parameters->DilationShape[dim] = static_cast(DilationShape[dim]); + Parameters->Padding[dim] = static_cast(Padding[dim]); + Parameters->Padding[dim + Dimensions] = static_cast(Padding[dim + Dimensions]); + Parameters->StrideShape[dim] = static_cast(StrideShape[dim]); + + size_t dilated_kernel_dim = 0; + if (!TryComputeKernelSize(Parameters->DilationShape[dim], Parameters->KernelShape[dim], dilated_kernel_dim)) { + return false; + } + + if (mul_overflow_size_t_builtin(input_size, Parameters->InputShape[dim], &input_size) || + mul_overflow_size_t_builtin(output_size, Parameters->OutputShape[dim], &output_size) || + mul_overflow_size_t_builtin(k, Parameters->KernelShape[dim], &k) || + mul_overflow_size_t_builtin(dilated_kernel_size, dilated_kernel_dim, &dilated_kernel_size)) { + return false; + } + } + + Parameters->InputSize = input_size; + Parameters->OutputSize = output_size; + Parameters->K = k; + Parameters->ThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (!CheckCapabilitiesSme(Parameters)) { + return false; + } + + size_t working_elements = 0; + if (mul_overflow_size_t_builtin(Parameters->OutputSize, Parameters->FilterCount, &working_elements)) { + return false; + } + *WorkingBufferSize = Parameters->InputOutputChannelsLast ? 0 : working_elements; + return true; +} + +bool + MLASCALL + ArmKleidiAI::MlasHalfConv( + const MLAS_CONV_PARAMETERS* Parameters, + const MLAS_FP16* Input, + const MLAS_FP16* Filter, + bool FilterAndBiasArePacked, + const MLAS_FP16* Bias, + MLAS_FP16* WorkingBuffer, + MLAS_FP16* Output, + MLAS_THREADPOOL* ThreadPool + ) +{ + if (!CheckCapabilitiesSme(Parameters)) { + return false; + } + + return ConvolveSme(Parameters, Input, Filter, FilterAndBiasArePacked, Bias, WorkingBuffer, Output, ThreadPool); +} + +size_t + MLASCALL + ArmKleidiAI::MlasHalfConvPackWeightsAndBiasSize( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape + ) +{ + size_t packed_size = 0; + if (!GetPackedFilterSize(FilterCount, InputChannels, KernelShape, DilationShape, &packed_size)) { + return 0; + } + return packed_size; +} + +bool + MLASCALL + ArmKleidiAI::MlasHalfConvPackWeightsAndBias( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape, + const MLAS_FP16* Filter, + const MLAS_FP16* Bias, + void* PackedWeightsAndBias, + MLAS_THREADPOOL* ThreadPool + ) +{ + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + return PackFilter( + FilterCount, + InputChannels, + KernelShape, + DilationShape, + Filter, + Bias, + PackedWeightsAndBias + ); +} diff --git a/onnxruntime/core/mlas/lib/kleidiai/halfgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/halfgemm_kleidiai.cpp new file mode 100644 index 0000000000000..bc10db249466d --- /dev/null +++ b/onnxruntime/core/mlas/lib/kleidiai/halfgemm_kleidiai.cpp @@ -0,0 +1,283 @@ +// +// SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include +#include +#include +#include +#include "mlas.h" + +#include "mlasi_kleidiai.h" + +#include "kai_ukernel_interface.h" + +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.h" + +namespace { +struct KaiHalfTlsBuffers { + std::vector lhs_converted; + std::vector rhs_converted; + std::vector bias_zero; + std::vector rhs_packed; +}; + +thread_local KaiHalfTlsBuffers g_kai_half_tls; + +template +bool TryResizeVector(std::vector& buffer, size_t size) { + if (size > buffer.max_size()) { + return false; + } + buffer.resize(size); + return true; +} + +static inline void ConvertFloatMatrixToHalf( + const float* src, + MLAS_FP16* dst, + size_t rows, + size_t cols, + size_t src_ld) { + for (size_t r = 0; r < rows; ++r) { + MlasConvertFloatToHalfBuffer(src + r * src_ld, dst + r * cols, cols); + } +} +} // namespace + +size_t +MLASCALL +ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K +) { + if (TransA != CblasNoTrans || TransB != CblasNoTrans || N == 0 || K == 0) { + return 0; + } + + return kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(N, K); +} + +bool +MLASCALL +ArmKleidiAI::MlasHalfGemmKleidiAIPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const MLAS_FP16* B, + size_t ldb, + void* PackedB +) { + if (TransA != CblasNoTrans || TransB != CblasNoTrans) { + return false; + } + + if (PackedB == nullptr || B == nullptr || N == 0 || K == 0) { + return false; + } + + const size_t packed_rhs_size = ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize(TransA, TransB, N, K); + if (packed_rhs_size == 0) { + return false; + } + + std::vector zero_bias(N, MLAS_FP16::FromBits(0)); + + size_t ldb_bytes = 0; + if (mul_overflow_size_t_builtin(ldb, sizeof(MLAS_FP16), &ldb_bytes)) { + return false; + } + + const auto& hgemm = GetKleidiAIHgemmUKernel(); + kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme( + 1, N, K, hgemm.ukernel.get_nr(), hgemm.ukernel.get_kr(), hgemm.ukernel.get_sr(), ldb_bytes, + B, + zero_bias.data(), + nullptr, + PackedB, + 0, + nullptr); + + return true; +} + +bool +MLASCALL +ArmKleidiAI::MlasHalfGemmBatch( + size_t M, + size_t N, + size_t K, + size_t BatchN, + const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig +) { + if (BatchN == 0 || M == 0 || N == 0) { + return true; + } + if (K == 0) { + return false; + } + if (DataParams == nullptr) { + return false; + } + + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig); + // Validate all batch entries up front so we never partially execute and then + // fall back (which would corrupt results for the already-written outputs). + bool needs_rhs_packing = false; + for (size_t b = 0; b < BatchN; ++b) { + const auto& data = DataParams[b]; + if (data.OutputProcessor != nullptr) { + return false; + } + if (data.BIsBackendNativePacked && data.Bias != nullptr) { + return false; + } + if (data.BIsBackendNativePacked && data.ldb != 0) { + return false; + } + // Native-packed RHS is consumed directly below. Only allocate the + // runtime RHS packing scratch when at least one batch entry needs it. + needs_rhs_packing = needs_rhs_packing || !data.BIsBackendNativePacked; + } + + const auto& hgemm = GetKleidiAIHgemmUKernel(); + const size_t n_step = hgemm.ukernel.get_n_step(); + const size_t nr = hgemm.ukernel.get_nr(); + const size_t kr = hgemm.ukernel.get_kr(); + const size_t sr = hgemm.ukernel.get_sr(); + KLEIDIAI_KERNEL_LOG(hgemm.name); + + const size_t packed_rhs_size = kai_get_rhs_packed_size_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(N, K); + if (packed_rhs_size == 0) { + return false; + } + + const float clamp_min = -std::numeric_limits::infinity(); + const float clamp_max = std::numeric_limits::infinity(); + + if (needs_rhs_packing && !TryResizeVector(g_kai_half_tls.rhs_packed, packed_rhs_size)) { + return false; + } + + for (size_t b = 0; b < BatchN; ++b) { + const auto& data = DataParams[b]; + + const MLAS_FP16* lhs_base = reinterpret_cast(data.A); + const MLAS_FP16* rhs_base = reinterpret_cast(data.B); + const std::byte* rhs_packed = nullptr; + size_t lhs_ld = data.lda; + size_t rhs_ld = data.ldb; + + if (data.AIsfp32) { + size_t lhs_elements = 0; + if (mul_overflow_size_t_builtin(M, K, &lhs_elements) || + !TryResizeVector(g_kai_half_tls.lhs_converted, lhs_elements)) { + return false; + } + ConvertFloatMatrixToHalf( + reinterpret_cast(data.A), + g_kai_half_tls.lhs_converted.data(), + M, K, data.lda); + lhs_base = g_kai_half_tls.lhs_converted.data(); + lhs_ld = K; + } + + if (data.BIsBackendNativePacked) { + rhs_packed = reinterpret_cast(data.B); + } else if (data.ldb == 0) { + // Prepacked B from MlasHalfGemmPackB/MlasHalfGemmConvertPackB. + // For the current default halfgemm dispatch this is a row-major + // fp16 KxN buffer with leading dimension N. It is not the native + // KleidiAI RHS-packed layout, so this path falls back to packing + // it into KleidiAI format before execution. + rhs_ld = N; + } else if (data.BIsfp32) { + size_t rhs_elements = 0; + if (mul_overflow_size_t_builtin(K, N, &rhs_elements) || + !TryResizeVector(g_kai_half_tls.rhs_converted, rhs_elements)) { + return false; + } + ConvertFloatMatrixToHalf( + reinterpret_cast(data.B), + g_kai_half_tls.rhs_converted.data(), + K, N, data.ldb); + rhs_base = g_kai_half_tls.rhs_converted.data(); + rhs_ld = N; + } + + if (rhs_packed == nullptr) { + auto* rhs_packed_buffer = g_kai_half_tls.rhs_packed.data(); + + size_t ldb_bytes = 0; + if (mul_overflow_size_t_builtin(rhs_ld, sizeof(MLAS_FP16), &ldb_bytes)) { + return false; + } + if (data.Bias == nullptr) { + if (!TryResizeVector(g_kai_half_tls.bias_zero, N)) { + return false; + } + std::fill(g_kai_half_tls.bias_zero.begin(), g_kai_half_tls.bias_zero.end(), MLAS_FP16::FromBits(0)); + } + + kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme( + 1, N, K, nr, kr, sr, ldb_bytes, + rhs_base, + data.Bias != nullptr ? data.Bias : g_kai_half_tls.bias_zero.data(), + nullptr, + rhs_packed_buffer, + 0, + nullptr); + rhs_packed = rhs_packed_buffer; + } + + size_t lda_bytes = 0; + if (mul_overflow_size_t_builtin(lhs_ld, sizeof(MLAS_FP16), &lda_bytes)) { + return false; + } + size_t dst_stride_bytes = 0; + if (mul_overflow_size_t_builtin(data.ldc, sizeof(MLAS_FP16), &dst_stride_bytes)) { + return false; + } + + MlasTrySimpleParallel(ThreadPool, static_cast(M), [&](ptrdiff_t m_idx) { + const size_t m = static_cast(m_idx); + const auto* lhs = lhs_base + m * lhs_ld; + auto* dst_row = data.C + m * data.ldc; + const auto* rhs_packed_base = rhs_packed; + // The selected KleidiAI HGEMM micro-kernel is 1xN by design. + // We execute one output row per call and parallelize over rows. + constexpr size_t kernel_m = 1; + for (size_t n_idx = 0; n_idx < N; n_idx += n_step) { + const size_t tile_n = std::min(n_step, N - n_idx); + const auto* rhs_tile = rhs_packed_base + hgemm.ukernel.get_rhs_packed_offset(n_idx, K); + auto* dst_tile = reinterpret_cast( + reinterpret_cast(dst_row) + + hgemm.ukernel.get_dst_offset(0, n_idx, dst_stride_bytes)); + + hgemm.ukernel.run_matmul( + kernel_m, + tile_n, + K, + lhs, + lda_bytes, + rhs_tile, + dst_tile, + dst_stride_bytes, + sizeof(MLAS_FP16), + clamp_min, + clamp_max); + } + }); + + } + + return true; +} diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h index 3c9f398ece887..207935b8a8bd8 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h +++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h @@ -7,6 +7,7 @@ #pragma once #include "../mlasi.h" +#include #include // Fix to ensure compatibility with MSVC build @@ -224,6 +225,102 @@ MlasConvSymmetricChannelsLast2DFloatPackW( size_t PackedFilterGroupStride, MLAS_THREADPOOL* ThreadPool ); + +bool +MLASCALL +MlasHalfGemmBatch( + size_t M, + size_t N, + size_t K, + size_t BatchN, + const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig + ); + +size_t +MLASCALL +MlasHalfGemmKleidiAIPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K + ); + +// Packs B into the native KleidiAI RHS-packed layout for the supported +// halfgemm configuration only. This differs from the generic MLAS halfgemm +// prepacked-B format produced by MlasHalfGemmPackB and +// MlasHalfGemmConvertPackB, so generic MLAS prepacked weights may need to be +// repacked into this layout before running the KleidiAI halfgemm path. +// Unsupported transpose combinations return false/0 so the caller can fall +// back to the generic MLAS path. +bool +MLASCALL +MlasHalfGemmKleidiAIPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const MLAS_FP16* B, + size_t ldb, + void* PackedB + ); + +bool +MLASCALL +MlasHalfConvPrepare(MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + bool InputOutputChannelsLast, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig); + +bool +MLASCALL +MlasHalfConv( + const MLAS_CONV_PARAMETERS* Parameters, + const MLAS_FP16* Input, + const MLAS_FP16* Filter, + bool FilterAndBiasArePacked, + const MLAS_FP16* Bias, + MLAS_FP16* WorkingBuffer, + MLAS_FP16* Output, + MLAS_THREADPOOL* ThreadPool + ); + +size_t +MLASCALL +MlasHalfConvPackWeightsAndBiasSize( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape + ); + +bool +MLASCALL +MlasHalfConvPackWeightsAndBias( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape, + const MLAS_FP16* Filter, + const MLAS_FP16* Bias, + void* PackedWeightsAndBias, + MLAS_THREADPOOL* ThreadPool + ); } /*++ @@ -259,3 +356,42 @@ inline bool mul_overflow_size_t_builtin(size_t a, size_t b, size_t* out) { if (out) *out = a * b; return false; } + +/*++ + +Routine Description: + + This routine adds two size_t values and returns the sum when no wraparound + occurs. Uses __builtin_add_overflow if available on the current system and + falls back to a default implementation otherwise. + +Arguments: + + a - Supplies the first number to be added. + + b - Supplies the second number to be added. + + out - Supplies a size_t reference which receives the result in success + cases. + +Return Value: + + Returns true if the operation was successful + Returns false if wraparound of size_t was detected + +--*/ +inline bool TryAddSize(size_t a, size_t b, size_t& out) +{ +#if defined(__has_builtin) +# if __has_builtin(__builtin_add_overflow) + return !__builtin_add_overflow(a, b, &out); +# endif +#endif + + if (a > (std::numeric_limits::max)() - b) { + return false; + } + + out = a + b; + return true; +} diff --git a/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp index 9e70b0742217a..2e5306d442476 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp @@ -101,7 +101,7 @@ MLASCALL ArmKleidiAI::MlasDynamicQGemmBatch( const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, - const size_t BatchSize, + const size_t BatchN, MLAS_THREADPOOL* ThreadPool ) { @@ -112,7 +112,7 @@ ArmKleidiAI::MlasDynamicQGemmBatch( size_t m_step = qgemm_gemm.ukernel.get_m_step(); size_t n_step = qgemm_gemm.ukernel.get_n_step(); - if (BatchSize == 0 || Shape.M == 0 || Shape.N == 0 || Shape.K == 0) { + if (BatchN == 0 || Shape.M == 0 || Shape.N == 0 || Shape.K == 0) { return; } @@ -123,7 +123,7 @@ ArmKleidiAI::MlasDynamicQGemmBatch( MLAS_THROW_EX(std::runtime_error, "Dynamic QGEMM requires valid DataParams."); } - for (size_t batch_idx = 0; batch_idx < BatchSize; ++batch_idx) { + for (size_t batch_idx = 0; batch_idx < BatchN; ++batch_idx) { const auto& params = DataParams[batch_idx]; if (params.A == nullptr) { @@ -151,24 +151,24 @@ ArmKleidiAI::MlasDynamicQGemmBatch( const size_t LhsPackedStride = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr); std::byte* LhsPackedData = nullptr; - if (g_kai_tls_qgemm.lhs_packed.capacity() < LhsPackedStride * BatchSize) { + if (g_kai_tls_qgemm.lhs_packed.capacity() < LhsPackedStride * BatchN) { - g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchSize); + g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchN); } - g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchSize); + g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchN); LhsPackedData = g_kai_tls_qgemm.lhs_packed.data(); // Per-batch table of LHS base pointers. - if (g_kai_tls_qgemm.lhs_base_table.capacity() < BatchSize) { + if (g_kai_tls_qgemm.lhs_base_table.capacity() < BatchN) { - g_kai_tls_qgemm.lhs_base_table.reserve(BatchSize); + g_kai_tls_qgemm.lhs_base_table.reserve(BatchN); } - g_kai_tls_qgemm.lhs_base_table.resize(BatchSize); + g_kai_tls_qgemm.lhs_base_table.resize(BatchN); // Capture the shared batch table pointer so worker threads use the same backing storage. const std::byte** tls_lhs_base = g_kai_tls_qgemm.lhs_base_table.data(); // B batches require no packing. // We have already decided the matmul variant we are using before having values for M, N, and K. - MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t batch_idx) { std::byte* lhs = nullptr; if (DataParams[batch_idx].Workspace && DataParams[batch_idx].WorkspaceSize >= LhsPackedStride) { @@ -184,7 +184,7 @@ ArmKleidiAI::MlasDynamicQGemmBatch( // Tile iteration dimensions. std::array dim; - dim[0] = BatchSize; // B + dim[0] = BatchN; // B dim[1] = MlasDivRoundup(Shape.M, m_step); // M dim[2] = MlasDivRoundup(Shape.N, n_step); // N diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index dbb414505ff38..b06761befd48a 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -967,6 +967,90 @@ bool void* PackedB); #endif +typedef +bool +(MLASCALL MLAS_HALF_GEMM_BATCH_OVERRIDE)( + size_t M, + size_t N, + size_t K, + size_t BatchN, + const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig); + +typedef +size_t +(MLASCALL MLAS_HALF_GEMM_PACK_B_SIZE_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K); + +typedef +bool +(MLASCALL MLAS_HALF_GEMM_PACK_B_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const MLAS_FP16* B, + size_t ldb, + void* PackedB); + +typedef +bool +(MLASCALL MLAS_HALF_CONV_PREPARE_OVERRIDE)( + MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + bool InputOutputChannelsLast, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig); + +typedef +bool +(MLASCALL MLAS_HALF_CONV_OVERRIDE)( + const MLAS_CONV_PARAMETERS* Parameters, + const MLAS_FP16* Input, + const MLAS_FP16* Filter, + bool FilterAndBiasArePacked, + const MLAS_FP16* Bias, + MLAS_FP16* WorkingBuffer, + MLAS_FP16* Output, + MLAS_THREADPOOL* ThreadPool); + +typedef +size_t +(MLASCALL MLAS_HALF_CONV_PACK_WEIGHTS_AND_BIAS_SIZE_OVERRIDE)( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape); + +typedef +bool +(MLASCALL MLAS_HALF_CONV_PACK_WEIGHTS_AND_BIAS_OVERRIDE)( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape, + const MLAS_FP16* Filter, + const MLAS_FP16* Bias, + void* PackedWeightsAndBias, + MLAS_THREADPOOL* ThreadPool); + extern "C" { #if defined(MLAS_TARGET_AMD64_IX86) @@ -1485,6 +1569,15 @@ struct MLAS_PLATFORM { MLAS_DYNAMIC_QGEMM_BATCH_OVERRIDE* MlasDynamicQGemmBatchOverride = nullptr; MLAS_DYNAMIC_QGEMM_PACK_B_SIZE_OVERRIDE* MlasDynamicQGemmPackBSizeOverride = nullptr; MLAS_DYNAMIC_QGEMM_PACK_B_OVERRIDE* MlasDynamicQGemmPackBOverride = nullptr; + // MLAS HalfGemm overrides + MLAS_HALF_GEMM_BATCH_OVERRIDE* MlasHalfGemmBatchOverride = nullptr; + MLAS_HALF_GEMM_PACK_B_SIZE_OVERRIDE* MlasHalfGemmPackBSizeOverride = nullptr; + MLAS_HALF_GEMM_PACK_B_OVERRIDE* MlasHalfGemmPackBOverride = nullptr; + // MLAS HalfConv overrides + MLAS_HALF_CONV_PREPARE_OVERRIDE* MlasHalfConvPrepareOverride = nullptr; + MLAS_HALF_CONV_OVERRIDE* MlasHalfConvOverride = nullptr; + MLAS_HALF_CONV_PACK_WEIGHTS_AND_BIAS_SIZE_OVERRIDE* MlasHalfConvPackWeightsAndBiasSizeOverride = nullptr; + MLAS_HALF_CONV_PACK_WEIGHTS_AND_BIAS_OVERRIDE* MlasHalfConvPackWeightsAndBiasOverride = nullptr; // MLAS Conv overrides MLAS_CONV_PREPARE_FLOAT_OVERRIDE* MlasConvPrepareOverride = nullptr; MLAS_CONV_FLOAT_OVERRIDE* MlasConvOverride = nullptr; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 466fa9a3e9497..e975787877462 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -677,13 +677,20 @@ Return Value: } #if defined(USE_KLEIDIAI) - if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME() || MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()){ this->MlasSGemmBatchOverride = ArmKleidiAI::MlasGemmBatch; this->MlasSGemmPackBSizeOverride = ArmKleidiAI::MlasGemmPackBSize; this->MlasSGemmPackBOverride = ArmKleidiAI::MlasGemmPackB; this->MlasDynamicQGemmBatchOverride = ArmKleidiAI::MlasDynamicQGemmBatch; this->MlasDynamicQGemmPackBSizeOverride = ArmKleidiAI::MlasDynamicQGemmPackBSize; this->MlasDynamicQGemmPackBOverride = ArmKleidiAI::MlasDynamicQGemmPackB; + this->MlasHalfGemmBatchOverride = ArmKleidiAI::MlasHalfGemmBatch; + this->MlasHalfGemmPackBSizeOverride = ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize; + this->MlasHalfGemmPackBOverride = ArmKleidiAI::MlasHalfGemmKleidiAIPackB; + this->MlasHalfConvPrepareOverride = ArmKleidiAI::MlasHalfConvPrepare; + this->MlasHalfConvOverride = ArmKleidiAI::MlasHalfConv; + this->MlasHalfConvPackWeightsAndBiasSizeOverride = ArmKleidiAI::MlasHalfConvPackWeightsAndBiasSize; + this->MlasHalfConvPackWeightsAndBiasOverride = ArmKleidiAI::MlasHalfConvPackWeightsAndBias; this->MlasConvPrepareOverride = ArmKleidiAI::MlasConvPrepare; this->MlasConvOverride = ArmKleidiAI::MlasConv; #if defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index c8a0bbaba9df5..38f045f913124 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -3,28 +3,422 @@ #include "core/optimizer/insert_cast_transformer.h" #include "core/framework/data_types.h" -#include "core/graph/graph_utils.h" #include "core/framework/compute_capability.h" +#include "core/graph/graph_utils.h" #include "core/graph/indexed_sub_graph.h" +#include "core/mlas/inc/mlas.h" + +#include +#include +#include using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; namespace onnxruntime { +void InsertCastTransformer::RecordPartitionAssignment(const onnxruntime::Graph& graph, + const onnxruntime::Node& node, + const std::string& ep_type) const { + if (!on_partition_assignment_fn_) { + return; + } + + auto indexed_subgraph = std::make_unique(); + indexed_subgraph->nodes.push_back(node.Index()); + const ComputeCapability compute_capability{std::move(indexed_subgraph)}; + on_partition_assignment_fn_(graph, compute_capability, ep_type); +} + +template +static bool IsTensorOfType(const NodeArg& node_arg) { + const auto* type_proto = node_arg.TypeAsProto(); + return node_arg.Exists() && + type_proto != nullptr && + DataTypeImpl::TypeFromProto(*type_proto) == DataTypeImpl::GetTensorType(); +} + +template +static bool HasTensorArgOfType(const NodeArgs& node_args) { + return std::any_of(node_args.cbegin(), node_args.cend(), + [](const NodeArg* node_arg) { + return node_arg != nullptr && IsTensorOfType(*node_arg); + }); +} + static bool IsMLFloat16Tensor(const NodeArg& node_arg) { - // Type() will return nullptr if node_arg.Exists() is true so don't need an additional check for that - return node_arg.Type() != nullptr && - DataTypeImpl::TypeFromProto(*node_arg.TypeAsProto()) == DataTypeImpl::GetTensorType(); + return IsTensorOfType(node_arg); } -bool InsertCastTransformer::NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const { - // If the node's input is float16 and currently the node is not assigned to any EP - // we need to insert a cast to float, and put the node on CPU for default behavior. - // We don't cast a node with a subgraph as we'd need to do a lot more checking of the subgraph inputs - // (both explicit and implicit) and contents to determine if it was safe to do so. - // TODO: a better check is to check does the CPU kernel with float exist or not. +static bool HasCpuFloat32FallbackKernel( + const onnxruntime::Node& node, + const InlinedVector>& cpu_kernel_registries, + const logging::Logger& logger); + +bool InsertCastTransformer::NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input, + const logging::Logger& logger) const { + // Returns true when this input is an fp16 input to an unassigned node that is eligible + // for the cast-to-fp32 fallback path. + // + // Nodes with subgraphs are excluded because rewriting explicit and implicit subgraph + // inputs safely requires additional checks of the subgraph boundaries and contents. return node->GetExecutionProviderType().empty() && !node->ContainsSubgraph() && - IsMLFloat16Tensor(*input); + IsMLFloat16Tensor(*input) && + HasCpuFloat32FallbackKernel(*node, cpu_kernel_registries_, logger); +} + +static bool HasFp16IO(const onnxruntime::Node& node) { + return HasTensorArgOfType(node.InputDefs()) || + HasTensorArgOfType(node.OutputDefs()); +} + +static bool HasFp16Input(const onnxruntime::Node& node) { + return HasTensorArgOfType(node.InputDefs()); +} + +static bool IsCpuFp16OptInPolicyOp(const onnxruntime::Node& node) { + // Standard-domain ops whose CPU fp16 kernels are governed by session.enable_cpu_fp16 + // and the fp32 fallback heuristic. Existing CPU fp16 kernels are intentionally not + // included in this opt-in/fallback policy. + return node.Domain().empty() && + (node.OpType() == "MatMul" || node.OpType() == "Gemm"); +} + +static std::optional DimValue(const ONNX_NAMESPACE::TensorShapeProto* shape, int dim_idx) { + if (!shape || dim_idx < 0 || dim_idx >= shape->dim_size()) { + return std::nullopt; + } + + const auto& dim = shape->dim(dim_idx); + if (!dim.has_dim_value()) { + return std::nullopt; + } + + return dim.dim_value(); +} + +static std::optional ProductOfDimsExceptLast(const ONNX_NAMESPACE::TensorShapeProto* shape) { + if (!shape || shape->dim_size() < 2) { + return std::nullopt; + } + + int64_t product = 1; + for (int i = 0; i < shape->dim_size() - 1; ++i) { + const auto dim = DimValue(shape, i); + if (!dim || *dim < 0) { + return std::nullopt; + } + + if (*dim != 0 && product > std::numeric_limits::max() / *dim) { + return std::nullopt; + } + product *= *dim; + } + + return product; +} + +static std::optional GetAttributeInt(const onnxruntime::Node& node, + const std::string& attr_name, + int64_t default_value) { + const auto& attrs = node.GetAttributes(); + const auto attr = attrs.find(attr_name); + if (attr == attrs.end()) { + return default_value; + } + + if (attr->second.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_INT) { + return std::nullopt; + } + + return attr->second.i(); +} + +struct CpuFp16GemmShape { + int64_t M; + int64_t N; + int64_t K; +}; + +static std::optional GetMatMulShapeForCpuFp16Heuristic(const onnxruntime::Node& node) { + const auto& inputs = node.InputDefs(); + if (inputs.size() < 2 || !inputs[0] || !inputs[1]) { + return std::nullopt; + } + + const auto* a_shape = inputs[0]->Shape(); + const auto* b_shape = inputs[1]->Shape(); + if (!a_shape || !b_shape || a_shape->dim_size() < 2 || b_shape->dim_size() < 2) { + return std::nullopt; + } + + const auto b_inner_dim = DimValue(b_shape, b_shape->dim_size() - 2); + const auto b_size_to_inner_dim = ProductOfDimsExceptLast(b_shape); + // Match MatMulComputeHelper: RHS shapes like [1, ..., 1, K, N] also flatten + // the left input because the leading RHS dims are only padding. + const bool flattens_left = a_shape->dim_size() >= b_shape->dim_size() && + b_inner_dim && b_size_to_inner_dim && + *b_size_to_inner_dim == *b_inner_dim; + + // Match MatMulComputeHelper: effectively 2D RHS flattens the left input, while + // genuinely batched RHS uses the per-GEMM row count. + auto M = flattens_left ? ProductOfDimsExceptLast(a_shape) + : DimValue(a_shape, a_shape->dim_size() - 2); + auto K = DimValue(a_shape, a_shape->dim_size() - 1); + auto N = DimValue(b_shape, b_shape->dim_size() - 1); + if (!M || !N || !K) { + return std::nullopt; + } + + return CpuFp16GemmShape{*M, *N, *K}; +} + +struct CpuFp16MatMulRhsShape { + int64_t N; + int64_t K; +}; + +static std::optional GetMatMulRhsShapeForCpuFp16Heuristic(const onnxruntime::Node& node) { + const auto& inputs = node.InputDefs(); + if (inputs.size() < 2 || !inputs[1]) { + return std::nullopt; + } + + const auto* b_shape = inputs[1]->Shape(); + if (!b_shape || b_shape->dim_size() != 2) { + return std::nullopt; + } + + auto K = DimValue(b_shape, b_shape->dim_size() - 2); + auto N = DimValue(b_shape, b_shape->dim_size() - 1); + if (!N || !K || *N < 0 || *K < 0) { + return std::nullopt; + } + + if (*K != 0 && *N > std::numeric_limits::max() / *K) { + return std::nullopt; + } + + return CpuFp16MatMulRhsShape{*N, *K}; +} + +static std::optional GetGemmShapeForCpuFp16Heuristic(const onnxruntime::Node& node) { + const auto& inputs = node.InputDefs(); + if (inputs.size() < 2 || !inputs[0] || !inputs[1]) { + return std::nullopt; + } + + const auto* a_shape = inputs[0]->Shape(); + const auto* b_shape = inputs[1]->Shape(); + if (!a_shape || !b_shape || a_shape->dim_size() != 2 || b_shape->dim_size() != 2) { + return std::nullopt; + } + + const auto trans_a = GetAttributeInt(node, "transA", 0); + const auto trans_b = GetAttributeInt(node, "transB", 0); + if (!trans_a || !trans_b) { + return std::nullopt; + } + + const auto M = DimValue(a_shape, *trans_a ? 1 : 0); + const auto K = DimValue(a_shape, *trans_a ? 0 : 1); + const auto N = DimValue(b_shape, *trans_b ? 0 : 1); + if (!M || !N || !K) { + return std::nullopt; + } + + return CpuFp16GemmShape{*M, *N, *K}; +} + +static bool ShouldKeepNativeCpuFp16ForMatMulOrGemm( + const Graph& graph, + const onnxruntime::Node& node, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG& mlas_backend_kernel_selector_config) { + constexpr int64_t kMaxNativeFp16GemvM = 4; + constexpr int64_t kMinNativeFp16ConstantMatMulNK = 512 * 1024; + constexpr int64_t kMinNativeFp16NK = 1024 * 1024; + + std::optional shape; + if (node.OpType() == "MatMul") { + const auto& inputs = node.InputDefs(); + const bool has_constant_rhs = + inputs.size() > 1 && inputs[1] && graph_utils::IsInitializer(graph, inputs[1]->Name(), true); + if (has_constant_rhs) { + const auto rhs_shape = GetMatMulRhsShapeForCpuFp16Heuristic(node); + if (!rhs_shape) { + return false; + } + + const auto nk = rhs_shape->N * rhs_shape->K; + if (nk < kMinNativeFp16ConstantMatMulNK) { + return false; + } + + return MlasHalfGemmNativePackBSize(CblasNoTrans, CblasNoTrans, + static_cast(rhs_shape->N), + static_cast(rhs_shape->K), + &mlas_backend_kernel_selector_config) != 0; + } + + shape = GetMatMulShapeForCpuFp16Heuristic(node); + } else if (node.OpType() == "Gemm") { + shape = GetGemmShapeForCpuFp16Heuristic(node); + } + + if (!shape || shape->M < 0 || shape->N < 0 || shape->K < 0) { + return false; + } + + if (shape->K != 0 && shape->N > std::numeric_limits::max() / shape->K) { + return false; + } + + const auto nk = shape->N * shape->K; + + if (node.OpType() == "MatMul") { + return shape->M <= kMaxNativeFp16GemvM && + nk >= kMinNativeFp16NK && + MlasHGemmSupported(CblasNoTrans, CblasNoTrans); + } + + if (shape->M > kMaxNativeFp16GemvM) { + return false; + } + + return nk >= kMinNativeFp16NK; +} + +static bool KernelDomainMatchesNode(const KernelDef& kernel_def, const onnxruntime::Node& node) { + const auto& kernel_domain = kernel_def.Domain(); + const auto& node_domain = node.Domain(); + return kernel_domain == node_domain || + (kernel_domain == kOnnxDomainAlias && node_domain.empty()); +} + +static bool KernelVersionMatchesNode(const KernelDef& kernel_def, const onnxruntime::Node& node) { + const auto [kernel_start_version, kernel_end_version] = kernel_def.SinceVersion(); + return kernel_start_version <= node.SinceVersion() && + kernel_end_version >= node.SinceVersion(); +} + +static bool IsKernelTypeCompatible(gsl::span enabled_types, + const ONNX_NAMESPACE::TypeProto& actual_type, + bool replace_fp16_with_float) { + const auto* type_to_check = &actual_type; + TypeProto float_tensor_type; + if (replace_fp16_with_float && + DataTypeImpl::TypeFromProto(actual_type) == DataTypeImpl::GetTensorType()) { + float_tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + type_to_check = &float_tensor_type; + } + + return std::any_of(enabled_types.begin(), enabled_types.end(), + [type_to_check](const DataTypeImpl* enabled_type) { + return enabled_type->IsCompatible(*type_to_check); + }); +} + +static bool KernelTypeConstraintsMatchAllNodeArgs(const onnxruntime::Node& node, + const KernelDef& kernel_def, + bool replace_fp16_with_float) { +#if !defined(ORT_MINIMAL_BUILD) + static const OpSchemaKernelTypeStrResolver resolver; +#else + static const KernelTypeStrResolver resolver; +#endif + + const auto actual_inputs = node.InputDefs(); + const auto actual_outputs = node.OutputDefs(); + const auto& actual_input_arg_counts = node.InputArgCount(); + InlinedVector actual_input_arg_offsets; + actual_input_arg_offsets.reserve(actual_input_arg_counts.size()); + int current_offset = 0; + for (const auto arg_count : actual_input_arg_counts) { + actual_input_arg_offsets.push_back(current_offset); + current_offset += arg_count; + } + + const auto CheckArg = [&](const NodeArg* arg, + gsl::span enabled_types) { + if (!arg || !arg->Exists()) { + return true; + } + + const auto* type_proto = arg->TypeAsProto(); + if (type_proto == nullptr) { + return false; + } + + return IsKernelTypeCompatible(enabled_types, *type_proto, replace_fp16_with_float); + }; + + for (const auto& [kernel_type_str, enabled_types] : kernel_def.TypeConstraints()) { + gsl::span constraint_args; + if (!resolver.ResolveKernelTypeStr(node, kernel_type_str, constraint_args).IsOK()) { + return false; + } + + for (const auto& [arg_type, formal_arg_idx] : constraint_args) { + if (arg_type == ArgType::kInput) { + if (formal_arg_idx >= actual_input_arg_counts.size() || + actual_input_arg_counts[formal_arg_idx] == 0) { + continue; + } + + const auto first_arg_idx = actual_input_arg_offsets[formal_arg_idx]; + for (int arg_idx = 0; arg_idx < actual_input_arg_counts[formal_arg_idx]; arg_idx++) { + const auto actual_arg_idx = static_cast(first_arg_idx + arg_idx); + ORT_ENFORCE(actual_arg_idx < actual_inputs.size()); + if (!CheckArg(actual_inputs[actual_arg_idx], enabled_types)) { + return false; + } + } + } else { + if (formal_arg_idx < actual_outputs.size() && + !CheckArg(actual_outputs[formal_arg_idx], enabled_types)) { + return false; + } + } + } + } + + return true; +} + +static bool HasCpuKernelWithTypeSupport( + const onnxruntime::Node& node, + const InlinedVector>& cpu_kernel_registries, + bool replace_fp16_with_float) { + for (const KernelRegistry* cpu_kernel_registry : cpu_kernel_registries) { + for (const auto& [_, kernel_create_info] : cpu_kernel_registry->GetKernelCreateMap()) { + const auto* kernel_def = kernel_create_info.kernel_def.get(); + if (kernel_def != nullptr && + kernel_def->Provider() == kCpuExecutionProvider && + kernel_def->OpName() == node.OpType() && + KernelDomainMatchesNode(*kernel_def, node) && + KernelVersionMatchesNode(*kernel_def, node) && + KernelTypeConstraintsMatchAllNodeArgs(node, *kernel_def, replace_fp16_with_float)) { + return true; + } + } + } + + return false; +} + +static bool HasCpuKernelForCurrentTypes( + const onnxruntime::Node& node, + const InlinedVector>& cpu_kernel_registries, + const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + return HasCpuKernelWithTypeSupport(node, cpu_kernel_registries, false); +} + +static bool HasCpuFloat32FallbackKernel( + const onnxruntime::Node& node, + const InlinedVector>& cpu_kernel_registries, + const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + return HasCpuKernelWithTypeSupport(node, cpu_kernel_registries, true); } onnxruntime::NodeArg* AddCastNode(onnxruntime::Graph& graph, @@ -50,21 +444,29 @@ onnxruntime::NodeArg* AddCastNode(onnxruntime::Graph& graph, // check if the node has an fp16 input but was not able to be assigned an execution provider. // we will need to add casts to/from fp32 around the node for it to be executed using the CPU EP. -static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node) { +static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node, + const InlinedVector>& cpu_kernel_registries, + const logging::Logger& logger) { bool not_assigned = node.GetExecutionProviderType().empty(); if (not_assigned) { - const auto& input_defs = node.InputDefs(); - bool has_fp16_input = std::any_of(input_defs.cbegin(), input_defs.cend(), - [](const NodeArg* input_def) { - return IsMLFloat16Tensor(*input_def); - }); - return has_fp16_input; + return HasFp16Input(node) && HasCpuFloat32FallbackKernel(node, cpu_kernel_registries, logger); } return false; } +static bool CpuAssignedFp16NodeNeedsFallbackCast( + const onnxruntime::Node& node, + const InlinedVector>& cpu_kernel_registries, + const logging::Logger& logger) { + return node.GetExecutionProviderType() == kCpuExecutionProvider && + !node.ContainsSubgraph() && + HasFp16IO(node) && + !HasCpuKernelForCurrentTypes(node, cpu_kernel_registries, logger) && + HasCpuFloat32FallbackKernel(node, cpu_kernel_registries, logger); +} + // Detect an isolated node that is able to process fp16 data but is between other nodes that have fp16 inputs // but will need a Cast inserted to enable them to run. // @@ -87,7 +489,7 @@ static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node) { // // Return true if all the fp16 inputs and outputs are connected to nodes that will be cast to fp32. static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph, - const KernelRegistry& cpu_kernel_registry, + const InlinedVector>& cpu_kernel_registries, const logging::Logger& logger) { // we can check if it's an isolated fp16 node // if node has input coming from other nodes (only consuming graph inputs or initializers if it doesn't), @@ -150,7 +552,13 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: type_constraint_map[type_str] = DataTypeImpl::GetTensorType(); break; // we don't have multiple tensors feeding into one input } - type_constraint_map[type_str] = DataTypeImpl::TypeFromProto(*(input_def->TypeAsProto())); + + const auto* type_proto = input_def->TypeAsProto(); + if (type_proto == nullptr) { + return false; + } + + type_constraint_map[type_str] = DataTypeImpl::TypeFromProto(*type_proto); break; // we don't have multiple tensors feeding into one input } } @@ -164,7 +572,7 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: const int arg_idx = input_edge->GetDstArgIndex(); if (fp16_args.find(arg_idx) != fp16_args.end()) { // if the node producing our fp16 input does not need its input cast to fp32 we should run in fp16 - if (!NodeNeedsInputCastToFp32(input_edge->GetNode())) { + if (!NodeNeedsInputCastToFp32(input_edge->GetNode(), cpu_kernel_registries, logger)) { return false; } } @@ -192,7 +600,12 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: fp16_args.emplace((int)idx); type_constraint_map[type_str] = DataTypeImpl::GetTensorType(); } else { - type_constraint_map[type_str] = DataTypeImpl::TypeFromProto(*(output_def->TypeAsProto())); + const auto* type_proto = output_def->TypeAsProto(); + if (type_proto == nullptr) { + return false; + } + + type_constraint_map[type_str] = DataTypeImpl::TypeFromProto(*type_proto); } } @@ -204,7 +617,7 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: const int arg_idx = output_edge->GetSrcArgIndex(); if (fp16_args.find(arg_idx) != fp16_args.end()) { // if the node producing our fp16 input does not need its input cast to fp32 we should run in fp16 - if (!NodeNeedsInputCastToFp32(output_edge->GetNode())) { + if (!NodeNeedsInputCastToFp32(output_edge->GetNode(), cpu_kernel_registries, logger)) { return false; } } @@ -212,32 +625,36 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: // now all fp16 inputs and outputs would have a cast // make sure fp32 version of the kernel is available. - const KernelCreateInfo* kernel_create_info{}; - const auto lookup_status = cpu_kernel_registry.TryFindKernel( - kCpuExecutionProvider, node.OpType(), node.Domain(), - node.SinceVersion(), type_constraint_map, logger, &kernel_create_info); - if (lookup_status.IsOK() && kernel_create_info != nullptr) { - return true; + for (const KernelRegistry* cpu_kernel_registry : cpu_kernel_registries) { + const KernelCreateInfo* kernel_create_info{}; + const auto lookup_status = cpu_kernel_registry->TryFindKernel( + kCpuExecutionProvider, node.OpType(), node.Domain(), + node.SinceVersion(), type_constraint_map, logger, &kernel_create_info); + if (lookup_status.IsOK() && kernel_create_info != nullptr) { + return true; + } } } return false; } -// These nodes have an fp16 CPU kernel and were therefore assigned to CPU EP by the partitioner; -// that assignment has already been recorded. Clearing the EP is the mechanism that makes -// NeedInsertCast return true so they get wrapped with fp32 casts like any other unassigned node. -// Collect their indices so ApplyImpl can skip the on_partition_assignment_fn_ callback for them: -// the callback is only for nodes that are newly receiving a CPU EP fallback from this transformer, -// not for nodes whose partitioner assignment is already on record. -static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry, - const logging::Logger& logger, - InlinedHashSet& nodes_with_recorded_cpu_assignment) { +static Status ForceSingleNodeCPUFloat16ToFloat32( + onnxruntime::Graph& graph, + const InlinedVector>& cpu_kernel_registries, + const logging::Logger& logger, + InlinedHashSet& forced_fp32_nodes, + InlinedHashSet& nodes_with_recorded_cpu_assignment) { for (auto& node : graph.Nodes()) { - if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry, logger)) { - // unassign the node so that NeedInsertCast will return true for it, forcing it to fp32 + if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registries, logger)) { + // These nodes have an fp16 CPU kernel and were therefore assigned to CPU EP by the partitioner; + // that assignment has already been recorded. Clearing the EP is the mechanism that makes + // NeedInsertCast return true so they get wrapped with fp32 casts like any other unassigned node. + // Track the node index as well so the later broad fp16-preservation logic does not + // immediately assign it back to CPU and undo the heuristic. nodes_with_recorded_cpu_assignment.insert(node.Index()); node.SetExecutionProviderType(""); + forced_fp32_nodes.insert(node.Index()); } } @@ -518,14 +935,16 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie // is rewritten to consume fp32 (via inserted casts) and given a CPU EP assignment. // on_partition_assignment_fn_ is fired to record each such new CPU EP assignment. // - // Exception: ForceSingleNodeCPUFloat16ToFloat32 may clear the EP of nodes that were already - // assigned to CPU EP by the partitioner (they have an fp16 CPU kernel and got cast-wrapped to - // avoid isolated fp16 islands). Those nodes are tracked here so we can skip the callback — - // their assignment was already recorded by the partitioner. + // Exception: some CPU EP assignments may be cleared by this transformer to route isolated or + // unprofitable fp16 islands through fp32 casts. Those assignments were already recorded by the + // partitioner, so track them and skip the callback when they are assigned back to CPU. + InlinedHashSet forced_fp32_nodes; InlinedHashSet nodes_with_recorded_cpu_assignment; - if (force_cpu_fp32_) - ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_, logger, - nodes_with_recorded_cpu_assignment)); + if (force_cpu_fp32_ && !cpu_kernel_registries_.empty()) { + ORT_RETURN_IF_ERROR( + ForceSingleNodeCPUFloat16ToFloat32(graph, cpu_kernel_registries_, logger, + forced_fp32_nodes, nodes_with_recorded_cpu_assignment)); + } GraphViewer graph_viewer(graph); auto& order = graph_viewer.GetNodesInTopologicalOrder(); @@ -540,11 +959,56 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie if (!node) return Status(ONNXRUNTIME, INVALID_ARGUMENT); + if (CpuAssignedFp16NodeNeedsFallbackCast(*node, cpu_kernel_registries_, logger)) { + node->SetExecutionProviderType(""); + } + + if (!enable_cpu_fp16_ && + node->GetExecutionProviderType() == kCpuExecutionProvider && + IsCpuFp16OptInPolicyOp(*node) && + HasFp16Input(*node) && + !node->ContainsSubgraph()) { + node->SetExecutionProviderType(""); + } + + if (enable_cpu_fp16_ && + force_cpu_fp32_ && + IsCpuFp16OptInPolicyOp(*node) && + HasFp16Input(*node) && + !node->ContainsSubgraph() && + (node->GetExecutionProviderType().empty() || + node->GetExecutionProviderType() == kCpuExecutionProvider) && + !ShouldKeepNativeCpuFp16ForMatMulOrGemm(graph, *node, mlas_backend_kernel_selector_config_) && + HasCpuFloat32FallbackKernel(*node, cpu_kernel_registries_, logger)) { + // Current Arm fp16 paths are profitable for constant-RHS MatMul once native + // packed-B is available, and for large GEMV-like shapes. Keep Gemm conservative + // until MLAS native fp16 is consistently faster across its common shapes. + if (node->GetExecutionProviderType() == kCpuExecutionProvider) { + nodes_with_recorded_cpu_assignment.insert(node->Index()); + } + node->SetExecutionProviderType(""); + forced_fp32_nodes.insert(node->Index()); + } + + const bool has_fp16_io = !node->ContainsSubgraph() && HasFp16IO(*node); + const bool has_cpu_fp16_kernel = + has_fp16_io && HasCpuKernelForCurrentTypes(*node, cpu_kernel_registries_, logger); + + if (enable_cpu_fp16_ && + node->GetExecutionProviderType().empty() && + has_cpu_fp16_kernel && + forced_fp32_nodes.find(node->Index()) == forced_fp32_nodes.end()) { + // When CPU fp16 is enabled, assign any currently-unassigned fp16-capable node to CPU + // so it is preserved in fp16 instead of being routed through the fp32 cast fallback. + node->SetExecutionProviderType(kCpuExecutionProvider); + RecordPartitionAssignment(graph, *node, kCpuExecutionProvider); + } + auto& inputs = node->MutableInputDefs(); std::map replacement_defs; bool casted = false; for (auto input : inputs) { - if (NeedInsertCast(node, input)) { + if (NeedInsertCast(node, input, logger)) { auto src_arg = input; if (input_def_updates.count(src_arg)) { replacement_defs[src_arg] = input_def_updates[src_arg]; @@ -566,17 +1030,10 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie if (casted) { // Set current node to run on the CPU execution provider - // Keep in mind that the EP will be empty because NeedInsertCast() already insures that + // Keep in mind that the EP will be empty because NeedInsertCast() already ensures that node->SetExecutionProviderType(kCpuExecutionProvider); - - // Record the new CPU EP assignment via the partition assignment callback if provided. - // Skip nodes in nodes_with_recorded_cpu_assignment: their CPU EP assignment was made by the partitioner - // and is already on record; the callback must not fire again for them. - if (on_partition_assignment_fn_ && nodes_with_recorded_cpu_assignment.find(node->Index()) == nodes_with_recorded_cpu_assignment.end()) { - auto sub_graph = std::make_unique(); - sub_graph->nodes = {node->Index()}; - ComputeCapability capability(std::move(sub_graph)); - on_partition_assignment_fn_(graph, capability, kCpuExecutionProvider); + if (nodes_with_recorded_cpu_assignment.find(node->Index()) == nodes_with_recorded_cpu_assignment.end()) { + RecordPartitionAssignment(graph, *node, kCpuExecutionProvider); } // Some ONNX operators have an attribute `dtype` which define the output type for these operators diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.h b/onnxruntime/core/optimizer/insert_cast_transformer.h index ca968fdfa1545..f5b1822d9fed0 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.h +++ b/onnxruntime/core/optimizer/insert_cast_transformer.h @@ -3,12 +3,14 @@ #pragma once #include "core/common/common.h" +#include "core/common/inlined_containers.h" #include "core/graph/graph_viewer.h" #include "core/framework/op_kernel.h" +#include "core/framework/graph_partitioner.h" #include "core/optimizer/graph_transformer.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/kernel_registry.h" -#include "core/framework/graph_partitioner.h" +#include "core/mlas/inc/mlas.h" namespace onnxruntime { @@ -23,28 +25,73 @@ class InsertCastTransformer : public onnxruntime::GraphTransformer { * @brief Initializer * @param name for logging purpose * @param cpu_kernel_registry used to query whether an op node can be safely created + * @param enable_cpu_fp16 if true, allows CPU fp16 kernels to run without forcing fp32 casts + * @param force_cpu_fp32 if true, applies the CPU fp16 fallback heuristic, which may keep selected + * fp16 CPU nodes on the existing fp32 cast fallback path when native fp16 is + * not expected to be profitable for the active MLAS backend + * @param mlas_backend_kernel_selector_config + * active MLAS backend selector config. Used by the CPU fp16 heuristic to avoid + * preserving native fp16 paths that rely on backend-specific support, such as + * native packed-B, when that backend is disabled or unavailable. + * @param on_partition_assignment_fn + * optional callback for recording nodes manually assigned by this transformer. */ InsertCastTransformer(const std::string& name, const KernelRegistry* cpu_kernel_registry, + bool enable_cpu_fp16 = false, bool force_cpu_fp32 = true, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config = nullptr, + OnPartitionAssignmentFunction on_partition_assignment_fn = {}) + : onnxruntime::GraphTransformer(name), + cpu_kernel_registries_(cpu_kernel_registry != nullptr ? InlinedVector>{cpu_kernel_registry} + : InlinedVector>{}), + enable_cpu_fp16_(enable_cpu_fp16), + force_cpu_fp32_(!cpu_kernel_registries_.empty() && force_cpu_fp32), + mlas_backend_kernel_selector_config_(mlas_backend_kernel_selector_config != nullptr + ? *mlas_backend_kernel_selector_config + : MLAS_BACKEND_KERNEL_SELECTOR_CONFIG{}), + on_partition_assignment_fn_(std::move(on_partition_assignment_fn)) {} + + InsertCastTransformer(const std::string& name, const KernelRegistry* cpu_kernel_registry, + OnPartitionAssignmentFunction on_partition_assignment_fn) + : InsertCastTransformer(name, cpu_kernel_registry, + /*enable_cpu_fp16*/ false, + /*force_cpu_fp32*/ true, + /*mlas_backend_kernel_selector_config*/ nullptr, + std::move(on_partition_assignment_fn)) {} + + InsertCastTransformer(const std::string& name, + InlinedVector> cpu_kernel_registries, + bool enable_cpu_fp16 = false, bool force_cpu_fp32 = true, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config = nullptr, OnPartitionAssignmentFunction on_partition_assignment_fn = {}) : onnxruntime::GraphTransformer(name), - cpu_kernel_registries_(cpu_kernel_registry), - force_cpu_fp32_(cpu_kernel_registry != nullptr), + cpu_kernel_registries_(std::move(cpu_kernel_registries)), + enable_cpu_fp16_(enable_cpu_fp16), + force_cpu_fp32_(!cpu_kernel_registries_.empty() && force_cpu_fp32), + mlas_backend_kernel_selector_config_(mlas_backend_kernel_selector_config != nullptr + ? *mlas_backend_kernel_selector_config + : MLAS_BACKEND_KERNEL_SELECTOR_CONFIG{}), on_partition_assignment_fn_(std::move(on_partition_assignment_fn)) {} private: Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; - bool NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const; + bool NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input, + const logging::Logger& logger) const; + void RecordPartitionAssignment(const onnxruntime::Graph& graph, const onnxruntime::Node& node, + const std::string& ep_type) const; - const KernelRegistry* cpu_kernel_registries_; + const InlinedVector> cpu_kernel_registries_; - // Currently because we only have very few cpu kernels support float16, place those nodes on float16 - // will introduce many cast between fp32 and fp16, which will slow the execution. - // A better solution is to have a cost model to evaluate does it works to place the node on float16. - // Here for simplify, we only force the single-node-float16 sub-graph to float32 + const bool enable_cpu_fp16_; + + // Some CPU fp16 kernels are only profitable for specific shapes and backend capabilities. A broader cost model would + // be better; for now we use conservative checks for known slower cases and native packed-B availability. const bool force_cpu_fp32_; + // Copied from session options so graph optimization makes the same backend-capability decisions that execution will. + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config_; + // Optional callback to record when nodes are assigned to CPU EP by this transformer. // Reuses the same callback type as GraphPartitioner to maintain consistent EP assignment tracking. - OnPartitionAssignmentFunction on_partition_assignment_fn_; + const OnPartitionAssignmentFunction on_partition_assignment_fn_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index fb854b19accfa..5476a9b210f65 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -177,11 +177,13 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 21, Atan); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Hardmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, float, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, double, MatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, MLFloat16, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, Softmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, Softmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, float, TopK); @@ -397,8 +399,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, double, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int64_t, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, float, @@ -575,6 +579,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, ScatterND); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, BitShift); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, BitShift); @@ -709,8 +714,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, #endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, MatMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Min); @@ -1745,6 +1752,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { float, Gemm)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc index 08dbc46213f65..222dcc4b3c6dc 100644 --- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc +++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc @@ -9,9 +9,13 @@ #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +#include + +#include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/common/float16.h" #include "core/framework/op_kernel.h" +#include "core/providers/cpu/mlas_backend_kernel_selector_config_utils.h" #include "core/providers/cpu/nn/conv_attributes.h" #include "contrib_ops/cpu/fused_activation.h" @@ -46,6 +50,13 @@ class FusedConvFp16 final : public OpKernel { FusedConvFp16(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) { ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK()); channels_last_ = (info.GetKernelDef().OpName() == "NhwcFusedConv"); + SetupMlasBackendKernelSelectorFromConfigOptions(mlas_backend_kernel_selector_config_, info.GetConfigOptions()); + const auto& input_defs = info.node().InputDefs(); + has_bias_input_ = input_defs.size() >= 3 && input_defs[2]->Exists(); + const Tensor* bias = nullptr; + if (has_bias_input_ && info.TryGetConstantInput(2, &bias)) { + constant_B_ = bias; + } } Status Compute(OpKernelContext* context) const override; @@ -93,8 +104,13 @@ class FusedConvFp16 final : public OpKernel { } MLAS_ACTIVATION activation_; + MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config_; ConvAttributes conv_attrs_; + bool has_bias_input_{false}; + const Tensor* constant_B_{nullptr}; TensorShape W_shape_; + BufferUniquePtr packed_halfconv_weights_and_bias_buffer_; + size_t packed_halfconv_weights_and_bias_size_{0}; BufferUniquePtr packed_W_buffer_; size_t packed_W_size_{0}; bool is_W_packed_{false}; @@ -139,8 +155,53 @@ Status FusedConvFp16::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr const size_t kernel_dim = group_input_channels * kernel_size; bool share_prepacked_weights = (prepacked_weights != nullptr); + const bool has_valid_constant_halfconv_bias = + constant_B_ != nullptr && + !share_prepacked_weights && + constant_B_->Shape().NumDimensions() == 1 && + constant_B_->Shape()[0] == static_cast(output_channels); const bool is_depthwise_conv = (group_input_channels == 1 && group_output_channels == 1); + + if (group_count == 1 && + rank == 4 && + output_channels > 1 && + activation_.ActivationKind == MlasIdentityActivation && + (!has_bias_input_ || has_valid_constant_halfconv_bias)) { + std::array halfconv_kernel_shape{shape[2], shape[3]}; + std::array halfconv_dilations{1, 1}; + if (!conv_attrs_.dilations.empty()) { + halfconv_dilations = {conv_attrs_.dilations[0], conv_attrs_.dilations[1]}; + } + + packed_halfconv_weights_and_bias_size_ = MlasHalfConvPackWeightsAndBiasSize( + output_channels, + group_input_channels, + halfconv_kernel_shape.data(), + halfconv_dilations.data(), + &mlas_backend_kernel_selector_config_); + if (packed_halfconv_weights_and_bias_size_ != 0) { + auto* packed_halfconv_weights_and_bias = alloc->Alloc(packed_halfconv_weights_and_bias_size_); + BufferUniquePtr packed_halfconv_weights_and_bias_buffer( + packed_halfconv_weights_and_bias, BufferDeleter(alloc)); + const auto* bias_data = constant_B_ != nullptr ? constant_B_->Data() : nullptr; + if (MlasHalfConvPackWeightsAndBias( + output_channels, + group_input_channels, + halfconv_kernel_shape.data(), + halfconv_dilations.data(), + Wdata, + bias_data, + packed_halfconv_weights_and_bias, + nullptr, + &mlas_backend_kernel_selector_config_)) { + packed_halfconv_weights_and_bias_buffer_ = std::move(packed_halfconv_weights_and_bias_buffer); + } else { + packed_halfconv_weights_and_bias_size_ = 0; + } + } + } + // Don't pack the filter buffer if the MlasConvDepthwise path is used. if (!is_depthwise_conv) { packed_W_size_ = MlasHalfGemmPackBSize(group_output_channels, kernel_dim, false); @@ -185,6 +246,9 @@ Status FusedConvFp16::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr } if (share_prepacked_weights) { + // `packed_halfconv_weights_and_bias_buffer_` is session-local only. It encodes a bias + // choice even when the bias is null, so it must not be shared under the + // W-only prepack cache key. prepacked_weights->buffers_.push_back(nullptr); // packed_W_buffer_ is nullptr prepacked_weights->buffer_sizes_.push_back(0); } @@ -222,12 +286,13 @@ Status FusedConvFp16::UseSharedPrePackedBuffers(std::vector& pr used_shared_buffers = true; - if (prepacked_buffers.size() == 1) { // This means that only packed_W_ exists + if (prepacked_buffers.size() == 1) { // only packed_W_ exists packed_W_buffer_ = std::move(prepacked_buffers[0]); - } else if (prepacked_buffers.size() == 2) { // This means that only reordered_W_ exists - // Enforce that the first "placeholder" buffer is nullptr + } else if (prepacked_buffers.size() == 2) { // placeholder + reordered_W_ ORT_ENFORCE(prepacked_buffers[0].get() == nullptr); reordered_W_buffer_ = std::move(prepacked_buffers[1]); + } else { + ORT_ENFORCE(false, "Unexpected number of shared prepacked fp16 conv buffers."); } return Status::OK(); @@ -296,6 +361,63 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { AllocatorPtr alloc; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); + concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); + + const auto* Bdata = B != nullptr ? B->Data() : nullptr; + + if (Sum == nullptr && kernel_rank == 2) { + MLAS_CONV_PARAMETERS parameters; + + size_t working_buffer_size = 0; + if (MlasHalfConvPrepare(¶meters, + kernel_rank, + narrow(N), + narrow(conv_attrs_.group), + narrow(C / conv_attrs_.group), + input_shape.GetDims().data(), + kernel_shape.data(), + dilations.data(), + pads.data(), + strides.data(), + output_shape.GetDims().data(), + narrow(M / conv_attrs_.group), + &activation_, + &working_buffer_size, + 0.0f, + channels_last_, + thread_pool, + &mlas_backend_kernel_selector_config_)) { + const MLFloat16* halfconv_filter = nullptr; + const MLFloat16* halfconv_bias = Bdata; + bool halfconv_filter_and_bias_are_packed = false; + + if (packed_halfconv_weights_and_bias_buffer_ != nullptr) { + halfconv_filter = static_cast(packed_halfconv_weights_and_bias_buffer_.get()); + halfconv_filter_and_bias_are_packed = true; + halfconv_bias = nullptr; + } else if (W != nullptr) { + halfconv_filter = W->Data(); + } + + if (halfconv_filter != nullptr) { + auto* working_data = working_buffer_size > 0 + ? alloc->Alloc(sizeof(MLFloat16) * SafeInt(working_buffer_size)) + : nullptr; + BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc)); + + if (MlasHalfConv(¶meters, + X->Data(), + halfconv_filter, + halfconv_filter_and_bias_are_packed, + halfconv_bias, + static_cast(working_buffer.get()), + Y->MutableData(), + thread_pool)) { + return Status::OK(); + } + } + } + } // Handle the case of a dynamic weight filter. BufferUniquePtr reordered_W_buffer; @@ -337,7 +459,6 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { const int64_t col_buffer_size = kernel_dim * output_image_size; const auto* Xdata = X->Data(); - const auto* Bdata = B != nullptr ? B->Data() : nullptr; auto* Ydata = Y->MutableData(); const auto* sum_data = Sum != nullptr ? Sum->Data() : nullptr; @@ -387,8 +508,6 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { padding_data.resize(static_cast(C), MLFloat16()); } - concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); - /************************************* * Thread partition idea: we are essentially partition a GEMM A[M,K] x B[K,N]. * Here B contains the conv filters, which are usually not big, so we assume @@ -544,7 +663,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { const auto* gemm_add = add_src == nullptr ? nullptr : worker_addsrc + group_id * group_output_channels; MLAS_HALF_GEMM_ACTIVATION_PROCESSOR act(activation_, gemm_add); - MLAS_HALF_GEMM_DATA_PARAMS gemm_params; + MLAS_HALF_GEMM_DATA_PARAMS gemm_params{}; gemm_params.A = AData; gemm_params.lda = lda; if (packed_W_buffer_) { @@ -563,7 +682,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { static_cast(output_count), static_cast(group_output_channels), static_cast(kernel_dim), - 1, &gemm_params, nullptr); + 1, &gemm_params, nullptr, &mlas_backend_kernel_selector_config_); } } }; diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index c0da9aec1e1b1..a86e5a36aadf9 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -206,26 +206,30 @@ void Gemm_MLFloat16(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, if (c_data == nullptr) beta = onnxruntime::MLFloat16::Zero; #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED - bool support_mlas = false; + bool support_mlas_bias = false; if (c_shape == nullptr) { - support_mlas = true; + support_mlas_bias = true; } else if (c_shape->NumDimensions() == 1 && (*c_shape)[0] == N) { - support_mlas = true; - } else if (c_shape->NumDimensions() == 2 && (((*c_shape)[0] == 1 && (*c_shape)[1] == N) || ((*c_shape)[0] == N && (*c_shape)[1] == 1))) { - support_mlas = true; + support_mlas_bias = true; + } else if (c_shape->NumDimensions() == 2 && + (((*c_shape)[0] == 1 && (*c_shape)[1] == N) || ((*c_shape)[0] == N && (*c_shape)[1] == 1))) { + support_mlas_bias = true; } - if (trans_a == CblasNoTrans && trans_b == CblasNoTrans && support_mlas && alpha.ToFloat() == 1.0 && beta.ToFloat() == 1.0) { - MLAS_HALF_GEMM_DATA_PARAMS data; + const bool use_mlas_no_bias = beta == onnxruntime::MLFloat16::Zero; + const bool use_mlas_bias = beta == onnxruntime::MLFloat16::One && support_mlas_bias; + if (trans_a == CblasNoTrans && trans_b == CblasNoTrans && + alpha == onnxruntime::MLFloat16::One && (use_mlas_no_bias || use_mlas_bias)) { + MLAS_HALF_GEMM_DATA_PARAMS data{}; data.A = a_data; data.lda = K; data.B = b_data; data.ldb = N; data.C = y_data; data.ldc = N; - if (c_shape != nullptr) { + if (use_mlas_bias && c_shape != nullptr) { data.Bias = c_data; } - MlasHalfGemmBatch(M, N, K, 1, &data, thread_pool); + MlasHalfGemmBatch(M, N, K, 1, &data, thread_pool, mlas_backend_kernel_selector_config); return; } #endif diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index ca91f46db93da..e19f8ec370d26 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -6,6 +6,9 @@ #include "core/providers/cpu/math/matmul_helper.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" +#include +#include +#include namespace onnxruntime { @@ -23,6 +26,13 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), MatMul); +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 1, 8, + MLFloat16, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + // opset 9 supports more types ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( MatMul, @@ -40,6 +50,14 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), MatMul); +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 9, + 12, + MLFloat16, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( MatMul, 9, @@ -72,6 +90,13 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), MatMul); +ONNX_CPU_OPERATOR_TYPED_KERNEL( + MatMul, + 13, + MLFloat16, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + ONNX_CPU_OPERATOR_TYPED_KERNEL( MatMul, 13, @@ -106,9 +131,8 @@ Status MatMul::Compute(OpKernelContext* ctx) const { if (helper.K() == 0) { // When we have (M, 0, N) then the inputs are empty, but the output should // be filled out with zeros. - EigenMatrixMapRowMajor dest(y->MutableData(), - narrow(helper.M()), narrow(helper.N())); - dest.setZero(); + auto output_span = gsl::make_span(y->MutableData(), narrow(y->Shape().Size())); + std::fill(output_span.begin(), output_span.end(), T{}); return Status::OK(); } @@ -235,6 +259,45 @@ bool GemmPackBBfloat16(AllocatorPtr& alloc, } #endif +bool GemmPackBHalfNative(AllocatorPtr& alloc, + const Tensor& tensor_b, + IAllocatorUniquePtr& packed_b, + size_t& packed_b_size, + TensorShape& b_shape, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config) { + // Only handle the common case of a 2D weight matrix. Additional matrices + // could be handled by stacking the packed buffers. + if (tensor_b.Shape().NumDimensions() != 2) { + return false; + } + + b_shape = tensor_b.Shape(); + + const size_t K = static_cast(b_shape[0]); + const size_t N = static_cast(b_shape[1]); + + packed_b_size = MlasHalfGemmNativePackBSize(CblasNoTrans, CblasNoTrans, N, K, + mlas_backend_kernel_selector_config); + if (packed_b_size == 0) { + return false; + } + + packed_b = IAllocator::MakeUniquePtr(alloc, packed_b_size, true); + auto* packed_b_data = packed_b.get(); + memset(packed_b_data, 0, packed_b_size); + + if (!MlasHalfGemmNativePackB(CblasNoTrans, CblasNoTrans, N, K, + reinterpret_cast(tensor_b.Data()), N, packed_b_data, + mlas_backend_kernel_selector_config)) { + packed_b.reset(); + packed_b_size = 0; + b_shape = TensorShape(); + return false; + } + + return true; +} + Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { @@ -284,6 +347,103 @@ Status MatMul::UseSharedPrePackedBuffers(std::vector& pr return Status::OK(); } +Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + ORT_UNUSED_PARAMETER(prepacked_weights); + + if (input_idx == 1) { + size_t packed_b_size = 0; + is_packed = GemmPackBHalfNative(alloc, tensor, packed_b_, packed_b_size, b_shape_, + &mlas_backend_kernel_selector_config_); + // The native fp16 packed-B layout depends on the active MLAS backend selector. + // Keep it owned by this kernel until shared prepacked weights carry layout metadata. + } + + return Status::OK(); +} + +Status MatMul::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span prepacked_buffer_sizes, + int input_idx, + /*out*/ bool& used_shared_buffers) { + ORT_UNUSED_PARAMETER(prepacked_buffers); + ORT_UNUSED_PARAMETER(prepacked_buffer_sizes); + ORT_UNUSED_PARAMETER(input_idx); + + // Native fp16 packed-B buffers are backend-layout-specific. Decline shared + // buffers until the shared prepack cache can validate that layout. + used_shared_buffers = false; + return Status::OK(); +} + +Status MatMul::Compute(OpKernelContext* ctx) const { + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); + + const Tensor* a = ctx->Input(0); + const Tensor* b = packed_b_ ? nullptr : ctx->Input(1); + const auto& b_shape = b ? b->Shape() : b_shape_; + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape)); + Tensor* y = ctx->Output(0, helper.OutputShape()); + + if (y->Shape().Size() == 0) { + return Status::OK(); + } + + if (helper.K() == 0) { + auto output_span = gsl::make_span(y->MutableData(), narrow(y->Shape().Size())); + std::fill(output_span.begin(), output_span.end(), MLFloat16{}); + return Status::OK(); + } + + const auto* a_data = a->Data(); + const auto* b_data = b ? b->Data() : nullptr; + auto* y_data = y->MutableData(); + + const size_t max_len = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(false); + const size_t ldb = helper.Ldb(false); + + if (M <= 2 && packed_b_ == nullptr && MlasHGemmSupported(CblasNoTrans, CblasNoTrans)) { + const auto alpha = MLFloat16(1.0f); + const auto beta = MLFloat16(0.0f); + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = b_data + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = alpha.val; + data[i].beta = beta.val; + } + + MlasGemmBatch(CblasNoTrans, CblasNoTrans, M, N, K, data.data(), max_len, thread_pool); + return Status::OK(); + } + + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = packed_b_ ? packed_b_.get() : static_cast(b_data + helper.RightOffsets()[i]); + data[i].ldb = packed_b_ ? 0 : ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].BIsBackendNativePacked = static_cast(packed_b_); + } + + MlasHalfGemmBatch(M, N, K, max_len, data.data(), thread_pool, &mlas_backend_kernel_selector_config_); + return Status::OK(); +} + Status MatMul::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h index d1c0df19f924e..25d222818aaf6 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.h +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -106,4 +106,28 @@ class MatMul final : public OpKernel { #endif }; +template <> +class MatMul final : public OpKernel { + public: + MatMul(const OpKernelInfo& info) : OpKernel(info) { + SetupMlasBackendKernelSelectorFromConfigOptions(mlas_backend_kernel_selector_config_, info.GetConfigOptions()); + } + + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, + int input_idx, + /*out*/ bool& used_shared_buffers) override; + + Status Compute(OpKernelContext* context) const override; + + private: + TensorShape b_shape_; + IAllocatorUniquePtr packed_b_; + MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config_; +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index bf21b300f07b6..517841af4f8e5 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -66,6 +66,7 @@ #endif #include "core/providers/cpu/controlflow/utils.h" #include "core/providers/cpu/cpu_execution_provider.h" +#include "core/providers/cpu/mlas_backend_kernel_selector_config_utils.h" #include "core/session/abi_devices.h" #ifdef USE_DML // TODO: This is necessary for the workaround in TransformGraph #include "core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h" @@ -1715,14 +1716,24 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool const InlinedVector> kernel_regs = kernel_registry_manager_.GetKernelRegistriesByProviderType(kCpuExecutionProvider); - const KernelRegistry* cpu_regs = nullptr; - if (!kernel_regs.empty()) { - // NOTE: This assumes that CPU kernels are always at the n-1 index of kernel registries vector as per the design - // of GetKernelRegistriesByProviderType function. - cpu_regs = kernel_regs[kernel_regs.size() - 1]; - } - - InsertCastTransformer insert_cast_transformer{"CastFloat16Transformer", cpu_regs, on_partition_assignment_fn}; + const bool enable_cpu_fp16 = + session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableCpuFp16, "0") == "1"; + const bool use_cpu_fp16_fp32_fallback_heuristic = + session_options_.config_options.GetConfigOrDefault( + kOrtSessionOptionsCpuFp16UseFp32FallbackHeuristic, "1") == "1"; + const bool force_cpu_fp32 = !enable_cpu_fp16 || use_cpu_fp16_fp32_fallback_heuristic; + + // Keep InsertCastTransformer's CPU fp16 profitability checks aligned with execution-time MLAS backend selection. + MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config; + SetupMlasBackendKernelSelectorFromConfigOptions(mlas_backend_kernel_selector_config, + session_options_.config_options); + + InsertCastTransformer insert_cast_transformer{ + "CastFloat16Transformer", kernel_regs, + /*enable_cpu_fp16*/ enable_cpu_fp16, + /*force_cpu_fp32*/ force_cpu_fp32, + &mlas_backend_kernel_selector_config, + on_partition_assignment_fn}; ORT_RETURN_IF_ERROR_SESSIONID_( apply_transformer_once(insert_cast_transformer, *session_logger_, graph, ((graph_optimizations_loop_level > 1) ? &is_graph_modified : nullptr))); diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index 608d30c12b587..3eeb0c5df8f1f 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -190,6 +190,35 @@ void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const float* A, const MlasGemm(CblasNoTrans, CblasNoTrans, M, N, K, 1.f, A, K, B, N, 0.f, C, N, threadpool, mlas_backend_kernel_selector_config); } +template <> +void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const MLFloat16* A, const MLFloat16* B, MLFloat16* C, ThreadPool* threadpool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config) { +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + MLAS_HALF_GEMM_DATA_PARAMS data{}; + data.A = A; + data.lda = static_cast(K); + data.B = B; + data.ldb = static_cast(N); + data.C = C; + data.ldc = static_cast(N); + MlasHalfGemmBatch(static_cast(M), static_cast(N), static_cast(K), 1, &data, threadpool, + mlas_backend_kernel_selector_config); +#else + ORT_UNUSED_PARAMETER(threadpool); + ORT_UNUSED_PARAMETER(mlas_backend_kernel_selector_config); +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + auto C_mat = EigenMatrixMap(reinterpret_cast(C), N, M); + C_mat.noalias() = ConstEigenMatrixMap(reinterpret_cast(B), N, K) * + ConstEigenMatrixMap(reinterpret_cast(A), K, M); +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif +#endif +} + #ifdef MLAS_SUPPORTS_GEMM_DOUBLE template <> void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const double* A, const double* B, double* C, ThreadPool* threadpool, diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 6268e425ebb61..ad4c8941da782 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -521,8 +521,12 @@ TEST(ContribOpAttentionTest, AttentionBatch1_Float16) { 3.154296875, 0.1082763671875, 4.25, 5.6484375, 3.970703125, 0.072998046875, 4.25, 5.6484375}; + // WebGPU Attention does not support mask_index input. + constexpr bool disable_webgpu = true; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, - batch_size, sequence_length, hidden_size, number_of_heads, true); + batch_size, sequence_length, hidden_size, number_of_heads, true /*use_float16*/, + false, false, 0, nullptr, nullptr, AttentionMaskType::MASK_1D_KEY_SEQ_LEN, 0, 0, + false /*disable_cpu*/, false /*disable_cuda*/, false /*disable_dml*/, disable_webgpu); } TEST(ContribOpAttentionTest, AttentionBatch2) { diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index de9e236b17e4e..5740eacceee82 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -292,7 +292,12 @@ void TestMatMulNBitsTyped(std::optional abs_error = std::nullopt, } else if (base_opts.accuracy_level == 4) { base_opts.output_abs_error = 0.1f; } else if constexpr (std::is_same::value) { - base_opts.output_abs_error = 0.055f; +#if defined(USE_WEBGPU) + // Match existing fp16 MatMulNBits tolerance for WebGPU builds while keeping CPU stricter. + base_opts.output_abs_error = 0.1f; +#else + base_opts.output_abs_error = 0.065f; +#endif } else { base_opts.output_abs_error = 0.05f; } diff --git a/onnxruntime/test/contrib_ops/nhwc_pool_in_op_test.cc b/onnxruntime/test/contrib_ops/nhwc_pool_in_op_test.cc index e40e635c79a26..c58d813e1b061 100644 --- a/onnxruntime/test/contrib_ops/nhwc_pool_in_op_test.cc +++ b/onnxruntime/test/contrib_ops/nhwc_pool_in_op_test.cc @@ -7,6 +7,7 @@ #include #include +#include #include "core/util/math.h" #include "core/mlas/inc/mlas.h" @@ -17,6 +18,16 @@ namespace onnxruntime { namespace test { #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +namespace { + +// These tests target ORT's CPU implementation of the MS-internal NHWC fp16 pool +// ops. Do not offer the models to EPs that may be registered by the test +// harness but do not own this internal-domain CPU/MLAS coverage. +const std::unordered_set kNhwcFp16PoolExcludedProviders{ + kCoreMLExecutionProvider, + kTensorrtExecutionProvider, +}; + class NhwcFp16PoolOpTester { private: bool is_max_pool_; // max or average pool @@ -181,10 +192,12 @@ class NhwcFp16PoolOpTester { if (!dilations_.empty()) { test.AddAttribute("dilations", dilations_); } - test.Run(OpTester::ExpectResult::kExpectSuccess, ""); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", kNhwcFp16PoolExcludedProviders); } }; +} // namespace + TEST(NhwcFp16PoolOpTest, MaxPool1D) { for (int64_t channels = 1; channels < 94; channels++) { NhwcFp16PoolOpTester test(true); @@ -303,7 +316,7 @@ TEST(NhwcFp16PoolOpTest, AvgPoolIncludePadPixel) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", kNhwcFp16PoolExcludedProviders); } TEST(NhwcFp16PoolOpTest, GlobalAveragePool) { @@ -508,7 +521,7 @@ TEST(NhwcFp16PoolOpTest, GlobalAveragePool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", kNhwcFp16PoolExcludedProviders); } #endif diff --git a/onnxruntime/test/framework/insert_cast_transformer_test.cc b/onnxruntime/test/framework/insert_cast_transformer_test.cc index 4fa3799f1750f..4429ac76e2d3c 100644 --- a/onnxruntime/test/framework/insert_cast_transformer_test.cc +++ b/onnxruntime/test/framework/insert_cast_transformer_test.cc @@ -2,15 +2,20 @@ // Licensed under the MIT License. #include "core/framework/allocator.h" +#include "core/framework/tensorprotoutils.h" #include "core/optimizer/insert_cast_transformer.h" #include "core/graph/model.h" #include "core/graph/node_attr_utils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include #include "gtest/gtest.h" +#include "test/internal_testing_ep/internal_testing_execution_provider.h" #include "test/unittest_util/framework_test_utils.h" #include "test/test_environment.h" #include "test/util/include/default_providers.h" #include "test/util/include/inference_session_wrapper.h" #include "test/util/include/asserts.h" +#include using namespace ONNX_NAMESPACE; namespace onnxruntime { @@ -19,6 +24,24 @@ namespace test { #define MODEL_FOLDER ORT_TSTR("testdata/transform/") typedef std::vector ArgMap; + +static TypeProto MakeTensorType(TensorProto_DataType elem_type, std::initializer_list shape = {}) { + TypeProto tensor_type; + tensor_type.mutable_tensor_type()->set_elem_type(elem_type); + if (shape.size() > 0) { + auto* tensor_shape = tensor_type.mutable_tensor_type()->mutable_shape(); + for (const auto dim : shape) { + tensor_shape->add_dim()->set_dim_value(dim); + } + } + + return tensor_type; +} + +static TypeProto MakeFp16TensorType(std::initializer_list shape = {}) { + return MakeTensorType(TensorProto_DataType_FLOAT16, shape); +} + TEST(TransformerTest, InsertCastGPUTest) { auto model = std::make_shared("test", false, DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& graph = model->MainGraph(); @@ -309,6 +332,770 @@ TEST(TransformerTest, Fp16NodeWithSubgraph) { EXPECT_EQ(new_subgraph_ops.find("Cast")->second, 3) << "'Add' node in subgraph should have had Casts added"; } +static std::shared_ptr MakeCpuFp16Model(const std::string& model_name, const std::string& op_type, + bool assign_cpu_ep) { + auto model = std::make_shared(model_name, false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model->MainGraph(); + + TypeProto tensor_float_16 = MakeFp16TensorType(); + + auto& a = graph.GetOrCreateNodeArg("A", &tensor_float_16); + auto& b = graph.GetOrCreateNodeArg("B", &tensor_float_16); + auto& y = graph.GetOrCreateNodeArg("Y", &tensor_float_16); + + Node& node = [&]() -> Node& { + if (op_type == "MatMul" || op_type == "Add") { + return graph.AddNode(op_type, op_type, "fp16 test op", ArgMap{&a, &b}, ArgMap{&y}); + } + + if (op_type == "Abs") { + return graph.AddNode(op_type, op_type, "fp16 test op", ArgMap{&a}, ArgMap{&y}); + } + + if (op_type == "Gemm") { + auto& c = graph.GetOrCreateNodeArg("C", &tensor_float_16); + graph.SetInputs({&a, &b, &c}); + return graph.AddNode(op_type, op_type, "fp16 test op", ArgMap{&a, &b, &c}, ArgMap{&y}); + } + + ORT_THROW("Unsupported op type for test: ", op_type); + }(); + + if (op_type == "Gemm") { + // Inputs already set when the node is created. + } else if (op_type == "Abs") { + graph.SetInputs({&a}); + } else { + graph.SetInputs({&a, &b}); + } + graph.SetOutputs({&y}); + + if (assign_cpu_ep) { + node.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); + } + + ORT_THROW_IF_ERROR(graph.Resolve()); + return model; +} + +static std::shared_ptr MakeCpuFp16MatMulModelWithShapes(const std::string& model_name, + std::initializer_list a_shape, + std::initializer_list b_shape, + std::initializer_list y_shape, + bool assign_cpu_ep, + bool constant_b = false) { + auto model = std::make_shared(model_name, false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model->MainGraph(); + + TypeProto a_type = MakeFp16TensorType(a_shape); + TypeProto b_type = MakeFp16TensorType(b_shape); + TypeProto y_type = MakeFp16TensorType(y_shape); + + auto& a = graph.GetOrCreateNodeArg("A", &a_type); + auto& b = graph.GetOrCreateNodeArg("B", &b_type); + auto& y = graph.GetOrCreateNodeArg("Y", &y_type); + + if (constant_b) { + TensorProto b_initializer; + b_initializer.set_name("B"); + b_initializer.set_data_type(TensorProto_DataType_FLOAT16); + for (const auto dim : b_shape) { + b_initializer.add_dims(dim); + } + + size_t element_count = 1; + for (const auto dim : b_shape) { + element_count *= static_cast(dim); + } + std::vector data(element_count, MLFloat16::Zero); + utils::SetRawDataInTensorProto(b_initializer, data.data(), data.size() * sizeof(MLFloat16)); + graph.AddInitializedTensor(b_initializer); + } + + auto& node = graph.AddNode("MatMul", "MatMul", "fp16 matmul shape heuristic test", ArgMap{&a, &b}, ArgMap{&y}); + if (assign_cpu_ep) { + node.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); + } + + if (constant_b) { + graph.SetInputs({&a}); + } else { + graph.SetInputs({&a, &b}); + } + graph.SetOutputs({&y}); + ORT_THROW_IF_ERROR(graph.Resolve()); + + return model; +} + +static std::shared_ptr MakeCpuFp16ResizeModel(const std::string& model_name, bool assign_cpu_ep) { + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 13; + auto model = std::make_shared(model_name, false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, + std::vector(), + DefaultLoggingManager().DefaultLogger()); + auto& graph = model->MainGraph(); + + TypeProto x_type = MakeFp16TensorType({1, 1, 2, 2}); + TypeProto roi_type = MakeTensorType(TensorProto_DataType_FLOAT, {0}); + TypeProto scales_type = MakeTensorType(TensorProto_DataType_FLOAT, {4}); + TypeProto y_type = MakeFp16TensorType({1, 1, 4, 4}); + + auto& x = graph.GetOrCreateNodeArg("X", &x_type); + auto& roi = graph.GetOrCreateNodeArg("roi", &roi_type); + auto& scales = graph.GetOrCreateNodeArg("scales", &scales_type); + auto& y = graph.GetOrCreateNodeArg("Y", &y_type); + + auto& node = graph.AddNode("Resize", "Resize", "fp16 resize fallback test", + ArgMap{&x, &roi, &scales}, ArgMap{&y}); + if (assign_cpu_ep) { + node.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); + } + + graph.SetInputs({&x, &roi, &scales}); + graph.SetOutputs({&y}); + ORT_THROW_IF_ERROR(graph.Resolve()); + + return model; +} + +static void ExpectMlFloat16Output(const std::vector& output, + const std::vector& expected) { + ASSERT_EQ(output.size(), expected.size()); + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_EQ(output[i].val, expected[i].val); + } +} + +static bool IsNodeArgType(const NodeArg& node_arg, const MLDataType type) { + return node_arg.Type() != nullptr && + DataTypeImpl::TypeFromProto(*node_arg.TypeAsProto()) == type; +} + +static bool CanRunNativeCpuFp16GemmRuntime() { +#if defined(__aarch64__) || defined(_M_ARM64) + return MlasFp16AccelerationSupported(); +#else + return false; +#endif +} + +static const Node* FindNodeByOpType(const Graph& graph, const std::string& op_type) { + for (const auto& node : graph.Nodes()) { + if (node.OpType() == op_type) { + return &node; + } + } + + return nullptr; +} + +static std::vector RunCpuFp16Model(const std::string& model_name, + const std::string& op_type, + bool enable_cpu_fp16) { + const auto input_path = ToPathString(model_name + (enable_cpu_fp16 ? "_enabled_runtime.onnx" : "_disabled_runtime.onnx")); + ORT_THROW_IF_ERROR(Model::Save(*MakeCpuFp16Model(model_name, op_type, false), input_path)); + + SessionOptions so; + so.session_logid = model_name; + so.graph_optimization_level = TransformerLevel::Level4; + ORT_THROW_IF_ERROR(so.config_options.AddConfigEntry( + kOrtSessionOptionsEnableCpuFp16, enable_cpu_fp16 ? "1" : "0")); + ORT_THROW_IF_ERROR(so.config_options.AddConfigEntry( + kOrtSessionOptionsCpuFp16UseFp32FallbackHeuristic, enable_cpu_fp16 ? "0" : "1")); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ORT_THROW_IF_ERROR(session.Load(input_path)); + ORT_THROW_IF_ERROR(session.Initialize()); + + auto allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; + + NameMLValMap feeds; + OrtValue a; + CreateMLValue(allocator, {2, 2}, + std::vector{MLFloat16(1.0f), MLFloat16(2.0f), + MLFloat16(3.0f), MLFloat16(4.0f)}, + &a); + feeds.insert({"A", a}); + + OrtValue b; + CreateMLValue(allocator, {2, 2}, + std::vector{MLFloat16(5.0f), MLFloat16(6.0f), + MLFloat16(7.0f), MLFloat16(8.0f)}, + &b); + feeds.insert({"B", b}); + + if (op_type == "Gemm") { + OrtValue c; + CreateMLValue(allocator, {2, 2}, + std::vector{MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f)}, + &c); + feeds.insert({"C", c}); + } + + std::vector fetches; + RunOptions run_options; + ORT_THROW_IF_ERROR(session.Run(run_options, feeds, AsSpan({std::string("Y")}), &fetches)); + ORT_ENFORCE(fetches.size() == 1u, "Expected exactly one output for ", model_name); + + const auto& output = fetches[0].Get(); + std::vector result(output.Data(), + output.Data() + output.Shape().Size()); + + std::error_code ec; + std::filesystem::remove(std::filesystem::path{input_path}, ec); + + return result; +} + +TEST(TransformerTest, CpuFp16MatMulPreservesWhenEnabled) { + { + auto model = MakeCpuFp16Model("cpu_fp16_matmul_optin", "MatMul", true); + auto& graph = model->MainGraph(); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, false); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_TRUE(op_counts.find("Cast") == op_counts.cend()); + } +} + +TEST(TransformerTest, CpuFp16CpuAssignedMatMulHasNoCastsWhenEnabled) { + auto model = MakeCpuFp16Model("cpu_fp16_matmul_unassigned", "MatMul", false); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + node->SetExecutionProviderType(kCpuExecutionProvider); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, false); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_TRUE(op_counts.find("Cast") == op_counts.cend()); + + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16CpuAssignedGemmHasNoCastsWhenEnabled) { + auto model = MakeCpuFp16Model("cpu_fp16_gemm_unassigned", "Gemm", false); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + node->SetExecutionProviderType(kCpuExecutionProvider); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, false); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_TRUE(op_counts.find("Cast") == op_counts.cend()); + + EXPECT_EQ(node->OpType(), "Gemm"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16NonCpuAssignedMatMulIsUntouchedWhenEnabled) { + auto model = MakeCpuFp16Model("cpu_fp16_matmul_other_ep", "MatMul", false); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + node->SetExecutionProviderType("SomeOtherEP"); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, false); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_TRUE(op_counts.find("Cast") == op_counts.cend()); + + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), "SomeOtherEP"); +} + +TEST(TransformerTest, CpuFp16NonCpuAssignedGemmIsUntouchedWhenEnabled) { + auto model = MakeCpuFp16Model("cpu_fp16_gemm_other_ep", "Gemm", false); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + node->SetExecutionProviderType("SomeOtherEP"); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, false); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_TRUE(op_counts.find("Cast") == op_counts.cend()); + + EXPECT_EQ(node->OpType(), "Gemm"); + EXPECT_EQ(node->GetExecutionProviderType(), "SomeOtherEP"); +} + +TEST(TransformerTest, CpuFp16UnassignedMatMulKeepsFp16WhenEnabled) { + auto model = MakeCpuFp16Model("cpu_fp16_matmul_unassigned", "MatMul", false); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, false); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.count("Cast"), 0U); + + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulHeuristicForcesBertLikeShapeToFp32) { + auto model = MakeCpuFp16MatMulModelWithShapes("cpu_fp16_matmul_heuristic_bert_like", + {128, 512}, {512, 512}, {128, 512}, true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.at("Cast"), 3); + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulHeuristicKeepsLargeGemvNativeFp16) { + if (!CanRunNativeCpuFp16GemmRuntime()) { + GTEST_SKIP() << "Native CPU fp16 Gemm/MatMul runtime support is unavailable."; + } + + auto model = MakeCpuFp16MatMulModelWithShapes("cpu_fp16_matmul_heuristic_large_gemv", + {1, 4096}, {4096, 4096}, {1, 4096}, true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.count("Cast"), 0U); + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulHeuristicKeepsBatchedLargeGemvNativeFp16) { + if (!CanRunNativeCpuFp16GemmRuntime()) { + GTEST_SKIP() << "Native CPU fp16 Gemm/MatMul runtime support is unavailable."; + } + + auto model = MakeCpuFp16MatMulModelWithShapes("cpu_fp16_matmul_heuristic_batched_large_gemv", + {8, 1, 4096}, {8, 4096, 4096}, {8, 1, 4096}, true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.count("Cast"), 0U); + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulHeuristicForcesSingletonBroadcastRhsToFp32) { + auto model = MakeCpuFp16MatMulModelWithShapes("cpu_fp16_matmul_heuristic_singleton_broadcast_rhs", + {8, 1, 4096}, {1, 4096, 4096}, {8, 1, 4096}, true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.at("Cast"), 3); + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulHeuristicForcesFlattenedLargeGemvToFp32) { + auto model = MakeCpuFp16MatMulModelWithShapes("cpu_fp16_matmul_heuristic_flattened_large_gemv", + {8, 1, 4096}, {4096, 4096}, {8, 1, 4096}, true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.at("Cast"), 3); + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulHeuristicKeepsConstantRhsBertLikeShapeNativeFp16) { + if (MlasHalfGemmNativePackBSize(CblasNoTrans, CblasNoTrans, 1024, 512) == 0) { + GTEST_SKIP() << "No native packed-B fp16 MatMul backend is available."; + } + + auto model = MakeCpuFp16MatMulModelWithShapes("cpu_fp16_matmul_heuristic_constant_rhs_bert_like", + {128, 512}, {512, 1024}, {128, 1024}, true, true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.count("Cast"), 0U); + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulHeuristicForcesBatchedConstantRhsToFp32) { + if (MlasHalfGemmNativePackBSize(CblasNoTrans, CblasNoTrans, 1024, 512) == 0) { + GTEST_SKIP() << "No native packed-B fp16 MatMul backend is available."; + } + + auto model = MakeCpuFp16MatMulModelWithShapes("cpu_fp16_matmul_heuristic_batched_constant_rhs", + {2, 1, 512}, {2, 512, 1024}, {2, 1, 1024}, true, true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.at("Cast"), 3); + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulHeuristicForcesConstantRhsWhenNativePackedBUnavailable) { + auto model = MakeCpuFp16MatMulModelWithShapes("cpu_fp16_matmul_heuristic_constant_rhs_no_native_pack", + {128, 512}, {512, 1024}, {128, 1024}, true, true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config; + mlas_backend_kernel_selector_config.use_kleidiai = false; + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, true, &mlas_backend_kernel_selector_config); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.at("Cast"), 3); + EXPECT_EQ(node->OpType(), "MatMul"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulSessionConfigUsesFallbackHeuristicKey) { + const auto input_path = ToPathString("cpu_fp16_matmul_heuristic_mlas_config_runtime.onnx"); + ORT_THROW_IF_ERROR(Model::Save(*MakeCpuFp16MatMulModelWithShapes( + "cpu_fp16_matmul_heuristic_mlas_config_runtime", + {128, 512}, {512, 1024}, {128, 1024}, false, true), + input_path)); + + SessionOptions so; + so.session_logid = "cpu_fp16_matmul_heuristic_mlas_config_runtime"; + so.graph_optimization_level = TransformerLevel::Level4; + ORT_THROW_IF_ERROR(so.config_options.AddConfigEntry(kOrtSessionOptionsEnableCpuFp16, "1")); + ORT_THROW_IF_ERROR(so.config_options.AddConfigEntry(kOrtSessionOptionsCpuFp16UseFp32FallbackHeuristic, "1")); + ORT_THROW_IF_ERROR(so.config_options.AddConfigEntry(kOrtSessionOptionsMlasDisableKleidiAi, "1")); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(input_path)); + ASSERT_STATUS_OK(session.Initialize()); + + const auto op_counts = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_counts.at("Cast"), 2); + + const Node* matmul_node = FindNodeByOpType(session.GetGraph(), "MatMul"); + ASSERT_NE(matmul_node, nullptr); + + ASSERT_EQ(matmul_node->InputDefs().size(), 2U); + ASSERT_EQ(matmul_node->OutputDefs().size(), 1U); + EXPECT_TRUE(IsNodeArgType(*matmul_node->InputDefs()[0], DataTypeImpl::GetTensorType())); + EXPECT_TRUE(IsNodeArgType(*matmul_node->InputDefs()[1], DataTypeImpl::GetTensorType())); + EXPECT_TRUE(IsNodeArgType(*matmul_node->OutputDefs()[0], DataTypeImpl::GetTensorType())); + + std::error_code ec; + std::filesystem::remove(std::filesystem::path{input_path}, ec); +} + +TEST(TransformerTest, CpuFp16UnsupportedOpStillGetsCastsWhenEnabled) { + auto model = MakeCpuFp16Model("cpu_fp16_abs_unassigned", "Abs", false); + auto& graph = model->MainGraph(); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, false); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_TRUE(op_counts.find("Cast") != op_counts.cend()); + EXPECT_EQ(op_counts.at("Cast"), 2); + + const Node* abs_node = nullptr; + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "Abs") { + abs_node = &node; + break; + } + } + ASSERT_NE(abs_node, nullptr); + EXPECT_EQ(abs_node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16UnsupportedCpuAssignedOpStillGetsCasts) { + auto model = MakeCpuFp16Model("cpu_fp16_abs_cpu_assigned", "Abs", true); + auto& graph = model->MainGraph(); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + false, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.at("Cast"), 2); + + const Node* abs_node = FindNodeByOpType(graph, "Abs"); + ASSERT_NE(abs_node, nullptr); + EXPECT_EQ(abs_node->GetExecutionProviderType(), kCpuExecutionProvider); + ASSERT_EQ(abs_node->InputDefs().size(), 1U); + ASSERT_EQ(abs_node->OutputDefs().size(), 1U); + EXPECT_TRUE(IsNodeArgType(*abs_node->InputDefs()[0], DataTypeImpl::GetTensorType())); + EXPECT_TRUE(IsNodeArgType(*abs_node->OutputDefs()[0], DataTypeImpl::GetTensorType())); +} + +TEST(TransformerTest, CpuFp16CpuAssignedResizeWithoutFp16KernelUsesFp32Fallback) { + auto model = MakeCpuFp16ResizeModel("cpu_fp16_resize_cpu_assigned", true); + auto& graph = model->MainGraph(); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, false); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.at("Cast"), 2); + + const Node* resize_node = FindNodeByOpType(graph, "Resize"); + ASSERT_NE(resize_node, nullptr); + EXPECT_EQ(resize_node->SinceVersion(), 13); + EXPECT_EQ(resize_node->GetExecutionProviderType(), kCpuExecutionProvider); + ASSERT_EQ(resize_node->InputDefs().size(), 3U); + ASSERT_EQ(resize_node->OutputDefs().size(), 1U); + EXPECT_TRUE(IsNodeArgType(*resize_node->InputDefs()[0], DataTypeImpl::GetTensorType())); + EXPECT_TRUE(IsNodeArgType(*resize_node->InputDefs()[1], DataTypeImpl::GetTensorType())); + EXPECT_TRUE(IsNodeArgType(*resize_node->InputDefs()[2], DataTypeImpl::GetTensorType())); + EXPECT_TRUE(IsNodeArgType(*resize_node->OutputDefs()[0], DataTypeImpl::GetTensorType())); +} + +TEST(TransformerTest, CpuFp16SupportedCpuOpKeepsFp16WhenEnabled) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP() << "CPU fp16 kernels are not registered on this platform."; + } + + auto model = MakeCpuFp16Model("cpu_fp16_add_unassigned", "Add", false); + auto& graph = model->MainGraph(); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + true, false); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.count("Cast"), 0U); + + const Node* add_node = nullptr; + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "Add") { + add_node = &node; + break; + } + } + ASSERT_NE(add_node, nullptr); + EXPECT_EQ(add_node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16CpuAssignedNewOptInOpsUseFp32FallbackWhenDisabled) { + const std::vector> test_cases{ + {"MatMul", 3}, + {"Gemm", 4}, + }; + + for (const auto& [op_type, expected_cast_count] : test_cases) { + auto model = MakeCpuFp16Model("cpu_fp16_" + op_type + "_disabled_cpu_assigned", op_type, true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + false, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.at("Cast"), expected_cast_count); + EXPECT_EQ(node->OpType(), op_type); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); + } +} + +TEST(TransformerTest, CpuFp16CpuAssignedExistingFp16OpHasNoExtraCastsWhenDisabled) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP() << "CPU fp16 kernels are not registered on this platform."; + } + + auto model = MakeCpuFp16Model("cpu_fp16_add_disabled_cpu_assigned", "Add", true); + auto& graph = model->MainGraph(); + auto* node = graph.Nodes().begin().operator->(); + ASSERT_NE(node, nullptr); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get(), + false, true); + bool modified = false; + EXPECT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); + EXPECT_STATUS_OK(graph.Resolve()); + + const auto op_counts = CountOpsInGraph(graph); + EXPECT_EQ(op_counts.count("Cast"), 0U); + EXPECT_EQ(node->OpType(), "Add"); + EXPECT_EQ(node->GetExecutionProviderType(), kCpuExecutionProvider); +} + +TEST(TransformerTest, CpuFp16MatMulRuntimeRunsWhenEnabled) { + if (!CanRunNativeCpuFp16GemmRuntime()) { + GTEST_SKIP() << "Native CPU fp16 MatMul runtime is not available on this platform."; + } + + const auto output = RunCpuFp16Model("cpu_fp16_matmul_runtime", "MatMul", true); + const std::vector expected{ + MLFloat16(19.0f), MLFloat16(22.0f), + MLFloat16(43.0f), MLFloat16(50.0f)}; + ExpectMlFloat16Output(output, expected); +} + +TEST(TransformerTest, CpuFp16GemmRuntimeRunsWhenEnabled) { + if (!CanRunNativeCpuFp16GemmRuntime()) { + GTEST_SKIP() << "Native CPU fp16 Gemm runtime is not available on this platform."; + } + + const auto output = RunCpuFp16Model("cpu_fp16_gemm_runtime", "Gemm", true); + const std::vector expected{ + MLFloat16(20.0f), MLFloat16(23.0f), + MLFloat16(44.0f), MLFloat16(51.0f)}; + ExpectMlFloat16Output(output, expected); +} + +TEST(TransformerTest, CpuFp16MatMulRuntimeRunsWhenDisabled) { + const auto output = RunCpuFp16Model("cpu_fp16_matmul_runtime", "MatMul", false); + const std::vector expected{ + MLFloat16(19.0f), MLFloat16(22.0f), + MLFloat16(43.0f), MLFloat16(50.0f)}; + ExpectMlFloat16Output(output, expected); +} + +TEST(TransformerTest, CpuFp16GemmRuntimeRunsWhenDisabled) { + const auto output = RunCpuFp16Model("cpu_fp16_gemm_runtime", "Gemm", false); + const std::vector expected{ + MLFloat16(20.0f), MLFloat16(23.0f), + MLFloat16(44.0f), MLFloat16(51.0f)}; + ExpectMlFloat16Output(output, expected); +} + +TEST(TransformerTest, CpuFp16MixedEpMatMulStaysOnNonCpuEpWhenEnabled) { + const auto input_path = ToPathString("cpu_fp16_matmul_mixed_ep_runtime.onnx"); + ORT_THROW_IF_ERROR(Model::Save(*MakeCpuFp16Model("cpu_fp16_matmul_mixed_ep_runtime", "MatMul", false), input_path)); + + SessionOptions so; + so.session_logid = "cpu_fp16_matmul_mixed_ep_runtime"; + so.graph_optimization_level = TransformerLevel::Level4; + ORT_THROW_IF_ERROR(so.config_options.AddConfigEntry(kOrtSessionOptionsEnableCpuFp16, "1")); + ORT_THROW_IF_ERROR(so.config_options.AddConfigEntry(kOrtSessionOptionsCpuFp16UseFp32FallbackHeuristic, "0")); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.RegisterExecutionProvider( + std::make_unique(std::unordered_set{"MatMul"}))); + ASSERT_STATUS_OK(session.Load(input_path)); + ASSERT_STATUS_OK(session.Initialize()); + + const auto& graph = session.GetGraph(); + size_t internal_testing_nodes = 0; + size_t cpu_nodes = 0; + size_t cast_nodes = 0; + for (const auto& node : graph.Nodes()) { + if (node.GetExecutionProviderType() == internal_testing_ep::kInternalTestingExecutionProvider) { + ++internal_testing_nodes; + } + if (node.GetExecutionProviderType() == kCpuExecutionProvider) { + ++cpu_nodes; + } + if (node.OpType() == "Cast") { + ++cast_nodes; + } + } + + EXPECT_EQ(internal_testing_nodes, 1u); + EXPECT_EQ(cpu_nodes, 0u); + EXPECT_EQ(cast_nodes, 0u); + + std::error_code ec; + std::filesystem::remove(std::filesystem::path{input_path}, ec); +} + TEST(TransformerTest, IsIsolatedFp16NodeOnCpuTest) { auto model = std::make_shared("test", false, DefaultLoggingManager().DefaultLogger()); onnxruntime::Graph& graph = model->MainGraph(); @@ -323,21 +1110,22 @@ TEST(TransformerTest, IsIsolatedFp16NodeOnCpuTest) { o4_def("O4", &tensor_float_16), o5_def("O5", &tensor_float_16); - // for the sake of this example, pretend Clip has no fp16 kernel but Abs does - // -> Clip -> Abs -> Clip -> Abs -> Clip -> + // For the sake of this example, Clip requires fp32 fallback while Identity + // can run fp16 on CPU. + // -> Clip -> Identity -> Clip -> Identity -> Clip -> // | | // - O4 - O5 auto& node1 = graph.AddNode("node1", "Clip", "no fp16", {&i1_def}, {&o1_def}); - auto& node2 = graph.AddNode("node2", "Abs", "fp16", {&o1_def}, {&o2_def}); + auto& node2 = graph.AddNode("node2", "Identity", "fp16", {&o1_def}, {&o2_def}); auto& node3 = graph.AddNode("node3", "Clip", "no fp16", {&o2_def}, {&o3_def}); - auto& node4 = graph.AddNode("node4", "Abs", "fp16 producing graph output", {&o3_def}, {&o4_def}); + auto& node4 = graph.AddNode("node4", "Identity", "fp16 producing graph output", {&o3_def}, {&o4_def}); auto& node5 = graph.AddNode("node5", "Clip", "no fp16", {&o4_def}, {&o5_def}); // manually set outputs as we want O4 and well as O5 to be graph outputs. // AddNode creates a NodeArg instance in Graph so need to get address from the node graph.SetOutputs({node4.OutputDefs()[0], node5.OutputDefs()[0]}); - // node2 and node4 have a kernel + // node2 and node4 are pre-assigned to CPU. node2.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); node4.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); @@ -355,12 +1143,8 @@ TEST(TransformerTest, IsIsolatedFp16NodeOnCpuTest) { }; // we expect: - // node2 Abs to get forced to fp32 as it's isolated between node1 and node3 which need Casts - // node4 Abs should not get forced to fp32 as it produces a graph output - // - // -> CastFp32 -> Clip -> Abs -> Clip -> CastFp16 -> Abs -> CastFp32 -> Clip -> CastFp16 - // | | - // - O4 - O5 + // node2 Identity to get forced to fp32 as it's isolated between node1 and node3 which need Casts. + // node4 Identity to stay fp16 as it produces graph output O4. EXPECT_TRUE(is_type(*node1.InputDefs()[0], DataTypeImpl::GetTensorType())); EXPECT_TRUE(is_type(*node2.InputDefs()[0], DataTypeImpl::GetTensorType())); EXPECT_TRUE(is_type(*node3.InputDefs()[0], DataTypeImpl::GetTensorType())); @@ -371,6 +1155,31 @@ TEST(TransformerTest, IsIsolatedFp16NodeOnCpuTest) { EXPECT_EQ(ops["Cast"], 4); } +TEST(TransformerTest, CpuAssignedFp16FallbackPreservesGraphOutputType) { + auto model = MakeCpuFp16Model("cpu_fp16_abs_cpu_assigned_graph_output", "Abs", true); + auto& graph = model->MainGraph(); + + InsertCastTransformer transformer("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); + + bool modified = true; + EXPECT_TRUE(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()).IsOK()); + + auto is_type = [](const NodeArg& node_arg, const MLDataType type) { + return node_arg.Type() != nullptr && + DataTypeImpl::TypeFromProto(*node_arg.TypeAsProto()) == type; + }; + + const Node* abs_node = FindNodeByOpType(graph, "Abs"); + ASSERT_NE(abs_node, nullptr); + EXPECT_TRUE(is_type(*abs_node->InputDefs()[0], DataTypeImpl::GetTensorType())); + EXPECT_TRUE(is_type(*abs_node->OutputDefs()[0], DataTypeImpl::GetTensorType())); + ASSERT_EQ(graph.GetOutputs().size(), 1U); + EXPECT_TRUE(is_type(*graph.GetOutputs()[0], DataTypeImpl::GetTensorType())); + + auto ops = CountOpsInGraph(graph); + EXPECT_EQ(ops["Cast"], 2); +} + // Verify that RemoveDuplicateCastTransformer does not fuse Cast(float->int32)->Cast(int32->bool) // because the intermediate int32 truncation changes semantics (e.g. -0.1 -> 0 -> false vs -0.1 -> true). // Regression test for https://github.com/microsoft/onnxruntime/issues/28089 diff --git a/onnxruntime/test/mlas/unittest/test_conv2d.cpp b/onnxruntime/test/mlas/unittest/test_conv2d.cpp index 091d4ee833f8f..98bf05c8bd476 100644 --- a/onnxruntime/test/mlas/unittest/test_conv2d.cpp +++ b/onnxruntime/test/mlas/unittest/test_conv2d.cpp @@ -4,6 +4,64 @@ #include "test_conv2d.h" #include "test_conv2d_fixture.h" +TEST(Conv2d_HalfConv, PrepareRespectsBackendSelectorConfig) { + const int64_t input_shape[] = {5, 5}; + const int64_t kernel_shape[] = {3, 3}; + const int64_t dilation_shape[] = {1, 1}; + const int64_t padding[] = {1, 1, 1, 1}; + const int64_t stride_shape[] = {1, 1}; + const int64_t output_shape[] = {5, 5}; + + MLAS_ACTIVATION activation; + activation.ActivationKind = MlasIdentityActivation; + + MLAS_CONV_PARAMETERS parameters{}; + size_t working_buffer_size = 0; + if (!MlasHalfConvPrepare(¶meters, + 2, + 1, + 1, + 4, + input_shape, + kernel_shape, + dilation_shape, + padding, + stride_shape, + output_shape, + 8, + &activation, + &working_buffer_size, + 0.0f, + false, + nullptr, + nullptr)) { + GTEST_SKIP() << "HalfConv prepare path unavailable"; + } + + MLAS_BACKEND_KERNEL_SELECTOR_CONFIG selector_config; + selector_config.use_kleidiai = false; + + MLAS_CONV_PARAMETERS disabled_parameters{}; + EXPECT_FALSE(MlasHalfConvPrepare(&disabled_parameters, + 2, + 1, + 1, + 4, + input_shape, + kernel_shape, + dilation_shape, + padding, + stride_shape, + output_shape, + 8, + &activation, + &working_buffer_size, + 0.0f, + false, + nullptr, + &selector_config)); +} + static size_t Conv2dRegistLongExecute() { size_t count = MlasLongExecuteTests>::RegisterLongExecute(); if (GetMlasThreadPool() != nullptr) { diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp index aafdcc14c0028..d350a76f4a399 100644 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp @@ -15,6 +15,273 @@ Module Name: --*/ #include "test_halfgemm.h" +#include "core/mlas/lib/halfgemm.h" +#if defined(USE_KLEIDIAI) +#include "core/mlas/lib/mlasi.h" +#include "core/mlas/lib/kleidiai/mlasi_kleidiai.h" +#endif + +#include +#include +#include +#include + +namespace { + +struct HalfGemmPackBPaddingKernel { + static constexpr size_t PackedK = 4; +}; + +} // namespace + +#if defined(USE_KLEIDIAI) +namespace { + +bool g_test_halfgemm_override_called = false; + +bool MLASCALL +TestHalfGemmBatchOverride( + size_t M, + size_t N, + size_t, + size_t BatchN, + const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL*, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG*) { + g_test_halfgemm_override_called = true; + + for (size_t batch = 0; batch < BatchN; ++batch) { + auto* c = DataParams[batch].C; + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + c[m * DataParams[batch].ldc + n] = onnxruntime::MLFloat16(1.0f); + } + } + } + + return true; +} + +void ReferenceHalfGemm( + size_t M, + size_t N, + size_t K, + const MLFp16* A, + const MLFp16* B, + MLFp16* C) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float sum = 0.0f; + for (size_t k = 0; k < K; ++k) { + sum += float(A[m * K + k]) * float(B[k * N + n]); + } + C[m * N + n] = MLFp16(sum); + } + } +} + +struct HalfGemmOverrideGuard { + explicit HalfGemmOverrideGuard(MLAS_HALF_GEMM_BATCH_OVERRIDE* replacement) + : original_(GetMlasPlatform().MlasHalfGemmBatchOverride) { + GetMlasPlatform().MlasHalfGemmBatchOverride = replacement; + } + + ~HalfGemmOverrideGuard() { + GetMlasPlatform().MlasHalfGemmBatchOverride = original_; + } + + private: + MLAS_HALF_GEMM_BATCH_OVERRIDE* original_; +}; + +} // namespace + +TEST(HalfGemmKleidiAISelector, DisableKleidiAIBypassesOverride) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP() << "HalfGemm FP16 acceleration not available."; + } + + constexpr size_t M = 5; + constexpr size_t N = 7; + constexpr size_t K = 9; + + std::vector A(M * K); + std::vector B(K * N); + std::vector C(M * N, MLFp16(0.0f)); + std::vector CReference(M * N, MLFp16(0.0f)); + + SmallFloatFill(A.data(), A.size()); + SmallFloatFill(B.data(), B.size()); + + MLAS_HALF_GEMM_DATA_PARAMS data{}; + data.A = A.data(); + data.B = B.data(); + data.C = reinterpret_cast(C.data()); + data.lda = K; + data.ldb = N; + data.ldc = N; + + HalfGemmOverrideGuard guard(TestHalfGemmBatchOverride); + + g_test_halfgemm_override_called = false; + MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr); + ASSERT_TRUE(g_test_halfgemm_override_called); + for (const auto& value : C) { + ASSERT_EQ(float(value), 1.0f); + } + + std::fill(C.begin(), C.end(), MLFp16(0.0f)); + ReferenceHalfGemm(M, N, K, A.data(), B.data(), CReference.data()); + + MLAS_BACKEND_KERNEL_SELECTOR_CONFIG selector_config; + selector_config.use_kleidiai = false; + + g_test_halfgemm_override_called = false; + MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, &selector_config); + ASSERT_FALSE(g_test_halfgemm_override_called); + + for (size_t i = 0; i < C.size(); ++i) { + ASSERT_TRUE(CloseEnough(float(C[i]), float(CReference[i]))) << "index=" << i; + } +} +#endif + +namespace { + +struct HalfGemmCase { + const char* name; + size_t M; + size_t N; + size_t K; + size_t Batch; + bool has_bias; +}; + +template +void RunHalfGemmCases(const HalfGemmCase* test_cases, size_t num_cases, Runner run_case) { + for (size_t i = 0; i < num_cases; ++i) { + const auto& test_case = test_cases[i]; + SCOPED_TRACE(testing::Message() + << test_case.name + << " Batch=" << test_case.Batch + << " M=" << test_case.M + << " N=" << test_case.N + << " K=" << test_case.K + << " hasBias=" << test_case.has_bias); + run_case(test_case); + } +} + +void ReferenceHalfGemmPackedCompatibility( + size_t M, + size_t N, + size_t K, + const MLFp16* A, + const MLFp16* B, + MLFp16* C) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float sum = 0.0f; + for (size_t k = 0; k < K; ++k) { + MLFp16 down(float(A[m * K + k]) * float(B[k * N + n]) + sum); + sum = float(down); + } + C[m * N + n] = MLFp16(sum); + } + } +} + +constexpr HalfGemmCase kNativeFp16Cases[] = { + {"WideNBatch1WithBias", 43, 500, 401, 1, true}, + {"WideNBatch3NoBias", 43, 500, 401, 3, false}, + {"VectorLikeBatch3WithBias", 1, 32, 79, 3, true}, + {"RectangularBatch1WithBias", 64, 48, 80, 1, true}, + {"RectangularBatch1NoBias", 64, 48, 80, 1, false}, +}; + +template +void RunNativeFp16WithoutOutputProcessorCases(const HalfGemmCase* test_cases, size_t num_cases) { + RunHalfGemmCases(test_cases, num_cases, [](const HalfGemmCase& test_case) { + MlasHalfGemmTest test; + test.TestNativeFp16WithoutOutputProcessor( + test_case.M, test_case.N, test_case.K, test_case.Batch, test_case.has_bias); + }); +} + +constexpr HalfGemmCase kKleidiAIPathNonPackedCases[] = { + {"WideNBatch3WithBias", 43, 500, 401, 3, true}, + {"WideNBatch3NoBias", 43, 500, 401, 3, false}, + {"RectangularBatch1WithBias", 64, 48, 80, 1, true}, + {"RectangularBatch1NoBias", 64, 48, 80, 1, false}, +}; + +constexpr HalfGemmCase kKleidiAIPathPackedCases[] = { + {"WideNBatch1WithBias", 43, 500, 401, 1, true}, + {"WideNBatch1NoBias", 43, 500, 401, 1, false}, + {"RectangularBatch1WithBias", 64, 48, 80, 1, true}, + {"RectangularBatch1NoBias", 64, 48, 80, 1, false}, +}; + +template +void RunKleidiAIWithoutOutputProcessorCases(const HalfGemmCase* test_cases, size_t num_cases) { + RunHalfGemmCases(test_cases, num_cases, [](const HalfGemmCase& test_case) { + MlasHalfGemmTest test; + test.TestKleidiAIWithoutOutputProcessor( + test_case.M, test_case.N, test_case.K, test_case.Batch, test_case.has_bias); + }); +} + +} // namespace + +TEST(HalfGemm, ZeroKInitializesBiasAndRunsOutputProcessor) { + constexpr size_t M = 3; + constexpr size_t N = 4; + constexpr size_t K = 0; + + std::vector Bias{MLFp16(1.0f), MLFp16(-2.0f), MLFp16(3.5f), MLFp16(0.25f)}; + std::vector C(M * N, MLFp16(-9.0f)); + std::vector CFloat(M * N, -9.0f); + + MLAS_ACTIVATION act; + act.ActivationKind = MlasIdentityActivation; + MLAS_HALF_GEMM_2FLOAT_PROCESSOR output_processor(act, CFloat.data(), N); + + MLAS_HALF_GEMM_DATA_PARAMS data{}; + data.Bias = reinterpret_cast(Bias.data()); + data.C = reinterpret_cast(C.data()); + data.ldc = N; + data.AIsfp32 = true; + data.OutputProcessor = &output_processor; + + MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr); + + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + const size_t index = m * N + n; + ASSERT_EQ(float(C[index]), float(Bias[n])) << "index=" << index; + ASSERT_EQ(CFloat[index], float(Bias[n])) << "index=" << index; + } + } +} + +TEST(HalfGemm, ZeroKInitializesZeroWithoutBias) { + constexpr size_t M = 3; + constexpr size_t N = 4; + constexpr size_t K = 0; + + std::vector C(M * N, MLFp16(-9.0f)); + + MLAS_HALF_GEMM_DATA_PARAMS data{}; + data.C = reinterpret_cast(C.data()); + data.ldc = N; + data.BIsfp32 = true; + + MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr); + + for (const auto& value : C) { + ASSERT_EQ(float(value), 0.0f); + } +} // // Short Execute() test helper to register each test separately by all parameters. @@ -166,3 +433,394 @@ static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_exe } return HalfGemmRegistLongExecute() > 0; }); + +TEST(HalfGemmPackB, ReturnsZeroOnOverflow) { + const size_t max = (std::numeric_limits::max)(); + EXPECT_EQ(MlasHalfGemmPackBSize(1, max, true), size_t{0}); + EXPECT_EQ(MlasHalfGemmPackBSize(max, 2, true), size_t{0}); +} + +TEST(HalfGemmPackB, CopyPackBZeroPadsAlignedKTail) { + constexpr size_t N = 5; + constexpr size_t K = 3; + constexpr size_t AlignedK = 4; + + std::vector<_mlas_fp16_> b(K * N); + for (size_t i = 0; i < b.size(); ++i) { + b[i] = static_cast<_mlas_fp16_>(i + 1); + } + + constexpr _mlas_fp16_ stale_tail_value = 0xFFFF; + std::vector<_mlas_fp16_> packed(AlignedK * N, stale_tail_value); + + MlasHalfGemmCopyPackB( + packed.data(), + b.data(), + N, + N, + K); + + for (size_t n = 0; n < N; ++n) { + EXPECT_EQ(packed[K * N + n], 0) << "n=" << n; + } +} + +TEST(HalfGemmPackB, GenericPackedBFlagRunsOnFallback) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + constexpr size_t M = 5; + constexpr size_t N = 7; + constexpr size_t K = 9; + + std::vector A(M * K); + std::vector B(K * N); + std::vector C(M * N, MLFp16(0.0f)); + std::vector CReference(M * N, MLFp16(0.0f)); + + SmallFloatFill(A.data(), A.size()); + SmallFloatFill(B.data(), B.size()); + + const size_t packed_b_size = MlasHalfGemmPackBSize(N, K, false); + if (packed_b_size == 0) { + GTEST_SKIP(); + } + + std::vector packed_b(packed_b_size); + MlasHalfGemmPackB(N, K, reinterpret_cast(B.data()), N, packed_b.data()); + + MLAS_HALF_GEMM_DATA_PARAMS data{}; + data.A = A.data(); + data.B = packed_b.data(); + data.C = reinterpret_cast(C.data()); + data.lda = K; + data.ldb = 0; + data.ldc = N; + data.BIsPacked = true; + + MLAS_BACKEND_KERNEL_SELECTOR_CONFIG selector_config{}; + selector_config.use_kleidiai = false; + + MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, &selector_config); + ReferenceHalfGemmPackedCompatibility(M, N, K, A.data(), B.data(), CReference.data()); + + for (size_t i = 0; i < C.size(); ++i) { + ASSERT_TRUE(CloseEnough(float(C[i]), float(CReference[i]))) << "index=" << i; + } +} + +TEST(HalfGemmKleidiAINativeFp16, NoPackSingleThreadWithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + MlasHalfGemmTest test; + test.TestNativeFp16WithoutOutputProcessor(43, 500, 401, 1, true); +} + +TEST(HalfGemmKleidiAINativeFp16, NoPackSingleThreadWithoutOutputProcessorBatch3) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + MlasHalfGemmTest test; + test.TestNativeFp16WithoutOutputProcessor(43, 500, 401, 3, true); +} + +TEST(HalfGemmKleidiAINativeFp16, NoPackThreadedWithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + if (GetMlasThreadPool() == nullptr) { + GTEST_SKIP(); + } + + MlasHalfGemmTest test; + test.TestNativeFp16WithoutOutputProcessor(43, 500, 401, 1, true); +} + +TEST(HalfGemmKleidiAINativeFp16, NoPackThreadedWithoutOutputProcessorBatch3) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + if (GetMlasThreadPool() == nullptr) { + GTEST_SKIP(); + } + + MlasHalfGemmTest test; + test.TestNativeFp16WithoutOutputProcessor(43, 500, 401, 3, true); +} + +TEST(HalfGemmKleidiAINativeFp16, NoPackSingleThreadWithoutOutputProcessorVariedShapesAndBias) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + RunNativeFp16WithoutOutputProcessorCases(kNativeFp16Cases, std::size(kNativeFp16Cases)); +} + +TEST(HalfGemmKleidiAINativeFp16, NoPackThreadedWithoutOutputProcessorVariedShapesAndBias) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + if (GetMlasThreadPool() == nullptr) { + GTEST_SKIP(); + } + + RunNativeFp16WithoutOutputProcessorCases(kNativeFp16Cases, std::size(kNativeFp16Cases)); +} + +TEST(HalfGemmKleidiAIPath, Fp32AConversionSingleThreadBatch3WithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + MlasHalfGemmTest test; + test.TestKleidiAIWithoutOutputProcessor(43, 500, 401, 3, true); +} + +TEST(HalfGemmKleidiAIPath, Fp32AConversionSingleThreadVariedShapesAndBiasWithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + RunKleidiAIWithoutOutputProcessorCases( + kKleidiAIPathNonPackedCases, std::size(kKleidiAIPathNonPackedCases)); +} + +TEST(HalfGemmKleidiAIPath, Fp32BConversionSingleThreadBatch3WithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + MlasHalfGemmTest test; + test.TestKleidiAIWithoutOutputProcessor(43, 500, 401, 3, true); +} + +TEST(HalfGemmKleidiAIPath, Fp32BConversionSingleThreadVariedShapesAndBiasWithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + RunKleidiAIWithoutOutputProcessorCases( + kKleidiAIPathNonPackedCases, std::size(kKleidiAIPathNonPackedCases)); +} + +TEST(HalfGemmKleidiAIPath, Fp32ABConversionSingleThreadBatch3WithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + MlasHalfGemmTest test; + test.TestKleidiAIWithoutOutputProcessor(43, 500, 401, 3, true); +} + +TEST(HalfGemmKleidiAIPath, Fp32ABConversionSingleThreadVariedShapesAndBiasWithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + + RunKleidiAIWithoutOutputProcessorCases( + kKleidiAIPathNonPackedCases, std::size(kKleidiAIPathNonPackedCases)); +} + +TEST(HalfGemmKleidiAIPath, PackedBFp16SingleThreadWithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + if (MlasHalfGemmPackBSize(128, 128, false) == 0) { + GTEST_SKIP(); + } + + MlasHalfGemmTest test; + test.TestKleidiAIWithoutOutputProcessor(43, 500, 401, 1, true); +} + +TEST(HalfGemmKleidiAIPath, PackedBFp16SingleThreadVariedShapesAndBiasWithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + if (MlasHalfGemmPackBSize(128, 128, false) == 0) { + GTEST_SKIP(); + } + + RunKleidiAIWithoutOutputProcessorCases( + kKleidiAIPathPackedCases, std::size(kKleidiAIPathPackedCases)); +} + +TEST(HalfGemmKleidiAIPath, PackedBFloatSingleThreadWithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + if (MlasHalfGemmPackBSize(128, 128, true) == 0) { + GTEST_SKIP(); + } + + MlasHalfGemmTest test; + test.TestKleidiAIWithoutOutputProcessor(43, 500, 401, 1, true); +} + +TEST(HalfGemmKleidiAIPath, PackedBFloatSingleThreadVariedShapesAndBiasWithoutOutputProcessor) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + if (MlasHalfGemmPackBSize(128, 128, true) == 0) { + GTEST_SKIP(); + } + + RunKleidiAIWithoutOutputProcessorCases( + kKleidiAIPathPackedCases, std::size(kKleidiAIPathPackedCases)); +} + +#if defined(USE_KLEIDIAI) +// KleidiAI-specific packed-B uses a separate direct-consumption contract from +// generic halfgemm PackB. Unsupported combinations fail at the public API +// boundary because generic MLAS cannot consume this backend-native layout. +TEST(HalfGemmKleidiAIPath, KleidiAIPackedBWithBiasThrows) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + if (GetMlasPlatform().MlasHalfGemmBatchOverride == nullptr) { + GTEST_SKIP() << "KleidiAI halfgemm override unavailable"; + } + + constexpr size_t M = 5; + constexpr size_t N = 7; + constexpr size_t K = 9; + + std::vector A(M * K); + std::vector B(K * N); + std::vector Bias(N); + std::vector C(M * N, MLFp16(0.0f)); + + SmallFloatFill(A.data(), A.size()); + SmallFloatFill(B.data(), B.size()); + SmallFloatFill(Bias.data(), Bias.size()); + + const size_t packed_b_size = ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize(CblasNoTrans, CblasNoTrans, N, K); + ASSERT_NE(packed_b_size, size_t{0}); + + std::vector packed_b(packed_b_size); + ASSERT_TRUE(ArmKleidiAI::MlasHalfGemmKleidiAIPackB( + CblasNoTrans, CblasNoTrans, N, K, reinterpret_cast(B.data()), N, packed_b.data())); + + MLAS_HALF_GEMM_DATA_PARAMS data{}; + data.A = A.data(); + data.B = packed_b.data(); + data.Bias = reinterpret_cast(Bias.data()); + data.C = reinterpret_cast(C.data()); + data.lda = K; + data.ldb = 0; + data.ldc = N; + data.BIsBackendNativePacked = true; + + ASSERT_FALSE(ArmKleidiAI::MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr)); +#if !defined(ORT_NO_EXCEPTIONS) + EXPECT_THROW(MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr), std::runtime_error); +#endif +} + +TEST(HalfGemmKleidiAIPath, KleidiAIPackedBWithOutputProcessorThrows) { + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP(); + } + if (GetMlasPlatform().MlasHalfGemmBatchOverride == nullptr) { + GTEST_SKIP() << "KleidiAI halfgemm override unavailable"; + } + + constexpr size_t M = 5; + constexpr size_t N = 7; + constexpr size_t K = 9; + + std::vector A(M * K); + std::vector B(K * N); + std::vector C(M * N, MLFp16(0.0f)); + std::vector CFloat(M * N, 0.0f); + + SmallFloatFill(A.data(), A.size()); + SmallFloatFill(B.data(), B.size()); + + const size_t packed_b_size = ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize(CblasNoTrans, CblasNoTrans, N, K); + ASSERT_NE(packed_b_size, size_t{0}); + + std::vector packed_b(packed_b_size); + ASSERT_TRUE(ArmKleidiAI::MlasHalfGemmKleidiAIPackB( + CblasNoTrans, CblasNoTrans, N, K, reinterpret_cast(B.data()), N, packed_b.data())); + + MLAS_ACTIVATION act; + act.ActivationKind = MlasIdentityActivation; + MLAS_HALF_GEMM_2FLOAT_PROCESSOR output_processor(act, CFloat.data(), N); + + MLAS_HALF_GEMM_DATA_PARAMS data{}; + data.A = A.data(); + data.B = packed_b.data(); + data.C = reinterpret_cast(C.data()); + data.lda = K; + data.ldb = 0; + data.ldc = N; + data.BIsBackendNativePacked = true; + data.OutputProcessor = &output_processor; + + ASSERT_FALSE(ArmKleidiAI::MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr)); +#if !defined(ORT_NO_EXCEPTIONS) + EXPECT_THROW(MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr), std::runtime_error); +#endif +} + +TEST(HalfGemmKleidiAIPath, ZeroKFallsBack) { + if (GetMlasPlatform().MlasHalfGemmBatchOverride == nullptr) { + GTEST_SKIP() << "KleidiAI halfgemm override unavailable"; + } + + constexpr size_t M = 5; + constexpr size_t N = 7; + constexpr size_t K = 0; + + std::vector Bias(N); + std::vector C(M * N, MLFp16(1.0f)); + + SmallFloatFill(Bias.data(), Bias.size()); + + MLAS_HALF_GEMM_DATA_PARAMS data{}; + data.Bias = reinterpret_cast(Bias.data()); + data.C = reinterpret_cast(C.data()); + data.ldc = N; + + const bool handled = ArmKleidiAI::MlasHalfGemmBatch(M, N, K, 1, &data, nullptr, nullptr); + ASSERT_FALSE(handled); +} + +TEST(HalfGemmKleidiAIPath, KleidiAIPackedBSizeRejectsUnsupportedTranspose) { + if (GetMlasPlatform().MlasHalfGemmBatchOverride == nullptr) { + GTEST_SKIP() << "KleidiAI halfgemm override unavailable"; + } + + constexpr size_t N = 7; + constexpr size_t K = 9; + + EXPECT_EQ(ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize(CblasTrans, CblasNoTrans, N, K), size_t{0}); + EXPECT_EQ(ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize(CblasNoTrans, CblasTrans, N, K), size_t{0}); +} + +TEST(HalfGemmKleidiAIPath, KleidiAIPackedBRejectsUnsupportedTranspose) { + if (GetMlasPlatform().MlasHalfGemmBatchOverride == nullptr) { + GTEST_SKIP() << "KleidiAI halfgemm override unavailable"; + } + + constexpr size_t N = 7; + constexpr size_t K = 9; + + std::vector B(K * N); + SmallFloatFill(B.data(), B.size()); + + const size_t packed_b_size = ArmKleidiAI::MlasHalfGemmKleidiAIPackBSize(CblasNoTrans, CblasNoTrans, N, K); + ASSERT_NE(packed_b_size, size_t{0}); + + std::vector packed_b(packed_b_size); + EXPECT_FALSE(ArmKleidiAI::MlasHalfGemmKleidiAIPackB( + CblasTrans, CblasNoTrans, N, K, reinterpret_cast(B.data()), N, packed_b.data())); + EXPECT_FALSE(ArmKleidiAI::MlasHalfGemmKleidiAIPackB( + CblasNoTrans, CblasTrans, N, K, reinterpret_cast(B.data()), N, packed_b.data())); +} +#endif diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.h b/onnxruntime/test/mlas/unittest/test_halfgemm.h index 4db5c2bebca40..c68697aa74b3c 100644 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.h +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.h @@ -17,6 +17,9 @@ Module Name: #pragma once #include "test_fp16.h" +#include "core/mlas/lib/mlasi.h" +#include +#include /** * @brief Test class for half precision GEMM @@ -26,6 +29,21 @@ Module Name: template class MlasHalfGemmTest : public MlasTestBase { private: + // Native FP16 is validated against the FP32 reference path + // rather than the existing stepwise-FP16 oracle, so these backend-specific + // tests use a separate tolerance. + static bool CloseEnoughNativeFp16(float got, float ref) { + constexpr float abs_tol = 0.03125f; + constexpr float rel_tol = 0.005f; + + const float diff = std::fabs(got - ref); + if (diff <= abs_tol) { + return true; + } + + return diff <= rel_tol * std::max(std::fabs(got), std::fabs(ref)); + } + MatrixGuardBuffer BufferBPacked; MatrixGuardBuffer BufferA; MatrixGuardBuffer BufferB; @@ -60,7 +78,9 @@ class MlasHalfGemmTest : public MlasTestBase { const MLFp16* Bias, MLFp16* C, size_t ldc, - float* Cfloat) { + float* Cfloat, + bool use_output_processor = true, + bool enforce_kleidiai_override = false) { MLAS_ACTIVATION act; act.ActivationKind = MlasIdentityActivation; std::vector Converters; @@ -84,17 +104,29 @@ class MlasHalfGemmTest : public MlasTestBase { ASSERT_EQ(BatchSize, size_t(1)) << "Packing B not supported in batching yet!"; params.B = PackB(N, K, B, ldb); params.ldb = 0; + params.BIsPacked = true; } else { params.B = B + (K * N * i); params.ldb = ldb; } params.AIsfp32 = std::is_same::value; params.BIsfp32 = std::is_same::value; - Converters.emplace_back(act, Cfloat + (M * N * i), N); - params.OutputProcessor = &(Converters[i]); + if (use_output_processor) { + Converters.emplace_back(act, Cfloat + (M * N * i), N); + params.OutputProcessor = &(Converters[i]); + } else { + params.OutputProcessor = nullptr; + } } - MlasHalfGemmBatch(M, N, K, BatchSize, GemmParameters.data(), threadpool_); + if (enforce_kleidiai_override) { + ASSERT_NE(GetMlasPlatform().MlasHalfGemmBatchOverride, nullptr); + const bool handled = GetMlasPlatform().MlasHalfGemmBatchOverride( + M, N, K, BatchSize, GemmParameters.data(), threadpool_, nullptr); + ASSERT_TRUE(handled); + } else { + MlasHalfGemmBatch(M, N, K, BatchSize, GemmParameters.data(), threadpool_, nullptr); + } } void ReferenceQgemm(size_t M, @@ -153,6 +185,60 @@ class MlasHalfGemmTest : public MlasTestBase { } } + void ReferenceMlasGemmFp32(size_t M, + size_t N, + size_t K, + size_t BatchSize, + const AType* A, + const BType* B, + const MLFp16* Bias, + float* C) { + MatrixGuardBuffer buffer_a_fp32{}; + MatrixGuardBuffer buffer_b_fp32{}; + MatrixGuardBuffer buffer_c_fp32{}; + + float* AFloat = buffer_a_fp32.GetBuffer(M * K * BatchSize); + float* BFloat = buffer_b_fp32.GetBuffer(K * N * BatchSize); + float* CFloat = buffer_c_fp32.GetBuffer(M * N * BatchSize, true); + + for (size_t i = 0; i < M * K * BatchSize; ++i) { + AFloat[i] = float(A[i]); + } + + for (size_t i = 0; i < K * N * BatchSize; ++i) { + BFloat[i] = float(B[i]); + } + + for (size_t batch = 0; batch < BatchSize; ++batch) { + MlasGemm( + CblasNoTrans, CblasNoTrans, M, N, K, + 1.0f, + AFloat + batch * (M * K), K, + BFloat + batch * (K * N), N, + 0.0f, + CFloat + batch * (M * N), N, + threadpool_, + nullptr); + } + + for (size_t batch = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const size_t idx = (M * N * batch) + (m * N) + n; + float sum = CFloat[idx]; + if (Bias != nullptr) { + sum += float(Bias[n]); + } + C[idx] = float(MLFp16(sum)); + } + } + + if (Bias) { + Bias += N; + } + } + } + public: MlasHalfGemmTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} @@ -200,6 +286,56 @@ class MlasHalfGemmTest : public MlasTestBase { } } + void TestKleidiAIWithoutOutputProcessor(size_t M, size_t N, size_t K, size_t BatchSize, bool withBias) { + if (GetMlasPlatform().MlasHalfGemmBatchOverride == nullptr) { + GTEST_SKIP() << "KleidiAI halfgemm override unavailable"; + } + + const AType* A = BufferA.GetFilledBuffer(K * M * BatchSize + 16, SmallFloatFill); + const BType* B = BufferB.GetFilledBuffer(N * K * BatchSize + 16, SmallFloatFill); + + const MLFp16* Bias = nullptr; + if (withBias) { + Bias = BufferBias.GetFilledBuffer(N * BatchSize + 16, SmallFloatFill); + } + + MLFp16* C = BufferC.GetFilledBuffer(N * M * BatchSize, SmallFloatFill); + float* Cfloat = BufferFloatC.GetBuffer(N * M * BatchSize, true); + float* CReference = BufferCReference.GetFilledBuffer( + N * M * BatchSize, + [](float* start, size_t size) { + std::fill_n(start, size, -1.0f); + }); + MatrixGuardBuffer buffer_fp32_reference{}; + float* CFp32Reference = buffer_fp32_reference.GetFilledBuffer( + N * M * BatchSize, + [](float* start, size_t size) { + std::fill_n(start, size, -1.0f); + }); + + this->CallGemm(M, N, K, BatchSize, A, K, B, N, Bias, C, N, Cfloat, false, true); + ReferenceQgemm(M, N, K, BatchSize, A, B, Bias, CReference); + ReferenceMlasGemmFp32(M, N, K, BatchSize, A, B, Bias, CFp32Reference); + + for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++, f++) { + ASSERT_TRUE(CloseEnoughNativeFp16(float(C[f]), CFp32Reference[f])) << "@[" << batch << "x" << m << "x" << n << "], " + << "Batch=" << BatchSize << "M=" << M << ", N=" << N << ", K=" << K + << " got=" << float(C[f]) + << " fp32=" << CFp32Reference[f] + << " stepwise=" << CReference[f]; + } + } + } + } + + void TestNativeFp16WithoutOutputProcessor(size_t M, size_t N, size_t K, size_t BatchSize, bool withBias) { + static_assert(std::is_same_v); + static_assert(std::is_same_v); + TestKleidiAIWithoutOutputProcessor(M, N, K, BatchSize, withBias); + } + private: public: static const char* GetTestSuiteName() { diff --git a/onnxruntime/test/mlas/unittest/test_util.h b/onnxruntime/test/mlas/unittest/test_util.h index a000e353f370d..8c31ad943e319 100644 --- a/onnxruntime/test/mlas/unittest/test_util.h +++ b/onnxruntime/test/mlas/unittest/test_util.h @@ -44,6 +44,7 @@ class MatrixGuardBuffer { _BaseBuffer = nullptr; _BaseBufferSize = 0; _ElementsAllocated = 0; + _GuardAddress = nullptr; } ~MatrixGuardBuffer(void) { @@ -150,6 +151,7 @@ class MatrixGuardBuffer { } _ElementsAllocated = 0; + _GuardAddress = nullptr; } private: diff --git a/onnxruntime/test/optimizer/nhwc_transformer_test.cc b/onnxruntime/test/optimizer/nhwc_transformer_test.cc index b73929efab8a6..438823aaa6cec 100644 --- a/onnxruntime/test/optimizer/nhwc_transformer_test.cc +++ b/onnxruntime/test/optimizer/nhwc_transformer_test.cc @@ -997,7 +997,10 @@ TEST_F(NhwcTransformerTestsFp16, FusedConvWithSumFp16) { TransformerTester(build_test_case, check_nhwc_graph, TransformerLevel::Level2, - TransformerLevel::Level3); + TransformerLevel::Level3, + /*opset_version*/ 12, + /*per_sample_tolerance*/ 0.02, + /*relative_per_sample_tolerance*/ 0.02); } TEST_F(NhwcTransformerTestsFp16, ConvMaxPoolFp16) { diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 9effdd7e5fb6e..05f9ac2b18224 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -118,6 +118,52 @@ TEST(GemmOpTest, GemmNoTrans_f16) { ConvertFloatToMLFloat16(A.data(), f_A.data(), 8); ConvertFloatToMLFloat16(B.data(), f_B.data(), 12); + { + // Missing C uses effective beta == 0. + std::vector f_Y(6); + std::vector Y{19.3f, -1.4f, -26.9f, + -19.3f, 1.4f, 26.9f}; + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); + + OpTester test("Gemm", 13); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 0.0f); + test.AddInput("A", {2, 4}, f_A); + test.AddInput("B", {4, 3}, f_B, true); + test.AddOutput("Y", {2, 3}, f_Y); + test.SetOutputTolerance(0.005f); + test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported + .Config(run_with_tunable_op) + .RunWithConfig(); + } + { + // beta == 0 ignores C, even when C has the full output shape. + std::vector f_Y(6); + std::vector Y{19.3f, -1.4f, -26.9f, + -19.3f, 1.4f, 26.9f}; + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); + + std::vector f_C(6); + ConvertFloatToMLFloat16(C.data(), f_C.data(), 6); + + OpTester test("Gemm", 13); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 0.0f); + test.AddInput("A", {2, 4}, f_A); + test.AddInput("B", {4, 3}, f_B); + test.AddInput("C", {2, 3}, f_C); + test.AddOutput("Y", {2, 3}, f_Y); + test.SetOutputTolerance(0.005f); + test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported + .Config(run_with_tunable_op) + .RunWithConfig(); + } { // bias has same shape as output std::vector f_Y(6); diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index f624ecf57d05e..55eef75ffa569 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -3,6 +3,8 @@ #include "gtest/gtest.h" +#include "core/mlas/inc/mlas.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "test/providers/provider_test_utils.h" #include "test/common/dnnl_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" @@ -687,6 +689,68 @@ TEST(MathOpTest, MatMulBatchedSplitK) { .RunWithConfig(); } +TEST(MathOpTest, MatMulFloat16NativePrepackedWeightsAreNotShared) { + if (MlasHalfGemmNativePackBSize(CblasNoTrans, CblasNoTrans, 3, 4) == 0) { + GTEST_SKIP() << "Native fp16 MatMul prepack unavailable"; + } + + std::vector a_values{1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}; + std::vector b_values(12, 1.0f); + std::vector y_values{10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f}; + + std::vector a_fp16(8); + std::vector b_fp16(12); + std::vector y_fp16(6); + ConvertFloatToMLFloat16(a_values.data(), a_fp16.data(), a_fp16.size()); + ConvertFloatToMLFloat16(b_values.data(), b_fp16.data(), b_fp16.size()); + ConvertFloatToMLFloat16(y_values.data(), y_fp16.data(), y_fp16.size()); + + OpTester test("MatMul", 14); + test.AddInput("A", {2, 4}, a_fp16); + test.AddInput("B", {4, 3}, b_fp16, true); + test.AddOutput("Y", {2, 3}, y_fp16); + + OrtValue b; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({4, 3}), + b_fp16.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b); + + SessionOptions so; + ASSERT_EQ(so.AddInitializer("B", &b), Status::OK()); + ASSERT_EQ(so.config_options.AddConfigEntry(kOrtSessionOptionsEnableCpuFp16, "1"), Status::OK()); + ASSERT_EQ(so.config_options.AddConfigEntry(kOrtSessionOptionsCpuFp16UseFp32FallbackHeuristic, "0"), Status::OK()); + + test.EnableSharingOfPrePackedWeightsAcrossSessions(); + + auto cpu_ep = []() -> std::vector> { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + return execution_providers; + }; + + size_t first_session_prepacked_weights = 0; + size_t shared_prepacked_weights = 0; + + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&first_session_prepacked_weights, &shared_prepacked_weights); + + ASSERT_EQ(shared_prepacked_weights, static_cast(0)); + ASSERT_GT(first_session_prepacked_weights, static_cast(0)); + ASSERT_EQ(test.GetNumPrePackedWeightsShared(), static_cast(0)); + + size_t second_session_prepacked_weights = 0; + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&second_session_prepacked_weights, &shared_prepacked_weights); + + ASSERT_EQ(second_session_prepacked_weights, first_session_prepacked_weights); + ASSERT_EQ(shared_prepacked_weights, static_cast(0)); +} + #endif } // namespace test diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 843d925ed6638..03b1eac1e91bf 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -10,6 +10,7 @@ #include "test/common/random_generator.h" #include "test/providers/provider_test_utils.h" #include "default_providers.h" +#include "core/session/onnxruntime_session_options_config_keys.h" using namespace std; namespace onnxruntime { @@ -47,7 +48,8 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", int opset = 11, - float rel_error = 0.002f) { + float rel_error = 0.002f, + bool disable_kleidiai = false) { std::unique_ptr tester; if (!attributes.activation.empty()) { tester = std::make_unique("NhwcFusedConv", 1, onnxruntime::kMSDomain); @@ -89,6 +91,12 @@ void TestConvFp16Op(const ConvOpAndTestAttributes& attributes, tester->AddOutput("Y", expected_output_shape, expected_output, /*no sort*/ false, rel_error, 0.0f); + if (disable_kleidiai) { + SessionOptions session_options; + ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsMlasDisableKleidiAi, "1")); + tester->Config(session_options); + } + std::unordered_set excluded_providers(attributes.excluded_providers); // Disable TensorRT because weight as input is not supported excluded_providers.insert(kTensorrtExecutionProvider); @@ -369,6 +377,158 @@ TEST(ConvFp16Test, Conv2D_1) { TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } +TEST(ConvFp16Test, Conv2D_KleidiAiImatmulEligibleNoBias) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{3, 3}, // kernel_shape + vector{1, 1, 1, 1}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), MLFloat16(4.0f), + MLFloat16(5.0f), MLFloat16(6.0f), MLFloat16(7.0f), MLFloat16(8.0f), + MLFloat16(9.0f), MLFloat16(10.0f), MLFloat16(11.0f), MLFloat16(12.0f), + MLFloat16(13.0f), MLFloat16(14.0f), MLFloat16(15.0f), MLFloat16(16.0f)}; + vector X_shape = {1, 1, 4, 4}; + vector W = { + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f)}; + vector W_shape = {2, 1, 3, 3}; + vector Y_shape = {1, 2, 4, 4}; + auto expected_vals = { + MLFloat16(14.0f), MLFloat16(24.0f), MLFloat16(30.0f), MLFloat16(22.0f), + MLFloat16(33.0f), MLFloat16(54.0f), MLFloat16(63.0f), MLFloat16(45.0f), + MLFloat16(57.0f), MLFloat16(90.0f), MLFloat16(99.0f), MLFloat16(69.0f), + MLFloat16(46.0f), MLFloat16(72.0f), MLFloat16(78.0f), MLFloat16(54.0f), + MLFloat16(7.0f), MLFloat16(12.0f), MLFloat16(15.0f), MLFloat16(11.0f), + MLFloat16(16.5f), MLFloat16(27.0f), MLFloat16(31.5f), MLFloat16(22.5f), + MLFloat16(28.5f), MLFloat16(45.0f), MLFloat16(49.5f), MLFloat16(34.5f), + MLFloat16(23.0f), MLFloat16(36.0f), MLFloat16(39.0f), MLFloat16(27.0f)}; + + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} + +TEST(ConvFp16Test, Conv2D_KleidiAiImatmulEligibleBiasAndDisabledFallback) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{3, 3}, // kernel_shape + vector{1, 1, 1, 1}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X = { + MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), MLFloat16(4.0f), + MLFloat16(5.0f), MLFloat16(6.0f), MLFloat16(7.0f), MLFloat16(8.0f), + MLFloat16(9.0f), MLFloat16(10.0f), MLFloat16(11.0f), MLFloat16(12.0f), + MLFloat16(13.0f), MLFloat16(14.0f), MLFloat16(15.0f), MLFloat16(16.0f)}; + vector X_shape = {1, 1, 4, 4}; + vector W = { + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f)}; + vector W_shape = {2, 1, 3, 3}; + vector B = {MLFloat16(1.0f), MLFloat16(-2.0f)}; + vector B_shape = {2}; + vector Y_shape = {1, 2, 4, 4}; + auto expected_vals = { + MLFloat16(15.0f), MLFloat16(25.0f), MLFloat16(31.0f), MLFloat16(23.0f), + MLFloat16(34.0f), MLFloat16(55.0f), MLFloat16(64.0f), MLFloat16(46.0f), + MLFloat16(58.0f), MLFloat16(91.0f), MLFloat16(100.0f), MLFloat16(70.0f), + MLFloat16(47.0f), MLFloat16(73.0f), MLFloat16(79.0f), MLFloat16(55.0f), + MLFloat16(5.0f), MLFloat16(10.0f), MLFloat16(13.0f), MLFloat16(9.0f), + MLFloat16(14.5f), MLFloat16(25.0f), MLFloat16(29.5f), MLFloat16(20.5f), + MLFloat16(26.5f), MLFloat16(43.0f), MLFloat16(47.5f), MLFloat16(32.5f), + MLFloat16(21.0f), MLFloat16(34.0f), MLFloat16(37.0f), MLFloat16(25.0f)}; + + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); + + constexpr bool weight_is_initializer = true; + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, weight_is_initializer); + + constexpr bool disable_kleidiai = true; + TestConvFp16Op(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, weight_is_initializer, + OpTester::ExpectResult::kExpectSuccess, "", 11, 0.002f, disable_kleidiai); +} + +TEST(ConvFp16Test, NhwcFusedConv2D_KleidiAiImatmulEligibleBiasAndDisabledFallback) { +#if !defined(__aarch64__) && !defined(_M_ARM64) + GTEST_SKIP() << "Native CPU fp16 Conv runtime support is only tested on Arm64."; +#else + if (!MlasFp16AccelerationSupported()) { + GTEST_SKIP() << "Native CPU fp16 Conv runtime support is unavailable."; + } +#endif + + auto run_test = [](bool disable_kleidiai) { + OpTester test("NhwcFusedConv", 1, onnxruntime::kMSDomain); + test.AddAttribute("group", static_cast(1)); + test.AddAttribute("kernel_shape", vector{3, 3}); + test.AddAttribute("pads", vector{1, 1, 1, 1}); + test.AddAttribute("strides", vector{1, 1}); + test.AddAttribute("dilations", vector{1, 1}); + + vector X = { + MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), MLFloat16(4.0f), + MLFloat16(5.0f), MLFloat16(6.0f), MLFloat16(7.0f), MLFloat16(8.0f), + MLFloat16(9.0f), MLFloat16(10.0f), MLFloat16(11.0f), MLFloat16(12.0f), + MLFloat16(13.0f), MLFloat16(14.0f), MLFloat16(15.0f), MLFloat16(16.0f)}; + vector X_shape = {1, 4, 4, 1}; + vector W = { + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f)}; + vector W_shape = {2, 1, 3, 3}; + vector B = {MLFloat16(1.0f), MLFloat16(-2.0f)}; + vector B_shape = {2}; + vector Y_shape = {1, 4, 4, 2}; + auto expected_vals = { + MLFloat16(15.0f), MLFloat16(5.0f), MLFloat16(25.0f), MLFloat16(10.0f), + MLFloat16(31.0f), MLFloat16(13.0f), MLFloat16(23.0f), MLFloat16(9.0f), + MLFloat16(34.0f), MLFloat16(14.5f), MLFloat16(55.0f), MLFloat16(25.0f), + MLFloat16(64.0f), MLFloat16(29.5f), MLFloat16(46.0f), MLFloat16(20.5f), + MLFloat16(58.0f), MLFloat16(26.5f), MLFloat16(91.0f), MLFloat16(43.0f), + MLFloat16(100.0f), MLFloat16(47.5f), MLFloat16(70.0f), MLFloat16(32.5f), + MLFloat16(47.0f), MLFloat16(21.0f), MLFloat16(73.0f), MLFloat16(34.0f), + MLFloat16(79.0f), MLFloat16(37.0f), MLFloat16(55.0f), MLFloat16(25.0f)}; + + test.AddInput("X", X_shape, X); + test.AddInput("W", W_shape, W, true); + test.AddInput("B", B_shape, B, true); + test.AddOutput("Y", Y_shape, expected_vals, /*no sort*/ false, 0.002f, 0.0f); + + if (disable_kleidiai) { + SessionOptions session_options; + ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsMlasDisableKleidiAi, "1")); + test.Config(session_options); + } + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.ConfigEps(std::move(execution_providers)).RunWithConfig(); + }; + + run_test(false); + run_test(true); +} + TEST(ConvFp16Test, Conv2D_2) { ConvOpAndTestAttributes attrs = { "", // auto_pad @@ -1496,6 +1656,165 @@ TEST(ConvFp16Test, SharedPrepackedWeights) { } } +TEST(ConvFp16Test, SharedPrepackedWeights_HalfConvEligible_NoBias) { + OpTester test("Conv", 11); + test.AddAttribute("group", static_cast(1)); + test.AddAttribute("kernel_shape", vector{3, 3}); + test.AddAttribute("pads", vector{1, 1, 1, 1}); + test.AddAttribute("strides", vector{1, 1}); + test.AddAttribute("dilations", vector{1, 1}); + + vector X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), MLFloat16(4.0f), + MLFloat16(5.0f), MLFloat16(6.0f), MLFloat16(7.0f), MLFloat16(8.0f), + MLFloat16(9.0f), MLFloat16(10.0f), MLFloat16(11.0f), MLFloat16(12.0f), + MLFloat16(13.0f), MLFloat16(14.0f), MLFloat16(15.0f), MLFloat16(16.0f)}; + vector X_shape = {1, 1, 4, 4}; + vector W = {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f)}; + vector W_shape = {2, 1, 3, 3}; + vector Y_shape = {1, 2, 4, 4}; + auto expected_vals = { + MLFloat16(14.0f), MLFloat16(24.0f), MLFloat16(30.0f), MLFloat16(22.0f), MLFloat16(33.0f), MLFloat16(54.0f), + MLFloat16(63.0f), MLFloat16(45.0f), MLFloat16(57.0f), MLFloat16(90.0f), MLFloat16(99.0f), MLFloat16(69.0f), + MLFloat16(46.0f), MLFloat16(72.0f), MLFloat16(78.0f), MLFloat16(54.0f), MLFloat16(7.0f), MLFloat16(12.0f), + MLFloat16(15.0f), MLFloat16(11.0f), MLFloat16(16.5f), MLFloat16(27.0f), MLFloat16(31.5f), MLFloat16(22.5f), + MLFloat16(28.5f), MLFloat16(45.0f), MLFloat16(49.5f), MLFloat16(34.5f), MLFloat16(23.0f), MLFloat16(36.0f), + MLFloat16(39.0f), MLFloat16(27.0f)}; + + test.AddInput("X", X_shape, X); + test.AddInput("W", W_shape, W, true); + test.AddOutput("Y", Y_shape, expected_vals, /*no sort*/ false, 0.002f, 0.0f); + + OrtValue w; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape(W_shape), + W.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), w); + + SessionOptions so; + ASSERT_EQ(so.AddInitializer("W", &w), Status::OK()); + + test.EnableSharingOfPrePackedWeightsAcrossSessions(); + + auto cpu_ep = []() -> std::vector> { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + return execution_providers; + }; + + size_t number_of_pre_packed_weights_counter_session_1 = 0; + size_t number_of_shared_pre_packed_weights_counter = 0; + + { + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&number_of_pre_packed_weights_counter_session_1, &number_of_shared_pre_packed_weights_counter); + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(0)); + } + + const auto number_of_elements_in_shared_prepacked_buffers_container = test.GetNumPrePackedWeightsShared(); + + if (number_of_pre_packed_weights_counter_session_1 == 0) { + GTEST_SKIP() << "No pre-packed weights were produced."; + } + + ASSERT_EQ(number_of_elements_in_shared_prepacked_buffers_container, static_cast(1)); + + { + size_t number_of_pre_packed_weights_counter_session_2 = 0; + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&number_of_pre_packed_weights_counter_session_2, &number_of_shared_pre_packed_weights_counter); + + ASSERT_GE(number_of_pre_packed_weights_counter_session_1, number_of_shared_pre_packed_weights_counter); + ASSERT_GE(number_of_pre_packed_weights_counter_session_2, number_of_shared_pre_packed_weights_counter); + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(1)); + } +} + +TEST(ConvFp16Test, SharedPrepackedWeights_HalfConvEligible_BiasNotShared) { + OpTester test("Conv", 11); + test.AddAttribute("group", static_cast(1)); + test.AddAttribute("kernel_shape", vector{3, 3}); + test.AddAttribute("pads", vector{1, 1, 1, 1}); + test.AddAttribute("strides", vector{1, 1}); + test.AddAttribute("dilations", vector{1, 1}); + + vector X = {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f), MLFloat16(4.0f), + MLFloat16(5.0f), MLFloat16(6.0f), MLFloat16(7.0f), MLFloat16(8.0f), + MLFloat16(9.0f), MLFloat16(10.0f), MLFloat16(11.0f), MLFloat16(12.0f), + MLFloat16(13.0f), MLFloat16(14.0f), MLFloat16(15.0f), MLFloat16(16.0f)}; + vector X_shape = {1, 1, 4, 4}; + vector W = {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), + MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f), + MLFloat16(0.5f), MLFloat16(0.5f), MLFloat16(0.5f)}; + vector W_shape = {2, 1, 3, 3}; + vector B = {MLFloat16(1.0f), MLFloat16(-2.0f)}; + vector B_shape = {2}; + vector Y_shape = {1, 2, 4, 4}; + auto expected_vals = { + MLFloat16(15.0f), MLFloat16(25.0f), MLFloat16(31.0f), MLFloat16(23.0f), MLFloat16(34.0f), MLFloat16(55.0f), + MLFloat16(64.0f), MLFloat16(46.0f), MLFloat16(58.0f), MLFloat16(91.0f), MLFloat16(100.0f), MLFloat16(70.0f), + MLFloat16(47.0f), MLFloat16(73.0f), MLFloat16(79.0f), MLFloat16(55.0f), MLFloat16(5.0f), MLFloat16(10.0f), + MLFloat16(13.0f), MLFloat16(9.0f), MLFloat16(14.5f), MLFloat16(25.0f), MLFloat16(29.5f), MLFloat16(20.5f), + MLFloat16(26.5f), MLFloat16(43.0f), MLFloat16(47.5f), MLFloat16(32.5f), MLFloat16(21.0f), MLFloat16(34.0f), + MLFloat16(37.0f), MLFloat16(25.0f)}; + + test.AddInput("X", X_shape, X); + test.AddInput("W", W_shape, W, true); + test.AddInput("B", B_shape, B, true); + test.AddOutput("Y", Y_shape, expected_vals, /*no sort*/ false, 0.002f, 0.0f); + + OrtValue w; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape(W_shape), + W.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), w); + + SessionOptions so; + ASSERT_EQ(so.AddInitializer("W", &w), Status::OK()); + + test.EnableSharingOfPrePackedWeightsAcrossSessions(); + + auto cpu_ep = []() -> std::vector> { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + return execution_providers; + }; + + size_t number_of_pre_packed_weights_counter_session_1 = 0; + size_t number_of_shared_pre_packed_weights_counter = 0; + + { + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&number_of_pre_packed_weights_counter_session_1, &number_of_shared_pre_packed_weights_counter); + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(0)); + } + + const auto number_of_elements_in_shared_prepacked_buffers_container = test.GetNumPrePackedWeightsShared(); + + if (number_of_pre_packed_weights_counter_session_1 == 0) { + GTEST_SKIP() << "No pre-packed weights were produced."; + } + + ASSERT_EQ(number_of_elements_in_shared_prepacked_buffers_container, static_cast(1)); + + { + size_t number_of_pre_packed_weights_counter_session_2 = 0; + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&number_of_pre_packed_weights_counter_session_2, &number_of_shared_pre_packed_weights_counter); + + ASSERT_GE(number_of_pre_packed_weights_counter_session_1, number_of_shared_pre_packed_weights_counter); + ASSERT_GE(number_of_pre_packed_weights_counter_session_2, number_of_shared_pre_packed_weights_counter); + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(1)); + } +} + #endif } // namespace test diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 4b4d9e9ddd1ba..4714f2c7fb0f9 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" @@ -253,6 +254,14 @@ TYPED_TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear) { // QNN: result diff // TRT: Segmentation fault in A100 std::unordered_set excluded_providers({kQnnExecutionProvider}); + if constexpr (std::is_same_v) { + // These EPs do not provide this Resize(13) MLFloat16 kernel. + excluded_providers.insert(kNnapiExecutionProvider); + excluded_providers.insert(kXnnpackExecutionProvider); + excluded_providers.insert(kCoreMLExecutionProvider); + excluded_providers.insert(kTensorrtExecutionProvider); + excluded_providers.insert("example_ep"); + } test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100(excluded_providers)); };