[MLAS] Update the NHWC sans transposes path to also support Depthwise convolutions#28565
[MLAS] Update the NHWC sans transposes path to also support Depthwise convolutions#28565orlmon01 wants to merge 8 commits into
Conversation
* Allow for NHWC Depthwise convolutions when groups are values other than 1 * Added verification tests * Changed the fallback / skip tests to now check for asymettric padding, non-depthwise grouped conv, and multiplier > 1 Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
|
@microsoft-github-policy-service agree company="Arm" |
There was a problem hiding this comment.
Pull request overview
Expands the existing MLAS/KleidiAI NHWC “no-transpose” convolution fast path to support true depthwise convolutions (grouped conv where filters-per-group == 1), and wires that capability through the NHWC transformer plus adds test/benchmark coverage.
Changes:
- Relax MLAS NHWC capability gating to allow GroupCount > 1 only for true depthwise (FilterCount-per-group == 1).
- Update NHWC transformer filtering to pass the real group count and compute per-group filter count.
- Extend KleidiAI NHWC execution to handle grouped NHWC tensors via per-group gather/compute/scatter, plus add unit tests and a benchmark comparison.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/mlas/lib/convolve.cpp | Updates NHWC capability gate to allow depthwise grouped convs. |
| onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp | Implements grouped-NHWC handling by gathering/scattering channels per group. |
| onnxruntime/core/optimizer/nhwc_transformer.cc | Passes group count + per-group filter count into the NHWC fast-path capability check. |
| onnxruntime/core/providers/cpu/nn/conv.h | Broadens KleidiAI fast-path compilation guard to MLAS_TARGET_ARM64. |
| onnxruntime/core/providers/cpu/nn/conv.cc | Same guard update for KleidiAI fast-path code. |
| onnxruntime/test/optimizer/nhwc_transformer_test.cc | Adds/updates tests validating depthwise enablement and expected skip cases. |
| onnxruntime/test/contrib_ops/fused_conv_test.cc | Adds an NHWC depthwise FusedConv correctness test (conditionally enabled). |
| onnxruntime/test/mlas/bench/bench_sconv.cpp | Adds benchmark cases comparing NCHW baseline vs NHWC KleidiAI fast path, including depthwise shapes. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| const auto group = node.GetAttributeInt("group").value_or(1); | ||
| if (group != 1) { | ||
| if (group <= 0) { | ||
| return false; | ||
| } | ||
| const auto group_count = narrow<size_t>(group); | ||
|
|
There was a problem hiding this comment.
This is a valid concern, added a limit check to be safe
| for (size_t g = 0; g < groups; ++g) { | ||
| const float* input_group = in; | ||
| std::vector<float> input_group_buffer; | ||
| if (grouped_channels_last) { | ||
| input_group_buffer.resize(ih * iw * ci); | ||
| for (size_t pixel = 0; pixel < ih * iw; ++pixel) { | ||
| const float* src = input_base + pixel * input_channels_total + g * ci; | ||
| std::copy_n(src, ci, input_group_buffer.data() + pixel * ci); | ||
| } |
There was a problem hiding this comment.
This is a fairly minor concern but it will affect performance so I've moved it out of the loop and only size it once.
| if (rank <= 0) throw std::invalid_argument("Kernel rank must greater than 0!"); | ||
| if (batch_size <= 0) throw std::invalid_argument("Batch size must greater than 0!"); | ||
| if (groups <= 0) throw std::invalid_argument("Group count must greater than 0!"); | ||
| if (input_channels_per_group <= 0) throw std::invalid_argument("input_channels_per_group must greater than 0!"); | ||
| if (output_channels_per_group <= 0) throw std::invalid_argument("output_channels_per_group must greater than 0!"); |
Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
…eSme and only size it once * Fixed some grammer in throw statements Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
Description
A path for MLAS to support NHWC Convolutions without the need for transposes was added in PR: #26834
This PR expands those changes to also support Depthwise Convolutions via the same pathway
What changed:
1.
depthwise Conv/FusedConv nodes get rewritten to com.microsoft.NhwcFusedConv.
onnxruntime/test/optimizer/nhwc_transformer_test.cc:416.
Added performance benchmark tests to allow for comparison between the new NHWC path and the old NCHW default.
Sample output: