Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,24 @@ def __init__(
os.environ["VLLM_CACHE_ROOT"] = os.path.expanduser(
f"~/.cache/vllm/{config.bundle_indices}"
)
self.tokenization_kwargs = {
"truncate_prompt_tokens": config.max_prompt_tokens
if config.enable_prompt_truncation
else None
}
self.default_sampling_params = vllm.SamplingParams(
n=1,
temperature=config.temperature,
max_tokens=config.max_response_tokens,
min_tokens=config.min_response_tokens,
truncate_prompt_tokens=(
config.max_prompt_tokens if config.enable_prompt_truncation else None
),
skip_special_tokens=True,
include_stop_str_in_output=False,
output_kind=RequestOutputKind.FINAL_ONLY,
logprobs=config.logprobs,
top_p=config.top_p,
top_k=config.top_k,
ignore_eos=config.ignore_eos,
**(self.tokenization_kwargs if self.vllm_version <= parse_version("0.16.0") else {}),
)
self.ray_namespace = config.ray_namespace
self.request_id = 0
Expand Down Expand Up @@ -417,11 +420,17 @@ async def sample(
async def _generate_internal(self, prompt: Any, lora_request=None, **kwargs) -> Any:
# Send the request to the LLM engine.
self.request_id += 1
generate_kwargs = (
{"tokenization_kwargs": self.tokenization_kwargs}
if self.vllm_version > parse_version("0.16.0")
else {}
)
stream = self.async_llm.generate(
request_id=str(self.request_id),
prompt=prompt,
sampling_params=self._create_sampling_params(**kwargs),
lora_request=lora_request,
**generate_kwargs,
)

# Consume the stream until the request is finished.
Expand Down
11 changes: 11 additions & 0 deletions trinity/common/models/vllm_patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ def vllm_patch():
if trf_version >= parse_version("5.0.0") and vllm_version < parse_version("0.16.0"):
raise ImportError("Please upgrade vllm to 0.16.0 or above to use transformers>=5.0.0.")

from transformers.configuration_utils import PreTrainedConfig

original_init = PreTrainedConfig.__init__

def new_init(self, *args, **kwargs):
if "ignore_keys_at_rope_validation" in kwargs:
kwargs["ignore_keys_at_rope_validation"] = set(kwargs["ignore_keys_at_rope_validation"])
original_init(self, *args, **kwargs)

PreTrainedConfig.__init__ = new_init


def get_vllm_version():
try:
Expand Down
4 changes: 2 additions & 2 deletions trinity/common/models/vllm_patch/worker_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
def patch_vllm_prompt_logprobs(model_runner: GPUModelRunner): # noqa: C901
"""Patch vLLM model runner to support prompt logprobs extraction."""
version = get_vllm_version()
if version < parse_version("0.10.2") or version > parse_version("0.16.0"):
if version < parse_version("0.10.2") or version >= parse_version("0.17.0"):
raise ValueError(
f"Unsupported vllm version: {vllm.__version__}. "
"This patch requires vllm version >= 0.10.2, <= 0.16.0."
"This patch requires vllm version >= 0.10.2, < 0.17.0."
)
is_v0102 = version == parse_version("0.10.2")

Expand Down
201 changes: 201 additions & 0 deletions trinity/common/patch/qwen3_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
from dataclasses import dataclass
from functools import wraps
from typing import Optional

import torch
from transformers.models.qwen3_5.modeling_qwen3_5 import (
BaseModelOutputWithPast,
Cache,
Qwen3_5CausalLMOutputWithPast,
Qwen3_5DynamicCache,
Qwen3_5ForConditionalGeneration,
Qwen3_5ModelOutputWithPast,
TransformersKwargs,
Unpack,
capture_outputs,
create_causal_mask,
merge_with_config_defaults,
)


# TODO: may optimize this function
def ulysses_gated_delta_net_forward_decorator(func):
@wraps(func)
def wrapper(
hidden_states: torch.Tensor,
cache_params: Qwen3_5DynamicCache | None = None,
cache_position: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
):
from verl.utils.ulysses import (
gather_outputs_and_unpad,
get_ulysses_sequence_parallel_world_size,
slice_input_tensor,
)

ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
if ulysses_sp_size > 1:
hidden_states = gather_outputs_and_unpad(hidden_states, gather_dim=1)

output = func(hidden_states, cache_params, cache_position, attention_mask)

if ulysses_sp_size > 1:
output = slice_input_tensor(output, dim=1, padding=False)
return output

return wrapper


@merge_with_config_defaults
@capture_outputs
def qwen35_text_forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
cache_position: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if use_cache and past_key_values is None:
past_key_values = Qwen3_5DynamicCache(config=self.config)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

# mrope: the hard coded `3` is for temporal, height and width.
if position_ids is None:
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
elif position_ids.ndim == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)

if position_ids.ndim == 3 and position_ids.shape[0] == 4:
text_position_ids = position_ids[0]
position_ids = position_ids[1:]
else:
text_position_ids = position_ids[0]

causal_mask = create_causal_mask(
config=self.config,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=text_position_ids,
)
linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position)

hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
layer_mask = (
linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask
)

hidden_states = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=layer_mask,
position_ids=text_position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)

hidden_states = self.norm(hidden_states)

return Qwen3_5ModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)


@dataclass
class Qwen3_5CausalLMOutputForPPO(Qwen3_5CausalLMOutputWithPast):
log_probs: Optional[torch.FloatTensor] = None
entropy: Optional[torch.FloatTensor] = None


def forward_with_torch_backend(
self: Qwen3_5ForConditionalGeneration,
input_ids: torch.LongTensor = None,
labels: Optional[torch.LongTensor] = None,
temperature: float = 1.0,
**kwargs,
) -> tuple | Qwen3_5CausalLMOutputForPPO:
from verl.utils.experimental.torch_functional import FusedLinearForPPO

outputs = self.model(input_ids=input_ids, **kwargs)
hidden_states = outputs[0]

# Loss calculations
if labels is not None:
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
elif input_ids is not None:
rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)
else:
raise RuntimeError(
"To use forward_with_torch_backend, either labels or input_ids must be provided."
)

fused_linear_for_ppo = FusedLinearForPPO()
log_probs, entropy = fused_linear_for_ppo.forward(
hidden_states=hidden_states,
vocab_weights=self.lm_head.weight,
input_ids=rolled_labels,
temperature=temperature,
)
return Qwen3_5CausalLMOutputForPPO(
log_probs=log_probs,
entropy=entropy,
hidden_states=outputs.hidden_states,
)


def forward_with_triton_backend(
self: Qwen3_5ForConditionalGeneration,
input_ids: torch.LongTensor = None,
labels: Optional[torch.LongTensor] = None,
temperature: float = 1.0,
**kwargs,
) -> tuple | Qwen3_5CausalLMOutputForPPO:
from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy

outputs = self.model(input_ids=input_ids, **kwargs)
hidden_states = outputs[0]

# Loss calculations
if labels is not None:
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
elif input_ids is not None:
rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)
else:
raise RuntimeError(
"To use forward_with_triton_backend, either labels or input_ids must be provided."
)

log_probs, entropy = linear_cross_entropy(
hidden_states,
self.lm_head.weight,
rolled_labels,
temperature,
"none",
)
return Qwen3_5CausalLMOutputForPPO(
log_probs=log_probs,
entropy=entropy,
hidden_states=outputs.hidden_states,
)
Loading