[MLAS] Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback#28416
Open
melkap01-Arm wants to merge 22 commits into
Open
[MLAS] Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback#28416melkap01-Arm wants to merge 22 commits into
melkap01-Arm wants to merge 22 commits into
Conversation
Adds DynamicQuantMatMulFp8 schema under the Microsoft contrib opset.
Registers the CPU contrib kernel when FP8 types are enabled.
Adds dynamic_quant_matmul_fp8.{h,cc} CPU kernel implementation.
Adds MLAS FP8 GEMM API surface and scalar fallback implementation in qgemm_fp8.cpp.
Wires the MLAS FP8 source into the MLAS build.
Adds provider tests for the FP8 op-kernel path.
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
…tion enforced Signed-off-by: melkap01 <melike.kaptan@arm.com>
Contributor
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds a CPU contrib implementation for com.microsoft::DynamicQuantMatMulFp8 backed by a new MLAS FP8 GEMM API, including a scalar fallback implementation and provider/MLAS unit tests.
Changes:
- Introduces
DynamicQuantMatMulFp8schema, CPU kernel registration, and a CPU opkernel implementation with prepack support for constant non-FP8 B. - Adds MLAS FP8 GEMM public API (
MlasFp8GemmBatch) and scalar fallback implementation, plus a sharedsize_toverflow helper. - Adds provider tests for the new contrib op and MLAS unit tests for the FP8 GEMM path.
Reviewed changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/test/mlas/unittest/test_qgemm_fp8.cpp | Adds MLAS unit tests for the FP8 GEMM batch API (threaded + edge cases). |
| onnxruntime/test/contrib_ops/dynamic_quant_matmul_fp8_test.cc | Adds CPU provider tests covering op contract, prepack, scale/zero-point validation, and edge cases. |
| onnxruntime/core/mlas/lib/qgemm_fp8.cpp | Implements scalar fallback for MlasFp8GemmBatch with validation and parallelism. |
| onnxruntime/core/mlas/lib/mlasi.h | Adds MlasMultiplyOverflowsSizeT helper used for overflow-safe size computations. |
| onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp | Switches overflow checks to the shared MlasMultiplyOverflowsSizeT helper. |
| onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp | Switches overflow checks to the shared MlasMultiplyOverflowsSizeT helper. |
| onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h | Renames a parameter and removes the local overflow helper in favor of shared MLAS helper. |
| onnxruntime/core/mlas/inc/mlas.h | Adds public MLAS FP8 GEMM structs/API and the FP8 mode enum. |
| onnxruntime/core/graph/contrib_ops/quantization_defs.cc | Adds the DynamicQuantMatMulFp8 contrib operator schema + shape inference. |
| onnxruntime/core/graph/contrib_ops/ms_opset.h | Registers the new contrib schema in the Microsoft opset (gated on float8 support). |
| onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.h | Declares the new CPU contrib kernel with prepack/shared-prepack support. |
| onnxruntime/contrib_ops/cpu/quantization/dynamic_quant_matmul_fp8.cc | Implements the CPU kernel, including PrePack quantization of constant B and MLAS dispatch. |
| onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc | Registers the CPU kernel when float8 types are enabled. |
| cmake/onnxruntime_mlas.cmake | Wires qgemm_fp8.cpp into the MLAS build. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: melkap01 <melike.kaptan@arm.com>
…es implemented Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Add CPU DynamicQuantMatMulFp8 contrib op with MLAS FP8 fallback
This MR adds a CPU contrib implementation for com.microsoft::DynamicQuantMatMulFp8. The path supports dynamic quantization of float/float16/bfloat16 activations to FP8, FP8 or constant pre-quantized B weights, block-wise scales, configurable block sizes, and float/float16/bfloat16 outputs.
Main Changes
Adds DynamicQuantMatMulFp8 schema under the Microsoft contrib opset.
Registers the CPU contrib kernel when FP8 types are enabled.
Adds dynamic_quant_matmul_fp8.{h,cc} CPU opkernel implementation.
Adds MlasFp8GemmBatch and its scalar qgemm_fp8.cpp fallback implementation, which performs the FP8 GEMM compute path used by the DynamicQuantMatMulFp8 CPU kernel.
Wires the MLAS FP8 source into the MLAS build.
Adds provider tests for the FP8 op-kernel path.
Operator Contract
A supports float, float16, and bfloat16.
Runtime B supports FP8 only.
Constant initializer B supports float, float16, bfloat16, or FP8.
Non-FP8 constant B is quantized once during PrePack.
Dynamic non-FP8 B is intentionally rejected.
Output Y supports float, float16, and bfloat16.
Scale tensors support float, float16, and bfloat16.
FP8 formats supported:
FLOAT8E4M3FN
FLOAT8E4M3FNUZ
FLOAT8E5M2
FLOAT8E5M2FNUZ
Quantization Semantics
The implementation enforces symmetric quantization.
All A/B/Y zero-point inputs, when provided, must encode 0.0.
Non-zero zero points are rejected.
Scale values are validated as finite and positive before use.
Y_scale and Y_zero_point are optional schema inputs.
Y_scale, when provided, must be scalar and is applied to the final accumulation.
Y_zero_point, when provided, must be scalar and zero-valued.
Block Layout
Adds block_size_m, block_size_k, and block_size_n attributes, all defaulting to 128.
A scale / zero-point layout is validated against ceil(M / block_size_m) and K / block_size_k.
B scale / zero-point layout is validated against K / block_size_k and N / block_size_n.
Rank-2 A scale / zero-point tensors are allowed as shared tensors across GEMMs.
Batched A scale / zero-point tensors must match the output GEMM batch layout currently supported by the kernel.
Shape inference was tightened to match runtime behavior and avoid accepting unsupported broadcasted A-scale layouts.
Kernel Behavior
Runtime FP8 B is consumed directly.
Constant non-FP8 B is converted to FP8 in PrePack.
Prepacked B metadata restores B shape, FP8 type, and packed buffer size for shared prepack reuse.
FP8 type consistency is validated across A/B and B/B-zero-point.
Runtime B rank is restricted to 2D for the non-prepacked path.
K == 0 produces zero-filled output instead of returning uninitialized data.
M == 0 and N == 0 empty outputs return cleanly.
MLAS FP8 Fallback
Adds MlasFp8GemmBatch / MlasFp8Gemm API.
Implements FP8 decode, scale application, float accumulation, optional output scaling, and output zero-point handling.
Supports all four FP8 modes listed above.
Parallelizes fallback work over BatchN * M.
Adds defensive validation before threaded execution:
valid FP8 mode
non-zero block sizes
required pointers only when actually dereferenced
leading dimensions only when used
strided offset overflow checks
block scale offset overflow checks
public block-count validation against shape-derived block counts
This is a functional scalar fallback, not a hardware-optimized FP8 GEMM backend.
Tests Provider tests cover:
Constant non-FP8 B prepack path.
Runtime FP8 B path.
Omitted optional output quantization inputs.
Optional Y_scale.
Float16 and bfloat16 outputs.
Bfloat16 scale tensors.
Symmetric zero-point rejection for A/B/Y.
FP8 B / B-zero-point type mismatch rejection.
Non-default block sizes and partial M blocks.
Shared prepacked B metadata restore.
Shared prepack semantic correctness with different B scales.
Rejection of unsupported dynamic non-FP8 B.
Batched A-scale layout rejection.
Malformed A/B scale shape validation before scale reads.
M == 0, N == 0, and K == 0 edge cases.
Invalid Y_scale shape, value, and type on the K == 0 path.
Known Limitations
No dynamic non-FP8 B support by design.
No packed-B optimized FP8 backend is exposed yet.
No KleidiAI FP8 dispatch is included in this path.
MLAS FP8 GEMM is currently correctness-oriented scalar fallback code, not a production performance kernel.
Full MatMul broadcast semantics for batched A scale tensors are not implemented; schema/runtime validation is tightened to the currently supported layout.
Verification
Built onnxruntime_provider_test.
Built onnxruntime_mlas_test.
result:
Passed
Ran on Qwen3 model (converted to .onnx version)
result:
All DynamicQuantMatMulFp8 tests passed.
Motivation and Context