diff --git a/onnxruntime/core/mlas/lib/convolve.cpp b/onnxruntime/core/mlas/lib/convolve.cpp index 4378ec1948fdb..5dc5eb1281b12 100644 --- a/onnxruntime/core/mlas/lib/convolve.cpp +++ b/onnxruntime/core/mlas/lib/convolve.cpp @@ -1377,7 +1377,7 @@ MlasConvSupportsSymmetricChannelsLast2DFloatKernel( return false; } - if (Dimensions != 2 || BatchCount != 1 || GroupCount != 1 || Beta != 0.0f) { + if (Dimensions != 2 || BatchCount != 1 || Beta != 0.0f) { return false; } @@ -1395,7 +1395,12 @@ MlasConvSupportsSymmetricChannelsLast2DFloatKernel( return false; } - if (FilterCount <= 1 || KernelShape[0] < 3 || KernelShape[1] < 3) { + const bool is_depthwise = GroupCount > 1 && FilterCount == 1; + if (GroupCount > 1 && !is_depthwise) { + return false; + } + + if (!is_depthwise && (FilterCount <= 1 || KernelShape[0] < 3 || KernelShape[1] < 3)) { return false; } diff --git a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp index cca4f5a19c417..1da8530350d75 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp @@ -550,15 +550,40 @@ static void ConvolveSme(const size_t co, //channels out dim[1] = MlasDivRoundup(m, m_step); dim[2] = MlasDivRoundup(co, n_step); + const bool grouped_channels_last = input_is_channels_last && groups > 1; + const size_t input_channels_total = ci * groups; + const size_t output_channels_total = co * groups; + const float* input_base = in; + float* output_base = out; + std::vector input_group_buffer; + if (grouped_channels_last) { + input_group_buffer.resize(ih * iw * ci); + } + for (size_t g = 0; g < groups; ++g) { + const float* input_group = in; + if (grouped_channels_last) { + for (size_t pixel = 0; pixel < ih * iw; ++pixel) { + const float* src = input_base + pixel * input_channels_total + g * ci; + std::copy_n(src, ci, input_group_buffer.data() + pixel * ci); + } + input_group = input_group_buffer.data(); + } auto result = out; const bool need_transpose = (!input_is_channels_last) && (co > 1); + const bool use_temp_output = grouped_channels_last || need_transpose; if (need_transpose) { result = tmp_mlas_aligned; } + if (grouped_channels_last) { + result = tmp_mlas_aligned; + } - auto lhs = LhsPackImageDataSme(ci, ih, iw, d_kh, d_kw, sh, sw, padding, in, input_is_channels_last, ThreadPool); + auto lhs = LhsPackImageDataSme(ci, ih, iw, d_kh, d_kw, sh, sw, padding, + input_group, + input_is_channels_last, + ThreadPool); const std::byte* rhs_data = packed_rhs ? packed_rhs + g * packed_rhs_group_stride : nullptr; std::unique_ptr rhs_storage; if (rhs_data == nullptr) { @@ -613,13 +638,26 @@ static void ConvolveSme(const size_t co, //channels out ); }); + if (grouped_channels_last) { + for (size_t pixel = 0; pixel < m; ++pixel) { + float* dst = output_base + pixel * output_channels_total + g * co; + const float* src = result + pixel * co; + std::copy_n(src, co, dst); + } + } + if (need_transpose) { //Note: this could be absorbed into post conv activation MlasTranspose(tmp_mlas_aligned, out, m, co, ThreadPool); } - in += ci * ih * iw; - out += m * co; + if (!grouped_channels_last) { + in += ci * ih * iw; + out += use_temp_output ? 0 : m * co; + if (need_transpose) { + out += m * co; + } + } weights += co * ci * kh * kw; if(bias){ bias += co; diff --git a/onnxruntime/core/optimizer/nhwc_transformer.cc b/onnxruntime/core/optimizer/nhwc_transformer.cc index 6c0717865b135..2737dc9a02c0b 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.cc +++ b/onnxruntime/core/optimizer/nhwc_transformer.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "core/common/cpuid_info.h" #include "core/graph/constants.h" @@ -191,9 +192,14 @@ bool FloatNhwcWrapperFilter(const onnx_transpose_optimization::api::GraphRef& gr } const auto group = node.GetAttributeInt("group").value_or(1); - if (group != 1) { + if (group <= 0) { return false; } + constexpr uint64_t kSizeTMax = static_cast(std::numeric_limits::max()); + if (static_cast(group) > kSizeTMax) { + return false; + } + const auto group_count = narrow(group); std::array input_spatial_shape{}; std::array kernel_spatial_shape{}; @@ -201,17 +207,22 @@ bool FloatNhwcWrapperFilter(const onnx_transpose_optimization::api::GraphRef& gr std::array strides{1, 1}; std::array pads{}; size_t batch_count = 0; - size_t filter_count = 0; + size_t total_filter_count = 0; if (!TryGetDimValueAsSizeT(*input_shape, 0, batch_count) || !TryGetDimValueAsSizeT(*input_shape, 2, input_spatial_shape[0]) || !TryGetDimValueAsSizeT(*input_shape, 3, input_spatial_shape[1]) || - !TryGetDimValueAsSizeT(*weight_shape, 0, filter_count) || + !TryGetDimValueAsSizeT(*weight_shape, 0, total_filter_count) || !TryGetDimValueAsSizeT(*weight_shape, 2, kernel_spatial_shape[0]) || !TryGetDimValueAsSizeT(*weight_shape, 3, kernel_spatial_shape[1])) { return false; } + if (total_filter_count == 0 || total_filter_count % group_count != 0) { + return false; + } + const size_t filter_count = total_filter_count / group_count; + const auto dilations_opt = node.GetAttributeInts("dilations"); if (dilations_opt.has_value() && !TryReadPositiveInts(*dilations_opt, dilations)) { return false; @@ -229,7 +240,7 @@ bool FloatNhwcWrapperFilter(const onnx_transpose_optimization::api::GraphRef& gr return MlasConvSupportsSymmetricChannelsLast2DFloatKernel( /*Dimensions*/ 2, batch_count, - /*GroupCount*/ 1, + group_count, input_spatial_shape.data(), kernel_spatial_shape.data(), dilations.data(), diff --git a/onnxruntime/core/providers/cpu/nn/conv.cc b/onnxruntime/core/providers/cpu/nn/conv.cc index 87ce1b05caae2..eff72e4498b3f 100644 --- a/onnxruntime/core/providers/cpu/nn/conv.cc +++ b/onnxruntime/core/providers/cpu/nn/conv.cc @@ -23,7 +23,7 @@ #include "core/common/safeint.h" #include "core/util/math_cpuonly.h" -#if defined(USE_KLEIDIAI) && defined(__aarch64__) && defined(__linux__) +#if defined(USE_KLEIDIAI) && defined(MLAS_TARGET_ARM64) #include "core/mlas/lib/kleidiai/mlasi_kleidiai.h" #endif @@ -191,7 +191,7 @@ Status Conv::Compute(OpKernelContext* context) const { return Status::OK(); } -#if defined(USE_KLEIDIAI) && defined(__aarch64__) && defined(__linux__) +#if defined(USE_KLEIDIAI) && defined(MLAS_TARGET_ARM64) Status Conv::EnsurePackedChannelsLastFilter(concurrency::ThreadPool* thread_pool, size_t filter_count_per_group, size_t input_channels_per_group, @@ -329,7 +329,7 @@ Status Conv::Compute(OpKernelContext* context) const { narrow(M / conv_attrs_.group), /*Beta*/ 0.0f); -#if defined(USE_KLEIDIAI) && defined(__aarch64__) && defined(__linux__) +#if defined(USE_KLEIDIAI) && defined(MLAS_TARGET_ARM64) if (nhwc_fastpath && can_cache_packed_filter_) { ORT_RETURN_IF_ERROR(EnsurePackedChannelsLastFilter(thread_pool, narrow(M / conv_attrs_.group), @@ -385,7 +385,7 @@ Status Conv::Compute(OpKernelContext* context) const { nhwc_fastpath ? 0.0f : Beta, thread_pool); -#if defined(USE_KLEIDIAI) && defined(__aarch64__) && defined(__linux__) +#if defined(USE_KLEIDIAI) && defined(MLAS_TARGET_ARM64) if (nhwc_fastpath && packed_filter_ != nullptr) { Parameters.FilterIsPacked = true; Parameters.PackedFilter = packed_filter_.get(); diff --git a/onnxruntime/core/providers/cpu/nn/conv.h b/onnxruntime/core/providers/cpu/nn/conv.h index 1cbe417cdbd96..9e073df545328 100644 --- a/onnxruntime/core/providers/cpu/nn/conv.h +++ b/onnxruntime/core/providers/cpu/nn/conv.h @@ -31,7 +31,7 @@ class Conv : public OpKernel { activation_.ActivationKind = MlasIdentityActivation; SetupMlasBackendKernelSelectorFromConfigOptions(mlas_backend_kernel_selector_config_, info.GetConfigOptions()); -#if defined(USE_KLEIDIAI) && defined(__aarch64__) && defined(__linux__) +#if defined(USE_KLEIDIAI) && defined(MLAS_TARGET_ARM64) if (channels_last_) { const auto& input_defs = info.node().InputDefs(); const bool has_bias_input = input_defs.size() >= 3 && input_defs[2] != nullptr; @@ -56,7 +56,7 @@ class Conv : public OpKernel { ConvAttributes conv_attrs_; bool channels_last_{false}; -#if defined(USE_KLEIDIAI) && defined(__aarch64__) && defined(__linux__) +#if defined(USE_KLEIDIAI) && defined(MLAS_TARGET_ARM64) private: Status EnsurePackedChannelsLastFilter(concurrency::ThreadPool* thread_pool, size_t filter_count_per_group, diff --git a/onnxruntime/test/contrib_ops/fused_conv_test.cc b/onnxruntime/test/contrib_ops/fused_conv_test.cc index 608ccadff8f1d..f453cb3bca746 100644 --- a/onnxruntime/test/contrib_ops/fused_conv_test.cc +++ b/onnxruntime/test/contrib_ops/fused_conv_test.cc @@ -3,11 +3,19 @@ #include "gtest/gtest.h" +#include + +#include "core/common/narrow.h" +#include "core/framework/kernel_registry.h" #include "test/common/cuda_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" +#if defined(USE_KLEIDIAI) && defined(MLAS_TARGET_ARM64) +#include "core/mlas/lib/mlasi.h" +#endif + namespace onnxruntime { namespace test { @@ -121,6 +129,88 @@ void RunConvOp(const ConvOpAndTestAttributes& attributes, } #ifdef USE_KLEIDIAI +namespace { + +#if defined(MLAS_TARGET_ARM64) +bool HasFloatNhwcFusedConvKernel() { + auto cpu_ep = DefaultCpuExecutionProvider(); + if (cpu_ep == nullptr) { + return false; + } + + auto kernel_registry = cpu_ep->GetKernelRegistry(); + if (!kernel_registry) { + return false; + } + + KernelRegistry::TypeConstraintMap type_constraints{ + {"T", DataTypeImpl::GetTensorType()}, + }; + + const KernelCreateInfo* kernel_create_info{}; + const auto status = kernel_registry->TryFindKernel( + kCpuExecutionProvider, + "NhwcFusedConv", + kMSDomain, + 1, + type_constraints, + DefaultLoggingManager().DefaultLogger(), + &kernel_create_info); + + return status.IsOK() && kernel_create_info != nullptr; +} + +bool HasFloatNhwcNoTransposeSupport(const vector& input_shape, + const vector& weight_shape, + const vector& pads, + const vector& strides, + int64_t group) { + if (!HasFloatNhwcFusedConvKernel() || !MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) { + return false; + } + + if (group <= 0 || input_shape.size() != 4 || weight_shape.size() != 4 || + pads.size() != 4 || strides.size() != 2 || + weight_shape[0] <= 0 || weight_shape[0] % group != 0) { + return false; + } + + std::array input_spatial_shape{ + narrow(input_shape[1]), + narrow(input_shape[2]), + }; + std::array kernel_spatial_shape{ + narrow(weight_shape[2]), + narrow(weight_shape[3]), + }; + std::array dilations{1, 1}; + std::array strides_size_t{ + narrow(strides[0]), + narrow(strides[1]), + }; + std::array pads_size_t{ + narrow(pads[0]), + narrow(pads[1]), + narrow(pads[2]), + narrow(pads[3]), + }; + + return MlasConvSupportsSymmetricChannelsLast2DFloatKernel( + /*Dimensions*/ 2, + narrow(input_shape[0]), + narrow(group), + input_spatial_shape.data(), + kernel_spatial_shape.data(), + dilations.data(), + pads_size_t.data(), + strides_size_t.data(), + narrow(weight_shape[0] / group), + /*Beta*/ 0.0f); +} +#endif + +} // namespace + void TestNhwcFusedConvFloatOp(const ConvOpAndTestAttributes& attributes, const vector>& inputs, const vector>& input_shapes, @@ -372,6 +462,45 @@ TEST(FusedConvTest, Cpu_NhwcConv2D_AutoPadSameUpper) { TestNhwcFusedConvFloatOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); TestNhwcFusedConvFloatOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } + +TEST(FusedConvTest, Cpu_NhwcDepthwiseConv2D_SymmetricPadding) { +#if !defined(MLAS_TARGET_ARM64) + GTEST_SKIP() << "Float NHWC depthwise fast-path requires Arm64."; +#else + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 2, // group + vector{3, 3}, // kernel_shape + vector{1, 1, 1, 1}, // pads + vector{1, 1}, // strides + "Relu" // activation + }; + + vector X_shape = {1, 3, 3, 2}; + vector X = {1.0f, 10.0f, 2.0f, 20.0f, 3.0f, 30.0f, + 4.0f, 40.0f, 5.0f, 50.0f, 6.0f, 60.0f, + 7.0f, 70.0f, 8.0f, 80.0f, 9.0f, 90.0f}; + vector W_shape = {2, 1, 3, 3}; + vector W = {1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, + 1.0f, 0.0f, 1.0f, + 0.0f, 2.0f, 0.0f, + 1.0f, 0.0f, 1.0f}; + vector Y_shape = {1, 3, 3, 2}; + auto expected_vals = {12.0f, 70.0f, 21.0f, 140.0f, 16.0f, 110.0f, + 27.0f, 180.0f, 45.0f, 300.0f, 33.0f, 220.0f, + 24.0f, 190.0f, 39.0f, 260.0f, 28.0f, 230.0f}; + + if (!HasFloatNhwcNoTransposeSupport(X_shape, W_shape, attrs.pads, attrs.strides, attrs.group)) { + GTEST_SKIP() << "Float NHWC depthwise fast-path is not available on this configuration."; + } + + TestNhwcFusedConvFloatOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestNhwcFusedConvFloatOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +#endif +} #endif TEST(FusedConvTest, Cpu_Conv3D_Batched_Relu) { diff --git a/onnxruntime/test/mlas/bench/bench_sconv.cpp b/onnxruntime/test/mlas/bench/bench_sconv.cpp index 9df09728ffa17..85cb68076b638 100644 --- a/onnxruntime/test/mlas/bench/bench_sconv.cpp +++ b/onnxruntime/test/mlas/bench/bench_sconv.cpp @@ -146,6 +146,145 @@ static MLAS_THREADPOOL* GetMlasThreadPoolForConvBenchmark(void) { return threadpool.get(); } +void SCONV_NHWC_KLEIDIAI(benchmark::State& state, const char* /*dummy*/) { + const int64_t rank = state.range(0); // Rank + const int64_t batch_size = state.range(1); // N + const int64_t groups = state.range(2); // G + const int64_t input_channels_per_group = state.range(3); // Cpg + const int64_t output_channels_per_group = state.range(4); // Fpg + + if (rank <= 0) throw std::invalid_argument("Kernel rank must be greater than 0"); + if (batch_size <= 0) throw std::invalid_argument("Batch size must be greater than 0"); + if (groups <= 0) throw std::invalid_argument("Group count must be greater than 0"); + if (input_channels_per_group <= 0) throw std::invalid_argument("input_channels_per_group must be greater than 0"); + if (output_channels_per_group <= 0) throw std::invalid_argument("output_channels_per_group must be greater than 0"); + + size_t arg_position = 5; + const auto input_shape = BenchArgsVector(state, arg_position, rank); + const auto kernel_shape = BenchArgsVector(state, arg_position, rank); + const auto paddings = BenchArgsVector(state, arg_position, rank * 2); + const auto strides = BenchArgsVector(state, arg_position, rank); + const auto dilations = BenchArgsVector(state, arg_position, rank); + + if (std::any_of(input_shape.begin(), input_shape.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all input image dim must > 0"); + } + + if (std::any_of(kernel_shape.begin(), kernel_shape.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all kernel dim must > 0"); + } + + if (std::any_of(strides.begin(), strides.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all strides dim must > 0"); + } + + if (std::any_of(dilations.begin(), dilations.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all dilations dim must > 0"); + } + + if (rank != 2 || batch_size != 1) { + state.SkipWithError("KleidiAI NHWC benchmark requires 2D convolution with batch size 1."); + return; + } + + std::vector input_shape_size_t(static_cast(rank)); + std::vector kernel_shape_size_t(static_cast(rank)); + std::vector paddings_size_t(static_cast(rank * 2)); + std::vector strides_size_t(static_cast(rank)); + std::vector dilations_size_t(static_cast(rank)); + for (int64_t i = 0; i < rank; ++i) { + input_shape_size_t[static_cast(i)] = static_cast(input_shape[static_cast(i)]); + kernel_shape_size_t[static_cast(i)] = static_cast(kernel_shape[static_cast(i)]); + strides_size_t[static_cast(i)] = static_cast(strides[static_cast(i)]); + dilations_size_t[static_cast(i)] = static_cast(dilations[static_cast(i)]); + paddings_size_t[static_cast(i)] = static_cast(paddings[static_cast(i)]); + paddings_size_t[static_cast(i + rank)] = static_cast(paddings[static_cast(i + rank)]); + } + + if (!MlasConvSupportsSymmetricChannelsLast2DFloatKernel( + static_cast(rank), + static_cast(batch_size), + static_cast(groups), + input_shape_size_t.data(), + kernel_shape_size_t.data(), + dilations_size_t.data(), + paddings_size_t.data(), + strides_size_t.data(), + static_cast(output_channels_per_group), + 0.0f)) { + state.SkipWithError("KleidiAI NHWC kernel is not supported for this benchmark shape on the current platform."); + return; + } + + const int64_t GC = groups * input_channels_per_group; + const int64_t GF = groups * output_channels_per_group; + std::vector x_shape = {batch_size}; + x_shape.insert(x_shape.end(), input_shape.begin(), input_shape.end()); + x_shape.push_back(GC); + + std::vector f_shape = {GF, input_channels_per_group}; + f_shape.insert(f_shape.end(), kernel_shape.begin(), kernel_shape.end()); + + std::vector output_shape(static_cast(rank)); + for (int64_t i = 0; i < rank; ++i) { + auto km = 1 + dilations[static_cast(i)] * (kernel_shape[static_cast(i)] - 1); + output_shape[static_cast(i)] = + (paddings[static_cast(i)] + paddings[static_cast(i + rank)] + input_shape[static_cast(i)] - km) / + strides[static_cast(i)] + + 1; + } + + std::vector y_shape = {batch_size}; + y_shape.insert(y_shape.end(), output_shape.begin(), output_shape.end()); + y_shape.push_back(GF); + + MLAS_ACTIVATION activation; + activation.ActivationKind = MlasIdentityActivation; + MLAS_CONV_PARAMETERS Parameters; + size_t WorkingBufferSize = 0; + MlasConvPrepare(&Parameters, + static_cast(rank), + static_cast(batch_size), + static_cast(groups), + static_cast(input_channels_per_group), + input_shape.data(), + kernel_shape.data(), + dilations.data(), + paddings.data(), + strides.data(), + output_shape.data(), + static_cast(output_channels_per_group), + &activation, + &WorkingBufferSize, + true, + 0.0f, + nullptr); + + auto X = RandomVectorUniform(x_shape, -2.0, 2.0); + auto F = RandomVectorUniform(f_shape, -1.0, 1.0); + int64_t y_size = std::accumulate(y_shape.begin(), y_shape.end(), 1LL, std::multiplies()); + std::vector Y(static_cast(y_size)); + std::vector working_buffer(WorkingBufferSize); + + MlasConv(&Parameters, + X.data(), + F.data(), + nullptr, + working_buffer.data(), + Y.data(), + nullptr); + + for (auto _ : state) { + MlasConv(&Parameters, + X.data(), + F.data(), + nullptr, + working_buffer.data(), + Y.data(), + nullptr); + } +} + void SCONV_NCHW_THREADED(benchmark::State& state, const char* /*dummy*/) { MLAS_THREADPOOL* tp = GetMlasThreadPoolForConvBenchmark(); @@ -354,6 +493,21 @@ static void MobileClip(benchmark::internal::Benchmark* b) { BENCHMARK_CAPTURE(SCONV_NCHW, MobileClip, "")->Apply(MobileClip)->UseRealTime(); BENCHMARK_CAPTURE(SCONV_NCHW_THREADED, MobileClip, "")->Apply(MobileClip)->UseRealTime(); +static void KleidiAiNhwcComparison(benchmark::internal::Benchmark* b) { + b->ArgNames(ArgNamesForConv(2)); + + // Dense 3x3 conv shapes that fit the Arm SME / KleidiAI NHWC fast-path envelope. + b->Args({2, 1, 1, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1}); + b->Args({2, 1, 1, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1}); + + // Classic depthwise shapes now supported by the NHWC helper gate. + b->Args({2, 1, 64, 1, 1, 56, 56, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1}); + b->Args({2, 1, 72, 1, 1, 48, 80, 3, 3, 1, 1, 1, 1, 2, 2, 1, 1}); +} + +BENCHMARK_CAPTURE(SCONV_NCHW, KleidiAiNhwcComparison_NchwBaseline, "")->Apply(KleidiAiNhwcComparison)->UseRealTime(); +BENCHMARK_CAPTURE(SCONV_NHWC_KLEIDIAI, KleidiAiNhwcComparison_NhwcFastPath, "")->Apply(KleidiAiNhwcComparison)->UseRealTime(); + static void General_Conv2d(benchmark::internal::Benchmark* b) { b->ArgNames(ArgNamesForConv(2)); b->ArgsProduct( diff --git a/onnxruntime/test/optimizer/nhwc_transformer_test.cc b/onnxruntime/test/optimizer/nhwc_transformer_test.cc index b73929efab8a6..ff5a93360a8a9 100644 --- a/onnxruntime/test/optimizer/nhwc_transformer_test.cc +++ b/onnxruntime/test/optimizer/nhwc_transformer_test.cc @@ -81,9 +81,15 @@ static bool HasFloatNhwcNoTransposeSupport(const std::vector& input_sha return false; } - if (has_sum_input || group != 1 || input_shape.size() != 4 || weight_shape.size() != 4) { + if (has_sum_input || group <= 0 || input_shape.size() != 4 || weight_shape.size() != 4) { return false; } + const auto group_count = narrow(group); + + if (weight_shape[0] <= 0 || weight_shape[0] % group != 0) { + return false; + } + const auto filter_count = narrow(weight_shape[0] / group); std::array input_spatial_shape{ narrow(input_shape[2]), @@ -169,13 +175,13 @@ static bool HasFloatNhwcNoTransposeSupport(const std::vector& input_sha return MlasConvSupportsSymmetricChannelsLast2DFloatKernel( /*Dimensions*/ 2, narrow(input_shape[0]), - /*GroupCount*/ 1, + group_count, input_spatial_shape.data(), kernel_spatial_shape.data(), dilations_size_t.data(), pads_size_t.data(), strides_size_t.data(), - narrow(weight_shape[0]), + filter_count, /*Beta*/ 0.0f); #else ORT_UNUSED_PARAMETER(input_shape); @@ -407,7 +413,7 @@ TEST(NhwcTransformerTests, ConvGlobalAveragePool) { TransformerLevel::Level3); } -TEST(NhwcTransformerTests, ConvDepthwiseFloat_SkipNhwc) { +TEST(NhwcTransformerTests, ConvDepthwiseFloat_UsesHelperCapability) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* input_arg = builder.MakeInput({1, 8, 7, 7}, -1.0f, 1.0f); auto* weight_arg = builder.MakeInitializer({8, 1, 3, 3}, -1.0f, 1.0f); @@ -435,6 +441,90 @@ TEST(NhwcTransformerTests, ConvDepthwiseFloat_SkipNhwc) { /*relative_per_sample_tolerance*/ 1e-6); } +TEST(NhwcTransformerTests, ConvDepthwiseFloat_AsymmetricPaddingSkipsNhwc) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({1, 8, 7, 7}, -1.0f, 1.0f); + auto* weight_arg = builder.MakeInitializer({8, 1, 3, 3}, -1.0f, 1.0f); + auto* output_arg = builder.MakeOutput(); + + Node& conv_node = builder.AddConvNode(input_arg, weight_arg, output_arg); + conv_node.AddAttribute("group", static_cast(8)); + conv_node.AddAttribute("pads", std::vector{0, 1, 1, 1}); + }; + + auto check_nhwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const bool expect_nhwc = HasFloatNhwcNoTransposeSupport({1, 8, 7, 7}, {8, 1, 3, 3}, {0, 1, 1, 1}, {}, {}, 8); + EXPECT_FALSE(expect_nhwc); + EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], 0); + EXPECT_EQ(op_to_count["Transpose"], 0); + }; + + TransformerTester(build_test_case, + check_nhwc_graph, + TransformerLevel::Level2, + TransformerLevel::Level3, + /*opset_version*/ 12, + /*per_sample_tolerance*/ 1e-6, + /*relative_per_sample_tolerance*/ 1e-6); +} + +TEST(NhwcTransformerTests, ConvGroupedFloat_SkipNhwc) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({1, 8, 7, 7}, -1.0f, 1.0f); + auto* weight_arg = builder.MakeInitializer({8, 2, 3, 3}, -1.0f, 1.0f); + auto* output_arg = builder.MakeOutput(); + + Node& conv_node = builder.AddConvNode(input_arg, weight_arg, output_arg); + conv_node.AddAttribute("group", static_cast(4)); + conv_node.AddAttribute("pads", std::vector{1, 1, 1, 1}); + }; + + auto check_nhwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const bool expect_nhwc = HasFloatNhwcNoTransposeSupport({1, 8, 7, 7}, {8, 2, 3, 3}, {1, 1, 1, 1}, {}, {}, 4); + EXPECT_FALSE(expect_nhwc); + EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], 0); + EXPECT_EQ(op_to_count["Transpose"], 0); + }; + + TransformerTester(build_test_case, + check_nhwc_graph, + TransformerLevel::Level2, + TransformerLevel::Level3, + /*opset_version*/ 12, + /*per_sample_tolerance*/ 1e-6, + /*relative_per_sample_tolerance*/ 1e-6); +} + +TEST(NhwcTransformerTests, ConvDepthwiseMultiplier2Float_SkipNhwc) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({1, 8, 7, 7}, -1.0f, 1.0f); + auto* weight_arg = builder.MakeInitializer({16, 1, 3, 3}, -1.0f, 1.0f); + auto* output_arg = builder.MakeOutput(); + + Node& conv_node = builder.AddConvNode(input_arg, weight_arg, output_arg); + conv_node.AddAttribute("group", static_cast(8)); + conv_node.AddAttribute("pads", std::vector{1, 1, 1, 1}); + }; + + auto check_nhwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const bool expect_nhwc = HasFloatNhwcNoTransposeSupport({1, 8, 7, 7}, {16, 1, 3, 3}, {1, 1, 1, 1}, {}, {}, 8); + EXPECT_FALSE(expect_nhwc); + EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], 0); + EXPECT_EQ(op_to_count["Transpose"], 0); + }; + + TransformerTester(build_test_case, + check_nhwc_graph, + TransformerLevel::Level2, + TransformerLevel::Level3, + /*opset_version*/ 12, + /*per_sample_tolerance*/ 1e-6, + /*relative_per_sample_tolerance*/ 1e-6); +} + TEST(NhwcTransformerTests, ConvFloat_UsesNhwcOnlyWithKleidi) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* input_arg = builder.MakeInput({1, 8, 7, 7}, -1.0f, 1.0f);