From 5f6d359d6b49e9ac992c513c1d05c77af6bc193a Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 3 Mar 2026 20:09:56 +0800 Subject: [PATCH] Support Qwen3.5 --- trinity/common/models/vllm_model.py | 15 +- trinity/common/models/vllm_patch/__init__.py | 11 + .../common/models/vllm_patch/worker_patch.py | 4 +- trinity/common/patch/qwen3_5.py | 201 ++++++++++++++++++ trinity/trainer/verl/monkey_patch.py | 115 +++++++++- 5 files changed, 340 insertions(+), 6 deletions(-) create mode 100644 trinity/common/patch/qwen3_5.py diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 86394d11f0..2713372c94 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -63,14 +63,16 @@ def __init__( os.environ["VLLM_CACHE_ROOT"] = os.path.expanduser( f"~/.cache/vllm/{config.bundle_indices}" ) + self.tokenization_kwargs = { + "truncate_prompt_tokens": config.max_prompt_tokens + if config.enable_prompt_truncation + else None + } self.default_sampling_params = vllm.SamplingParams( n=1, temperature=config.temperature, max_tokens=config.max_response_tokens, min_tokens=config.min_response_tokens, - truncate_prompt_tokens=( - config.max_prompt_tokens if config.enable_prompt_truncation else None - ), skip_special_tokens=True, include_stop_str_in_output=False, output_kind=RequestOutputKind.FINAL_ONLY, @@ -78,6 +80,7 @@ def __init__( top_p=config.top_p, top_k=config.top_k, ignore_eos=config.ignore_eos, + **(self.tokenization_kwargs if self.vllm_version <= parse_version("0.16.0") else {}), ) self.ray_namespace = config.ray_namespace self.request_id = 0 @@ -417,11 +420,17 @@ async def sample( async def _generate_internal(self, prompt: Any, lora_request=None, **kwargs) -> Any: # Send the request to the LLM engine. self.request_id += 1 + generate_kwargs = ( + {"tokenization_kwargs": self.tokenization_kwargs} + if self.vllm_version > parse_version("0.16.0") + else {} + ) stream = self.async_llm.generate( request_id=str(self.request_id), prompt=prompt, sampling_params=self._create_sampling_params(**kwargs), lora_request=lora_request, + **generate_kwargs, ) # Consume the stream until the request is finished. diff --git a/trinity/common/models/vllm_patch/__init__.py b/trinity/common/models/vllm_patch/__init__.py index 6fa7b99fe8..40702fc227 100644 --- a/trinity/common/models/vllm_patch/__init__.py +++ b/trinity/common/models/vllm_patch/__init__.py @@ -20,6 +20,17 @@ def vllm_patch(): if trf_version >= parse_version("5.0.0") and vllm_version < parse_version("0.16.0"): raise ImportError("Please upgrade vllm to 0.16.0 or above to use transformers>=5.0.0.") + from transformers.configuration_utils import PreTrainedConfig + + original_init = PreTrainedConfig.__init__ + + def new_init(self, *args, **kwargs): + if "ignore_keys_at_rope_validation" in kwargs: + kwargs["ignore_keys_at_rope_validation"] = set(kwargs["ignore_keys_at_rope_validation"]) + original_init(self, *args, **kwargs) + + PreTrainedConfig.__init__ = new_init + def get_vllm_version(): try: diff --git a/trinity/common/models/vllm_patch/worker_patch.py b/trinity/common/models/vllm_patch/worker_patch.py index 89b116a954..866983cabe 100644 --- a/trinity/common/models/vllm_patch/worker_patch.py +++ b/trinity/common/models/vllm_patch/worker_patch.py @@ -13,10 +13,10 @@ def patch_vllm_prompt_logprobs(model_runner: GPUModelRunner): # noqa: C901 """Patch vLLM model runner to support prompt logprobs extraction.""" version = get_vllm_version() - if version < parse_version("0.10.2") or version > parse_version("0.16.0"): + if version < parse_version("0.10.2") or version >= parse_version("0.17.0"): raise ValueError( f"Unsupported vllm version: {vllm.__version__}. " - "This patch requires vllm version >= 0.10.2, <= 0.16.0." + "This patch requires vllm version >= 0.10.2, < 0.17.0." ) is_v0102 = version == parse_version("0.10.2") diff --git a/trinity/common/patch/qwen3_5.py b/trinity/common/patch/qwen3_5.py new file mode 100644 index 0000000000..80e87913d4 --- /dev/null +++ b/trinity/common/patch/qwen3_5.py @@ -0,0 +1,201 @@ +from dataclasses import dataclass +from functools import wraps +from typing import Optional + +import torch +from transformers.models.qwen3_5.modeling_qwen3_5 import ( + BaseModelOutputWithPast, + Cache, + Qwen3_5CausalLMOutputWithPast, + Qwen3_5DynamicCache, + Qwen3_5ForConditionalGeneration, + Qwen3_5ModelOutputWithPast, + TransformersKwargs, + Unpack, + capture_outputs, + create_causal_mask, + merge_with_config_defaults, +) + + +# TODO: may optimize this function +def ulysses_gated_delta_net_forward_decorator(func): + @wraps(func) + def wrapper( + hidden_states: torch.Tensor, + cache_params: Qwen3_5DynamicCache | None = None, + cache_position: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + ): + from verl.utils.ulysses import ( + gather_outputs_and_unpad, + get_ulysses_sequence_parallel_world_size, + slice_input_tensor, + ) + + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + if ulysses_sp_size > 1: + hidden_states = gather_outputs_and_unpad(hidden_states, gather_dim=1) + + output = func(hidden_states, cache_params, cache_position, attention_mask) + + if ulysses_sp_size > 1: + output = slice_input_tensor(output, dim=1, padding=False) + return output + + return wrapper + + +@merge_with_config_defaults +@capture_outputs +def qwen35_text_forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], +) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = Qwen3_5DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # mrope: the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + causal_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_mask = ( + linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + ) + + hidden_states = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=layer_mask, + position_ids=text_position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return Qwen3_5ModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@dataclass +class Qwen3_5CausalLMOutputForPPO(Qwen3_5CausalLMOutputWithPast): + log_probs: Optional[torch.FloatTensor] = None + entropy: Optional[torch.FloatTensor] = None + + +def forward_with_torch_backend( + self: Qwen3_5ForConditionalGeneration, + input_ids: torch.LongTensor = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **kwargs, +) -> tuple | Qwen3_5CausalLMOutputForPPO: + from verl.utils.experimental.torch_functional import FusedLinearForPPO + + outputs = self.model(input_ids=input_ids, **kwargs) + hidden_states = outputs[0] + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError( + "To use forward_with_torch_backend, either labels or input_ids must be provided." + ) + + fused_linear_for_ppo = FusedLinearForPPO() + log_probs, entropy = fused_linear_for_ppo.forward( + hidden_states=hidden_states, + vocab_weights=self.lm_head.weight, + input_ids=rolled_labels, + temperature=temperature, + ) + return Qwen3_5CausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + hidden_states=outputs.hidden_states, + ) + + +def forward_with_triton_backend( + self: Qwen3_5ForConditionalGeneration, + input_ids: torch.LongTensor = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + **kwargs, +) -> tuple | Qwen3_5CausalLMOutputForPPO: + from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy + + outputs = self.model(input_ids=input_ids, **kwargs) + hidden_states = outputs[0] + + # Loss calculations + if labels is not None: + rolled_labels = torch.roll(labels, shifts=-1, dims=-1) + elif input_ids is not None: + rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) + else: + raise RuntimeError( + "To use forward_with_triton_backend, either labels or input_ids must be provided." + ) + + log_probs, entropy = linear_cross_entropy( + hidden_states, + self.lm_head.weight, + rolled_labels, + temperature, + "none", + ) + return Qwen3_5CausalLMOutputForPPO( + log_probs=log_probs, + entropy=entropy, + hidden_states=outputs.hidden_states, + ) diff --git a/trinity/trainer/verl/monkey_patch.py b/trinity/trainer/verl/monkey_patch.py index d0ba880a6c..608c602e0e 100644 --- a/trinity/trainer/verl/monkey_patch.py +++ b/trinity/trainer/verl/monkey_patch.py @@ -1,10 +1,77 @@ +import importlib import sys +from typing import Dict, Optional, Set import torch from transformers.modeling_utils import PreTrainedModel from trinity.utils.log import get_logger +# Map model types to their specific implementation modules. +# To extend support for a new model, simply add an entry here. +MODEL_TYPE_TO_MODULE_MAP: Dict[str, str] = { + "qwen2_5_vl": "verl.models.transformers.qwen2_vl", + "qwen2_vl": "verl.models.transformers.qwen2_vl", + "qwen3_vl": "verl.models.transformers.qwen3_vl", + "qwen3_vl_moe": "verl.models.transformers.qwen3_vl", + "qwen3_5": "trinity.common.patch.qwen3_5", + "qwen3_5_moe": "trinity.common.patch.qwen3_5", + "glm4v": "verl.models.transformers.glm4v", +} + +DEFAULT_MODULE_PATH = "verl.models.transformers.dense_common" +VALID_BACKENDS: Set[str] = {"triton", "torch"} + + +# modified from verl.models.transformers.monkey_patch.patch_forward_with_backends +def patch_forward_with_backends( + model: PreTrainedModel, + use_fused_kernels: bool = False, + fused_kernels_backend: Optional[str] = None, +) -> None: + """ + Monkey-patch the model's forward method with optimized backend implementations. + + Args: + model: The model to patch. + use_fused_kernels: Whether to enable fused kernels. + fused_kernels_backend: The backend to use ('triton' or 'torch'). + """ + logger = get_logger(__name__) + + # 1. Validation & Early Exit + if not use_fused_kernels: + return + + if fused_kernels_backend not in VALID_BACKENDS: + logger.warning( + f"Skipping patch for {model.__class__.__name__}: " + f"Invalid backend '{fused_kernels_backend}'. Choose from {VALID_BACKENDS}." + ) + return + + # 2. Resolve Module Path + model_type: str = getattr(model.config, "model_type", None) + module_path = MODEL_TYPE_TO_MODULE_MAP.get(model_type, DEFAULT_MODULE_PATH) + + # 3. Dynamic Import + try: + backend_module = importlib.import_module(module_path) + except ImportError as e: + logger.error(f"Failed to import {module_path} for {model.__class__.__name__}: {e}") + return + + # 4. Select and Apply Forward Function + func_name = f"forward_with_{fused_kernels_backend}_backend" + patched_forward = getattr(backend_module, func_name, None) + + if patched_forward is None: + logger.error(f"Function '{func_name}' not found in {module_path}") + return + + model.__class__.forward = patched_forward + logger.info(f"Applied {fused_kernels_backend.upper()} backend for {model.__class__.__name__}") + # modified from verl.models.transformers.monkey_patch.apply_monkey_patch def apply_monkey_patch( # noqa: C901 @@ -33,7 +100,6 @@ def apply_monkey_patch( # noqa: C901 """ from verl.models.transformers.monkey_patch import ( _ulysses_flash_attention_forward, - patch_forward_with_backends, patch_vlm_for_ulysses_input_slicing, ) from verl.utils.import_utils import is_trl_available @@ -127,6 +193,53 @@ def state_dict(self, *args, **kwargs): patch_vlm_for_ulysses_input_slicing(Qwen3VLTextModel) patch_vlm_for_ulysses_input_slicing(Qwen3VLMoeTextModel) + elif model.config.model_type in ["qwen3_5", "qwen3_5_moe"]: + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5TextModel + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeTextModel, + ) + + # Step 1: bug fix in transformers==5.2.0 + # see https://github.com/huggingface/transformers/pull/44382 + if "Qwen3_5TextDecoderLayer" in model._no_split_modules: + model._no_split_modules.remove("Qwen3_5TextDecoderLayer") + model.model._no_split_modules.remove("Qwen3_5TextDecoderLayer") + if "Qwen3_5MoeTextDecoderLayer" in model._no_split_modules: + model._no_split_modules.remove("Qwen3_5MoeTextDecoderLayer") + model.model._no_split_modules.remove("Qwen3_5MoeTextDecoderLayer") + + # see https://github.com/huggingface/transformers/pull/44399 + if is_transformers_version_in_range(max_version="5.2.0"): + from trinity.common.patch.qwen3_5 import qwen35_text_forward + + Qwen3_5TextModel.forward = qwen35_text_forward + Qwen3_5MoeTextModel.forward = qwen35_text_forward + + # Step 2: patch input for multimodal sequence parallelism + if ulysses_sp_size > 1: + patch_vlm_for_ulysses_input_slicing(Qwen3_5TextModel) + patch_vlm_for_ulysses_input_slicing(Qwen3_5MoeTextModel) + + from trinity.common.patch.qwen3_5 import ( + ulysses_gated_delta_net_forward_decorator, + ) + + for layer in model.model.language_model.layers: + if layer.layer_type == "linear_attention": + layer.linear_attn.forward = ulysses_gated_delta_net_forward_decorator( + layer.linear_attn.forward + ) + + # Step 3: patch verl.utils.flops_counter + from verl.utils.flops_counter import ESTIMATE_FUNC, _estimate_qwen2_flops + + ESTIMATE_FUNC.update( + { + "qwen3_5": _estimate_qwen2_flops, + "qwen3_5_moe": _estimate_qwen2_flops, + } + ) + elif model.config.model_type == "glm4v": # Step 1: patch model to support image-text mixed data