Skip to content

Conversation

@lvliang-intel
Copy link
Contributor

@lvliang-intel lvliang-intel commented Feb 4, 2026

Description

This update adds quantization support for Qwen3-Omni by integrating a custom MLLM processor and template, implementing dedicated forward logic for thinker/talker calibration, and introducing model-specific block discovery.

Note: This feature requires Transformers >= 5.1.0, as earlier versions contain compatibility issues with Qwen3-Omni.

Type of Change

  • Bug fix
  • New feature
  • Documentation update
  • Performance improvement
  • Code refactoring
  • Other (please specify):

Related Issues

#1387

Fixes or relates to #

Checklist Before Submitting

  • My code has been tested locally.
  • Documentation has been updated as needed.
  • New or updated tests are included where applicable.

lvliang-intel and others added 4 commits February 4, 2026 14:50
Signed-off-by: lvliang-intel <liang1.lv@intel.com>
Signed-off-by: lvliang-intel <liang1.lv@intel.com>
@wenhuach21
Copy link
Contributor

Thank you for the PR! Could you help verify all inferences (vLLM, Transformers 4, and Transformers 5) before merging?

@lvliang-intel
Copy link
Contributor Author

Quantize:

from auto_round import AutoRound

model_name_or_path = "Qwen/Qwen3-Omni-30B-A3B-Instruct"

ar = AutoRound(
    model=model_name_or_path,
    scheme="W4A16",
    lr=5e-3,
    iters=100,
)
ar.quantize_and_save(format="auto_round", output_dir="tmp_qwen_omni_w4a16")

Inference with transformers 5.1.0

#!/usr/bin/env python3
"""Verify a quantized Qwen3-Omni model with transformers.

Tests text-only, image, audio, and video inputs.
"""

import argparse
import os
import sys
import traceback

import soundfile as sf
import torch
from transformers import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoeProcessor
from qwen_omni_utils import process_mm_info


USE_AUDIO_IN_VIDEO = True

# Demo resources from Qwen
DEMO_IMAGE = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/cars.jpg"
DEMO_AUDIO = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/cough.wav"
DEMO_VIDEO = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/draw.mp4"


TEST_CASES = {
    "text_only": {
        "conversation": [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "What is the capital of France? Answer in one short sentence."},
                ],
            },
        ],
    },
    "image": {
        "conversation": [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": DEMO_IMAGE},
                    {"type": "text", "text": "Describe this image in one short sentence."},
                ],
            },
        ],
    },
    "audio": {
        "conversation": [
            {
                "role": "user",
                "content": [
                    {"type": "audio", "audio": DEMO_AUDIO},
                    {"type": "text", "text": "What sound can you hear? Answer in one short sentence."},
                ],
            },
        ],
    },
    "video": {
        "conversation": [
            {
                "role": "user",
                "content": [
                    {"type": "video", "video": DEMO_VIDEO},
                    {"type": "text", "text": "Describe what happens in this video in one short sentence."},
                ],
            },
        ],
    },
    "image_audio": {
        "conversation": [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": DEMO_IMAGE},
                    {"type": "audio", "audio": DEMO_AUDIO},
                    {"type": "text", "text": "What can you see and hear? Answer in one short sentence."},
                ],
            },
        ],
    },
}


def run_test(model, processor, test_name, conversation, max_new_tokens, enable_audio_output):
    """Run a single test case and return the result."""
    print(f"\n{'='*60}")
    print(f"Test: {test_name}")
    print(f"{'='*60}")

    text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
    audios, images, videos = process_mm_info(conversation, use_audio_in_video=USE_AUDIO_IN_VIDEO)

    inputs = processor(
        text=text,
        audio=audios,
        images=images,
        videos=videos,
        return_tensors="pt",
        padding=True,
        use_audio_in_video=USE_AUDIO_IN_VIDEO,
    )
    inputs = inputs.to(model.device).to(model.dtype)

    generate_kwargs = dict(
        **inputs,
        max_new_tokens=max_new_tokens,
        use_audio_in_video=USE_AUDIO_IN_VIDEO,
        do_sample=False,
    )
    if enable_audio_output:
        generate_kwargs["speaker"] = "Ethan"
        generate_kwargs["thinker_return_dict_in_generate"] = True

    output = model.generate(**generate_kwargs)

    # Qwen3-Omni generate() always returns (text_ids, audio) tuple
    if isinstance(output, tuple):
        text_ids, audio_out = output
    else:
        text_ids, audio_out = output, None

    # With thinker_return_dict_in_generate=True, text_ids has .sequences
    if hasattr(text_ids, "sequences"):
        decode_ids = text_ids.sequences[:, inputs["input_ids"].shape[1]:]
    else:
        decode_ids = text_ids[:, inputs["input_ids"].shape[1]:]

    generated_text = processor.batch_decode(
        decode_ids,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )[0]

    print(f"Output: {generated_text}")

    if audio_out is not None and enable_audio_output:
        wav_path = f"output_{test_name}.wav"
        sf.write(wav_path, audio_out.reshape(-1).detach().cpu().numpy(), samplerate=24000)
        print(f"Audio saved to {wav_path}")

    if len(generated_text.strip()) == 0:
        print(f"WARNING: Empty output for test '{test_name}'")
        return False
    return True


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Verify a quantized Qwen3-Omni model with text, image, audio, and video inputs."
    )
    parser.add_argument(
        "--model-dir",
        required=True,
        help="Path to the quantized model directory.",
    )
    parser.add_argument(
        "--max-new-tokens",
        type=int,
        default=256,
        help="Maximum number of new tokens to generate.",
    )
    parser.add_argument(
        "--tests",
        nargs="+",
        default=list(TEST_CASES.keys()),
        choices=list(TEST_CASES.keys()),
        help="Which tests to run (default: all).",
    )
    parser.add_argument(
        "--enable-audio-output",
        action="store_true",
        default=False,
        help="Enable audio output generation (requires talker model).",
    )
    return parser.parse_args()


def main() -> int:
    args = parse_args()
    model_dir = os.path.abspath(args.model_dir)
    if not os.path.isdir(model_dir):
        print(f"Model directory not found: {model_dir}")
        return 1

    print(f"Loading model from {model_dir} ...")
    model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
        model_dir,
        torch_dtype="auto",
        device_map="auto",
        trust_remote_code=True,
        attn_implementation="flash_attention_2",
    )
    processor = Qwen3OmniMoeProcessor.from_pretrained(model_dir)
    print("Model and processor loaded.\n")

    passed, failed = [], []
    for test_name in args.tests:
        tc = TEST_CASES[test_name]
        try:
            ok = run_test(
                model, processor, test_name, tc["conversation"],
                args.max_new_tokens, args.enable_audio_output,
            )
            (passed if ok else failed).append(test_name)
        except Exception as exc:
            print(f"ERROR in test '{test_name}': {exc}")
            traceback.print_exc()
            failed.append(test_name)

    print(f"\n{'='*60}")
    print(f"Results: {len(passed)} passed, {len(failed)} failed out of {len(args.tests)} tests")
    if passed:
        print(f"  Passed: {', '.join(passed)}")
    if failed:
        print(f"  Failed: {', '.join(failed)}")
    print(f"{'='*60}")

    del model
    torch.cuda.empty_cache()
    return 1 if failed else 0


if __name__ == "__main__":
    sys.exit(main())


CUDA_VISIBLE_DEVICES=0 python verify_quantized_transformers.py --model-dir ./tmp_qwen_omni_w4a16/
Loading model from /mnt/disk1/lvl/auto-round/tmp_qwen_omni_w4a16 ...
Unrecognized keys in rope_parameters for 'rope_type'='default': {'interleaved', 'mrope_section'}
You are attempting to use Flash Attention 2 without specifying a torch dtype. This might lead to unexpected behaviour
2026-02-11 18:50:27 WARNING modeling_utils.py L4356: loss_type=None was set in the config but it is unrecognized. Using the default loss: ForCausalLMLoss.
2026-02-11 18:53:41 INFO moe_experts_interface.py L432: Unfused 68 MOE experts modules for quantization
2026-02-11 18:53:41 INFO replace_modules.py L80: Prepared 68 MOE modules for quantization
2026-02-11 18:53:45 WARNING backend.py L1088: Better backend is found, please install all the following requirements to enable it.
2026-02-11 18:53:45 WARNING backend.py L1088: pip install -v "gptqmodel>=2.0" --no-build-isolation
Loading weights: 100%|█| 54786/54786 [00:15<00:00, 3647.06it/s, Materiali
Unrecognized keys in rope_parameters for 'rope_type'='default': {'interleaved', 'mrope_section'}
Model and processor loaded.

Test: text_only
Setting pad_token_id to eos_token_id:151645 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's attention_mask to obtain reliable results.
Setting pad_token_id to eos_token_id:2150 for open-end generation.
2026-02-11 18:54:24 WARNING utils.py L2088: The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's attention_mask to obtain reliable results.
Output: The capital of France is Paris.

Test: image
Setting pad_token_id to eos_token_id:151645 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's attention_mask to obtain reliable results.
Setting pad_token_id to eos_token_id:2150 for open-end generation.
Output: This composite image displays four different luxury vehicles: a white Rolls-Royce Phantom, a grey Mercedes-Benz GLE SUV, a red Ferrari Portofino M, and a white Porsche 911.

Test: audio
/mnt/disk1/lvl/conda_envs/artest/lib/python3.11/site-packages/librosa/core/audio.py:172: FutureWarning: librosa.core.audio.__audioread_load
Deprecated as of librosa version 0.10.0.
It will be removed in librosa version 1.0.
y, sr_native = __audioread_load(path, offset, duration, dtype)
Setting pad_token_id to eos_token_id:151645 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's attention_mask to obtain reliable results.
Setting pad_token_id to eos_token_id:2150 for open-end generation.
Output: The sound of a person coughing is heard.

Test: video
/mnt/disk1/lvl/conda_envs/artest/lib/python3.11/site-packages/librosa/core/audio.py:172: FutureWarning: librosa.core.audio.__audioread_load
Deprecated as of librosa version 0.10.0.
It will be removed in librosa version 1.0.
y, sr_native = __audioread_load(path, offset, duration, dtype)
qwen-vl-utils using torchvision to read video.
/mnt/disk1/lvl/conda_envs/artest/lib/python3.11/site-packages/torchvision/io/_video_deprecation_warning.py:9: UserWarning: The video decoding and encoding capabilities of torchvision are deprecated from version 0.22 and will be removed in version 0.24. We recommend that you migrate to TorchCodec, where we'll consolidate the future decoding/encoding capabilities of PyTorch: https://github.com/pytorch/torchcodec
warnings.warn(
Setting pad_token_id to eos_token_id:151645 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's attention_mask to obtain reliable results.
Setting pad_token_id to eos_token_id:2150 for open-end generation.
Output: A person uses a stylus to draw a guitar on a tablet.

Test: image_audio
/mnt/disk1/lvl/conda_envs/artest/lib/python3.11/site-packages/librosa/core/audio.py:172: FutureWarning: librosa.core.audio.__audioread_load
Deprecated as of librosa version 0.10.0.
It will be removed in librosa version 1.0.
y, sr_native = __audioread_load(path, offset, duration, dtype)
Setting pad_token_id to eos_token_id:151645 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's attention_mask to obtain reliable results.
Setting pad_token_id to eos_token_id:2150 for open-end generation.
Output: The image displays four luxury vehicles—a Rolls-Royce, a Mercedes-Benz GLE SUV, a red Ferrari Portofino M, and a white Porsche 911—while the audio features a person coughing.

Results: 5 passed, 0 failed out of 5 tests
Passed: text_only, image, audio, video, image_audio

vLLM tests are currently blocked because the latest vLLM version depends on an outdated Transformers release. Qwen3-Omni requires Transformers >= 5.1.0 to address several known issues.

@lvliang-intel lvliang-intel marked this pull request as ready for review February 11, 2026 11:09
Copilot AI review requested due to automatic review settings February 11, 2026 11:09
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds quantization support for the Qwen3-Omni MoE model family by integrating model-specific loading/version gating, calibration forward behavior for thinker/talker, and custom multimodal block discovery.

Changes:

  • Added explicit Transformers version guard for qwen3_omni_moe.
  • Introduced Qwen3-Omni processor/template registration and model-specific multimodal block name discovery.
  • Implemented a Qwen3-Omni-specific forward path to run thinker (and optionally talker) during calibration.

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
pyproject.toml Adds a project-specific word to typos’ allowlist.
auto_round/utils/model.py Adds Transformers version guard and adjusts lm_head discovery logic.
auto_round/utils/common.py Adds _no_split_modules normalization and extends multimodal ignore-key lists.
auto_round/special_model_handler.py Adds Qwen3-Omni special forward + block discovery + ignore-layer rule.
auto_round/compressors/shard_writer.py Improves tie_word_embeddings lookup for nested multimodal configs.
auto_round/compressors/mllm/utils.py Extends multimodal ignore-key list for Qwen3-Omni components.
auto_round/compressors/mllm/template.py Registers a Qwen3-Omni model template with the new processor.
auto_round/compressors/mllm/processor.py Adds a custom processor for Qwen3-Omni chat-template inputs.
auto_round/compressors/base.py Imports the new normalization helper.
auto_round/auto_scheme/utils.py Uses normalized _no_split_modules when dispatching across devices.

)

# Run talker forward if available (for calibration purposes)
if hasattr(model, "talker") and model.has_talker:
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can raise AttributeError when model.has_talker doesn’t exist (the hasattr only checks talker). Use getattr(model, "has_talker", False) (and optionally also ensure model.talker is not None) to make this guard safe.

Suggested change
if hasattr(model, "talker") and model.has_talker:
if getattr(model, "has_talker", False) and getattr(model, "talker", None) is not None:

Copilot uses AI. Check for mistakes.
Comment on lines +238 to +242
# Use text projection to convert thinker embeddings to talker space
if hasattr(model.talker, "text_projection"):
# Get thinker embeddings
thinker_embeds = model.thinker.get_input_embeddings()(input_ids)
talker_inputs_embeds = model.talker.text_projection(thinker_embeds)
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This path assumes input_ids is provided; if calibration runs with inputs_embeds (or other modalities without input_ids), this will throw and then be silently ignored (due to the broad except), meaning the talker forward never runs. Consider deriving inputs from inputs_embeds when present, or projecting from thinker_output.hidden_states[-1] (which you already compute) instead of re-embedding input_ids.

Suggested change
# Use text projection to convert thinker embeddings to talker space
if hasattr(model.talker, "text_projection"):
# Get thinker embeddings
thinker_embeds = model.thinker.get_input_embeddings()(input_ids)
talker_inputs_embeds = model.talker.text_projection(thinker_embeds)
# Use text projection to convert thinker hidden states to talker space
if hasattr(model.talker, "text_projection"):
# Project thinker hidden states directly into the talker embedding space
talker_inputs_embeds = model.talker.text_projection(thinker_hidden)

Copilot uses AI. Check for mistakes.
lvliang-intel and others added 2 commits February 11, 2026 19:20
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: lvliang-intel <liang1.lv@intel.com>
@chensuyue chensuyue requested a review from xin3he February 12, 2026 06:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants