Skip to content

Native Qwen3-Reranker CausalLM support in RerankCalculatorOV#4063

Open
ambeckley wants to merge 1 commit intoopenvinotoolkit:mainfrom
ambeckley:ambeckley/native-qwen3-reranker-support
Open

Native Qwen3-Reranker CausalLM support in RerankCalculatorOV#4063
ambeckley wants to merge 1 commit intoopenvinotoolkit:mainfrom
ambeckley:ambeckley/native-qwen3-reranker-support

Conversation

@ambeckley
Copy link

@ambeckley ambeckley commented Mar 17, 2026

Summary

  • Adds native support for Qwen3-Reranker models (0.6B, 8B, all sizes) using CausalLM architecture, exported with --task text-generation
  • Auto-detects Qwen3 via model_type in config.json — no changes needed for existing reranker models
  • Applies server-side chat template formatting and CausalLM graph postprocessing (yes/no logit extraction via PrePostProcessor), so clients use the standard /v3/rerank API with no workarounds

Motivation

Qwen3-Reranker models use CausalLM architecture, not cross-encoder text-classification. The current workaround (#3578) requires community-modified seq-cls model. This PR enables all official Qwen3-Reranker model sizes to work natively through OVMS without client modifications. It is still backwards compatibility with tomaarsen/Qwen3-Reranker-*-seq-cls models.

Changes

src/rerank/rerank_servable.hpp

  • Added isQwen3, hasPositionIds, hasBeamIdx detection flags
  • Override applyPrePostProcessing() to:
    • Parse config.json for model_type: "qwen3"
    • Detect position_ids and beam_idx model inputs
    • Check output dimensionality (3D = CausalLM, 2D = text-classification with warning)
    • Look up yes/no token IDs via tokenizer
    • Build PrePostProcessor graph: Slice last token → Squeeze → Gather yes/no logits → Subtract (yes - no), producing [batch, 1] output compatible with existing sigmoid scoring

src/rerank/rerank_calculator_ov.cc

  • Added Qwen3 chat template input formatting in PrepareInputsForRerankModel()
  • Compute position_ids from attention mask for CausalLM models
  • Zero-fill beam_idx for CausalLM models
  • Guard token_type_ids creation with !isQwen3 check (CausalLM uses position_ids, not token_type_ids, as 3rd input)

Model Export

Models must be exported with --task text-generation (not the default text-classification):

optimum-cli export openvino --model Qwen/Qwen3-Reranker-8B --task text-generation --weight-format int8 Qwen3-Reranker-8B-causal-int8-ov

The default text-classification export produces a model with an untrained random classification head that outputs garbage scores.

Test plan

  • Tested with Qwen3-Reranker-0.6B (int8) on CPU — correct relevance scores
  • Tested with Qwen3-Reranker-8B (int8) on CPU and Intel Arc GPU — correct relevance scores
  • Verified existing non-Qwen3 reranker models are unaffected (isQwen3 = false, no code path changes)
  • CI/unit tests (to be added if maintainers request)

Qwen3-Reranker models use CausalLM architecture instead of cross-encoder
text-classification, requiring different input formatting and output
postprocessing. This enables OVMS to natively serve Qwen3-Reranker models
(all sizes: 0.6B, 8B) exported with --task text-generation via the
standard /v3/rerank API, with no client-side workarounds needed.

Changes:
- Auto-detect Qwen3 models via model_type in config.json
- Apply server-side chat template formatting for query-document pairs
- Add CausalLM graph postprocessing (Slice/Squeeze/Gather/Subtract)
  to extract yes/no logits from 3D output, producing scores compatible
  with existing sigmoid scoring
- Handle CausalLM-specific inputs (position_ids, beam_idx)
- Guard token_type_ids to avoid conflicts with CausalLM input layout
- Warn if model was exported as text-classification (random head weights)

Tested with Qwen3-Reranker-0.6B and Qwen3-Reranker-8B (int8) on
CPU and Intel Arc GPU, producing correct relevance scores.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds native support in OVMS rerank pipeline for Qwen3-Reranker models exported as CausalLM (--task text-generation) by detecting model_type=qwen3, applying server-side chat-template formatting, and postprocessing the CausalLM logits into a [batch, 1] rerank score tensor compatible with existing /v3/rerank scoring.

Changes:

  • Detect Qwen3 via config.json and build an OpenVINO PrePostProcessor postprocess graph to compute yes_logit - no_logit.
  • Add Qwen3-specific input text formatting (chat template) in the OV rerank calculator.
  • Add optional handling for extra CausalLM inputs (position_ids, beam_idx) and avoid creating token_type_ids for Qwen3.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.

File Description
src/rerank/rerank_servable.hpp Adds Qwen3 detection and CausalLM logits postprocessing via PrePostProcessor (yes/no logit extraction).
src/rerank/rerank_calculator_ov.cc Adds Qwen3 chat-template formatting and supplies CausalLM-specific inputs (position_ids, beam_idx), while skipping token_type_ids for Qwen3.

Comment on lines +299 to +313
auto position_ids = ov::Tensor(ov::element::i64, input_ids.get_shape());
int64_t* pos_data = position_ids.data<int64_t>();
int64_t* attn_data = attention_mask.data<int64_t>();
for (size_t b = 0; b < batch; b++) {
int64_t pos = 0;
for (size_t s = 0; s < seq_len; s++) {
pos_data[b * seq_len + s] = attn_data[b * seq_len + s] ? pos++ : 0;
}
}
inferRequest.set_tensor(RERANK_MODEL_POSITION_IDS_NAME, position_ids);
}
if (rerank_session->hasBeamIdx) {
size_t batch = input_ids.get_shape()[0];
auto beam_idx = ov::Tensor(ov::element::i32, {batch});
std::fill_n(beam_idx.data<int32_t>(), batch, 0);
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

position_ids and beam_idx tensors are created with hard-coded element types (i64 / i32). If the model’s actual input element types differ, set_tensor() will fail at runtime. Please derive the element type from the model/compiled model input (by name) and allocate the tensors with the expected type (or validate and error with a clear message).

Suggested change
auto position_ids = ov::Tensor(ov::element::i64, input_ids.get_shape());
int64_t* pos_data = position_ids.data<int64_t>();
int64_t* attn_data = attention_mask.data<int64_t>();
for (size_t b = 0; b < batch; b++) {
int64_t pos = 0;
for (size_t s = 0; s < seq_len; s++) {
pos_data[b * seq_len + s] = attn_data[b * seq_len + s] ? pos++ : 0;
}
}
inferRequest.set_tensor(RERANK_MODEL_POSITION_IDS_NAME, position_ids);
}
if (rerank_session->hasBeamIdx) {
size_t batch = input_ids.get_shape()[0];
auto beam_idx = ov::Tensor(ov::element::i32, {batch});
std::fill_n(beam_idx.data<int32_t>(), batch, 0);
// Derive element types from the compiled model inputs to avoid dtype mismatches.
const ov::element::Type pos_element_type =
inferRequest.get_compiled_model().input(RERANK_MODEL_POSITION_IDS_NAME).get_element_type();
const ov::element::Type mask_element_type =
inferRequest.get_compiled_model().input(RERANK_MODEL_ATTENTION_MASK_NAME).get_element_type();
ov::Tensor position_ids(pos_element_type, input_ids.get_shape());
if (pos_element_type == ov::element::i64) {
int64_t* pos_data = position_ids.data<int64_t>();
if (mask_element_type == ov::element::i64) {
int64_t* attn_data = attention_mask.data<int64_t>();
for (size_t b = 0; b < batch; b++) {
int64_t pos = 0;
for (size_t s = 0; s < seq_len; s++) {
pos_data[b * seq_len + s] = attn_data[b * seq_len + s] ? pos++ : 0;
}
}
} else if (mask_element_type == ov::element::i32) {
int32_t* attn_data = attention_mask.data<int32_t>();
for (size_t b = 0; b < batch; b++) {
int64_t pos = 0;
for (size_t s = 0; s < seq_len; s++) {
pos_data[b * seq_len + s] = attn_data[b * seq_len + s] ? pos++ : 0;
}
}
} else {
throw std::runtime_error("Unsupported attention_mask element type for position_ids generation");
}
} else if (pos_element_type == ov::element::i32) {
int32_t* pos_data = position_ids.data<int32_t>();
if (mask_element_type == ov::element::i64) {
int64_t* attn_data = attention_mask.data<int64_t>();
for (size_t b = 0; b < batch; b++) {
int32_t pos = 0;
for (size_t s = 0; s < seq_len; s++) {
pos_data[b * seq_len + s] =
attn_data[b * seq_len + s] ? static_cast<int32_t>(pos++) : 0;
}
}
} else if (mask_element_type == ov::element::i32) {
int32_t* attn_data = attention_mask.data<int32_t>();
for (size_t b = 0; b < batch; b++) {
int32_t pos = 0;
for (size_t s = 0; s < seq_len; s++) {
pos_data[b * seq_len + s] = attn_data[b * seq_len + s] ? pos++ : 0;
}
}
} else {
throw std::runtime_error("Unsupported attention_mask element type for position_ids generation");
}
} else {
throw std::runtime_error("Unsupported position_ids element type in compiled model");
}
inferRequest.set_tensor(RERANK_MODEL_POSITION_IDS_NAME, position_ids);
}
if (rerank_session->hasBeamIdx) {
size_t batch = input_ids.get_shape()[0];
const ov::element::Type beam_element_type =
inferRequest.get_compiled_model().input(RERANK_MODEL_BEAM_IDX_NAME).get_element_type();
ov::Tensor beam_idx(beam_element_type, {batch});
if (beam_element_type == ov::element::i32) {
std::fill_n(beam_idx.data<int32_t>(), batch, 0);
} else if (beam_element_type == ov::element::i64) {
std::fill_n(beam_idx.data<int64_t>(), batch, 0);
} else {
throw std::runtime_error("Unsupported beam_idx element type in compiled model");
}

Copilot uses AI. Check for mistakes.
Comment on lines +176 to +186
if (tokens.input_ids.get_shape().size() != 2) {
throw std::runtime_error("Tokens shape invalid.");
}
if (this->max_position_embeddings < tokens.input_ids.get_shape()[1]) {
std::ostringstream msg;
msg << "Qwen3 rerank request length of " << tokens.input_ids.get_shape()[1]
<< " tokens exceeds the model context of " << max_position_embeddings;
throw std::runtime_error(msg.str());
}
SPDLOG_LOGGER_DEBUG(rerank_calculator_logger, "Qwen3 rerank: {} documents, {} tokens per sequence",
batchSize, tokens.input_ids.get_shape()[1]);
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the Qwen3 path, the tokenizer outputs are not validated the way the non-Qwen3 path effectively is (via chunkDocuments()), but later code assumes attention_mask is i64 and uses attention_mask.data<int64_t>() to compute position_ids. Please add explicit validation of tokens.input_ids/tokens.attention_mask element types and shapes for the Qwen3 branch to avoid UB if the tokenizer output precision/layout differs.

Suggested change
if (tokens.input_ids.get_shape().size() != 2) {
throw std::runtime_error("Tokens shape invalid.");
}
if (this->max_position_embeddings < tokens.input_ids.get_shape()[1]) {
std::ostringstream msg;
msg << "Qwen3 rerank request length of " << tokens.input_ids.get_shape()[1]
<< " tokens exceeds the model context of " << max_position_embeddings;
throw std::runtime_error(msg.str());
}
SPDLOG_LOGGER_DEBUG(rerank_calculator_logger, "Qwen3 rerank: {} documents, {} tokens per sequence",
batchSize, tokens.input_ids.get_shape()[1]);
const ov::Shape& input_shape = tokens.input_ids.get_shape();
const ov::Shape& mask_shape = tokens.attention_mask.get_shape();
// Basic rank validation
if (input_shape.size() != 2 || mask_shape.size() != 2) {
throw std::runtime_error("Qwen3 tokenizer outputs must be 2D tensors [batch, sequence].");
}
// Ensure attention_mask layout matches input_ids layout
if (input_shape != mask_shape) {
throw std::runtime_error("Qwen3 tokenizer outputs have mismatched shapes for input_ids and attention_mask.");
}
// Ensure element types match the assumptions made later (attention_mask.data<int64_t>())
if (tokens.input_ids.get_element_type() != ov::element::i64) {
throw std::runtime_error("Qwen3 tokenizer input_ids tensor must have i64 element type.");
}
if (tokens.attention_mask.get_element_type() != ov::element::i64) {
throw std::runtime_error("Qwen3 tokenizer attention_mask tensor must have i64 element type.");
}
if (this->max_position_embeddings < input_shape[1]) {
std::ostringstream msg;
msg << "Qwen3 rerank request length of " << input_shape[1]
<< " tokens exceeds the model context of " << max_position_embeddings;
throw std::runtime_error(msg.str());
}
SPDLOG_LOGGER_DEBUG(rerank_calculator_logger, "Qwen3 rerank: {} documents, {} tokens per sequence",
batchSize, input_shape[1]);

Copilot uses AI. Check for mistakes.
Comment on lines +295 to +315
// For CausalLM models (e.g. Qwen3 rerankers): set position_ids and beam_idx
if (rerank_session->hasPositionIds) {
size_t batch = input_ids.get_shape()[0];
size_t seq_len = input_ids.get_shape()[1];
auto position_ids = ov::Tensor(ov::element::i64, input_ids.get_shape());
int64_t* pos_data = position_ids.data<int64_t>();
int64_t* attn_data = attention_mask.data<int64_t>();
for (size_t b = 0; b < batch; b++) {
int64_t pos = 0;
for (size_t s = 0; s < seq_len; s++) {
pos_data[b * seq_len + s] = attn_data[b * seq_len + s] ? pos++ : 0;
}
}
inferRequest.set_tensor(RERANK_MODEL_POSITION_IDS_NAME, position_ids);
}
if (rerank_session->hasBeamIdx) {
size_t batch = input_ids.get_shape()[0];
auto beam_idx = ov::Tensor(ov::element::i32, {batch});
std::fill_n(beam_idx.data<int32_t>(), batch, 0);
inferRequest.set_tensor(RERANK_MODEL_BEAM_IDX_NAME, beam_idx);
}
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR introduces a new Qwen3-specific request formatting path and new model inputs (position_ids, beam_idx) handling, but there are no unit/functional tests covering these behaviors. Since the repo already has rerank-related tests (e.g. src/test/reranknode_test.cpp, src/test/rerank_chunking_test.cpp), please add coverage that at least validates Qwen3 detection from config.json and that the calculator sets the expected extra tensors / produces a [batch, 1] logits output.

Copilot uses AI. Check for mistakes.
Comment on lines +140 to +142
auto start = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-1});
auto stop = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{std::numeric_limits<int64_t>::max()});
auto step = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{1});
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::numeric_limits<int64_t>::max() is used here but <limits> isn’t included in this header. Please include <limits> explicitly to avoid relying on transitive includes (IWYU) and prevent fragile builds.

Copilot uses AI. Check for mistakes.
Comment on lines +101 to +106
if (outputShape.rank().get_length() == 2) {
// Already a 2D output (text-classification export) — postprocessing won't help
// because the classification head has random weights
SPDLOG_WARN("Qwen3 reranker has 2D output shape (text-classification export). "
"Re-export with --task text-generation for correct scoring.");
return model;
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outputShape.rank().get_length() will throw/assert if the output rank is dynamic. Consider using outputShape.rank() == 2 (which is safe with dynamic ranks) and also explicitly handling the expected CausalLM case (rank == 3) vs unexpected ranks (log and return/throw).

Suggested change
if (outputShape.rank().get_length() == 2) {
// Already a 2D output (text-classification export) — postprocessing won't help
// because the classification head has random weights
SPDLOG_WARN("Qwen3 reranker has 2D output shape (text-classification export). "
"Re-export with --task text-generation for correct scoring.");
return model;
ov::Rank outputRank = outputShape.rank();
if (outputRank.is_dynamic()) {
SPDLOG_WARN("Qwen3 reranker output rank is dynamic; skipping specialized postprocessing");
return model;
}
std::size_t outputRankLength = outputRank.get_length();
if (outputRankLength == 2) {
// Already a 2D output (text-classification export) — postprocessing won't help
// because the classification head has random weights
SPDLOG_WARN("Qwen3 reranker has 2D output shape (text-classification export). "
"Re-export with --task text-generation for correct scoring.");
return model;
} else if (outputRankLength != 3) {
SPDLOG_WARN("Qwen3 reranker has unexpected output rank {}. Expected 2 (classification) or 3 (CausalLM). "
"Skipping specialized postprocessing.",
outputRankLength);
return model;

Copilot uses AI. Check for mistakes.
// yes_logit - no_logit → sigmoid of this = softmax P(yes)
auto diff = std::make_shared<ov::op::v1::Subtract>(yesSlice, noSlice);

return diff; // [batch, 1]
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The custom postprocess returns a new node without ensuring the output tensor name and element type match what the rest of the rerank pipeline expects. RerankCalculatorOV later fetches inferRequest.get_tensor("logits") and reads it as float*; if PrePostProcessor drops/changes the output name or leaves the output as f16, inference or scoring will break. Please either (a) set the postprocessed result tensor names back to "logits" and convert the output to f32 in the postprocess graph, or (b) update the calculator to query the compiled model’s output name and handle non-f32 element types.

Suggested change
return diff; // [batch, 1]
// Ensure the final output tensor matches pipeline expectations:
// - element type: f32 (RerankCalculatorOV reads as float*)
// - tensor name: "logits" (queried via inferRequest.get_tensor("logits"))
auto diffF32 = std::make_shared<ov::op::v0::Convert>(diff, ov::element::f32);
diffF32->set_friendly_name("logits");
return diffF32; // [batch, 1] logits in f32

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might be valid. Setting the name will ensure the output read will work even if something changes due to graph surgery.
Can you confirm new output node is indeed f32? I wonder if the convert op here is indeed necessary.

Copy link
Collaborator

@mzegla mzegla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution. I added a few comments. Please also check copilot suggestions.

I think the ultimate solution for enablement of that model would be to use GenAI rerank pipeline in OVMS, as it seems like a better fit to have core OV logic like graphs processing there. But that's an idea for the future - this change is very useful either way 😃

// yes_logit - no_logit → sigmoid of this = softmax P(yes)
auto diff = std::make_shared<ov::op::v1::Subtract>(yesSlice, noSlice);

return diff; // [batch, 1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might be valid. Setting the name will ensure the output read will work even if something changes due to graph surgery.
Can you confirm new output node is indeed f32? I wonder if the convert op here is indeed necessary.

Comment on lines +113 to +120
auto yesTokens = tokenizer->encode("yes");
if (yesTokens.input_ids.get_size() == 1 && yesTokens.input_ids.get_element_type() == ov::element::i64) {
yesTokenId = reinterpret_cast<int64_t*>(yesTokens.input_ids.data())[0];
}
auto noTokens = tokenizer->encode("no");
if (noTokens.input_ids.get_size() == 1 && noTokens.input_ids.get_element_type() == ov::element::i64) {
noTokenId = reinterpret_cast<int64_t*>(noTokens.input_ids.data())[0];
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it's safe to rely on tokenizer->encode as in certain settings it will treat it as an input prompt and add special tokens like <bos><yes_token> and we end up picking wrong tokens.
I can see in GenAI there is a direct vocab read for that:
https://github.com/openvinotoolkit/openvino.genai/blob/716a778fc0ccfa86f1395b186a0cb2ca8ed7ece5/src/cpp/src/rag/text_rerank_pipeline.cpp#L179
I think it would be safer to do it that way.

auto start = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{-1});
auto stop = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{std::numeric_limits<int64_t>::max()});
auto step = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{1});
auto axis1 = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{1});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be named just axis or sliceAxis right? There is no axis2, so we don't need to differ and value of the constant tells which axis we pick.

}

protected:
std::shared_ptr<ov::Model> applyPrePostProcessing(ov::Core& core, std::shared_ptr<ov::Model> model, ov::AnyMap& properties) override {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would split that logic. Since we only need graph postprocessing for Qwen3 I would extract detection to another isQwen3Model method and go with something like:

if (isQwen3Model)
	applyQwen3GraphPostProcessing()

I think it will be more straighforward on the higher level as now we need to get into that function to see early return if it's not qwen3 model

hasPositionIds = true;
SPDLOG_DEBUG("Qwen3 reranker model has position_ids input");
}
if (input.get_any_name() == "beam_idx") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be else if ?


// Check output shape — only apply postprocessing for CausalLM models (3D output)
ov::PartialShape outputShape = model->get_output_partial_shape(0);
if (outputShape.rank().get_length() == 2) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would go for the wider check. If we expect 3D, let's check for 3D, so if rank != 3.

@mzegla
Copy link
Collaborator

mzegla commented Mar 24, 2026

Also looks like style check failed on the CI. Please run make style to fix linter issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants