From a45d09382d989829403ca5473eb40d21e2ff4472 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Fri, 15 May 2026 16:30:20 -0700 Subject: [PATCH 1/5] unblock the GatherBlockQuantized for 2 bits support --- .../quantization/gather_block_quantized.cc | 32 ++++++++++--- .../core/graph/contrib_ops/contrib_defs.cc | 9 ++-- .../gather_block_quantized_op_test.cc | 48 +++++++++++++++++++ 3 files changed, 79 insertions(+), 10 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc index 3a20a41696728..97ca8ba882826 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc @@ -30,6 +30,14 @@ int32_t Get4BitElement(const uint8_t* data_ptr, int64_t data_idx) { return data_val; } +// Extracts a 2-bit element from uint8_t storage. Four elements are packed per byte, +// with element index 0 in the lowest 2 bits and index 3 in the highest 2 bits. +int32_t Get2BitElementUint8(const uint8_t* data_ptr, int64_t data_idx) { + const uint8_t data_val_u8 = data_ptr[data_idx >> 2]; + const int shift = static_cast((data_idx & 3) * 2); + return static_cast((data_val_u8 >> shift) & 0x03); +} + } // namespace template @@ -53,7 +61,12 @@ class GatherBlockQuantized : public OpKernel { constexpr int64_t default_bits = 4; info.GetAttrOrDefault("bits", &bits_, default_bits); - ORT_ENFORCE(bits_ == 4 || bits_ == 8, "GatherBlockQuantized only support bits==4 or 8"); + if constexpr (std::is_same_v) { + ORT_ENFORCE(bits_ == 2 || bits_ == 4 || bits_ == 8, + "GatherBlockQuantized with uint8 data only supports bits==2, 4, or 8"); + } else { + ORT_ENFORCE(bits_ == 4, "GatherBlockQuantized with int4/uint4 data only supports bits==4"); + } ORT_ENFORCE(block_size_ >= 16 && ((block_size_ - 1) & block_size_) == 0, "'block_size' must be 2's power and not less than 16."); @@ -221,10 +234,12 @@ Status GatherBlockQuantized::CopyDataAndDequantize(const T1* data_ptr, int32_t data_val; if constexpr (!std::is_same_v) { data_val = Get4BitElement(data_ptr, data_idx); - } else { // unit8_t - if (bits_ == 4) { + } else { // uint8_t + if (bits_ == 2) { + data_val = Get2BitElementUint8(data_ptr, data_idx); + } else if (bits_ == 4) { data_val = Get4BitElement(data_ptr, data_idx); - } else { // buts_ == 8 + } else { // bits_ == 8 data_val = static_cast(data_ptr[data_idx]); } } @@ -238,7 +253,11 @@ Status GatherBlockQuantized::CopyDataAndDequantize(const T1* data_ptr, if constexpr (std::is_same_v) { if (zero_points_ptr) { - if (bits_ == 4) { + if (bits_ == 2) { + uint8_t packed = zero_points_ptr[scale_idx >> 2]; + const int shift = static_cast((scale_idx & 3) * 2); + zp_val = static_cast((packed >> shift) & 0x03); + } else if (bits_ == 4) { uint8_t packed = zero_points_ptr[scale_idx >> 1]; if (scale_idx & 1) { zp_val = static_cast((packed >> 4) & 0x0F); @@ -249,7 +268,8 @@ Status GatherBlockQuantized::CopyDataAndDequantize(const T1* data_ptr, zp_val = static_cast(zero_points_ptr[scale_idx]); } } else { - const int32_t default_zero_point = bits_ == 4 ? 8 : 128; + // Default zero point is 2^(bits-1): 2 for 2-bit, 8 for 4-bit, 128 for 8-bit. + const int32_t default_zero_point = 1 << (static_cast(bits_) - 1); zp_val = default_zero_point; } } else { diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index a5537c7d58b05..b9773e9aa3f96 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3692,7 +3692,8 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h 3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used to dequantize the output. 4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type. - 5. For uint8 data, the `gather_axis` must be 0. + 5. For uint8 data, the `gather_axis` must be 0. The supported `bits` values for uint8 data are 2, 4, and 8; + for `bits` < 8 the values are packed along the last dimension (low-order bits first). )DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(GatherBlockQuantized) @@ -3712,7 +3713,7 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h AttributeProto::INT, static_cast(128)) .Attr("bits", - "Number of bits used for weight quantization. Must be either 4 or 8. ", + "Number of bits used for weight quantization. Must be 2, 4, or 8 . ", AttributeProto::INT, static_cast(4)) .Input(0, "data", "Tensor of rank r >= 1. Block-wise quantized.", "T1") @@ -3799,9 +3800,9 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h if (!zp_shape.dim(i).has_dim_value() || zp_shape.dim(i).dim_value() != scales_shape.dim(i).dim_value()) { if (ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8 && - bits == 4 && + components > 1 && i == quantize_axis && - zp_shape.dim(i).dim_value() == (scales_shape.dim(i).dim_value() + 1) / 2) { + zp_shape.dim(i).dim_value() == (scales_shape.dim(i).dim_value() + components - 1) / components) { continue; } fail_shape_inference("zero points shape and scales shape do not match"); diff --git a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc index 6fea7a43712c7..48457282aab89 100644 --- a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc +++ b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc @@ -47,6 +47,24 @@ void PackDataForUint8TypeIfNecessary(std::vector& data, std::vector(8); Test_GatherAxis0_NoZeroPoints(8); } + +TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints_2Bits_Uint8) { + // 2-bit signed values in {-2, -1, 0, 1}. The test infra adds an offset of 2 when packing + // and the kernel uses default zero_point = 2^(bits-1) = 2, so the dequantized value matches. + // Block size 16 covers the entire last dim with one scale per row. + std::vector data = {-2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, + 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, + 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1, + -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2}; + std::vector data_shape = {2, 3, 16}; + std::vector indices = {1}; + std::vector indices_shape = {1}; + std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}; + std::vector scales_shape = {2, 3, 1}; + + // indices = [1] -> pick outer index 1, so we expect rows 3, 4, 5 of the unpacked data above + // (the second {-1,-2,1,0,...}, {1,1,...}, {-2,-2,...} block), each scaled by + // scales[3], scales[4], scales[5] = 2.0, 1.0, 2.0. + std::vector output = {-2.f, -4.f, 2.f, 0.f, -2.f, -4.f, 2.f, 0.f, -2.f, -4.f, 2.f, 0.f, -2.f, -4.f, 2.f, 0.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f}; + std::vector output_shape = {1, 3, 16}; + + std::vector zero_points = {}; + RunUnpackedData(data, data_shape, indices, indices_shape, scales, scales_shape, + zero_points, /*gather_axis=*/0, /*quantize_axis=*/2, + /*block_size=*/16, /*bits=*/2, output, output_shape, true); +} #endif template From 1a3af6b001bacde8137e03677eea2499bb283609 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Fri, 15 May 2026 17:40:26 -0700 Subject: [PATCH 2/5] fix the webgpu --- .../quantization/gather_block_quantized.cc | 74 ++++++++++-- .../quantization/gather_block_quantized.h | 2 +- .../core/graph/contrib_ops/contrib_defs.cc | 2 +- .../gather_block_quantized_op_test.cc | 105 ++++++++++++++++++ 4 files changed, 172 insertions(+), 11 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc index 62bfbcbbf9f5a..32b447c2cb128 100755 --- a/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc @@ -21,7 +21,8 @@ Status GatherBlockQuantizedProgram::GenerateShaderCode(ShaderHelper& shader) con const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseValueTypeAlias); - bool is_4bit = bits_ == 4; + const bool is_2bit = bits_ == 2; + const bool is_4bit = bits_ == 4; const std::string unpack = (is_signed_) ? "unpack4xI8" : "unpack4xU8"; shader.MainFunctionBody() @@ -57,7 +58,18 @@ Status GatherBlockQuantizedProgram::GenerateShaderCode(ShaderHelper& shader) con shader.MainFunctionBody() << " let data_offset = " << x_shape.IndicesToOffset("data_indices") << ";\n"; - if (is_4bit) { + if (is_2bit) { + // 2-bit values are packed 4 per byte (LSB first). x is the original uint8 tensor with + // Flatten=4 (4 bytes per u32); the input_shape uniform here is the *dequantized* shape, + // so data_offset is the dequantized 2-bit-element index. + shader.MainFunctionBody() + << " let byte_idx_2b = data_offset / 4;\n" + << " let bit_shift_2b = (data_offset % 4) * 2;\n" + << " let packed_word_2b = " << x.GetByOffset("byte_idx_2b / 4") << ";\n" + << " let byte_in_word_2b = byte_idx_2b % 4;\n" + << " let unpacked_bytes_2b = " << unpack << "(u32(packed_word_2b));\n" + << " var quantized_data = (unpacked_bytes_2b[byte_in_word_2b] >> bit_shift_2b) & 0x3;\n"; + } else if (is_4bit) { shader.MainFunctionBody() << " let data_index = data_offset % 8;\n" << " let packed_4bit_quantized_data = " << x.GetByOffset("data_offset / 8") << ";\n" @@ -83,7 +95,18 @@ Status GatherBlockQuantizedProgram::GenerateShaderCode(ShaderHelper& shader) con << " var scale = " << scales.GetByIndices("scale_indices") << ";\n"; if (!has_zeropoint_) { - const std::string default_zero_point = is_uint8_ ? is_4bit ? "input_element_t(8)" : "input_element_t(128)" : "input_element_t(0)"; + std::string default_zero_point; + if (is_uint8_) { + if (is_2bit) { + default_zero_point = "input_element_t(2)"; + } else if (is_4bit) { + default_zero_point = "input_element_t(8)"; + } else { + default_zero_point = "input_element_t(128)"; + } + } else { + default_zero_point = "input_element_t(0)"; + } shader.MainFunctionBody() << " let zero_point = " << default_zero_point << ";\n"; } else { @@ -91,7 +114,15 @@ Status GatherBlockQuantizedProgram::GenerateShaderCode(ShaderHelper& shader) con shader.MainFunctionBody() << " let zero_point_indices = scale_indices;\n" << " let zero_point_offset = " << scales.IndicesToOffset("zero_point_indices") << ";\n"; - if (is_4bit) { + if (is_2bit) { + shader.MainFunctionBody() + << " let zp_byte_idx_2b = zero_point_offset / 4;\n" + << " let zp_bit_shift_2b = (zero_point_offset % 4) * 2;\n" + << " let packed_zp_word_2b = " << zero_point.GetByOffset("zp_byte_idx_2b / 4") << ";\n" + << " let zp_byte_in_word_2b = zp_byte_idx_2b % 4;\n" + << " let zp_unpacked_2b = " << unpack << "(u32(packed_zp_word_2b));\n" + << " var zero_point = (zp_unpacked_2b[zp_byte_in_word_2b] >> zp_bit_shift_2b) & 0x3;\n"; + } else if (is_4bit) { shader.MainFunctionBody() << " let zero_point_index = zero_point_offset % 8;\n" << " let packed_4bit_zero_points = " << zero_point.GetByOffset("zero_point_offset / 8") << ";\n" @@ -174,7 +205,19 @@ Status GatherBlockQuantized::ComputeInternal(ComputeContext& context) const { zero_points = zero_points_representation_4bit.has_value() ? &zero_points_representation_4bit.value() : zero_points; } - const auto& x_shape = x->Shape(); + const auto& x_shape_intrinsic = x->Shape(); + // For bits == 2 with uint8 storage we don't construct a packed-type reinterpret (no UInt2x4 type + // exists). Instead, build a logical "dequantized" shape (last dim x4) and feed that to the shader + // as the input_shape uniform. The buffer remains the original uint8 storage with Flatten=4, and + // the shader does explicit byte+bit-position extraction. + TensorShape x_shape; + if (bits_ == 2 && is_int8) { + TensorShapeVector v = x_shape_intrinsic.AsShapeVector(); + v.back() *= 4; + x_shape = TensorShape(std::move(v)); + } else { + x_shape = x_shape_intrinsic; + } size_t indices_rank = indices->Shape().NumDimensions(); const auto scales_shape = scales->Shape(); @@ -208,12 +251,25 @@ Status GatherBlockQuantized::ComputeInternal(ComputeContext& context) const { .AddUniformVariables({{static_cast(quantize_axis)}}) .AddUniformVariables({{static_cast(gather_axis)}}) .AddUniformVariables({{static_cast(block_size_)}}) - .CacheHint(std::to_string(gather_axis), std::to_string(quantize_axis), std::to_string(block_size_)); + .CacheHint(std::to_string(bits_), std::to_string(gather_axis), std::to_string(quantize_axis), std::to_string(block_size_)); if (zero_points != nullptr) { - ORT_RETURN_IF_NOT(scales_shape == zero_points->Shape(), - "scales and zero_points must have the same shape."); - auto zero_points_shape = zero_points->Shape(); + if (bits_ == 2 && is_int8) { + // 2-bit zero points are packed 4 per byte along the quantize axis. + const auto& zp_shape = zero_points->Shape(); + ORT_RETURN_IF_NOT(zp_shape.NumDimensions() == scales_shape.NumDimensions(), + "scales and zero_points must have the same rank."); + for (size_t i = 0; i < scales_shape.NumDimensions(); ++i) { + int64_t expected = (i == static_cast(quantize_axis)) + ? (scales_shape[i] + 3) / 4 + : scales_shape[i]; + ORT_RETURN_IF_NOT(zp_shape[i] == expected, + "zero_points shape does not match expected packed shape for 2-bit data."); + } + } else { + ORT_RETURN_IF_NOT(scales_shape == zero_points->Shape(), + "scales and zero_points must have the same shape."); + } program.AddInputs({{zero_points, ProgramTensorMetadataDependency::None, ProgramInput::Flatten, (bits_ == 4) ? 8 : 4}}); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h index cd7392995f4cf..0a70307f4e9b2 100755 --- a/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h @@ -52,7 +52,7 @@ class GatherBlockQuantized final : public WebGpuKernel { quantize_axis_ = static_cast(info.GetAttrOrDefault("quantize_axis", 1)); bits_ = static_cast(info.GetAttrOrDefault("bits", 4)); - ORT_ENFORCE(bits_ == 4 || bits_ == 8, "'bits' must be 4 or 8."); + ORT_ENFORCE(bits_ == 2 || bits_ == 4 || bits_ == 8, "'bits' must be 2, 4 or 8."); ORT_ENFORCE(block_size_ >= 16 && ((block_size_ - 1) & block_size_) == 0, "'block_size' must be 2's power and not less than 16."); } diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index b9773e9aa3f96..df195dfea9b66 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3713,7 +3713,7 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h AttributeProto::INT, static_cast(128)) .Attr("bits", - "Number of bits used for weight quantization. Must be 2, 4, or 8 . ", + "Number of bits used for weight quantization. Must be 2, 4 or 8 . ", AttributeProto::INT, static_cast(4)) .Input(0, "data", "Tensor of rank r >= 1. Block-wise quantized.", "T1") diff --git a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc index 48457282aab89..4421245439447 100644 --- a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc +++ b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc @@ -162,6 +162,64 @@ void RunGatherBlockQuantized(const std::vector& data, run_test(true); } +// WebGPU-specific runner for GatherBlockQuantized. Only supports uint8 data with gather_axis == 0. +template +void RunGatherBlockQuantizedWebGpu(const std::vector& data, + const std::vector& data_shape, + const std::vector& indices, + const std::vector& indices_shape, + const std::vector& scales, + const std::vector& scales_shape, + const std::vector& zero_points, + const std::vector& zero_points_shape, + const int64_t gather_axis, + const int64_t quantize_axis, + const int64_t block_size, + const int64_t bits, + const std::vector& output, + const std::vector& output_shape, + OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess) { +#ifdef USE_WEBGPU + if (DefaultWebGpuExecutionProvider().get() == nullptr) { + return; + } + + OpTester test("GatherBlockQuantized", 1, kMSDomain); + test.AddAttribute("gather_axis", gather_axis); + test.AddAttribute("quantize_axis", quantize_axis); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", bits); + + test.AddInput("data", data_shape, data); + test.AddInput("indices", indices_shape, indices); + test.AddInput("scales", scales_shape, scales); + if (!zero_points.empty()) { + test.AddInput("zero_points", zero_points_shape, zero_points); + } + test.AddOutput("output", output_shape, output); + + std::vector> eps; + eps.push_back(DefaultWebGpuExecutionProvider()); + test.Run(expect_result, "", {}, nullptr, &eps); +#else + (void)data; + (void)data_shape; + (void)indices; + (void)indices_shape; + (void)scales; + (void)scales_shape; + (void)zero_points; + (void)zero_points_shape; + (void)gather_axis; + (void)quantize_axis; + (void)block_size; + (void)bits; + (void)output; + (void)output_shape; + (void)expect_result; +#endif +} + template typename std::enable_if< (boost::mp11::mp_contains, T1>::value && std::is_same::value) || @@ -621,6 +679,53 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints_2Bits_Uint8) { } #endif +#ifdef USE_WEBGPU +TEST(GatherBlockQuantizedOpTest, WebGpu_GatherAxis0NoZeroPoints_2Bits_Uint8) { + // Same logical data and expectation as the CPU GatherAxis0NoZeroPoints_2Bits_Uint8 test. + // Logical 2-bit values in {-2, -1, 0, 1}, encoded as v+2 in {0..3} and packed 4 per byte + // (low-order bits first). 16 logical 2-bit values -> 4 bytes per row. + // Pack helper: + auto pack4 = [](int v0, int v1, int v2, int v3) -> uint8_t { + auto enc = [](int v) { return static_cast((v + 2) & 0x3); }; + return static_cast(enc(v0) | (enc(v1) << 2) | (enc(v2) << 4) | (enc(v3) << 6)); + }; + + // Build packed data: shape {2, 3, 4} (16 logical elements per row -> 4 bytes). + std::vector data; + data.reserve(2 * 3 * 4); + auto push_row = [&](std::vector row) { + ORT_ENFORCE(row.size() == 16); + for (size_t i = 0; i < 16; i += 4) { + data.push_back(pack4(row[i], row[i + 1], row[i + 2], row[i + 3])); + } + }; + push_row({-2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1}); + push_row({1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2}); + push_row({0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1}); + push_row({-1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0}); + push_row({1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + push_row({-2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2}); + + std::vector data_shape = {2, 3, 4}; + std::vector indices = {1}; + std::vector indices_shape = {1}; + std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}; + std::vector scales_shape = {2, 3, 1}; + + std::vector output = {-2.f, -4.f, 2.f, 0.f, -2.f, -4.f, 2.f, 0.f, -2.f, -4.f, 2.f, 0.f, -2.f, -4.f, 2.f, 0.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f}; + std::vector output_shape = {1, 3, 16}; + + std::vector zero_points = {}; + std::vector zero_points_shape = {}; + RunGatherBlockQuantizedWebGpu(data, data_shape, indices, indices_shape, scales, scales_shape, + zero_points, zero_points_shape, + /*gather_axis=*/0, /*quantize_axis=*/2, + /*block_size=*/16, /*bits=*/2, output, output_shape); +} +#endif // USE_WEBGPU + template void Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_4Bits() { // This test case specific to shared 4bit token_embedding/lm_head use case on CUDA From 517bdc14adaed422da6e49497dfc6e76b328d0a8 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Fri, 15 May 2026 21:24:03 -0700 Subject: [PATCH 3/5] update doc --- docs/ContribOperators.md | 5 +++-- onnxruntime/core/graph/contrib_ops/contrib_defs.cc | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index ca072113e832d..a13730c8c632d 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2108,7 +2108,8 @@ This version of the operator has been available since version 1 of the 'com.micr 3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used to dequantize the output. 4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type. - 5. For uint8 data, the `gather_axis` must be 0. + 5. For uint8 data, the `gather_axis` must be 0. The supported `bits` values for uint8 data are 2, 4, and 8; + for `bits` < 8 the values are packed along the last dimension (low-order bits first). #### Version @@ -2118,7 +2119,7 @@ This version of the operator has been available since version 1 of the 'com.micr
bits : int
-
Number of bits used for weight quantization. Must be either 4 or 8.
+
Number of bits used for weight quantization. Must be 2, 4 or 8.
block_size : int
(Optional) block size used for weight quantization. It needs to be a power of 2 and not smaller than 16.
gather_axis : int
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index df195dfea9b66..9370f67d8bc78 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3713,7 +3713,7 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h AttributeProto::INT, static_cast(128)) .Attr("bits", - "Number of bits used for weight quantization. Must be 2, 4 or 8 . ", + "Number of bits used for weight quantization. Must be 2, 4 or 8. ", AttributeProto::INT, static_cast(4)) .Input(0, "data", "Tensor of rank r >= 1. Block-wise quantized.", "T1") From 7c03e331bbd985c020b683e0db40521502c9b3cd Mon Sep 17 00:00:00 2001 From: Hector Li Date: Mon, 18 May 2026 11:40:57 -0700 Subject: [PATCH 4/5] address review comments --- .../quantization/gather_block_quantized.cc | 22 ++++++++--- .../quantization/gather_block_quantized.cc | 38 +++++++++++++++++-- .../quantization/gather_block_quantized.h | 5 ++- 3 files changed, 54 insertions(+), 11 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc index 97ca8ba882826..fb72db68de503 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc @@ -253,13 +253,25 @@ Status GatherBlockQuantized::CopyDataAndDequantize(const T1* data_ptr, if constexpr (std::is_same_v) { if (zero_points_ptr) { + // For uint8 we enforce quantize_axis == last dim, which makes quantize_N == 1 + // and scale_full_block == scale_qaxis_dim. Zero points are packed only along + // the quantize axis, so the packed byte must be addressed using the scale row + // index and the within-row quantize-axis index, not the flat scale_idx; the + // latter crosses row boundaries when scale_qaxis_dim is not a multiple of the + // packing factor. + const int64_t scale_qaxis_dim = scale_full_block; + const int64_t scale_row = scale_idx / scale_qaxis_dim; + const int64_t q_in_row = scale_idx % scale_qaxis_dim; if (bits_ == 2) { - uint8_t packed = zero_points_ptr[scale_idx >> 2]; - const int shift = static_cast((scale_idx & 3) * 2); - zp_val = static_cast((packed >> shift) & 0x03); + const int64_t packed_zp_qaxis_dim = (scale_qaxis_dim + 3) / 4; + const int64_t byte_idx = scale_row * packed_zp_qaxis_dim + (q_in_row >> 2); + const int shift = static_cast((q_in_row & 3) * 2); + zp_val = static_cast((zero_points_ptr[byte_idx] >> shift) & 0x03); } else if (bits_ == 4) { - uint8_t packed = zero_points_ptr[scale_idx >> 1]; - if (scale_idx & 1) { + const int64_t packed_zp_qaxis_dim = (scale_qaxis_dim + 1) / 2; + const int64_t byte_idx = scale_row * packed_zp_qaxis_dim + (q_in_row >> 1); + uint8_t packed = zero_points_ptr[byte_idx]; + if (q_in_row & 1) { zp_val = static_cast((packed >> 4) & 0x0F); } else { zp_val = static_cast(packed & 0x0F); diff --git a/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc index 32b447c2cb128..a5883ea7c45cb 100755 --- a/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc @@ -115,11 +115,18 @@ Status GatherBlockQuantizedProgram::GenerateShaderCode(ShaderHelper& shader) con << " let zero_point_indices = scale_indices;\n" << " let zero_point_offset = " << scales.IndicesToOffset("zero_point_indices") << ";\n"; if (is_2bit) { + // 2-bit zero points are packed 4-per-byte along the quantize axis only. The scales + // tensor's flat offset cannot be used directly because dividing it by 4 crosses row + // boundaries when scale_qaxis_dim is not a multiple of 4 (e.g. scales {2,3,1} has + // packed zp shape {2,3,1} with one usable 2-bit value per byte per row). Derive the + // packed byte index from the scale row index plus the within-row quantize-axis index. shader.MainFunctionBody() - << " let zp_byte_idx_2b = zero_point_offset / 4;\n" - << " let zp_bit_shift_2b = (zero_point_offset % 4) * 2;\n" - << " let packed_zp_word_2b = " << zero_point.GetByOffset("zp_byte_idx_2b / 4") << ";\n" - << " let zp_byte_in_word_2b = zp_byte_idx_2b % 4;\n" + << " let q_idx_2b = " << scales.IndicesGet("scale_indices", "uniforms.quantize_axis") << ";\n" + << " let scale_row_2b = zero_point_offset / uniforms.scale_qaxis_dim;\n" + << " let zp_byte_offset_2b = scale_row_2b * uniforms.zp_packed_qaxis_dim + q_idx_2b / 4u;\n" + << " let zp_bit_shift_2b = (q_idx_2b % 4u) * 2u;\n" + << " let packed_zp_word_2b = " << zero_point.GetByOffset("zp_byte_offset_2b / 4") << ";\n" + << " let zp_byte_in_word_2b = zp_byte_offset_2b % 4;\n" << " let zp_unpacked_2b = " << unpack << "(u32(packed_zp_word_2b));\n" << " var zero_point = (zp_unpacked_2b[zp_byte_in_word_2b] >> zp_bit_shift_2b) & 0x3;\n"; } else if (is_4bit) { @@ -173,6 +180,16 @@ Status GatherBlockQuantized::ComputeInternal(ComputeContext& context) const { int64_t x_dtype = x->GetElementType(); bool is_signed = x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4; bool is_int8 = x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + bool is_uint8 = x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + + // Only uint8 storage supports the full bits set {2, 4, 8}. The packed int4/uint4 types + // can only carry bits==4, matching the CPU kernel's constraint. + if (is_uint8) { + ORT_RETURN_IF_NOT(bits_ == 2 || bits_ == 4 || bits_ == 8, + "'bits' must be 2, 4 or 8 for uint8 input."); + } else { + ORT_RETURN_IF_NOT(bits_ == 4, "'bits' must be 4 for non-uint8 input."); + } std::optional data_representation_4bit; std::optional zero_points_representation_4bit; @@ -238,6 +255,17 @@ Status GatherBlockQuantized::ComputeInternal(ComputeContext& context) const { int64_t output_size = output_shape.Size(); auto* output_tensor = context.Output(0, output_shape); + // For the 2-bit zero-point path we need to address the packed byte using the scale row index + // and the within-row quantize-axis index (not the flat scales offset, which crosses row + // boundaries when scale_qaxis_dim isn't a multiple of the packing factor). To keep the shader + // simple we require quantize_axis to be the last dim for uint8 2-bit, matching the CPU kernel. + if (bits_ == 2 && is_uint8) { + ORT_RETURN_IF_NOT(quantize_axis == x_rank - 1, + "For uint8 2-bit data, quantize_axis must be the last dimension."); + } + const uint32_t scale_qaxis_dim = static_cast(scales_shape[quantize_axis]); + const uint32_t zp_packed_qaxis_dim = (scale_qaxis_dim + 3) / 4; + GatherBlockQuantizedProgram program{is_signed, is_int8, indices_rank, gather_axis, bits_, zero_points != nullptr, x_shape, output_shape}; program @@ -251,6 +279,8 @@ Status GatherBlockQuantized::ComputeInternal(ComputeContext& context) const { .AddUniformVariables({{static_cast(quantize_axis)}}) .AddUniformVariables({{static_cast(gather_axis)}}) .AddUniformVariables({{static_cast(block_size_)}}) + .AddUniformVariables({{scale_qaxis_dim}}) + .AddUniformVariables({{zp_packed_qaxis_dim}}) .CacheHint(std::to_string(bits_), std::to_string(gather_axis), std::to_string(quantize_axis), std::to_string(block_size_)); if (zero_points != nullptr) { diff --git a/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h index 0a70307f4e9b2..305146c715c86 100755 --- a/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h @@ -31,7 +31,9 @@ class GatherBlockQuantizedProgram final : public Program= 16 && ((block_size_ - 1) & block_size_) == 0, "'block_size' must be 2's power and not less than 16."); } - Status ComputeInternal(ComputeContext& context) const override; private: From f25707aa94ed3eb0f80c0d8ff3b8b9663d193b37 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Wed, 20 May 2026 09:34:25 -0700 Subject: [PATCH 5/5] address review comments --- .../quantization/gather_block_quantized.cc | 4 +- .../gather_block_quantized_op_test.cc | 75 +++++++++++++++++++ 2 files changed, 77 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc index a5883ea7c45cb..8413fa4dbaed8 100755 --- a/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc @@ -228,7 +228,7 @@ Status GatherBlockQuantized::ComputeInternal(ComputeContext& context) const { // as the input_shape uniform. The buffer remains the original uint8 storage with Flatten=4, and // the shader does explicit byte+bit-position extraction. TensorShape x_shape; - if (bits_ == 2 && is_int8) { + if (bits_ == 2 && is_uint8) { TensorShapeVector v = x_shape_intrinsic.AsShapeVector(); v.back() *= 4; x_shape = TensorShape(std::move(v)); @@ -284,7 +284,7 @@ Status GatherBlockQuantized::ComputeInternal(ComputeContext& context) const { .CacheHint(std::to_string(bits_), std::to_string(gather_axis), std::to_string(quantize_axis), std::to_string(block_size_)); if (zero_points != nullptr) { - if (bits_ == 2 && is_int8) { + if (bits_ == 2 && is_uint8) { // 2-bit zero points are packed 4 per byte along the quantize axis. const auto& zp_shape = zero_points->Shape(); ORT_RETURN_IF_NOT(zp_shape.NumDimensions() == scales_shape.NumDimensions(), diff --git a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc index 4421245439447..b95238d03f508 100644 --- a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc +++ b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc @@ -677,6 +677,40 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints_2Bits_Uint8) { zero_points, /*gather_axis=*/0, /*quantize_axis=*/2, /*block_size=*/16, /*bits=*/2, output, output_shape, true); } + +TEST(GatherBlockQuantizedOpTest, GatherAxis0WithZeroPoints_2Bits_Uint8_PackedZpNotMultipleOf4) { + // Exercises the 2-bit zero-point row-boundary logic: scale_qaxis_dim = 5 is NOT a multiple + // of the 2-bit packing factor (4), so each zero_points row occupies (5+3)/4 = 2 bytes and + // the within-row qaxis index spans both bytes (lanes 0..3 in byte 0, lane 0 in byte 1). + // The kernel must address the packed byte using (scale_row, q_in_row), not a flat scales + // offset which would cross row boundaries. + // + // Logical layout: data {2, 80}, quantize_axis = 1 (last), block_size = 16, bits = 2. + // scales {2, 5}; zero_points logical {2, 5} -> packs to {2, 2}. + std::vector data(2 * 80, 0); // all zeros; dequant simplifies to (0 - zp) * scale + std::vector data_shape = {2, 80}; + std::vector indices = {1}; + std::vector indices_shape = {1}; + // All scales = 1 so dequant value equals -zp directly. + std::vector scales(2 * 5, 1.0f); + std::vector scales_shape = {2, 5}; + // Logical 2-bit zero points in {-2,-1,0,1}; helper packs along last dim (5 -> 2 bytes). + // Row 0: [-2, 1, 0, -1, 1] ; Row 1: [1, 0, -1, -2, 0] + std::vector zero_points = {-2, 1, 0, -1, 1, 1, 0, -1, -2, 0}; + + // indices = [1] -> pick row 1: zp = [1, 0, -1, -2, 0] + // dequant per block = (0 - zp) * 1 = [-1, 0, 1, 2, 0] + std::vector output; + output.reserve(80); + for (float v : {-1.f, 0.f, 1.f, 2.f, 0.f}) { + output.insert(output.end(), 16, v); + } + std::vector output_shape = {1, 80}; + + RunUnpackedData(data, data_shape, indices, indices_shape, scales, scales_shape, + zero_points, /*gather_axis=*/0, /*quantize_axis=*/1, + /*block_size=*/16, /*bits=*/2, output, output_shape, true); +} #endif #ifdef USE_WEBGPU @@ -724,6 +758,47 @@ TEST(GatherBlockQuantizedOpTest, WebGpu_GatherAxis0NoZeroPoints_2Bits_Uint8) { /*gather_axis=*/0, /*quantize_axis=*/2, /*block_size=*/16, /*bits=*/2, output, output_shape); } + +TEST(GatherBlockQuantizedOpTest, WebGpu_GatherAxis0WithZeroPoints_2Bits_Uint8_PackedZpNotMultipleOf4) { + // WebGPU companion to GatherAxis0WithZeroPoints_2Bits_Uint8_PackedZpNotMultipleOf4. + // scale_qaxis_dim = 5 (not a multiple of 4); packed zero_points last dim = (5+3)/4 = 2 bytes. + // The within-row qaxis index spans both bytes, validating the packed-byte addressing path + // (scale_row * zp_packed_qaxis_dim + q_idx/4, shift (q_idx%4)*2) in the WebGPU shader. + auto pack4 = [](int v0, int v1, int v2, int v3) -> uint8_t { + auto enc = [](int v) { return static_cast((v + 2) & 0x3); }; + return static_cast(enc(v0) | (enc(v1) << 2) | (enc(v2) << 4) | (enc(v3) << 6)); + }; + + // Packed data: each logical row = 80 zeros -> 20 bytes of pack4(0,0,0,0) = 0xAA. data_shape {2, 20}. + std::vector data(2 * 20, pack4(0, 0, 0, 0)); + std::vector data_shape = {2, 20}; + + std::vector indices = {1}; + std::vector indices_shape = {1}; + std::vector scales(2 * 5, 1.0f); + std::vector scales_shape = {2, 5}; + + // Packed zero_points: shape {2, 2} (5 logical -> 2 bytes per row). + // Row 0 logical [-2, 1, 0, -1, 1] -> byte0 = pack4(-2, 1, 0, -1) = 0x6C; byte1 lane0=1, rest dont-care. + // Row 1 logical [ 1, 0,-1,-2, 0] -> byte0 = pack4( 1, 0,-1,-2) = 0x1B; byte1 lane0=0, rest dont-care. + std::vector zero_points = { + pack4(-2, 1, 0, -1), pack4(1, -2, -2, -2), + pack4(1, 0, -1, -2), pack4(0, -2, -2, -2)}; + std::vector zero_points_shape = {2, 2}; + + // Picking row 1 via indices=[1]: dequant per block = (0 - zp) * 1 = [-1, 0, 1, 2, 0]. + std::vector output; + output.reserve(80); + for (float v : {-1.f, 0.f, 1.f, 2.f, 0.f}) { + output.insert(output.end(), 16, v); + } + std::vector output_shape = {1, 80}; + + RunGatherBlockQuantizedWebGpu(data, data_shape, indices, indices_shape, scales, scales_shape, + zero_points, zero_points_shape, + /*gather_axis=*/0, /*quantize_axis=*/1, + /*block_size=*/16, /*bits=*/2, output, output_shape); +} #endif // USE_WEBGPU template