Skip to content
9 changes: 7 additions & 2 deletions onnxruntime/core/mlas/lib/convolve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,7 @@ MlasConvSupportsSymmetricChannelsLast2DFloatKernel(
return false;
}

if (Dimensions != 2 || BatchCount != 1 || GroupCount != 1 || Beta != 0.0f) {
if (Dimensions != 2 || BatchCount != 1 || Beta != 0.0f) {
return false;
}

Expand All @@ -1395,7 +1395,12 @@ MlasConvSupportsSymmetricChannelsLast2DFloatKernel(
return false;
}

if (FilterCount <= 1 || KernelShape[0] < 3 || KernelShape[1] < 3) {
const bool is_depthwise = GroupCount > 1 && FilterCount == 1;
if (GroupCount > 1 && !is_depthwise) {
return false;
}

if (!is_depthwise && (FilterCount <= 1 || KernelShape[0] < 3 || KernelShape[1] < 3)) {
return false;
}

Expand Down
44 changes: 41 additions & 3 deletions onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,15 +550,40 @@ static void ConvolveSme(const size_t co, //channels out
dim[1] = MlasDivRoundup(m, m_step);
dim[2] = MlasDivRoundup(co, n_step);

const bool grouped_channels_last = input_is_channels_last && groups > 1;
const size_t input_channels_total = ci * groups;
const size_t output_channels_total = co * groups;
const float* input_base = in;
Comment on lines +553 to +556
float* output_base = out;
std::vector<float> input_group_buffer;
if (grouped_channels_last) {
input_group_buffer.resize(ih * iw * ci);
}

for (size_t g = 0; g < groups; ++g) {
const float* input_group = in;
if (grouped_channels_last) {
for (size_t pixel = 0; pixel < ih * iw; ++pixel) {
const float* src = input_base + pixel * input_channels_total + g * ci;
std::copy_n(src, ci, input_group_buffer.data() + pixel * ci);
}
Comment on lines 563 to +569
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fairly minor concern but it will affect performance so I've moved it out of the loop and only size it once.

input_group = input_group_buffer.data();
}

auto result = out;
const bool need_transpose = (!input_is_channels_last) && (co > 1);
const bool use_temp_output = grouped_channels_last || need_transpose;
if (need_transpose) {
result = tmp_mlas_aligned;
}
if (grouped_channels_last) {
result = tmp_mlas_aligned;
}

auto lhs = LhsPackImageDataSme(ci, ih, iw, d_kh, d_kw, sh, sw, padding, in, input_is_channels_last, ThreadPool);
auto lhs = LhsPackImageDataSme(ci, ih, iw, d_kh, d_kw, sh, sw, padding,
input_group,
input_is_channels_last,
ThreadPool);
const std::byte* rhs_data = packed_rhs ? packed_rhs + g * packed_rhs_group_stride : nullptr;
std::unique_ptr<std::byte[]> rhs_storage;
if (rhs_data == nullptr) {
Expand Down Expand Up @@ -613,13 +638,26 @@ static void ConvolveSme(const size_t co, //channels out
);
});

if (grouped_channels_last) {
for (size_t pixel = 0; pixel < m; ++pixel) {
float* dst = output_base + pixel * output_channels_total + g * co;
const float* src = result + pixel * co;
std::copy_n(src, co, dst);
}
}

if (need_transpose) {
//Note: this could be absorbed into post conv activation
MlasTranspose(tmp_mlas_aligned, out, m, co, ThreadPool);
}

in += ci * ih * iw;
out += m * co;
if (!grouped_channels_last) {
in += ci * ih * iw;
out += use_temp_output ? 0 : m * co;
if (need_transpose) {
out += m * co;
}
}
weights += co * ci * kh * kw;
if(bias){
bias += co;
Expand Down
19 changes: 15 additions & 4 deletions onnxruntime/core/optimizer/nhwc_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <array>
#include <cstdint>
#include <deque>
#include <limits>
#include <vector>
#include "core/common/cpuid_info.h"
#include "core/graph/constants.h"
Expand Down Expand Up @@ -191,27 +192,37 @@ bool FloatNhwcWrapperFilter(const onnx_transpose_optimization::api::GraphRef& gr
}

const auto group = node.GetAttributeInt("group").value_or(1);
if (group != 1) {
if (group <= 0) {
return false;
}
constexpr uint64_t kSizeTMax = static_cast<uint64_t>(std::numeric_limits<size_t>::max());
if (static_cast<uint64_t>(group) > kSizeTMax) {
return false;
}
const auto group_count = narrow<size_t>(group);

Comment on lines 194 to 203
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a valid concern, added a limit check to be safe

std::array<size_t, 2> input_spatial_shape{};
std::array<size_t, 2> kernel_spatial_shape{};
std::array<size_t, 2> dilations{1, 1};
std::array<size_t, 2> strides{1, 1};
std::array<size_t, 4> pads{};
size_t batch_count = 0;
size_t filter_count = 0;
size_t total_filter_count = 0;

if (!TryGetDimValueAsSizeT(*input_shape, 0, batch_count) ||
!TryGetDimValueAsSizeT(*input_shape, 2, input_spatial_shape[0]) ||
!TryGetDimValueAsSizeT(*input_shape, 3, input_spatial_shape[1]) ||
!TryGetDimValueAsSizeT(*weight_shape, 0, filter_count) ||
!TryGetDimValueAsSizeT(*weight_shape, 0, total_filter_count) ||
!TryGetDimValueAsSizeT(*weight_shape, 2, kernel_spatial_shape[0]) ||
!TryGetDimValueAsSizeT(*weight_shape, 3, kernel_spatial_shape[1])) {
return false;
}

if (total_filter_count == 0 || total_filter_count % group_count != 0) {
return false;
}
const size_t filter_count = total_filter_count / group_count;

const auto dilations_opt = node.GetAttributeInts("dilations");
if (dilations_opt.has_value() && !TryReadPositiveInts(*dilations_opt, dilations)) {
return false;
Expand All @@ -229,7 +240,7 @@ bool FloatNhwcWrapperFilter(const onnx_transpose_optimization::api::GraphRef& gr
return MlasConvSupportsSymmetricChannelsLast2DFloatKernel(
/*Dimensions*/ 2,
batch_count,
/*GroupCount*/ 1,
group_count,
input_spatial_shape.data(),
kernel_spatial_shape.data(),
dilations.data(),
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/cpu/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include "core/common/safeint.h"
#include "core/util/math_cpuonly.h"

#if defined(USE_KLEIDIAI) && defined(__aarch64__) && defined(__linux__)
#if defined(USE_KLEIDIAI) && defined(MLAS_TARGET_ARM64)
#include "core/mlas/lib/kleidiai/mlasi_kleidiai.h"
#endif

Expand Down Expand Up @@ -191,7 +191,7 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
return Status::OK();
}

#if defined(USE_KLEIDIAI) && defined(__aarch64__) && defined(__linux__)
#if defined(USE_KLEIDIAI) && defined(MLAS_TARGET_ARM64)
Status Conv<float>::EnsurePackedChannelsLastFilter(concurrency::ThreadPool* thread_pool,
size_t filter_count_per_group,
size_t input_channels_per_group,
Expand Down Expand Up @@ -329,7 +329,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
narrow<size_t>(M / conv_attrs_.group),
/*Beta*/ 0.0f);

#if defined(USE_KLEIDIAI) && defined(__aarch64__) && defined(__linux__)
#if defined(USE_KLEIDIAI) && defined(MLAS_TARGET_ARM64)
if (nhwc_fastpath && can_cache_packed_filter_) {
ORT_RETURN_IF_ERROR(EnsurePackedChannelsLastFilter(thread_pool,
narrow<size_t>(M / conv_attrs_.group),
Expand Down Expand Up @@ -385,7 +385,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
nhwc_fastpath ? 0.0f : Beta,
thread_pool);

#if defined(USE_KLEIDIAI) && defined(__aarch64__) && defined(__linux__)
#if defined(USE_KLEIDIAI) && defined(MLAS_TARGET_ARM64)
if (nhwc_fastpath && packed_filter_ != nullptr) {
Parameters.FilterIsPacked = true;
Parameters.PackedFilter = packed_filter_.get();
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/cpu/nn/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Conv<float> : public OpKernel {
activation_.ActivationKind = MlasIdentityActivation;
SetupMlasBackendKernelSelectorFromConfigOptions(mlas_backend_kernel_selector_config_, info.GetConfigOptions());

#if defined(USE_KLEIDIAI) && defined(__aarch64__) && defined(__linux__)
#if defined(USE_KLEIDIAI) && defined(MLAS_TARGET_ARM64)
if (channels_last_) {
const auto& input_defs = info.node().InputDefs();
const bool has_bias_input = input_defs.size() >= 3 && input_defs[2] != nullptr;
Expand All @@ -56,7 +56,7 @@ class Conv<float> : public OpKernel {
ConvAttributes conv_attrs_;
bool channels_last_{false};

#if defined(USE_KLEIDIAI) && defined(__aarch64__) && defined(__linux__)
#if defined(USE_KLEIDIAI) && defined(MLAS_TARGET_ARM64)
private:
Status EnsurePackedChannelsLastFilter(concurrency::ThreadPool* thread_pool,
size_t filter_count_per_group,
Expand Down
129 changes: 129 additions & 0 deletions onnxruntime/test/contrib_ops/fused_conv_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,19 @@

#include "gtest/gtest.h"

#include <array>

#include "core/common/narrow.h"
#include "core/framework/kernel_registry.h"
#include "test/common/cuda_op_test_utils.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/providers/provider_test_utils.h"
#include "test/util/include/default_providers.h"

#if defined(USE_KLEIDIAI) && defined(MLAS_TARGET_ARM64)
#include "core/mlas/lib/mlasi.h"
#endif

namespace onnxruntime {
namespace test {

Expand Down Expand Up @@ -121,6 +129,88 @@ void RunConvOp(const ConvOpAndTestAttributes& attributes,
}

#ifdef USE_KLEIDIAI
namespace {

#if defined(MLAS_TARGET_ARM64)
bool HasFloatNhwcFusedConvKernel() {
auto cpu_ep = DefaultCpuExecutionProvider();
if (cpu_ep == nullptr) {
return false;
}

auto kernel_registry = cpu_ep->GetKernelRegistry();
if (!kernel_registry) {
return false;
}

KernelRegistry::TypeConstraintMap type_constraints{
{"T", DataTypeImpl::GetTensorType<float>()},
};

const KernelCreateInfo* kernel_create_info{};
const auto status = kernel_registry->TryFindKernel(
kCpuExecutionProvider,
"NhwcFusedConv",
kMSDomain,
1,
type_constraints,
DefaultLoggingManager().DefaultLogger(),
&kernel_create_info);

return status.IsOK() && kernel_create_info != nullptr;
}

bool HasFloatNhwcNoTransposeSupport(const vector<int64_t>& input_shape,
const vector<int64_t>& weight_shape,
const vector<int64_t>& pads,
const vector<int64_t>& strides,
int64_t group) {
if (!HasFloatNhwcFusedConvKernel() || !MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) {
return false;
}

if (group <= 0 || input_shape.size() != 4 || weight_shape.size() != 4 ||
pads.size() != 4 || strides.size() != 2 ||
weight_shape[0] <= 0 || weight_shape[0] % group != 0) {
return false;
}
Comment on lines +172 to +176

std::array<size_t, 2> input_spatial_shape{
narrow<size_t>(input_shape[1]),
narrow<size_t>(input_shape[2]),
};
std::array<size_t, 2> kernel_spatial_shape{
narrow<size_t>(weight_shape[2]),
narrow<size_t>(weight_shape[3]),
};
std::array<size_t, 2> dilations{1, 1};
std::array<size_t, 2> strides_size_t{
narrow<size_t>(strides[0]),
narrow<size_t>(strides[1]),
};
std::array<size_t, 4> pads_size_t{
narrow<size_t>(pads[0]),
narrow<size_t>(pads[1]),
narrow<size_t>(pads[2]),
narrow<size_t>(pads[3]),
};

return MlasConvSupportsSymmetricChannelsLast2DFloatKernel(
/*Dimensions*/ 2,
narrow<size_t>(input_shape[0]),
narrow<size_t>(group),
input_spatial_shape.data(),
kernel_spatial_shape.data(),
dilations.data(),
pads_size_t.data(),
strides_size_t.data(),
narrow<size_t>(weight_shape[0] / group),
/*Beta*/ 0.0f);
}
#endif

} // namespace

void TestNhwcFusedConvFloatOp(const ConvOpAndTestAttributes& attributes,
const vector<vector<float>>& inputs,
const vector<vector<int64_t>>& input_shapes,
Expand Down Expand Up @@ -372,6 +462,45 @@ TEST(FusedConvTest, Cpu_NhwcConv2D_AutoPadSameUpper) {
TestNhwcFusedConvFloatOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
TestNhwcFusedConvFloatOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true);
}

TEST(FusedConvTest, Cpu_NhwcDepthwiseConv2D_SymmetricPadding) {
#if !defined(MLAS_TARGET_ARM64)
GTEST_SKIP() << "Float NHWC depthwise fast-path requires Arm64.";
#else
ConvOpAndTestAttributes attrs = {
"", // auto_pad
vector<int64_t>{1, 1}, // dilations
2, // group
vector<int64_t>{3, 3}, // kernel_shape
vector<int64_t>{1, 1, 1, 1}, // pads
vector<int64_t>{1, 1}, // strides
"Relu" // activation
};

vector<int64_t> X_shape = {1, 3, 3, 2};
vector<float> X = {1.0f, 10.0f, 2.0f, 20.0f, 3.0f, 30.0f,
4.0f, 40.0f, 5.0f, 50.0f, 6.0f, 60.0f,
7.0f, 70.0f, 8.0f, 80.0f, 9.0f, 90.0f};
vector<int64_t> W_shape = {2, 1, 3, 3};
vector<float> W = {1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f,
1.0f, 0.0f, 1.0f,
0.0f, 2.0f, 0.0f,
1.0f, 0.0f, 1.0f};
vector<int64_t> Y_shape = {1, 3, 3, 2};
auto expected_vals = {12.0f, 70.0f, 21.0f, 140.0f, 16.0f, 110.0f,
27.0f, 180.0f, 45.0f, 300.0f, 33.0f, 220.0f,
24.0f, 190.0f, 39.0f, 260.0f, 28.0f, 230.0f};

if (!HasFloatNhwcNoTransposeSupport(X_shape, W_shape, attrs.pads, attrs.strides, attrs.group)) {
GTEST_SKIP() << "Float NHWC depthwise fast-path is not available on this configuration.";
}

TestNhwcFusedConvFloatOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
TestNhwcFusedConvFloatOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true);
#endif
}
#endif

TEST(FusedConvTest, Cpu_Conv3D_Batched_Relu) {
Expand Down
Loading
Loading