From e6aed8fcf1ecd5a0cff597723cea767660040db2 Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Wed, 12 Feb 2025 09:42:59 -0800 Subject: [PATCH 1/6] support DeepSeek MTP spec decode Signed-off-by: Lu Fang --- tests/models/registry.py | 2 + tests/spec_decode/e2e/test_mtp_correctness.py | 313 +++++++ vllm/config.py | 2 +- vllm/model_executor/models/deepseek_mtp.py | 283 ++++++ vllm/model_executor/models/deepseek_v3.py | 812 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/spec_decode/draft_model_runner.py | 7 +- vllm/spec_decode/spec_decode_worker.py | 7 +- vllm/transformers_utils/config.py | 7 +- vllm/transformers_utils/configs/__init__.py | 4 + .../configs/deepseek_mtp.py | 48 ++ .../transformers_utils/configs/deepseek_v3.py | 228 +++++ vllm/worker/worker.py | 2 +- 13 files changed, 1708 insertions(+), 8 deletions(-) create mode 100644 tests/spec_decode/e2e/test_mtp_correctness.py create mode 100644 vllm/model_executor/models/deepseek_mtp.py create mode 100644 vllm/model_executor/models/deepseek_v3.py create mode 100644 vllm/transformers_utils/configs/deepseek_mtp.py create mode 100644 vllm/transformers_utils/configs/deepseek_v3.py diff --git a/tests/models/registry.py b/tests/models/registry.py index 7b1db55494fe..c7eef89d96bc 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -285,6 +285,8 @@ def check_available_online( speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501 "MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m", speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501 + "DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random", + speculative_model="luccafong/deepseek_mtp_draft_random"), # noqa: E501 } _FALLBACK_MODEL = { diff --git a/tests/spec_decode/e2e/test_mtp_correctness.py b/tests/spec_decode/e2e/test_mtp_correctness.py new file mode 100644 index 000000000000..eec37a985886 --- /dev/null +++ b/tests/spec_decode/e2e/test_mtp_correctness.py @@ -0,0 +1,313 @@ +# SPDX-License-Identifier: Apache-2.0 +"""This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. + +However, we still need to verify below scenario could be passed: + * Batch size 1 greedy equality + * Batch size >1 greedy equality + * Test greedy equality under preemption + * Test greedy equality under various number of speculative tokens. + +With those tests, we can say at least, mtp would not break the +correctess for the target model outputs. +""" + +import pytest + +from .conftest import run_equality_correctness_test + +# main model +MAIN_MODEL = "luccafong/deepseek_mtp_main_random" + +# speculative model +SPEC_MODEL = "luccafong/deepseek_mtp_draft_random" + +# max. number of speculative tokens: this corresponds to +# num_heads in the config.json of the speculator model. +MAX_SPEC_TOKENS = 3 + +# precision +PRECISION = "bfloat16" + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, + seed: int): + + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "enforce_eager": False, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + "gpu_memory_utilization": 0.85 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size: int, + output_len: int, seed: int): + """Verify greedy equality with cuda graph enabled and different + batch sizes.""" + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 128, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_e2e_greedy_correctness_with_preemption( + vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": k, + } + # Try a range of num. speculative tokens + for k in range(1, 1 + MAX_SPEC_TOKENS) + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_different_k(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify that mtp speculative decoding produces exact equality + to without spec decode with different values of num_speculative_tokens. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_disable_by_batch_size": 4 + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_disable_queue(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify that mtp speculative decoding produces exact equality + to without spec decode when speculation is disabled for large + batch sizes. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +if __name__ == "__main__": + import pytest + pytest.main([__file__]) diff --git a/vllm/config.py b/vllm/config.py index 1740871e7c10..7acbae8815d8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -757,7 +757,7 @@ def get_hidden_size(self) -> int: def is_deepseek_mla(self) -> bool: return (hasattr(self.hf_text_config, "model_type")) \ and (self.hf_text_config.model_type in \ - ('deepseek_v2', 'deepseek_v3'))\ + ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'))\ and (self.hf_text_config.kv_lora_rank is not None) def get_head_size(self) -> int: diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py new file mode 100644 index 000000000000..6d2bf9af5828 --- /dev/null +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -0,0 +1,283 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Iterable, List, Optional, Set, Tuple + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .deepseek_v3 import (DeepseekV3DecoderLayer, + get_spec_layer_idx_from_weight_name) +from .utils import maybe_prefix + + +class SharedHead(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(hidden_states) + + +class DeepSeekMultiTokenPredictorLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(config.model.hidden_size * 2, + config.model.hidden_size, + bias=False) + self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.block = DeepseekV3DecoderLayer(config, prefix, model_config, + cache_config, quant_config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + assert inputs_embeds is not None + inputs_embeds[positions == 0] = 0 # masking inputs at position=0 + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, _ = self.block(positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + residual=None) + return self.shared_head(hidden_states) + + +class DeepSeekMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + print(f"{config=}") + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + DeepSeekMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + step_idx: int = 0, + ) -> torch.Tensor: + return self.layers[str(self.mtp_start_layer_idx + step_idx)]( + input_ids, + positions, + kv_caches[step_idx], + attn_metadata, + previous_hidden_states, + inputs_embeds, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + step_idx: int = 0, + ) -> torch.Tensor: + mtp_layer = self.layers[str(self.mtp_start_layer_idx + step_idx)] + logits = self.logits_processor(mtp_layer.shared_head.head, + hidden_states, sampling_metadata) + return logits + + +class DeepSeekMTP(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.config = config + self.model_config = config.model + self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + + self.sampler = get_sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, previous_hidden_states, + inputs_embeds, step_idx) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.model.compute_logits(hidden_states, sampling_metadata, + step_idx) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is None: + continue + name = self._rewrite_spec_layer_name(spec_layer, name) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + Add .block for modules in transformer layer block for spec layer + """ + spec_layer_weight_names = [ + "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + ] + spec_layer_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + break + if not spec_layer_weight: + # treat rest weights as weights for transformer layer block + name = name.replace(f"model.layers.{spec_layer}.", + f"model.layers.{spec_layer}.block.") + return name diff --git a/vllm/model_executor/models/deepseek_v3.py b/vllm/model_executor/models/deepseek_v3.py new file mode 100644 index 000000000000..9d50186f3b81 --- /dev/null +++ b/vllm/model_executor/models/deepseek_v3.py @@ -0,0 +1,812 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only DeepseekV3 model.""" +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class DeepseekV3MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DeepseekV3MoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + self.n_shared_experts = config.n_shared_experts + self.routed_scaling_factor = config.routed_scaling_factor + if self.tp_size > config.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.n_routed_experts}.") + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts)) + else: + self.gate.e_score_correction_bias = None + + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias) + + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = DeepseekV3MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + import math + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekV3Attention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + # O projection. + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.use_normal_rope = False + else: + self.use_normal_rope = True + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.attn = Attention(self.num_local_heads, + self.qk_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, + self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, + self.qk_head_dim) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, _ = latent_cache.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, + self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = latent_cache[:, :, self.kv_lora_rank:] + + if self.use_normal_rope: + seq_len = positions.size(0) + ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape + q_pe = q_pe.reshape(seq_len, -1) + k_pe = k_pe.reshape(seq_len, -1) + + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + + if self.use_normal_rope: + q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape) + + q[..., self.qk_nope_head_dim:] = q_pe + k = torch.empty_like(q) + k[..., :self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim:] = k_pe + # padding value to qk_head_dim for alignment + v = torch.nn.functional.pad( + v, [0, self.qk_head_dim - self.v_head_dim], + value=0).view(-1, self.num_local_heads * self.qk_head_dim) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = attn_output.view( + -1, self.num_local_heads, + self.qk_head_dim)[..., :self.v_head_dim].reshape( + -1, self.num_local_heads * self.v_head_dim) + output, _ = self.o_proj(attn_output) + return output + + +class DeepseekV3MLAAttention(nn.Module): + """ + Main reference: DeepseekV2 paper, and FlashInfer Implementation + (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + + For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py + """ + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.mla_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank, + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=self.rotary_emb, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + ) + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + ckq = self.q_a_proj(hidden_states)[0] + hidden_states_or_q_c = self.q_a_layernorm(ckq) + else: + hidden_states_or_q_c = hidden_states + kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache, + attn_metadata) + + +class DeepseekV3DecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + if model_config.use_mla: + attn_cls = DeepseekV3MLAAttention + else: + attn_cls = DeepseekV3Attention + self.self_attn = attn_cls( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = DeepseekV3MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = DeepseekV3MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class DeepseekV3Model(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekV3DecoderLayer( + config, + prefix, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class DeepseekV3ForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = DeepseekV3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, + weight_name: str) -> Optional[int]: + if hasattr(config, + "num_nextn_predict_layers") and (config.num_nextn_predict_layers + > 0): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{layer_idx+i}."): + return layer_idx + i + return None diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 198b6d134718..74c0e3c05c18 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -187,6 +187,7 @@ _SPECULATIVE_DECODING_MODELS = { "EAGLEModel": ("eagle", "EAGLE"), + "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), } diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 3948298db40c..319a2bb437ee 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -274,7 +274,11 @@ def execute_model( kwargs = {"previous_hidden_states": hidden_states} \ if previous_hidden_states is not None else {} + compute_logits_kwargs = {} # Run model + if hasattr(self.model.config, "num_nextn_predict_layers"): + kwargs["step_idx"] = step + compute_logits_kwargs["step_idx"] = step with set_forward_context(model_input.attn_metadata, self.vllm_config): hidden_states = model_executable( @@ -290,7 +294,8 @@ def execute_model( # Compute the logits. logits = self.model.compute_logits(hidden_states, - model_input.sampling_metadata) + model_input.sampling_metadata, + **compute_logits_kwargs) # Sample the next token. output = self.model.sample( diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 8653bece8b5a..e73591846ffc 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -182,9 +182,12 @@ def create_worker( draft_worker_kwargs[ "model_runner_cls"] = TP1DraftModelRunner else: - if draft_model_config.hf_config.model_type == "eagle": + if draft_model_config.hf_config.model_type in [ + "eagle", "deepseek_mtp" + ]: raise NotImplementedError( - "EAGLE does not support TP > 1 yet") + f"{draft_model_config.hf_config.model_type} " + "does not support TP > 1 yet") allow_zero_draft_token_step = False proposer_worker = MultiStepWorker(**draft_worker_kwargs) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 4b76509e4541..25ed320318a3 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -29,9 +29,9 @@ # yapf conflicts with isort for this block # yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, - DbrxConfig, DeepseekVLV2Config, - EAGLEConfig, ExaoneConfig, - H2OVLChatConfig, + DbrxConfig, DeepSeekMTPConfig, + DeepseekVLV2Config, EAGLEConfig, + ExaoneConfig, H2OVLChatConfig, InternVLChatConfig, JAISConfig, MedusaConfig, MllamaConfig, MLPSpeculatorConfig, MPTConfig, @@ -69,6 +69,7 @@ "mlp_speculator": MLPSpeculatorConfig, "medusa": MedusaConfig, "eagle": EAGLEConfig, + "deepseek_mtp": DeepSeekMTPConfig, "exaone": ExaoneConfig, "h2ovl_chat": H2OVLChatConfig, "internvl_chat": InternVLChatConfig, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 9060565596b2..846e50a10957 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -3,6 +3,8 @@ from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.cohere2 import Cohere2Config from vllm.transformers_utils.configs.dbrx import DbrxConfig +from vllm.transformers_utils.configs.deepseek_mtp import DeepSeekMTPConfig +from vllm.transformers_utils.configs.deepseek_v3 import DeepseekV3Config from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config from vllm.transformers_utils.configs.eagle import EAGLEConfig from vllm.transformers_utils.configs.exaone import ExaoneConfig @@ -45,4 +47,6 @@ "SolarConfig", "Telechat2Config", "UltravoxConfig", + "DeepSeekMTPConfig", + "DeepseekV3Config", ] diff --git a/vllm/transformers_utils/configs/deepseek_mtp.py b/vllm/transformers_utils/configs/deepseek_mtp.py new file mode 100644 index 000000000000..324499a6abdd --- /dev/null +++ b/vllm/transformers_utils/configs/deepseek_mtp.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import Union + +from transformers import PretrainedConfig + +from vllm.transformers_utils.configs.deepseek_v3 import DeepseekV3Config + + +class DeepSeekMTPConfig(PretrainedConfig): + model_type = "deepseek_mtp" + + def __init__(self, + model: Union[PretrainedConfig, dict, None] = None, + **kwargs): + print("model: %s", model) + if model is not None: + self.model = DeepseekV3Config.from_dict(model, **kwargs) + else: + self.model = None + + if self.model is not None: + for k, v in kwargs.items(): + if k != "architectures" and k != "model_type" and hasattr( + self.model, k): + setattr(self.model, k, v) + + if "architectures" not in kwargs: + kwargs["architectures"] = ["DeepSeekMTPModel"] + + super().__init__(**kwargs) + + if self.model is not None: + for k, v in self.model.to_dict().items(): + if not hasattr(self, k): + setattr(self, k, v) + # for loading MTP kv cache + self.model.num_hidden_layers = self.model.num_nextn_predict_layers + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + **kwargs, + ) -> "DeepSeekMTPConfig": + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) + return cls.from_dict(config_dict, **kwargs) diff --git a/vllm/transformers_utils/configs/deepseek_v3.py b/vllm/transformers_utils/configs/deepseek_v3.py new file mode 100644 index 000000000000..2037b5d52aa8 --- /dev/null +++ b/vllm/transformers_utils/configs/deepseek_v3.py @@ -0,0 +1,228 @@ +# SPDX-License-Identifier: Apache-2.0 +from transformers.configuration_utils import PretrainedConfig + + +class DeepseekV3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration + of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V3. + Configuration objects inherit from [`PretrainedConfig`] and can be used to + control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 129280): + Vocabulary size of the Deep model. Defines the number of different + tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_nextn_predict_layers (`int`, *optional*, defaults to 1): + Number of nextn predict layers in the DeepSeekV3 Model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the + Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring + the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every + `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers + (embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to + implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use + Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention + (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group + key and value head should be constructed + by meanpooling all the original heads within that group. For more + details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not + specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the + decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used + with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values + attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during + pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to + understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining + results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE + embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a + float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using + this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to + `False`): + Whether to use a bias in the query, key, value and output + projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ```python + >>> from transformers import DeepseekV3Model, DeepseekV3Config + >>> # Initializing a Deepseek-V3 style configuration + >>> configuration = DeepseekV3Config() + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_v3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=129280, + hidden_size=7168, + intermediate_size=18432, + moe_intermediate_size=2048, + num_hidden_layers=61, + num_nextn_predict_layers=1, + num_attention_heads=128, + num_key_value_heads=128, + n_shared_experts=1, + n_routed_experts=256, + ep_size=1, + routed_scaling_factor=2.5, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method='noaux_tc', + n_group=8, + topk_group=4, + num_experts_per_tok=8, + moe_layer_freq=1, + first_k_dense_replace=3, + norm_topk_prob=True, + scoring_func='sigmoid', + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=0, + eos_token_id=1, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_nextn_predict_layers = num_nextn_predict_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 582aa460eb4f..090d8f44d4ef 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -71,7 +71,7 @@ def __init__( or (speculative_config.draft_model_config.model == model_config.model) \ or (speculative_config.draft_model_config.hf_config.model_type - not in ["medusa", "mlp_speculator", "eagle"]) \ + not in ["medusa", "mlp_speculator", "eagle", "deepseek_mtp"]) \ else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner From a2ae6bb26e10d41f550ace2eeeb0640b8924d7ab Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Wed, 12 Feb 2025 09:42:59 -0800 Subject: [PATCH 2/6] cleanup and add comments Signed-off-by: Lu Fang --- vllm/model_executor/models/deepseek_mtp.py | 1 - vllm/spec_decode/draft_model_runner.py | 2 ++ vllm/spec_decode/spec_decode_worker.py | 5 ++--- vllm/transformers_utils/configs/deepseek_mtp.py | 1 - vllm/worker/worker.py | 2 +- 5 files changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 6d2bf9af5828..318032a77423 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -97,7 +97,6 @@ class DeepSeekMultiTokenPredictor(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - print(f"{config=}") self.mtp_start_layer_idx = config.num_hidden_layers self.num_mtp_layers = config.num_nextn_predict_layers # to map the exact layer index from weights diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 319a2bb437ee..b57ca0cde01a 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -277,6 +277,8 @@ def execute_model( compute_logits_kwargs = {} # Run model if hasattr(self.model.config, "num_nextn_predict_layers"): + # for DeepSeek MTP only to use the corresponding layer for + # each step kwargs["step_idx"] = step compute_logits_kwargs["step_idx"] = step with set_forward_context(model_input.attn_metadata, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index e73591846ffc..bf3aa8e40b0d 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -182,9 +182,8 @@ def create_worker( draft_worker_kwargs[ "model_runner_cls"] = TP1DraftModelRunner else: - if draft_model_config.hf_config.model_type in [ - "eagle", "deepseek_mtp" - ]: + if draft_model_config.hf_config.model_type in ( + "eagle", "deepseek_mtp"): raise NotImplementedError( f"{draft_model_config.hf_config.model_type} " "does not support TP > 1 yet") diff --git a/vllm/transformers_utils/configs/deepseek_mtp.py b/vllm/transformers_utils/configs/deepseek_mtp.py index 324499a6abdd..ac2baa9c75cf 100644 --- a/vllm/transformers_utils/configs/deepseek_mtp.py +++ b/vllm/transformers_utils/configs/deepseek_mtp.py @@ -13,7 +13,6 @@ class DeepSeekMTPConfig(PretrainedConfig): def __init__(self, model: Union[PretrainedConfig, dict, None] = None, **kwargs): - print("model: %s", model) if model is not None: self.model = DeepseekV3Config.from_dict(model, **kwargs) else: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 090d8f44d4ef..bd07608f788f 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -71,7 +71,7 @@ def __init__( or (speculative_config.draft_model_config.model == model_config.model) \ or (speculative_config.draft_model_config.hf_config.model_type - not in ["medusa", "mlp_speculator", "eagle", "deepseek_mtp"]) \ + not in ("medusa", "mlp_speculator", "eagle", "deepseek_mtp")) \ else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner From b2834a783ba9fd648c217a3fed2d28f3a278cdb3 Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Wed, 12 Feb 2025 09:42:59 -0800 Subject: [PATCH 3/6] fix multi step and support same model checkpoint for main and spec model --- tests/spec_decode/e2e/test_mtp_correctness.py | 13 +----- vllm/config.py | 40 +++++++++++++++---- vllm/model_executor/models/deepseek_mtp.py | 31 +++++++------- vllm/sequence.py | 2 + vllm/spec_decode/draft_model_runner.py | 10 +++-- vllm/spec_decode/multi_step_worker.py | 6 ++- vllm/spec_decode/spec_decode_worker.py | 21 ++++++++-- vllm/worker/model_runner.py | 1 + vllm/worker/worker.py | 4 +- vllm/worker/worker_base.py | 2 + 10 files changed, 85 insertions(+), 45 deletions(-) diff --git a/tests/spec_decode/e2e/test_mtp_correctness.py b/tests/spec_decode/e2e/test_mtp_correctness.py index eec37a985886..f0fca64fcba4 100644 --- a/tests/spec_decode/e2e/test_mtp_correctness.py +++ b/tests/spec_decode/e2e/test_mtp_correctness.py @@ -27,12 +27,10 @@ # main model MAIN_MODEL = "luccafong/deepseek_mtp_main_random" -# speculative model -SPEC_MODEL = "luccafong/deepseek_mtp_draft_random" # max. number of speculative tokens: this corresponds to -# num_heads in the config.json of the speculator model. -MAX_SPEC_TOKENS = 3 +# num_nextn_predict_layers in the config.json of the speculator model. +MAX_SPEC_TOKENS = 1 # precision PRECISION = "bfloat16" @@ -57,7 +55,6 @@ @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, "num_speculative_tokens": MAX_SPEC_TOKENS, }, ]) @@ -97,12 +94,10 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, "num_speculative_tokens": MAX_SPEC_TOKENS, "disable_logprobs_during_spec_decoding": False, }, { - "speculative_model": SPEC_MODEL, "num_speculative_tokens": MAX_SPEC_TOKENS, "disable_logprobs_during_spec_decoding": True, }, @@ -152,7 +147,6 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, "num_speculative_tokens": MAX_SPEC_TOKENS, }, ]) @@ -196,7 +190,6 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, "num_speculative_tokens": MAX_SPEC_TOKENS, }, ]) @@ -239,7 +232,6 @@ def test_mtp_e2e_greedy_correctness_with_preemption( "test_llm_kwargs", [ { - "speculative_model": SPEC_MODEL, "num_speculative_tokens": k, } # Try a range of num. speculative tokens @@ -282,7 +274,6 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_model": SPEC_MODEL, "num_speculative_tokens": MAX_SPEC_TOKENS, "speculative_disable_by_batch_size": 4 }]) diff --git a/vllm/config.py b/vllm/config.py index 7acbae8815d8..1d29736898f0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -850,8 +850,12 @@ def get_num_attention_heads(self, def get_layers_start_end_indices( self, parallel_config: "ParallelConfig") -> Tuple[int, int]: from vllm.distributed.utils import get_pp_indices - total_num_hidden_layers = getattr(self.hf_text_config, - "num_hidden_layers", 0) + if self.hf_text_config.model_type == "deepseek_mtp": + total_num_hidden_layers = getattr(self.hf_text_config, + "num_nextn_predict_layers", 0) + else: + total_num_hidden_layers = getattr(self.hf_text_config, + "num_hidden_layers", 0) pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size pp_size = parallel_config.pipeline_parallel_size start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) @@ -1667,6 +1671,21 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str + @staticmethod + def hf_config_override( + hf_config: PretrainedConfig + ) -> PretrainedConfig: + if hf_config.model_type == "deepseek_v3": + hf_config.model_type = "deepseek_mtp" + if hf_config.model_type == "deepseek_mtp": + n_predict = getattr( + hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["DeepSeekMTPModel"] + }) + return hf_config + @staticmethod def maybe_create_spec_config( target_model_config: ModelConfig, @@ -1749,12 +1768,16 @@ def maybe_create_spec_config( Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if the necessary conditions are met, else None. """ - if speculative_model is None: if num_speculative_tokens is not None: - raise ValueError("num_speculative_tokens was provided without " + if target_model_config.hf_text_config.model_type == "deepseek_v3": + # use the draft model from the same model: + speculative_model = target_model_config.model + else: + raise ValueError("num_speculative_tokens was provided without " "speculative_model.") - return None + else: + return None if (speculative_disable_by_batch_size is not None and speculative_disable_by_batch_size < 2): @@ -1808,6 +1831,7 @@ def maybe_create_spec_config( max_seq_len_to_capture=target_model_config. max_seq_len_to_capture, max_logprobs=target_model_config.max_logprobs, + hf_overrides=SpeculativeConfig.hf_config_override, ) draft_hf_config = draft_model_config.hf_config @@ -1815,7 +1839,6 @@ def maybe_create_spec_config( if (num_speculative_tokens is not None and hasattr(draft_hf_config, "num_lookahead_tokens")): draft_hf_config.num_lookahead_tokens = num_speculative_tokens - n_predict = getattr(draft_hf_config, "n_predict", None) if n_predict is not None: if num_speculative_tokens is None: @@ -1925,11 +1948,12 @@ def _verify_and_get_draft_model_tensor_parallel_size( # If speculative_draft_tensor_parallel_size is unset then set it # appropriately else verify that it is set correctly. if speculative_draft_tensor_parallel_size is None: - if draft_hf_config.model_type == "mlp_speculator": + if draft_hf_config.model_type in ("mlp_speculator", "deepseek_mtp"): speculative_draft_tensor_parallel_size = 1 if target_parallel_config.tensor_parallel_size > 1: logger.warning( - "MLPSpeculator cannot currently be run with tp>1; " + f"{draft_hf_config.model_type} cannot currently " + "be run with tp>1; " "setting speculative_draft_tensor_parallel_size=1") else: speculative_draft_tensor_parallel_size = \ diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 318032a77423..c862eb9bfb15 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -58,8 +58,8 @@ def __init__( self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.eh_proj = nn.Linear(config.model.hidden_size * 2, - config.model.hidden_size, + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, bias=False) self.shared_head = SharedHead(config=config, quant_config=quant_config) self.block = DeepseekV3DecoderLayer(config, prefix, model_config, @@ -73,11 +73,13 @@ def forward( attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, ) -> torch.Tensor: if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) assert inputs_embeds is not None - inputs_embeds[positions == 0] = 0 # masking inputs at position=0 + inputs_embeds[positions <= spec_step_index] = 0 + # masking inputs at position<=k, token from k+1 inputs_embeds = self.enorm(inputs_embeds) previous_hidden_states = self.hnorm(previous_hidden_states) @@ -123,24 +125,25 @@ def forward( attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, - step_idx: int = 0, + spec_step_idx: int = 0, ) -> torch.Tensor: - return self.layers[str(self.mtp_start_layer_idx + step_idx)]( + return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]( input_ids, positions, - kv_caches[step_idx], + kv_caches[spec_step_idx], attn_metadata, previous_hidden_states, inputs_embeds, + spec_step_idx, ) def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - step_idx: int = 0, + spec_step_idx: int = 0, ) -> torch.Tensor: - mtp_layer = self.layers[str(self.mtp_start_layer_idx + step_idx)] + mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)] logits = self.logits_processor(mtp_layer.shared_head.head, hidden_states, sampling_metadata) return logits @@ -150,9 +153,7 @@ class DeepSeekMTP(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config - self.config = config - self.model_config = config.model + self.config = vllm_config.model_config.hf_config self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config, prefix=maybe_prefix( prefix, "model")) @@ -168,21 +169,21 @@ def forward( previous_hidden_states: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - step_idx: int = 0, + spec_step_idx: int = 0, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, previous_hidden_states, - inputs_embeds, step_idx) + inputs_embeds, spec_step_idx) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - step_idx: int = 0, + spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: return self.model.compute_logits(hidden_states, sampling_metadata, - step_idx) + spec_step_idx) def sample( self, diff --git a/vllm/sequence.py b/vllm/sequence.py index 534b9e60610a..b60ee3304f84 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1304,6 +1304,8 @@ class ExecuteModelRequest( previous_hidden_states: Optional[HiddenStates] = None # The number of forward steps to run. num_steps: int = 1 + # The step index for spec model input. + spec_step_idx: int = 0 # Finished request ids since last step. finished_requests_ids: List[str] = msgspec.field(default_factory=list) # The last sampled token ids for multi step decoding. diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index b57ca0cde01a..09a5f77a3872 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -175,6 +175,7 @@ def execute_model( previous_hidden_states: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs, ) -> Optional[List[SamplerOutput]]: """Executes num_steps forward passes with advacement of input tensors on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. @@ -271,7 +272,7 @@ def execute_model( for step in range(num_steps): multi_modal_kwargs = model_input.multi_modal_kwargs or {} - kwargs = {"previous_hidden_states": hidden_states} \ + model_execute_kwargs = {"previous_hidden_states": hidden_states} \ if previous_hidden_states is not None else {} compute_logits_kwargs = {} @@ -279,8 +280,9 @@ def execute_model( if hasattr(self.model.config, "num_nextn_predict_layers"): # for DeepSeek MTP only to use the corresponding layer for # each step - kwargs["step_idx"] = step - compute_logits_kwargs["step_idx"] = step + spec_step_idx = kwargs.get("spec_step_idx", 0) + model_execute_kwargs["spec_step_idx"] = spec_step_idx + compute_logits_kwargs["spec_step_idx"] = spec_step_idx with set_forward_context(model_input.attn_metadata, self.vllm_config): hidden_states = model_executable( @@ -291,7 +293,7 @@ def execute_model( intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), - **kwargs, + **model_execute_kwargs, ) # Compute the logits. diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 5474917a6fab..55d3b20a284f 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -2,7 +2,7 @@ import copy import weakref -from typing import Dict, List, Set, Tuple +from typing import Dict, List, Set, Tuple, Optional import torch @@ -95,9 +95,10 @@ def sampler_output( # TODO: Remove this branch once DraftModelRunner supports TP>1 # and other restrictions that are part of DraftModelRunner's # supports_gpu_multi_step(..) - for _ in range(sample_len): + for i in range(sample_len): model_output: List[SamplerOutput] = self.worker.execute_model( execute_model_req=expanded_request) + expanded_request.spec_step_idx += 1 assert (len(model_output) == 1 ), "composing multistep workers not supported" model_output = model_output[0] @@ -106,6 +107,7 @@ def sampler_output( model_output, expanded_request.seq_group_metadata_list, indices_of_seq_with_bonus_tokens) model_outputs.append(model_output) + expanded_request.spec_step_idx = 0 # move indices to device to avoid stream sync indices_of_seq_with_bonus_tokens = torch.tensor( diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index bf3aa8e40b0d..f7c9b33707f0 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -107,6 +107,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": typical_acceptance_sampler_posterior_alpha, disable_logprobs=speculative_config.disable_logprobs, disable_log_stats=speculative_config.disable_log_stats, + num_speculative_tokens=speculative_config.num_speculative_tokens, ) return spec_decode_worker @@ -152,9 +153,11 @@ def create_worker( typical_acceptance_sampler_posterior_alpha: float, disable_logprobs: bool, disable_log_stats: bool, + num_speculative_tokens: int, ) -> "SpecDecodeWorker": allow_zero_draft_token_step = True + num_spec_prefill_steps = 1 ngram_prompt_lookup_max = ( draft_worker_kwargs.pop("ngram_prompt_lookup_max")) ngram_prompt_lookup_min = ( @@ -190,6 +193,8 @@ def create_worker( allow_zero_draft_token_step = False proposer_worker = MultiStepWorker(**draft_worker_kwargs) + if draft_model_config.hf_config.model_type == "deepseek_mtp": + num_spec_prefill_steps = num_speculative_tokens proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( proposer_worker, draft_tp, target_tp) @@ -241,7 +246,9 @@ def create_worker( disable_log_stats=disable_log_stats, disable_by_batch_size=disable_by_batch_size, spec_decode_sampler=spec_decode_sampler, - allow_zero_draft_token_step=allow_zero_draft_token_step) + allow_zero_draft_token_step=allow_zero_draft_token_step, + num_spec_prefill_steps=num_spec_prefill_steps, + ) def __init__( self, @@ -254,6 +261,7 @@ def __init__( metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, allow_zero_draft_token_step: Optional[bool] = True, + num_spec_prefill_steps: int = 1, ): """ Create a SpecDecodeWorker. @@ -284,6 +292,10 @@ def __init__( allow_zero_draft_token_step: whether to allow a step where the draft model generates no draft token; should disallow when the tp of draft model is larger than 1 (TODO: #5814) + num_spec_prefill_steps: number of speculative prefill steps to run + before the speculative decoding starts. This is only used when + the draft model is a deepseek_mtp model that requires prefill + kv cache separately for each step layer. """ self.proposer_worker = proposer_worker self.scorer_worker = scorer_worker @@ -316,6 +328,7 @@ def __init__( self.previous_hidden_states: Optional[HiddenStates] = None self._disable_logprobs = disable_logprobs self._disable_log_stats = disable_log_stats + self._num_spec_prefill_steps = num_spec_prefill_steps def init_device(self) -> None: """Initialize both scorer and proposer models. @@ -664,8 +677,10 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, execute_model_req.previous_hidden_states = \ prepare_prefill_hidden_states( sampler_output.prefill_hidden_states) - - self.proposer_worker.execute_model(execute_model_req) + execute_model_req.spec_step_idx = 0 + for _ in range(self._num_spec_prefill_steps): + self.proposer_worker.execute_model(execute_model_req) + execute_model_req.spec_step_idx += 1 sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( execute_model_req=execute_model_req, sampler_output=sampler_output) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c7814f17375b..eb4234432a82 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1649,6 +1649,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError("num_steps > 1 is not supported in ModelRunner") diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index bd07608f788f..ff38e3bfc207 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -68,8 +68,8 @@ def __init__( speculative_config = self.speculative_config model_config = self.model_config speculative_args = {} if speculative_config is None \ - or (speculative_config.draft_model_config.model == - model_config.model) \ + or (speculative_config.draft_model_config.hf_config.model_type == + model_config.hf_config.model_type) \ or (speculative_config.draft_model_config.hf_config.model_type not in ("medusa", "mlp_speculator", "eagle", "deepseek_mtp")) \ else {"return_hidden_states": True} diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 819b81fbfdbb..89b4652f201f 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -392,6 +392,8 @@ def execute_model( model_input, worker_input, kwargs = inputs num_steps = worker_input.num_steps + if execute_model_req is not None: + kwargs["spec_step_idx"] = execute_model_req.spec_step_idx self.execute_worker(worker_input) From 859893577b71ab001281ada5fe3e94adb286bf6f Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Wed, 12 Feb 2025 09:42:59 -0800 Subject: [PATCH 4/6] allow mla for gpu multi step spec decode --- vllm/spec_decode/draft_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 09a5f77a3872..a7eb4baef11d 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -153,7 +153,7 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): return False # TODO: Add support for other attn backends - if self.attn_backend.get_name() != "FLASH_ATTN": + if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"): return False # TODO: Add support for LORA From 7a159b1ca592b3f552e05775dbef56209e0e7683 Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Wed, 12 Feb 2025 09:53:39 -0800 Subject: [PATCH 5/6] migrate changes to deepseekv3 and remove unused config --- vllm/model_executor/models/deepseek_mtp.py | 8 +- vllm/model_executor/models/deepseek_v2.py | 27 +- vllm/model_executor/models/deepseek_v3.py | 812 ------------------ vllm/transformers_utils/config.py | 3 +- vllm/transformers_utils/configs/__init__.py | 4 - .../configs/deepseek_mtp.py | 47 - .../transformers_utils/configs/deepseek_v3.py | 228 ----- 7 files changed, 22 insertions(+), 1107 deletions(-) delete mode 100644 vllm/model_executor/models/deepseek_v3.py delete mode 100644 vllm/transformers_utils/configs/deepseek_mtp.py delete mode 100644 vllm/transformers_utils/configs/deepseek_v3.py diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index c862eb9bfb15..339210b478c1 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -18,7 +18,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .deepseek_v3 import (DeepseekV3DecoderLayer, +from .deepseek_v2 import (DeepseekV2DecoderLayer, get_spec_layer_idx_from_weight_name) from .utils import maybe_prefix @@ -62,7 +62,7 @@ def __init__( config.hidden_size, bias=False) self.shared_head = SharedHead(config=config, quant_config=quant_config) - self.block = DeepseekV3DecoderLayer(config, prefix, model_config, + self.block = DeepseekV2DecoderLayer(config, prefix, model_config, cache_config, quant_config) def forward( @@ -78,8 +78,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) assert inputs_embeds is not None - inputs_embeds[positions <= spec_step_index] = 0 - # masking inputs at position<=k, token from k+1 + # masking inputs at position 0, as not needed by MTP + inputs_embeds[positions == 0] = 0 inputs_embeds = self.enorm(inputs_embeds) previous_hidden_states = self.hnorm(previous_hidden_states) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index fd0e58fa1458..5e5ae11b2b45 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -640,8 +640,7 @@ def forward( "residual": residual }) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + return hidden_states + residual class DeepseekV2ForCausalLM(nn.Module, SupportsPP): @@ -684,7 +683,7 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, + logits = self.logits_processor(self.lm_head, self.model.norm(hidden_states), sampling_metadata) return logits @@ -732,13 +731,9 @@ def load_weights(self, weights: Iterable[Tuple[str, if "rotary_emb.inv_freq" in name: continue - # TODO(simon): support nextn predict layers - if hasattr(self.config, "num_nextn_predict_layers" - ) and self.config.num_nextn_predict_layers > 0: - assert self.config.num_nextn_predict_layers == 1 - layer_idx = self.config.num_hidden_layers - if name.startswith(f"model.layers.{layer_idx}"): - continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -805,3 +800,15 @@ def load_weights(self, weights: Iterable[Tuple[str, class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass + + +def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, + weight_name: str) -> Optional[int]: + if hasattr(config, + "num_nextn_predict_layers") and (config.num_nextn_predict_layers + > 0): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{layer_idx+i}."): + return layer_idx + i + return None diff --git a/vllm/model_executor/models/deepseek_v3.py b/vllm/model_executor/models/deepseek_v3.py deleted file mode 100644 index 9d50186f3b81..000000000000 --- a/vllm/model_executor/models/deepseek_v3.py +++ /dev/null @@ -1,812 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. -# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only DeepseekV3 model.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union - -import torch -from torch import nn -from transformers import PretrainedConfig - -from vllm.attention import Attention, AttentionMetadata -from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors - -from .interfaces import SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) - - -class DeepseekV3MLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True, - prefix: str = "", - ) -> None: - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj") - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -class DeepseekV3MoE(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() - self.routed_scaling_factor = config.routed_scaling_factor - self.n_shared_experts = config.n_shared_experts - self.routed_scaling_factor = config.routed_scaling_factor - if self.tp_size > config.n_routed_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.n_routed_experts}.") - - if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") - - self.gate = ReplicatedLinear(config.hidden_size, - config.n_routed_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") - if config.topk_method == "noaux_tc": - self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts)) - else: - self.gate.e_score_correction_bias = None - - self.experts = FusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias) - - if config.n_shared_experts is not None: - intermediate_size = (config.moe_intermediate_size * - config.n_shared_experts) - self.shared_experts = DeepseekV3MLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False, - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - if self.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits) * self.routed_scaling_factor - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) - - return final_hidden_states.view(num_tokens, hidden_dim) - - -def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: - import math - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - -class DeepseekV3Attention(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - hidden_size: int, - num_heads: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - v_head_dim: int, - q_lora_rank: int, - kv_lora_rank: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = hidden_size - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim - self.v_head_dim = v_head_dim - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - self.num_heads = num_heads - tp_size = get_tensor_model_parallel_world_size() - assert num_heads % tp_size == 0 - self.num_local_heads = num_heads // tp_size - self.scaling = self.qk_head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear(self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_a_proj") - self.q_a_layernorm = RMSNorm(self.q_lora_rank, - eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(q_lora_rank, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") - else: - self.q_proj = ColumnParallelLinear(self.hidden_size, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") - - self.kv_a_proj_with_mqa = ReplicatedLinear( - self.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa") - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) - self.kv_b_proj = ColumnParallelLinear( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj") - # O projection. - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") - if rope_scaling: - rope_scaling["rope_type"] = 'deepseek_yarn' - self.use_normal_rope = False - else: - self.use_normal_rope = True - self.rotary_emb = get_rope(qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False) - - if rope_scaling: - mscale_all_dim = rope_scaling.get("mscale_all_dim", False) - scaling_factor = rope_scaling["factor"] - mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) - self.scaling = self.scaling * mscale * mscale - - self.attn = Attention(self.num_local_heads, - self.qk_head_dim, - self.scaling, - num_kv_heads=self.num_local_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - if self.q_lora_rank is not None: - q = self.q_a_proj(hidden_states)[0] - q = self.q_a_layernorm(q) - q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, - self.qk_head_dim) - else: - q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, - self.qk_head_dim) - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], - dim=-1) - latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] - kv_a, _ = latent_cache.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - latent_cache = latent_cache.unsqueeze(1) - kv_a = self.kv_a_layernorm(kv_a.contiguous()) - kv = self.kv_b_proj(kv_a)[0] - kv = kv.view(-1, self.num_local_heads, - self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank:] - - if self.use_normal_rope: - seq_len = positions.size(0) - ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape - q_pe = q_pe.reshape(seq_len, -1) - k_pe = k_pe.reshape(seq_len, -1) - - q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - - if self.use_normal_rope: - q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape) - - q[..., self.qk_nope_head_dim:] = q_pe - k = torch.empty_like(q) - k[..., :self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe - # padding value to qk_head_dim for alignment - v = torch.nn.functional.pad( - v, [0, self.qk_head_dim - self.v_head_dim], - value=0).view(-1, self.num_local_heads * self.qk_head_dim) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) - attn_output = attn_output.view( - -1, self.num_local_heads, - self.qk_head_dim)[..., :self.v_head_dim].reshape( - -1, self.num_local_heads * self.v_head_dim) - output, _ = self.o_proj(attn_output) - return output - - -class DeepseekV3MLAAttention(nn.Module): - """ - Main reference: DeepseekV2 paper, and FlashInfer Implementation - (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). - - For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py - """ - - def __init__( - self, - config: PretrainedConfig, - hidden_size: int, - num_heads: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - v_head_dim: int, - q_lora_rank: Optional[int], - kv_lora_rank: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = hidden_size - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim - self.v_head_dim = v_head_dim - - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - - self.num_heads = num_heads - tp_size = get_tensor_model_parallel_world_size() - assert num_heads % tp_size == 0 - self.num_local_heads = num_heads // tp_size - - self.scaling = self.qk_head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear(self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_a_proj") - self.q_a_layernorm = RMSNorm(self.q_lora_rank, - eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(q_lora_rank, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") - else: - self.q_proj = ColumnParallelLinear(self.hidden_size, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") - - self.kv_a_proj_with_mqa = ReplicatedLinear( - self.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa") - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) - self.kv_b_proj = ColumnParallelLinear( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj") - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") - - if rope_scaling: - rope_scaling["rope_type"] = 'deepseek_yarn' - self.rotary_emb = get_rope(qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False) - if rope_scaling: - mscale_all_dim = rope_scaling.get("mscale_all_dim", False) - scaling_factor = rope_scaling["factor"] - mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) - self.scaling = self.scaling * mscale * mscale - - self.mla_attn = Attention( - num_heads=self.num_local_heads, - head_size=self.kv_lora_rank, - scale=self.scaling, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_mla=True, - # MLA Args - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - qk_head_dim=self.qk_head_dim, - v_head_dim=self.v_head_dim, - rotary_emb=self.rotary_emb, - q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, - kv_b_proj=self.kv_b_proj, - o_proj=self.o_proj, - ) - - self.prefix = prefix - self.debug_layer_idx = int(self.prefix.split(".")[-2]) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - if self.q_lora_rank is not None: - ckq = self.q_a_proj(hidden_states)[0] - hidden_states_or_q_c = self.q_a_layernorm(ckq) - else: - hidden_states_or_q_c = hidden_states - kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) - return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache, - attn_metadata) - - -class DeepseekV3DecoderLayer(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - prefix: str, - model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - # DecoderLayers are created with `make_layers` which passes the prefix - # with the layer's index. - layer_idx = int(prefix.split(sep='.')[-1]) - if model_config.use_mla: - attn_cls = DeepseekV3MLAAttention - else: - attn_cls = DeepseekV3Attention - self.self_attn = attn_cls( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=config.v_head_dim, - q_lora_rank=config.q_lora_rank - if hasattr(config, "q_lora_rank") else None, - kv_lora_rank=config.kv_lora_rank, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekV3MoE( - config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - else: - self.mlp = DeepseekV3MLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], - ) -> torch.Tensor: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - - -@support_torch_compile -class DeepseekV3Model(nn.Module): - - fall_back_to_pt_during_load = False - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - if get_pp_group().is_first_rank: - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - else: - self.embed_tokens = PPMissingLayer() - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: DeepseekV3DecoderLayer( - config, - prefix, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - ), - prefix=f"{prefix}.layers") - - if get_pp_group().is_last_rank: - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - else: - self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) - - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - -class DeepseekV3ForCausalLM(nn.Module, SupportsPP): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.model = DeepseekV3Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) - - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) - if spec_layer is not None: - continue # skip spec decode layers for main model - for (param_name, weight_name, shard_id) in stacked_params_mapping: - # Skip non-stacked layers and experts (experts handled below). - if weight_name not in name: - continue - # We have mlp.experts[0].gate_proj in the checkpoint. - # Since we handle the experts below in expert_params_mapping, - # we need to skip here BEFORE we update the name, otherwise - # name will be updated to mlp.experts[0].gate_up_proj, which - # will then be updated below in expert_params_mapping - # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, - weight_name: str) -> Optional[int]: - if hasattr(config, - "num_nextn_predict_layers") and (config.num_nextn_predict_layers - > 0): - layer_idx = config.num_hidden_layers - for i in range(config.num_nextn_predict_layers): - if weight_name.startswith(f"model.layers.{layer_idx+i}."): - return layer_idx + i - return None diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 25ed320318a3..eda5a8143d12 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -29,7 +29,7 @@ # yapf conflicts with isort for this block # yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, - DbrxConfig, DeepSeekMTPConfig, + DbrxConfig, DeepseekVLV2Config, EAGLEConfig, ExaoneConfig, H2OVLChatConfig, InternVLChatConfig, JAISConfig, @@ -69,7 +69,6 @@ "mlp_speculator": MLPSpeculatorConfig, "medusa": MedusaConfig, "eagle": EAGLEConfig, - "deepseek_mtp": DeepSeekMTPConfig, "exaone": ExaoneConfig, "h2ovl_chat": H2OVLChatConfig, "internvl_chat": InternVLChatConfig, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 846e50a10957..9060565596b2 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -3,8 +3,6 @@ from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.cohere2 import Cohere2Config from vllm.transformers_utils.configs.dbrx import DbrxConfig -from vllm.transformers_utils.configs.deepseek_mtp import DeepSeekMTPConfig -from vllm.transformers_utils.configs.deepseek_v3 import DeepseekV3Config from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config from vllm.transformers_utils.configs.eagle import EAGLEConfig from vllm.transformers_utils.configs.exaone import ExaoneConfig @@ -47,6 +45,4 @@ "SolarConfig", "Telechat2Config", "UltravoxConfig", - "DeepSeekMTPConfig", - "DeepseekV3Config", ] diff --git a/vllm/transformers_utils/configs/deepseek_mtp.py b/vllm/transformers_utils/configs/deepseek_mtp.py deleted file mode 100644 index ac2baa9c75cf..000000000000 --- a/vllm/transformers_utils/configs/deepseek_mtp.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import os -from typing import Union - -from transformers import PretrainedConfig - -from vllm.transformers_utils.configs.deepseek_v3 import DeepseekV3Config - - -class DeepSeekMTPConfig(PretrainedConfig): - model_type = "deepseek_mtp" - - def __init__(self, - model: Union[PretrainedConfig, dict, None] = None, - **kwargs): - if model is not None: - self.model = DeepseekV3Config.from_dict(model, **kwargs) - else: - self.model = None - - if self.model is not None: - for k, v in kwargs.items(): - if k != "architectures" and k != "model_type" and hasattr( - self.model, k): - setattr(self.model, k, v) - - if "architectures" not in kwargs: - kwargs["architectures"] = ["DeepSeekMTPModel"] - - super().__init__(**kwargs) - - if self.model is not None: - for k, v in self.model.to_dict().items(): - if not hasattr(self, k): - setattr(self, k, v) - # for loading MTP kv cache - self.model.num_hidden_layers = self.model.num_nextn_predict_layers - - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: Union[str, os.PathLike], - **kwargs, - ) -> "DeepSeekMTPConfig": - config_dict, kwargs = cls.get_config_dict( - pretrained_model_name_or_path, **kwargs) - return cls.from_dict(config_dict, **kwargs) diff --git a/vllm/transformers_utils/configs/deepseek_v3.py b/vllm/transformers_utils/configs/deepseek_v3.py deleted file mode 100644 index 2037b5d52aa8..000000000000 --- a/vllm/transformers_utils/configs/deepseek_v3.py +++ /dev/null @@ -1,228 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from transformers.configuration_utils import PretrainedConfig - - -class DeepseekV3Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration - of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek - model according to the specified arguments, defining the model - architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the DeepSeek-V3. - Configuration objects inherit from [`PretrainedConfig`] and can be used to - control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - vocab_size (`int`, *optional*, defaults to 129280): - Vocabulary size of the Deep model. Defines the number of different - tokens that can be represented by the - `inputs_ids` passed when calling [`DeepseekV3Model`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - moe_intermediate_size (`int`, *optional*, defaults to 1407): - Dimension of the MoE representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer decoder. - num_nextn_predict_layers (`int`, *optional*, defaults to 1): - Number of nextn predict layers in the DeepSeekV3 Model. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the - Transformer decoder. - n_shared_experts (`int`, *optional*, defaults to None): - Number of shared experts, None means dense model. - n_routed_experts (`int`, *optional*, defaults to None): - Number of routed experts, None means dense model. - routed_scaling_factor (`float`, *optional*, defaults to 1.0): - Scaling factor or routed experts. - topk_method (`str`, *optional*, defaults to `gready`): - Topk method used in routed gate. - n_group (`int`, *optional*, defaults to None): - Number of groups for routed experts. - topk_group (`int`, *optional*, defaults to None): - Number of selected groups for each token(for each token, ensuring - the selected experts is only within `topk_group` groups). - num_experts_per_tok (`int`, *optional*, defaults to None): - Number of selected experts, None means dense model. - moe_layer_freq (`int`, *optional*, defaults to 1): - The frequency of the MoE layer: one expert layer for every - `moe_layer_freq - 1` dense layers. - first_k_dense_replace (`int`, *optional*, defaults to 0): - Number of dense layers in shallow layers - (embed->dense->dense->...->dense->moe->moe...->lm_head). - \--k dense layers--/ - norm_topk_prob (`bool`, *optional*, defaults to False): - Whether to normalize the weights of the routed experts. - scoring_func (`str`, *optional*, defaults to 'softmax'): - Method of computing expert weights. - aux_loss_alpha (`float`, *optional*, defaults to 0.001): - Auxiliary loss weight coefficient. - seq_aux = (`bool`, *optional*, defaults to True): - Whether to compute the auxiliary loss for each individual sample. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to - implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use - Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention - (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group - key and value head should be constructed - by meanpooling all the original heads within that group. For more - details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not - specified, will default to - `num_attention_heads`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the - decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used - with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for - initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values - attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*): - Padding token id. - bos_token_id (`int`, *optional*, defaults to 1): - Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 2): - End of stream token id. - pretraining_tp (`int`, *optional*, defaults to 1): - Experimental feature. Tensor parallelism rank used during - pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/parallelism) to - understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining - results. Please refer to [this - issue](https://github.com/pytorch/pytorch/issues/76232). - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE - embeddings. Currently supports two scaling - strategies: linear and dynamic. Their scaling factor must be a - float greater than 1. The expected format is - `{"type": strategy name, "factor": scaling factor}`. When using - this flag, don't update - `max_position_embeddings` to the expected new maximum. - attention_bias (`bool`, defaults to `False`, *optional*, defaults to - `False`): - Whether to use a bias in the query, key, value and output - projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - ```python - >>> from transformers import DeepseekV3Model, DeepseekV3Config - >>> # Initializing a Deepseek-V3 style configuration - >>> configuration = DeepseekV3Config() - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "deepseek_v3" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=129280, - hidden_size=7168, - intermediate_size=18432, - moe_intermediate_size=2048, - num_hidden_layers=61, - num_nextn_predict_layers=1, - num_attention_heads=128, - num_key_value_heads=128, - n_shared_experts=1, - n_routed_experts=256, - ep_size=1, - routed_scaling_factor=2.5, - kv_lora_rank=512, - q_lora_rank=1536, - qk_rope_head_dim=64, - v_head_dim=128, - qk_nope_head_dim=128, - topk_method='noaux_tc', - n_group=8, - topk_group=4, - num_experts_per_tok=8, - moe_layer_freq=1, - first_k_dense_replace=3, - norm_topk_prob=True, - scoring_func='sigmoid', - aux_loss_alpha=0.001, - seq_aux=True, - hidden_act="silu", - max_position_embeddings=4096, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=None, - bos_token_id=0, - eos_token_id=1, - pretraining_tp=1, - tie_word_embeddings=False, - rope_theta=10000.0, - rope_scaling=None, - attention_bias=False, - attention_dropout=0.0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.moe_intermediate_size = moe_intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_nextn_predict_layers = num_nextn_predict_layers - self.num_attention_heads = num_attention_heads - self.n_shared_experts = n_shared_experts - self.n_routed_experts = n_routed_experts - self.ep_size = ep_size - self.routed_scaling_factor = routed_scaling_factor - self.kv_lora_rank = kv_lora_rank - self.q_lora_rank = q_lora_rank - self.qk_rope_head_dim = qk_rope_head_dim - self.v_head_dim = v_head_dim - self.qk_nope_head_dim = qk_nope_head_dim - self.topk_method = topk_method - self.n_group = n_group - self.topk_group = topk_group - self.num_experts_per_tok = num_experts_per_tok - self.moe_layer_freq = moe_layer_freq - self.first_k_dense_replace = first_k_dense_replace - self.norm_topk_prob = norm_topk_prob - self.scoring_func = scoring_func - self.aux_loss_alpha = aux_loss_alpha - self.seq_aux = seq_aux - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_theta = rope_theta - self.rope_scaling = rope_scaling - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) From a87488e21b3cbe484c9524540a7d6ed40ffc3435 Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Wed, 12 Feb 2025 10:38:35 -0800 Subject: [PATCH 6/6] support k > 1 --- vllm/engine/output_processor/multi_step.py | 2 + vllm/model_executor/models/deepseek_mtp.py | 64 ++++++++++++++++++- vllm/sequence.py | 34 +++++++--- vllm/spec_decode/draft_model_runner.py | 58 ++++++++++++----- vllm/spec_decode/multi_step_worker.py | 29 ++++++--- vllm/spec_decode/spec_decode_worker.py | 73 +++++++++++++--------- vllm/spec_decode/top1_proposer.py | 36 ++++++++--- 7 files changed, 226 insertions(+), 70 deletions(-) diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 8ceef855e020..07d9be6ac3f2 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -185,6 +185,8 @@ def _process_seq_outputs(self, seq: Sequence, is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0 # Incrementally append tokens to the sequence, as if we had only one new # token. + # TODO: add an attribute here for reset, can be set at output processor + seq.data.reset_new_appended_tokens() for output_token_id, output_logprob in zip(output_token_ids, output_logprobs): seq.append_token_id( diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 339210b478c1..98d362a374b1 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -18,6 +18,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.model_executor.layers.sampler import SamplerOutput + from .deepseek_v2 import (DeepseekV2DecoderLayer, get_spec_layer_idx_from_weight_name) from .utils import maybe_prefix @@ -75,6 +77,12 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, spec_step_index: int = 0, ) -> torch.Tensor: + + # print(f"draft {attn_metadata=}") + # print(f"draft {input_ids=}") + # print(f"draft {positions=}") + # print(f"draft {previous_hidden_states.shape=}") + # print(f"draft {kv_cache.shape=}") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) assert inputs_embeds is not None @@ -116,7 +124,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): }) self.logits_processor = LogitsProcessor(config.vocab_size) - def forward( self, input_ids: torch.Tensor, @@ -185,6 +192,49 @@ def compute_logits( return self.model.compute_logits(hidden_states, sampling_metadata, spec_step_idx) + def generate_proposals( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> List[SamplerOutput]: + hidden_states = previous_hidden_states + cur_input_ids = input_ids + outputs = [] + for i in range(self.model.num_mtp_layers): + hidden_states = self.forward(cur_input_ids, positions, kv_caches, + attn_metadata, hidden_states, + spec_step_idx=i) + logits = self.compute_logits( + hidden_states=hidden_states, + sampling_metadata=sampling_metadata, + spec_step_idx=i + ) + output = self.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + outputs.append(output) + cur_input_ids = self.get_next_layer_input( + input_ids, attn_metadata, output) + return outputs + + def get_next_layer_input( + self, + input_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + outputs: SamplerOutput + ) -> Tuple[torch.Tensor, SamplerOutput]: + assert outputs.sampled_token_ids is not None + assert attn_metadata.query_start_loc is not None + input_ids = input_ids.roll(shifts=-1, dims=0) + query_end_loc = attn_metadata.query_start_loc[1:] - 1 + input_ids[query_end_loc] = outputs.sampled_token_ids[:, 0] + return input_ids + def sample( self, logits: torch.Tensor, @@ -193,6 +243,18 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + def get_last_sample_output( + self, + output: SamplerOutput, + attn_metadata: AttentionMetadata, + ) -> SamplerOutput: + query_end_loc = attn_metadata.query_start_loc[1:] - 1 + output.sampled_token_ids = output.sampled_token_ids[query_end_loc] + if output.sampled_token_probs is not None: + output.sampled_token_probs = output.sampled_token_probs[query_end_loc] + return output + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/sequence.py b/vllm/sequence.py index b60ee3304f84..4ea53e94f75d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -365,6 +365,9 @@ def get_delta_and_reset(self) -> SequenceDataDelta: self._new_appended_tokens = [] return delta + def reset_new_appended_tokens(self) -> None: + self._new_appended_tokens = [] + def apply_delta(self, delta: SequenceDataDelta): self._num_computed_tokens = delta.new_num_computed_tokens self._cumulative_logprob = delta.new_cumulative_logprob @@ -1209,12 +1212,14 @@ class HiddenStates(msgspec.Struct, array_like=True, # last proposed token is accepted (i.e., in case of bonus tokens). For the # case of no bonus tokens, these are ignored. second_last_token_hidden_states: Optional[torch.Tensor] = None - + # for varseq + hidden_states_seq_indices: Optional[torch.Tensor] = None _seq_ids: List[int] = msgspec.field(default_factory=list) def __post_init__(self): if self.seq_group_metadata_list is not None: - assert len(self.seq_group_metadata_list) == len(self.hidden_states) + # TODO: add assertion for the group metadata list with var seqs + # assert len(self.seq_group_metadata_list) == len(self.hidden_states) self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) @property @@ -1228,8 +1233,18 @@ def update(self, """Update hidden states from target model invocation. Only used for decode steps""" assert len(seq_group_metadata_list) == len(hidden_states) - self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) + last_seq_indice = len(self._seq_ids) + new_seq_ids = get_all_seq_ids(seq_group_metadata_list) + self._seq_ids.extend(new_seq_ids) self.hidden_states = torch.cat([self.hidden_states, hidden_states]) + if self.hidden_states_seq_indices is not None: + updated_indices = list(range(last_seq_indice, len(self._seq_ids))) + # assume new updated are hidden states from prefill which is always length of 1 + new_seq_indices = torch.tensor(updated_indices, device=self.hidden_states_seq_indices.device) + self.hidden_states_seq_indices = torch.concat([ + self.hidden_states_seq_indices, + new_seq_indices, + ]) if self.second_last_token_hidden_states is not None: # Adding dummy hidden_states to this to maintain same shape @@ -1252,10 +1267,15 @@ def prune(self, if seq_ids != self._seq_ids: # Batch contents changed - prune removed sequences. index = [self._seq_ids.index(seq_id) for seq_id in seq_ids] - self.hidden_states = self.hidden_states[index] - if self.second_last_token_hidden_states is not None: - self.second_last_token_hidden_states = self\ - .second_last_token_hidden_states[index] + if self.hidden_states_seq_indices is not None: + target_indices_tensor = torch.tensor(index, device=self.hidden_states_seq_indices.device) + index = (self.hidden_states_seq_indices[..., None] == target_indices_tensor).any(dim=-1) + self.hidden_states = self.hidden_states[index] + else: + self.hidden_states = self.hidden_states[index] + if self.second_last_token_hidden_states is not None: + self.second_last_token_hidden_states = self\ + .second_last_token_hidden_states[index] self._seq_ids = seq_ids def expand_with_bonus_tokens( diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index a7eb4baef11d..590756344713 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional +from typing import List, Optional, Tuple import torch @@ -14,6 +14,11 @@ # vllm_flash_attn is not installed, try the ROCm FA metadata from vllm.attention.backends.rocm_flash_attn import ( ROCmFlashAttentionMetadata as FlashAttentionMetadata) + try: + from vllm.attention.backends.triton_mla import TritonMLAMetadata + except (ModuleNotFoundError, ImportError): + TritonMLAMetadata = FlashAttentionMetadata + except (ModuleNotFoundError, ImportError) as err: raise RuntimeError( "Draft model speculative decoding currently only supports" @@ -57,7 +62,7 @@ def __init__(self, model_runner: ModelRunnerBase): "return_hidden_states is not supported for TP1DraftModelRunner." ) super().__init__(model_runner) - + self.mtp = False self.indices_of_seq_with_bonus_tokens = None def _update_sampling_metadata(self, sampling_metadata, num_seqs, @@ -92,7 +97,8 @@ def _gpu_advance_step(self, model_input: ModelRunnerInputBase, # Update attn_metadata attn_metadata = model_input.attn_metadata - assert isinstance(attn_metadata, FlashAttentionMetadata) + assert isinstance(attn_metadata, FlashAttentionMetadata) \ + or isinstance(attn_metadata, TritonMLAMetadata) attn_metadata.advance_step(model_input, sampled_token_ids, self.block_size, num_seqs, num_queries) @@ -193,6 +199,7 @@ def execute_model( # iteration invokes this function only once # (Look at multi-step-worker code) is_fallback = num_steps == 1 + self.mtp = self.model.config.model_type == "deepseek_mtp" if not is_fallback: # Since we do not broadcast data inside execute_model anymore, # we need to figure out the best way to support TP > 1 in this @@ -269,6 +276,9 @@ def execute_model( hidden_states = previous_hidden_states outputs: List[SamplerOutput] = [] + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata for step in range(num_steps): multi_modal_kwargs = model_input.multi_modal_kwargs or {} @@ -277,19 +287,33 @@ def execute_model( compute_logits_kwargs = {} # Run model + spec_step_idx = kwargs.get("spec_step_idx", 0) if hasattr(self.model.config, "num_nextn_predict_layers"): # for DeepSeek MTP only to use the corresponding layer for # each step - spec_step_idx = kwargs.get("spec_step_idx", 0) - model_execute_kwargs["spec_step_idx"] = spec_step_idx - compute_logits_kwargs["spec_step_idx"] = spec_step_idx - with set_forward_context(model_input.attn_metadata, + if spec_step_idx >= 0: + model_execute_kwargs["spec_step_idx"] = spec_step_idx + compute_logits_kwargs["spec_step_idx"] = spec_step_idx + else: + # for single step prefill + with set_forward_context(attn_metadata, + self.vllm_config): + return model_executable.generate_proposals( + input_ids=input_tokens, + positions=input_positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + sampling_metadata=model_input.sampling_metadata, + **model_execute_kwargs, + ) + # model_execute_kwargs["spec_step_idx"] = spec_step_idx + with set_forward_context(attn_metadata, self.vllm_config): hidden_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, + input_ids=input_tokens, + positions=input_positions, kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, + attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), @@ -297,18 +321,24 @@ def execute_model( ) # Compute the logits. - logits = self.model.compute_logits(hidden_states, + logits = self.model.compute_logits(hidden_states, # do not sample for the previous tokens model_input.sampling_metadata, **compute_logits_kwargs) - # Sample the next token. output = self.model.sample( logits=logits, sampling_metadata=model_input.sampling_metadata, ) + # TODO: do sampling/compute logits for the last token only + if self.mtp: + # return last token only for each step for MTP + output = self.model.get_last_sample_output(output, attn_metadata) + input_tokens = self.model.get_next_layer_input( + input_tokens, + attn_metadata, output) outputs.append(output) - if model_input.attn_metadata.num_prefills == 0 \ + if not self.mtp and model_input.attn_metadata.num_prefills == 0 \ and self.indices_of_seq_with_bonus_tokens is not None: assert output.sampled_token_ids is not None # output.sampled_token_ids should be of shape (num_seqs, 1) @@ -328,7 +358,7 @@ def execute_model( count += 1 # Prepare inputs for the next step - if step != num_steps - 1: + if step != num_steps - 1 and not self.mtp: model_input = self._gpu_advance_step(model_input, outputs[-1]) return outputs diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 55d3b20a284f..b5c24ef424e6 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -55,6 +55,10 @@ def set_should_modify_greedy_probs_inplace(self) -> None: self.model_runner.model.sampler.should_modify_greedy_probs_inplace = ( True) + @property + def has_mtp_runner(self) -> bool: + return getattr(self.model_runner, "mtp", False) + @torch.inference_mode() def sampler_output( self, @@ -72,11 +76,14 @@ def sampler_output( self._raise_if_unsupported(execute_model_req) # Expand the batch for sequences with a bonus token. # Perform a forward pass on the expanded batch and filter the - # response to retain only the original sequences' responses. - expanded_request, indices_of_seq_with_bonus_tokens =\ - self._expand_execute_model_request( - execute_model_req, seq_ids_with_bonus_token_in_last_step) - + # response to retain only the original sequences' responses. + if self.has_mtp_runner: + expanded_request, indices_of_seq_with_bonus_tokens =\ + execute_model_req, [] + else: + expanded_request, indices_of_seq_with_bonus_tokens =\ + self._expand_execute_model_request( + execute_model_req, seq_ids_with_bonus_token_in_last_step) # Run model sample_len times. model_outputs: List[SamplerOutput] = [] if current_platform.is_cuda_alike() and isinstance( @@ -110,10 +117,14 @@ def sampler_output( expanded_request.spec_step_idx = 0 # move indices to device to avoid stream sync - indices_of_seq_with_bonus_tokens = torch.tensor( - indices_of_seq_with_bonus_tokens, device=self.device) - filtered_model_outputs = self._filter_model_output( - model_outputs, indices_of_seq_with_bonus_tokens) + if self.has_mtp_runner: + filtered_model_outputs = model_outputs + else: + indices_of_seq_with_bonus_tokens = torch.tensor( + indices_of_seq_with_bonus_tokens, device=self.device) + filtered_model_outputs = self._filter_model_output( + model_outputs, indices_of_seq_with_bonus_tokens) + return filtered_model_outputs, True @staticmethod diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index f7c9b33707f0..68b6905fbb55 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -157,7 +157,7 @@ def create_worker( ) -> "SpecDecodeWorker": allow_zero_draft_token_step = True - num_spec_prefill_steps = 1 + next_n_prediction_steps = -1 ngram_prompt_lookup_max = ( draft_worker_kwargs.pop("ngram_prompt_lookup_max")) ngram_prompt_lookup_min = ( @@ -192,9 +192,9 @@ def create_worker( "does not support TP > 1 yet") allow_zero_draft_token_step = False - proposer_worker = MultiStepWorker(**draft_worker_kwargs) if draft_model_config.hf_config.model_type == "deepseek_mtp": - num_spec_prefill_steps = num_speculative_tokens + next_n_prediction_steps = num_speculative_tokens + proposer_worker = MultiStepWorker(**draft_worker_kwargs) proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( proposer_worker, draft_tp, target_tp) @@ -247,7 +247,7 @@ def create_worker( disable_by_batch_size=disable_by_batch_size, spec_decode_sampler=spec_decode_sampler, allow_zero_draft_token_step=allow_zero_draft_token_step, - num_spec_prefill_steps=num_spec_prefill_steps, + next_n_prediction_steps=next_n_prediction_steps, ) def __init__( @@ -261,7 +261,7 @@ def __init__( metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, allow_zero_draft_token_step: Optional[bool] = True, - num_spec_prefill_steps: int = 1, + next_n_prediction_steps: int = -1, ): """ Create a SpecDecodeWorker. @@ -292,7 +292,7 @@ def __init__( allow_zero_draft_token_step: whether to allow a step where the draft model generates no draft token; should disallow when the tp of draft model is larger than 1 (TODO: #5814) - num_spec_prefill_steps: number of speculative prefill steps to run + next_n_prediction_steps: number of speculative prefill steps to run before the speculative decoding starts. This is only used when the draft model is a deepseek_mtp model that requires prefill kv cache separately for each step layer. @@ -328,7 +328,7 @@ def __init__( self.previous_hidden_states: Optional[HiddenStates] = None self._disable_logprobs = disable_logprobs self._disable_log_stats = disable_log_stats - self._num_spec_prefill_steps = num_spec_prefill_steps + self._next_n_prediction_steps = next_n_prediction_steps def init_device(self) -> None: """Initialize both scorer and proposer models. @@ -513,8 +513,10 @@ def execute_model( if no_spec: return self._run_no_spec(execute_model_req, skip_proposer=disable_all_speculation) - return self._run_speculative_decoding_step(execute_model_req, - num_lookahead_slots) + results = self._run_speculative_decoding_step( + execute_model_req, + num_lookahead_slots) + return results @torch.inference_mode() def start_worker_execution_loop(self) -> None: @@ -643,7 +645,6 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, not called, meaning that the kv-cache in proposer for requests is not updated, so they cannot enable spec decode in the rest decoding. """ - sampler_output = self.scorer_worker.execute_model(execute_model_req) assert len(sampler_output) == 1 sampler_output = sampler_output[0] @@ -664,7 +665,7 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, if self.previous_hidden_states is None and len( seq_group_meta_with_hidden): self.previous_hidden_states = HiddenStates( - hidden_states, seq_group_meta_with_hidden) + hidden_states, seq_group_meta_with_hidden) # hidden states for T, (T+1 token) elif self.previous_hidden_states and len( seq_group_meta_with_hidden): self.previous_hidden_states.update(hidden_states, @@ -674,14 +675,11 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, # We prepare the prefill hidden states here so that there no # additional complexity in worker for spec_decode vs non_spec_decode # flow and execute_model doesn't need additional modifications. + execute_model_req.spec_step_idx = -1 execute_model_req.previous_hidden_states = \ prepare_prefill_hidden_states( sampler_output.prefill_hidden_states) - execute_model_req.spec_step_idx = 0 - for _ in range(self._num_spec_prefill_steps): - self.proposer_worker.execute_model(execute_model_req) - execute_model_req.spec_step_idx += 1 - + self.proposer_worker.execute_model(execute_model_req) sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( execute_model_req=execute_model_req, sampler_output=sampler_output) if self._disable_logprobs else @@ -887,22 +885,35 @@ def _verify_tokens( # Contract hidden states based on accepted tokens hs_size = hidden_states.shape[-1] accepted_index = accepted_token_ids + 1 # Convert -1 to 0 - accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b - # Drop non-terminal prefill chunks hidden states. - hidden_states = hidden_states[accepted_index != - VLLM_INVALID_TOKEN_ID] - accepted_index = accepted_index[accepted_index != + accepted_index = accepted_index.count_nonzero(dim=1) + if self._next_n_prediction_steps > 0: + hidden_states = hidden_states.reshape(-1, hs_size)[ + accepted_token_ids.reshape(-1) != VLLM_INVALID_TOKEN_ID] + seq_indices = torch.repeat_interleave( + torch.arange(0, accepted_index.shape[0]).to(hidden_states.device), + accepted_index + ) # seq indices for each hidden state + self.previous_hidden_states = HiddenStates( + hidden_states, terminal_metadata, + hidden_states_seq_indices=seq_indices, + ) + else: + # Drop non-terminal prefill chunks hidden states. + hidden_states = hidden_states[accepted_index != VLLM_INVALID_TOKEN_ID] - assert len(accepted_index) == hidden_states.shape[0] == len( - terminal_metadata) - index = accepted_index[:, None, None].expand(-1, 1, - hs_size) # b x 1 x d - second_last_token_hidden_states = hidden_states[:, -2] # b x d - hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d - # Store hidden states from target model for subsequent decode step - self.previous_hidden_states = HiddenStates( - hidden_states, terminal_metadata, - second_last_token_hidden_states) + accepted_index.add_(-1) + accepted_index = accepted_index[accepted_index != + VLLM_INVALID_TOKEN_ID] + assert len(accepted_index) == hidden_states.shape[0] == len( + terminal_metadata) + index = accepted_index[:, None, None].expand(-1, 1, + hs_size) # b x 1 x d + second_last_token_hidden_states = hidden_states[:, -2] # b x d + hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d + # Store hidden states from target model for subsequent decode step + self.previous_hidden_states = HiddenStates( + hidden_states, terminal_metadata, + second_last_token_hidden_states) return accepted_token_ids, logprobs def _create_output_sampler_list( diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index b538923c03e7..06bc3a23bd56 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -3,9 +3,8 @@ from typing import List, Optional, Set, Tuple import torch - from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata, SequenceData from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase @@ -52,15 +51,37 @@ def get_spec_proposals( speculation. """ proposal_len = execute_model_req.num_lookahead_slots - seq_group_metadata_list = execute_model_req.seq_group_metadata_list - + if hasattr(self._worker, "model_runner") and getattr(self._worker.model_runner,"mtp", False): + seq_group_metadata_list_for_proposal = [] + for metadata in execute_model_req.seq_group_metadata_list: + mtp_seq_data = {} + for key, seq_data in metadata.seq_data.items(): + mtp_seq_data[key] = SequenceData.from_seqs( + seq_data.prompt_token_ids, + output_token_ids=seq_data.output_token_ids, + ) + mtp_seq_data[key].update_num_computed_tokens( + len(seq_data.prompt_token_ids) + + len(seq_data.output_token_ids) - + len(seq_data._new_appended_tokens) + ) + new_metadata = SequenceGroupMetadata( + request_id=metadata.request_id, + is_prompt=False, + seq_data=mtp_seq_data, + sampling_params=metadata.sampling_params, + block_tables=metadata.block_tables, + lora_request=metadata.lora_request, + ) + seq_group_metadata_list_for_proposal.append(new_metadata) + else: + seq_group_metadata_list_for_proposal = execute_model_req.seq_group_metadata_list # Split speculative- and non-speculative- sequences. ( proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices, - ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len) - + ) = self._split_by_proposal_len(seq_group_metadata_list_for_proposal, proposal_len) if nonzero_proposal_len_seqs: # Speculate tokens using the draft worker for the speculative # sequences. @@ -98,7 +119,7 @@ def get_spec_proposals( # Combine speculative- and non-speculative sequences into the same # representation. proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs( - batch_size=len(seq_group_metadata_list), + batch_size=len(seq_group_metadata_list_for_proposal), proposal_len=proposal_len, maybe_sampler_output=maybe_sampler_output, proposal_lens=proposal_lens, @@ -246,7 +267,6 @@ def _merge_outputs( sampler_output = maybe_sampler_output proposal_tokens, proposal_probs, *_ = sampler_output_to_torch( sampler_output, sampler_transposed) - # Now, reformat the output GPU tensors such that each sequence has # a proposal. the proposal can be empty, e.g. [-1, -1, -1]