diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index ca072113e832d..a13730c8c632d 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2108,7 +2108,8 @@ This version of the operator has been available since version 1 of the 'com.micr
3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used
to dequantize the output.
4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type.
- 5. For uint8 data, the `gather_axis` must be 0.
+ 5. For uint8 data, the `gather_axis` must be 0. The supported `bits` values for uint8 data are 2, 4, and 8;
+ for `bits` < 8 the values are packed along the last dimension (low-order bits first).
#### Version
@@ -2118,7 +2119,7 @@ This version of the operator has been available since version 1 of the 'com.micr
- bits : int
-- Number of bits used for weight quantization. Must be either 4 or 8.
+- Number of bits used for weight quantization. Must be 2, 4 or 8.
- block_size : int
- (Optional) block size used for weight quantization. It needs to be a power of 2 and not smaller than 16.
- gather_axis : int
diff --git a/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc
index 3a20a41696728..fb72db68de503 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc
+++ b/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc
@@ -30,6 +30,14 @@ int32_t Get4BitElement(const uint8_t* data_ptr, int64_t data_idx) {
return data_val;
}
+// Extracts a 2-bit element from uint8_t storage. Four elements are packed per byte,
+// with element index 0 in the lowest 2 bits and index 3 in the highest 2 bits.
+int32_t Get2BitElementUint8(const uint8_t* data_ptr, int64_t data_idx) {
+ const uint8_t data_val_u8 = data_ptr[data_idx >> 2];
+ const int shift = static_cast((data_idx & 3) * 2);
+ return static_cast((data_val_u8 >> shift) & 0x03);
+}
+
} // namespace
template
@@ -53,7 +61,12 @@ class GatherBlockQuantized : public OpKernel {
constexpr int64_t default_bits = 4;
info.GetAttrOrDefault("bits", &bits_, default_bits);
- ORT_ENFORCE(bits_ == 4 || bits_ == 8, "GatherBlockQuantized only support bits==4 or 8");
+ if constexpr (std::is_same_v) {
+ ORT_ENFORCE(bits_ == 2 || bits_ == 4 || bits_ == 8,
+ "GatherBlockQuantized with uint8 data only supports bits==2, 4, or 8");
+ } else {
+ ORT_ENFORCE(bits_ == 4, "GatherBlockQuantized with int4/uint4 data only supports bits==4");
+ }
ORT_ENFORCE(block_size_ >= 16 && ((block_size_ - 1) & block_size_) == 0,
"'block_size' must be 2's power and not less than 16.");
@@ -221,10 +234,12 @@ Status GatherBlockQuantized::CopyDataAndDequantize(const T1* data_ptr,
int32_t data_val;
if constexpr (!std::is_same_v) {
data_val = Get4BitElement(data_ptr, data_idx);
- } else { // unit8_t
- if (bits_ == 4) {
+ } else { // uint8_t
+ if (bits_ == 2) {
+ data_val = Get2BitElementUint8(data_ptr, data_idx);
+ } else if (bits_ == 4) {
data_val = Get4BitElement(data_ptr, data_idx);
- } else { // buts_ == 8
+ } else { // bits_ == 8
data_val = static_cast(data_ptr[data_idx]);
}
}
@@ -238,9 +253,25 @@ Status GatherBlockQuantized::CopyDataAndDequantize(const T1* data_ptr,
if constexpr (std::is_same_v) {
if (zero_points_ptr) {
- if (bits_ == 4) {
- uint8_t packed = zero_points_ptr[scale_idx >> 1];
- if (scale_idx & 1) {
+ // For uint8 we enforce quantize_axis == last dim, which makes quantize_N == 1
+ // and scale_full_block == scale_qaxis_dim. Zero points are packed only along
+ // the quantize axis, so the packed byte must be addressed using the scale row
+ // index and the within-row quantize-axis index, not the flat scale_idx; the
+ // latter crosses row boundaries when scale_qaxis_dim is not a multiple of the
+ // packing factor.
+ const int64_t scale_qaxis_dim = scale_full_block;
+ const int64_t scale_row = scale_idx / scale_qaxis_dim;
+ const int64_t q_in_row = scale_idx % scale_qaxis_dim;
+ if (bits_ == 2) {
+ const int64_t packed_zp_qaxis_dim = (scale_qaxis_dim + 3) / 4;
+ const int64_t byte_idx = scale_row * packed_zp_qaxis_dim + (q_in_row >> 2);
+ const int shift = static_cast((q_in_row & 3) * 2);
+ zp_val = static_cast((zero_points_ptr[byte_idx] >> shift) & 0x03);
+ } else if (bits_ == 4) {
+ const int64_t packed_zp_qaxis_dim = (scale_qaxis_dim + 1) / 2;
+ const int64_t byte_idx = scale_row * packed_zp_qaxis_dim + (q_in_row >> 1);
+ uint8_t packed = zero_points_ptr[byte_idx];
+ if (q_in_row & 1) {
zp_val = static_cast((packed >> 4) & 0x0F);
} else {
zp_val = static_cast(packed & 0x0F);
@@ -249,7 +280,8 @@ Status GatherBlockQuantized::CopyDataAndDequantize(const T1* data_ptr,
zp_val = static_cast(zero_points_ptr[scale_idx]);
}
} else {
- const int32_t default_zero_point = bits_ == 4 ? 8 : 128;
+ // Default zero point is 2^(bits-1): 2 for 2-bit, 8 for 4-bit, 128 for 8-bit.
+ const int32_t default_zero_point = 1 << (static_cast(bits_) - 1);
zp_val = default_zero_point;
}
} else {
diff --git a/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc
index 62bfbcbbf9f5a..8413fa4dbaed8 100755
--- a/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc
+++ b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc
@@ -21,7 +21,8 @@ Status GatherBlockQuantizedProgram::GenerateShaderCode(ShaderHelper& shader) con
const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseValueTypeAlias);
- bool is_4bit = bits_ == 4;
+ const bool is_2bit = bits_ == 2;
+ const bool is_4bit = bits_ == 4;
const std::string unpack = (is_signed_) ? "unpack4xI8" : "unpack4xU8";
shader.MainFunctionBody()
@@ -57,7 +58,18 @@ Status GatherBlockQuantizedProgram::GenerateShaderCode(ShaderHelper& shader) con
shader.MainFunctionBody()
<< " let data_offset = " << x_shape.IndicesToOffset("data_indices") << ";\n";
- if (is_4bit) {
+ if (is_2bit) {
+ // 2-bit values are packed 4 per byte (LSB first). x is the original uint8 tensor with
+ // Flatten=4 (4 bytes per u32); the input_shape uniform here is the *dequantized* shape,
+ // so data_offset is the dequantized 2-bit-element index.
+ shader.MainFunctionBody()
+ << " let byte_idx_2b = data_offset / 4;\n"
+ << " let bit_shift_2b = (data_offset % 4) * 2;\n"
+ << " let packed_word_2b = " << x.GetByOffset("byte_idx_2b / 4") << ";\n"
+ << " let byte_in_word_2b = byte_idx_2b % 4;\n"
+ << " let unpacked_bytes_2b = " << unpack << "(u32(packed_word_2b));\n"
+ << " var quantized_data = (unpacked_bytes_2b[byte_in_word_2b] >> bit_shift_2b) & 0x3;\n";
+ } else if (is_4bit) {
shader.MainFunctionBody()
<< " let data_index = data_offset % 8;\n"
<< " let packed_4bit_quantized_data = " << x.GetByOffset("data_offset / 8") << ";\n"
@@ -83,7 +95,18 @@ Status GatherBlockQuantizedProgram::GenerateShaderCode(ShaderHelper& shader) con
<< " var scale = " << scales.GetByIndices("scale_indices") << ";\n";
if (!has_zeropoint_) {
- const std::string default_zero_point = is_uint8_ ? is_4bit ? "input_element_t(8)" : "input_element_t(128)" : "input_element_t(0)";
+ std::string default_zero_point;
+ if (is_uint8_) {
+ if (is_2bit) {
+ default_zero_point = "input_element_t(2)";
+ } else if (is_4bit) {
+ default_zero_point = "input_element_t(8)";
+ } else {
+ default_zero_point = "input_element_t(128)";
+ }
+ } else {
+ default_zero_point = "input_element_t(0)";
+ }
shader.MainFunctionBody()
<< " let zero_point = " << default_zero_point << ";\n";
} else {
@@ -91,7 +114,22 @@ Status GatherBlockQuantizedProgram::GenerateShaderCode(ShaderHelper& shader) con
shader.MainFunctionBody()
<< " let zero_point_indices = scale_indices;\n"
<< " let zero_point_offset = " << scales.IndicesToOffset("zero_point_indices") << ";\n";
- if (is_4bit) {
+ if (is_2bit) {
+ // 2-bit zero points are packed 4-per-byte along the quantize axis only. The scales
+ // tensor's flat offset cannot be used directly because dividing it by 4 crosses row
+ // boundaries when scale_qaxis_dim is not a multiple of 4 (e.g. scales {2,3,1} has
+ // packed zp shape {2,3,1} with one usable 2-bit value per byte per row). Derive the
+ // packed byte index from the scale row index plus the within-row quantize-axis index.
+ shader.MainFunctionBody()
+ << " let q_idx_2b = " << scales.IndicesGet("scale_indices", "uniforms.quantize_axis") << ";\n"
+ << " let scale_row_2b = zero_point_offset / uniforms.scale_qaxis_dim;\n"
+ << " let zp_byte_offset_2b = scale_row_2b * uniforms.zp_packed_qaxis_dim + q_idx_2b / 4u;\n"
+ << " let zp_bit_shift_2b = (q_idx_2b % 4u) * 2u;\n"
+ << " let packed_zp_word_2b = " << zero_point.GetByOffset("zp_byte_offset_2b / 4") << ";\n"
+ << " let zp_byte_in_word_2b = zp_byte_offset_2b % 4;\n"
+ << " let zp_unpacked_2b = " << unpack << "(u32(packed_zp_word_2b));\n"
+ << " var zero_point = (zp_unpacked_2b[zp_byte_in_word_2b] >> zp_bit_shift_2b) & 0x3;\n";
+ } else if (is_4bit) {
shader.MainFunctionBody()
<< " let zero_point_index = zero_point_offset % 8;\n"
<< " let packed_4bit_zero_points = " << zero_point.GetByOffset("zero_point_offset / 8") << ";\n"
@@ -142,6 +180,16 @@ Status GatherBlockQuantized::ComputeInternal(ComputeContext& context) const {
int64_t x_dtype = x->GetElementType();
bool is_signed = x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4;
bool is_int8 = x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
+ bool is_uint8 = x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
+
+ // Only uint8 storage supports the full bits set {2, 4, 8}. The packed int4/uint4 types
+ // can only carry bits==4, matching the CPU kernel's constraint.
+ if (is_uint8) {
+ ORT_RETURN_IF_NOT(bits_ == 2 || bits_ == 4 || bits_ == 8,
+ "'bits' must be 2, 4 or 8 for uint8 input.");
+ } else {
+ ORT_RETURN_IF_NOT(bits_ == 4, "'bits' must be 4 for non-uint8 input.");
+ }
std::optional data_representation_4bit;
std::optional zero_points_representation_4bit;
@@ -174,7 +222,19 @@ Status GatherBlockQuantized::ComputeInternal(ComputeContext& context) const {
zero_points = zero_points_representation_4bit.has_value() ? &zero_points_representation_4bit.value() : zero_points;
}
- const auto& x_shape = x->Shape();
+ const auto& x_shape_intrinsic = x->Shape();
+ // For bits == 2 with uint8 storage we don't construct a packed-type reinterpret (no UInt2x4 type
+ // exists). Instead, build a logical "dequantized" shape (last dim x4) and feed that to the shader
+ // as the input_shape uniform. The buffer remains the original uint8 storage with Flatten=4, and
+ // the shader does explicit byte+bit-position extraction.
+ TensorShape x_shape;
+ if (bits_ == 2 && is_uint8) {
+ TensorShapeVector v = x_shape_intrinsic.AsShapeVector();
+ v.back() *= 4;
+ x_shape = TensorShape(std::move(v));
+ } else {
+ x_shape = x_shape_intrinsic;
+ }
size_t indices_rank = indices->Shape().NumDimensions();
const auto scales_shape = scales->Shape();
@@ -195,6 +255,17 @@ Status GatherBlockQuantized::ComputeInternal(ComputeContext& context) const {
int64_t output_size = output_shape.Size();
auto* output_tensor = context.Output(0, output_shape);
+ // For the 2-bit zero-point path we need to address the packed byte using the scale row index
+ // and the within-row quantize-axis index (not the flat scales offset, which crosses row
+ // boundaries when scale_qaxis_dim isn't a multiple of the packing factor). To keep the shader
+ // simple we require quantize_axis to be the last dim for uint8 2-bit, matching the CPU kernel.
+ if (bits_ == 2 && is_uint8) {
+ ORT_RETURN_IF_NOT(quantize_axis == x_rank - 1,
+ "For uint8 2-bit data, quantize_axis must be the last dimension.");
+ }
+ const uint32_t scale_qaxis_dim = static_cast(scales_shape[quantize_axis]);
+ const uint32_t zp_packed_qaxis_dim = (scale_qaxis_dim + 3) / 4;
+
GatherBlockQuantizedProgram program{is_signed, is_int8, indices_rank, gather_axis, bits_, zero_points != nullptr, x_shape, output_shape};
program
@@ -208,12 +279,27 @@ Status GatherBlockQuantized::ComputeInternal(ComputeContext& context) const {
.AddUniformVariables({{static_cast(quantize_axis)}})
.AddUniformVariables({{static_cast(gather_axis)}})
.AddUniformVariables({{static_cast(block_size_)}})
- .CacheHint(std::to_string(gather_axis), std::to_string(quantize_axis), std::to_string(block_size_));
+ .AddUniformVariables({{scale_qaxis_dim}})
+ .AddUniformVariables({{zp_packed_qaxis_dim}})
+ .CacheHint(std::to_string(bits_), std::to_string(gather_axis), std::to_string(quantize_axis), std::to_string(block_size_));
if (zero_points != nullptr) {
- ORT_RETURN_IF_NOT(scales_shape == zero_points->Shape(),
- "scales and zero_points must have the same shape.");
- auto zero_points_shape = zero_points->Shape();
+ if (bits_ == 2 && is_uint8) {
+ // 2-bit zero points are packed 4 per byte along the quantize axis.
+ const auto& zp_shape = zero_points->Shape();
+ ORT_RETURN_IF_NOT(zp_shape.NumDimensions() == scales_shape.NumDimensions(),
+ "scales and zero_points must have the same rank.");
+ for (size_t i = 0; i < scales_shape.NumDimensions(); ++i) {
+ int64_t expected = (i == static_cast(quantize_axis))
+ ? (scales_shape[i] + 3) / 4
+ : scales_shape[i];
+ ORT_RETURN_IF_NOT(zp_shape[i] == expected,
+ "zero_points shape does not match expected packed shape for 2-bit data.");
+ }
+ } else {
+ ORT_RETURN_IF_NOT(scales_shape == zero_points->Shape(),
+ "scales and zero_points must have the same shape.");
+ }
program.AddInputs({{zero_points, ProgramTensorMetadataDependency::None, ProgramInput::Flatten, (bits_ == 4) ? 8 : 4}});
}
diff --git a/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h
index cd7392995f4cf..305146c715c86 100755
--- a/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h
+++ b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h
@@ -31,7 +31,9 @@ class GatherBlockQuantizedProgram final : public Program(info.GetAttrOrDefault("quantize_axis", 1));
bits_ = static_cast(info.GetAttrOrDefault("bits", 4));
- ORT_ENFORCE(bits_ == 4 || bits_ == 8, "'bits' must be 4 or 8.");
+ ORT_ENFORCE(bits_ == 2 || bits_ == 4 || bits_ == 8, "'bits' must be 2, 4 or 8.");
ORT_ENFORCE(block_size_ >= 16 && ((block_size_ - 1) & block_size_) == 0,
"'block_size' must be 2's power and not less than 16.");
}
-
Status ComputeInternal(ComputeContext& context) const override;
private:
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index a5537c7d58b05..9370f67d8bc78 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -3692,7 +3692,8 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used
to dequantize the output.
4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type.
- 5. For uint8 data, the `gather_axis` must be 0.
+ 5. For uint8 data, the `gather_axis` must be 0. The supported `bits` values for uint8 data are 2, 4, and 8;
+ for `bits` < 8 the values are packed along the last dimension (low-order bits first).
)DOC";
ONNX_CONTRIB_OPERATOR_SCHEMA(GatherBlockQuantized)
@@ -3712,7 +3713,7 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
AttributeProto::INT,
static_cast(128))
.Attr("bits",
- "Number of bits used for weight quantization. Must be either 4 or 8. ",
+ "Number of bits used for weight quantization. Must be 2, 4 or 8. ",
AttributeProto::INT,
static_cast(4))
.Input(0, "data", "Tensor of rank r >= 1. Block-wise quantized.", "T1")
@@ -3799,9 +3800,9 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
if (!zp_shape.dim(i).has_dim_value() ||
zp_shape.dim(i).dim_value() != scales_shape.dim(i).dim_value()) {
if (ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8 &&
- bits == 4 &&
+ components > 1 &&
i == quantize_axis &&
- zp_shape.dim(i).dim_value() == (scales_shape.dim(i).dim_value() + 1) / 2) {
+ zp_shape.dim(i).dim_value() == (scales_shape.dim(i).dim_value() + components - 1) / components) {
continue;
}
fail_shape_inference("zero points shape and scales shape do not match");
diff --git a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc
index 6fea7a43712c7..b95238d03f508 100644
--- a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc
+++ b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc
@@ -47,6 +47,24 @@ void PackDataForUint8TypeIfNecessary(std::vector& data, std::vector& data,
run_test(true);
}
+// WebGPU-specific runner for GatherBlockQuantized. Only supports uint8 data with gather_axis == 0.
+template
+void RunGatherBlockQuantizedWebGpu(const std::vector& data,
+ const std::vector& data_shape,
+ const std::vector& indices,
+ const std::vector& indices_shape,
+ const std::vector& scales,
+ const std::vector& scales_shape,
+ const std::vector& zero_points,
+ const std::vector& zero_points_shape,
+ const int64_t gather_axis,
+ const int64_t quantize_axis,
+ const int64_t block_size,
+ const int64_t bits,
+ const std::vector& output,
+ const std::vector& output_shape,
+ OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess) {
+#ifdef USE_WEBGPU
+ if (DefaultWebGpuExecutionProvider().get() == nullptr) {
+ return;
+ }
+
+ OpTester test("GatherBlockQuantized", 1, kMSDomain);
+ test.AddAttribute("gather_axis", gather_axis);
+ test.AddAttribute("quantize_axis", quantize_axis);
+ test.AddAttribute("block_size", block_size);
+ test.AddAttribute("bits", bits);
+
+ test.AddInput("data", data_shape, data);
+ test.AddInput("indices", indices_shape, indices);
+ test.AddInput("scales", scales_shape, scales);
+ if (!zero_points.empty()) {
+ test.AddInput("zero_points", zero_points_shape, zero_points);
+ }
+ test.AddOutput("output", output_shape, output);
+
+ std::vector> eps;
+ eps.push_back(DefaultWebGpuExecutionProvider());
+ test.Run(expect_result, "", {}, nullptr, &eps);
+#else
+ (void)data;
+ (void)data_shape;
+ (void)indices;
+ (void)indices_shape;
+ (void)scales;
+ (void)scales_shape;
+ (void)zero_points;
+ (void)zero_points_shape;
+ (void)gather_axis;
+ (void)quantize_axis;
+ (void)block_size;
+ (void)bits;
+ (void)output;
+ (void)output_shape;
+ (void)expect_result;
+#endif
+}
+
template
typename std::enable_if<
(boost::mp11::mp_contains, T1>::value && std::is_same::value) ||
@@ -571,8 +647,160 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints_8Bits) {
Test_GatherAxis0_NoZeroPoints(8);
Test_GatherAxis0_NoZeroPoints(8);
}
+
+TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints_2Bits_Uint8) {
+ // 2-bit signed values in {-2, -1, 0, 1}. The test infra adds an offset of 2 when packing
+ // and the kernel uses default zero_point = 2^(bits-1) = 2, so the dequantized value matches.
+ // Block size 16 covers the entire last dim with one scale per row.
+ std::vector data = {-2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1,
+ 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2,
+ 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1,
+ -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2};
+ std::vector data_shape = {2, 3, 16};
+ std::vector indices = {1};
+ std::vector indices_shape = {1};
+ std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f};
+ std::vector scales_shape = {2, 3, 1};
+
+ // indices = [1] -> pick outer index 1, so we expect rows 3, 4, 5 of the unpacked data above
+ // (the second {-1,-2,1,0,...}, {1,1,...}, {-2,-2,...} block), each scaled by
+ // scales[3], scales[4], scales[5] = 2.0, 1.0, 2.0.
+ std::vector output = {-2.f, -4.f, 2.f, 0.f, -2.f, -4.f, 2.f, 0.f, -2.f, -4.f, 2.f, 0.f, -2.f, -4.f, 2.f, 0.f,
+ 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f,
+ -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f};
+ std::vector output_shape = {1, 3, 16};
+
+ std::vector zero_points = {};
+ RunUnpackedData(data, data_shape, indices, indices_shape, scales, scales_shape,
+ zero_points, /*gather_axis=*/0, /*quantize_axis=*/2,
+ /*block_size=*/16, /*bits=*/2, output, output_shape, true);
+}
+
+TEST(GatherBlockQuantizedOpTest, GatherAxis0WithZeroPoints_2Bits_Uint8_PackedZpNotMultipleOf4) {
+ // Exercises the 2-bit zero-point row-boundary logic: scale_qaxis_dim = 5 is NOT a multiple
+ // of the 2-bit packing factor (4), so each zero_points row occupies (5+3)/4 = 2 bytes and
+ // the within-row qaxis index spans both bytes (lanes 0..3 in byte 0, lane 0 in byte 1).
+ // The kernel must address the packed byte using (scale_row, q_in_row), not a flat scales
+ // offset which would cross row boundaries.
+ //
+ // Logical layout: data {2, 80}, quantize_axis = 1 (last), block_size = 16, bits = 2.
+ // scales {2, 5}; zero_points logical {2, 5} -> packs to {2, 2}.
+ std::vector data(2 * 80, 0); // all zeros; dequant simplifies to (0 - zp) * scale
+ std::vector data_shape = {2, 80};
+ std::vector indices = {1};
+ std::vector indices_shape = {1};
+ // All scales = 1 so dequant value equals -zp directly.
+ std::vector scales(2 * 5, 1.0f);
+ std::vector scales_shape = {2, 5};
+ // Logical 2-bit zero points in {-2,-1,0,1}; helper packs along last dim (5 -> 2 bytes).
+ // Row 0: [-2, 1, 0, -1, 1] ; Row 1: [1, 0, -1, -2, 0]
+ std::vector zero_points = {-2, 1, 0, -1, 1, 1, 0, -1, -2, 0};
+
+ // indices = [1] -> pick row 1: zp = [1, 0, -1, -2, 0]
+ // dequant per block = (0 - zp) * 1 = [-1, 0, 1, 2, 0]
+ std::vector output;
+ output.reserve(80);
+ for (float v : {-1.f, 0.f, 1.f, 2.f, 0.f}) {
+ output.insert(output.end(), 16, v);
+ }
+ std::vector output_shape = {1, 80};
+
+ RunUnpackedData(data, data_shape, indices, indices_shape, scales, scales_shape,
+ zero_points, /*gather_axis=*/0, /*quantize_axis=*/1,
+ /*block_size=*/16, /*bits=*/2, output, output_shape, true);
+}
#endif
+#ifdef USE_WEBGPU
+TEST(GatherBlockQuantizedOpTest, WebGpu_GatherAxis0NoZeroPoints_2Bits_Uint8) {
+ // Same logical data and expectation as the CPU GatherAxis0NoZeroPoints_2Bits_Uint8 test.
+ // Logical 2-bit values in {-2, -1, 0, 1}, encoded as v+2 in {0..3} and packed 4 per byte
+ // (low-order bits first). 16 logical 2-bit values -> 4 bytes per row.
+ // Pack helper:
+ auto pack4 = [](int v0, int v1, int v2, int v3) -> uint8_t {
+ auto enc = [](int v) { return static_cast((v + 2) & 0x3); };
+ return static_cast(enc(v0) | (enc(v1) << 2) | (enc(v2) << 4) | (enc(v3) << 6));
+ };
+
+ // Build packed data: shape {2, 3, 4} (16 logical elements per row -> 4 bytes).
+ std::vector data;
+ data.reserve(2 * 3 * 4);
+ auto push_row = [&](std::vector row) {
+ ORT_ENFORCE(row.size() == 16);
+ for (size_t i = 0; i < 16; i += 4) {
+ data.push_back(pack4(row[i], row[i + 1], row[i + 2], row[i + 3]));
+ }
+ };
+ push_row({-2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1});
+ push_row({1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2});
+ push_row({0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1, 0, 1, -2, -1});
+ push_row({-1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0, -1, -2, 1, 0});
+ push_row({1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
+ push_row({-2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2});
+
+ std::vector data_shape = {2, 3, 4};
+ std::vector indices = {1};
+ std::vector indices_shape = {1};
+ std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f};
+ std::vector scales_shape = {2, 3, 1};
+
+ std::vector output = {-2.f, -4.f, 2.f, 0.f, -2.f, -4.f, 2.f, 0.f, -2.f, -4.f, 2.f, 0.f, -2.f, -4.f, 2.f, 0.f,
+ 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f,
+ -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f, -4.f};
+ std::vector output_shape = {1, 3, 16};
+
+ std::vector zero_points = {};
+ std::vector zero_points_shape = {};
+ RunGatherBlockQuantizedWebGpu(data, data_shape, indices, indices_shape, scales, scales_shape,
+ zero_points, zero_points_shape,
+ /*gather_axis=*/0, /*quantize_axis=*/2,
+ /*block_size=*/16, /*bits=*/2, output, output_shape);
+}
+
+TEST(GatherBlockQuantizedOpTest, WebGpu_GatherAxis0WithZeroPoints_2Bits_Uint8_PackedZpNotMultipleOf4) {
+ // WebGPU companion to GatherAxis0WithZeroPoints_2Bits_Uint8_PackedZpNotMultipleOf4.
+ // scale_qaxis_dim = 5 (not a multiple of 4); packed zero_points last dim = (5+3)/4 = 2 bytes.
+ // The within-row qaxis index spans both bytes, validating the packed-byte addressing path
+ // (scale_row * zp_packed_qaxis_dim + q_idx/4, shift (q_idx%4)*2) in the WebGPU shader.
+ auto pack4 = [](int v0, int v1, int v2, int v3) -> uint8_t {
+ auto enc = [](int v) { return static_cast((v + 2) & 0x3); };
+ return static_cast(enc(v0) | (enc(v1) << 2) | (enc(v2) << 4) | (enc(v3) << 6));
+ };
+
+ // Packed data: each logical row = 80 zeros -> 20 bytes of pack4(0,0,0,0) = 0xAA. data_shape {2, 20}.
+ std::vector data(2 * 20, pack4(0, 0, 0, 0));
+ std::vector data_shape = {2, 20};
+
+ std::vector indices = {1};
+ std::vector indices_shape = {1};
+ std::vector scales(2 * 5, 1.0f);
+ std::vector scales_shape = {2, 5};
+
+ // Packed zero_points: shape {2, 2} (5 logical -> 2 bytes per row).
+ // Row 0 logical [-2, 1, 0, -1, 1] -> byte0 = pack4(-2, 1, 0, -1) = 0x6C; byte1 lane0=1, rest dont-care.
+ // Row 1 logical [ 1, 0,-1,-2, 0] -> byte0 = pack4( 1, 0,-1,-2) = 0x1B; byte1 lane0=0, rest dont-care.
+ std::vector zero_points = {
+ pack4(-2, 1, 0, -1), pack4(1, -2, -2, -2),
+ pack4(1, 0, -1, -2), pack4(0, -2, -2, -2)};
+ std::vector zero_points_shape = {2, 2};
+
+ // Picking row 1 via indices=[1]: dequant per block = (0 - zp) * 1 = [-1, 0, 1, 2, 0].
+ std::vector output;
+ output.reserve(80);
+ for (float v : {-1.f, 0.f, 1.f, 2.f, 0.f}) {
+ output.insert(output.end(), 16, v);
+ }
+ std::vector output_shape = {1, 80};
+
+ RunGatherBlockQuantizedWebGpu(data, data_shape, indices, indices_shape, scales, scales_shape,
+ zero_points, zero_points_shape,
+ /*gather_axis=*/0, /*quantize_axis=*/1,
+ /*block_size=*/16, /*bits=*/2, output, output_shape);
+}
+#endif // USE_WEBGPU
+
template
void Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_4Bits() {
// This test case specific to shared 4bit token_embedding/lm_head use case on CUDA