Add CPU DynamicQuantMatMulFp8 contrib op#28688
Open
melkap01-Arm wants to merge 5 commits into
Open
Conversation
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
Signed-off-by: melkap01 <melike.kaptan@arm.com>
…port_on_contribOps
| * <a href="#com.microsoft.DequantizeWithOrder">com.microsoft.DequantizeWithOrder</a> | ||
| * <a href="#com.microsoft.DynamicQuantMatMulFp8">com.microsoft.DynamicQuantMatMulFp8</a> | ||
| * <a href="#com.microsoft.DynamicQuantizeLSTM">com.microsoft.DynamicQuantizeLSTM</a> | ||
| * <a href="#com.microsoft.DynamicQuantizeMatMul">com.microsoft.DynamicQuantizeMatMul</a> |
Contributor
There was a problem hiding this comment.
There already exists a DynamicQuantizeMatMul contrib op. Could we modify its op schema to support all of these FP8 changes?
|
|
||
| b_type_ = fp8_type_; | ||
| has_b_type_ = true; | ||
| if (K == 0) { |
Contributor
There was a problem hiding this comment.
if (K == 0 || N == 0) {
return Status::OK();
}| size_t block_size_n, | ||
| float fp8_max_abs, | ||
| float* scales) { | ||
| // Reference-style dynamic quantization: derive one positive scale from each source block. |
Contributor
There was a problem hiding this comment.
Can we use concurrency::ThreadPool::TryParallelFor here instead of many nested for loops?
| } | ||
|
|
||
| // Reject invalid scales before quantization divides by them or the GEMM dequantizes with them. | ||
| Status ValidatePositiveFiniteScaleTensor(const Tensor& scale, const char* scale_name) { |
Contributor
There was a problem hiding this comment.
Can we reduce duplication across the if conditions and use a templated method? Something like the following could work.
template <typename T>
Status ValidatePositiveFiniteScaleTensorImpl(const Tensor& scale, const char* scale_name) {
const auto* data = scale.Data<T>();
const size_t count = static_cast<size_t>(scale.Shape().Size());
for (size_t i = 0; i < count; ++i) {
const float value = static_cast<float>(data[i]);
ORT_RETURN_IF(!std::isfinite(value) || value <= 0.0f,
"DynamicQuantMatMulFp8 requires ", scale_name,
" values to be finite and positive.");
}
return Status::OK();
}
Status ValidatePositiveFiniteScaleTensor(const Tensor& scale, const char* scale_name) {
if (scale.IsDataType<float>()) {
return ValidatePositiveFiniteScaleTensorImpl<float>(scale, scale_name);
}
if (scale.IsDataType<MLFloat16>()) {
return ValidatePositiveFiniteScaleTensorImpl<MLFloat16>(scale, scale_name);
}
if (scale.IsDataType<BFloat16>()) {
return ValidatePositiveFiniteScaleTensorImpl<BFloat16>(scale, scale_name);
}
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"DynamicQuantMatMulFp8 requires ", scale_name,
" input to be float, float16, or bfloat16.");
}
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This MR adds a CPU contrib implementation for com.microsoft::DynamicQuantMatMulFp8. It keeps the FP8 GEMM path
internal to the contrib kernel for this PR, without adding a public MLAS FP8 API.
Scope
Operator Contract
A supports float, float16, and bfloat16.
Runtime B supports FP8 only and must be rank-2.
Constant initializer B supports float, float16, bfloat16, or FP8.
Non-FP8 constant B is dynamically quantized once during PrePack.
Dynamic non-FP8 B is intentionally rejected.
Y supports float, float16, and bfloat16.
Optional Y_scale and Y_zero_point are supported.
Supported FP8 formats: FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
The op enforces symmetric quantization; provided zero points must encode 0.0.
A scales are computed dynamically by the kernel.
For non-FP8 constant B, scales are computed during PrePack.
For FP8 runtime/constant B, B_scale is required and validated.
block_size_k and block_size_n default to 128.
B_scale and B_zero_point use [N / block_size_n, K / block_size_k] layout.
Runtime FP8 B is consumed directly.
Constant non-FP8 B is quantized to FP8 in PrePack.
Constant FP8 B preserves its FP8 type metadata.
Shared prepack metadata restores B shape/type, quantized B size, and B scale count.
K == 0 produces zero-filled output.
M == 0 and N == 0 return cleanly after cheap runtime validation.
Tests
Provider tests cover
Ran the fp8 converted Qwen3 ONNX model successfully.
Known Limitations
Comments Regarding the review points submitted on the main MR: #28416
Item 4 - Constant-input validation
Zero-point inputs are still part of the operator contract because they currently also carry the FP8 zero-point
type/encoding. The op only supports symmetric quantization, so any provided zero point must encode 0.0; non-zero
Until a better contract is suggested, such as moving the quantization/FP8 encoding choice into a separate
attribute or argument, we prefer to keep zero points as explicit inputs. That keeps the current model format
flexible while we validate what scheme best matches the models we expect to support.
To reduce per-run cost, constant B_scale and B_zero_point validation is handled during PrePack when those inputs
are initializers. Runtime validation is kept only for dynamic inputs, where values can change between runs.
Item 6 - Temporary allocations
The current implementation keeps the operator contract flexible, so temporary buffers are used only on the paths
that need them. A is dynamically quantized at runtime, so the kernel needs temporary FP8 A data and computed A
scales. Lower-precision outputs use a float scratch buffer because accumulation is done in float and then
converted to the requested output type. B_scale conversion is only needed when model-provided runtime/FP8 B_scale is not already float; for non-FP8 constant B, scales are computed and stored during PrePack.
This does not mean every allocation happens on every execution path. Some allocations are required for the main
dynamic-activation path, while others only happen for specific cases such as non-float scale tensors or lower-
precision outputs. This PR keeps the implementation correctness-focused while preserving the flexible contract, if more constrained contract suggested, such as not supporting runtime provided B_scale/B tensor then those allocations would be reduced as well.
Motivation and Context