diff --git a/tests/models/registry.py b/tests/models/registry.py index 7b1db55494fe..c7eef89d96bc 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -285,6 +285,8 @@ def check_available_online( speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501 "MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m", speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501 + "DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random", + speculative_model="luccafong/deepseek_mtp_draft_random"), # noqa: E501 } _FALLBACK_MODEL = { diff --git a/tests/spec_decode/e2e/test_mtp_correctness.py b/tests/spec_decode/e2e/test_mtp_correctness.py new file mode 100644 index 000000000000..f0fca64fcba4 --- /dev/null +++ b/tests/spec_decode/e2e/test_mtp_correctness.py @@ -0,0 +1,304 @@ +# SPDX-License-Identifier: Apache-2.0 +"""This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. + +However, we still need to verify below scenario could be passed: + * Batch size 1 greedy equality + * Batch size >1 greedy equality + * Test greedy equality under preemption + * Test greedy equality under various number of speculative tokens. + +With those tests, we can say at least, mtp would not break the +correctess for the target model outputs. +""" + +import pytest + +from .conftest import run_equality_correctness_test + +# main model +MAIN_MODEL = "luccafong/deepseek_mtp_main_random" + + +# max. number of speculative tokens: this corresponds to +# num_nextn_predict_layers in the config.json of the speculator model. +MAX_SPEC_TOKENS = 1 + +# precision +PRECISION = "bfloat16" + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, + seed: int): + + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": False, + }, + { + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "enforce_eager": False, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + "gpu_memory_utilization": 0.85 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size: int, + output_len: int, seed: int): + """Verify greedy equality with cuda graph enabled and different + batch sizes.""" + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 128, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_e2e_greedy_correctness_with_preemption( + vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "num_speculative_tokens": k, + } + # Try a range of num. speculative tokens + for k in range(1, 1 + MAX_SPEC_TOKENS) + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_different_k(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify that mtp speculative decoding produces exact equality + to without spec decode with different values of num_speculative_tokens. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_disable_by_batch_size": 4 + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mtp_disable_queue(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify that mtp speculative decoding produces exact equality + to without spec decode when speculation is disabled for large + batch sizes. + """ + run_equality_correctness_test(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size, output_len, seed) + + +if __name__ == "__main__": + import pytest + pytest.main([__file__]) diff --git a/vllm/config.py b/vllm/config.py index 1740871e7c10..1d29736898f0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -757,7 +757,7 @@ def get_hidden_size(self) -> int: def is_deepseek_mla(self) -> bool: return (hasattr(self.hf_text_config, "model_type")) \ and (self.hf_text_config.model_type in \ - ('deepseek_v2', 'deepseek_v3'))\ + ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'))\ and (self.hf_text_config.kv_lora_rank is not None) def get_head_size(self) -> int: @@ -850,8 +850,12 @@ def get_num_attention_heads(self, def get_layers_start_end_indices( self, parallel_config: "ParallelConfig") -> Tuple[int, int]: from vllm.distributed.utils import get_pp_indices - total_num_hidden_layers = getattr(self.hf_text_config, - "num_hidden_layers", 0) + if self.hf_text_config.model_type == "deepseek_mtp": + total_num_hidden_layers = getattr(self.hf_text_config, + "num_nextn_predict_layers", 0) + else: + total_num_hidden_layers = getattr(self.hf_text_config, + "num_hidden_layers", 0) pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size pp_size = parallel_config.pipeline_parallel_size start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) @@ -1667,6 +1671,21 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str + @staticmethod + def hf_config_override( + hf_config: PretrainedConfig + ) -> PretrainedConfig: + if hf_config.model_type == "deepseek_v3": + hf_config.model_type = "deepseek_mtp" + if hf_config.model_type == "deepseek_mtp": + n_predict = getattr( + hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["DeepSeekMTPModel"] + }) + return hf_config + @staticmethod def maybe_create_spec_config( target_model_config: ModelConfig, @@ -1749,12 +1768,16 @@ def maybe_create_spec_config( Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if the necessary conditions are met, else None. """ - if speculative_model is None: if num_speculative_tokens is not None: - raise ValueError("num_speculative_tokens was provided without " + if target_model_config.hf_text_config.model_type == "deepseek_v3": + # use the draft model from the same model: + speculative_model = target_model_config.model + else: + raise ValueError("num_speculative_tokens was provided without " "speculative_model.") - return None + else: + return None if (speculative_disable_by_batch_size is not None and speculative_disable_by_batch_size < 2): @@ -1808,6 +1831,7 @@ def maybe_create_spec_config( max_seq_len_to_capture=target_model_config. max_seq_len_to_capture, max_logprobs=target_model_config.max_logprobs, + hf_overrides=SpeculativeConfig.hf_config_override, ) draft_hf_config = draft_model_config.hf_config @@ -1815,7 +1839,6 @@ def maybe_create_spec_config( if (num_speculative_tokens is not None and hasattr(draft_hf_config, "num_lookahead_tokens")): draft_hf_config.num_lookahead_tokens = num_speculative_tokens - n_predict = getattr(draft_hf_config, "n_predict", None) if n_predict is not None: if num_speculative_tokens is None: @@ -1925,11 +1948,12 @@ def _verify_and_get_draft_model_tensor_parallel_size( # If speculative_draft_tensor_parallel_size is unset then set it # appropriately else verify that it is set correctly. if speculative_draft_tensor_parallel_size is None: - if draft_hf_config.model_type == "mlp_speculator": + if draft_hf_config.model_type in ("mlp_speculator", "deepseek_mtp"): speculative_draft_tensor_parallel_size = 1 if target_parallel_config.tensor_parallel_size > 1: logger.warning( - "MLPSpeculator cannot currently be run with tp>1; " + f"{draft_hf_config.model_type} cannot currently " + "be run with tp>1; " "setting speculative_draft_tensor_parallel_size=1") else: speculative_draft_tensor_parallel_size = \ diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 8ceef855e020..07d9be6ac3f2 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -185,6 +185,8 @@ def _process_seq_outputs(self, seq: Sequence, is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0 # Incrementally append tokens to the sequence, as if we had only one new # token. + # TODO: add an attribute here for reset, can be set at output processor + seq.data.reset_new_appended_tokens() for output_token_id, output_logprob in zip(output_token_ids, output_logprobs): seq.append_token_id( diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py new file mode 100644 index 000000000000..98d362a374b1 --- /dev/null +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -0,0 +1,345 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Iterable, List, Optional, Set, Tuple + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from vllm.model_executor.layers.sampler import SamplerOutput + +from .deepseek_v2 import (DeepseekV2DecoderLayer, + get_spec_layer_idx_from_weight_name) +from .utils import maybe_prefix + + +class SharedHead(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(hidden_states) + + +class DeepSeekMultiTokenPredictorLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.block = DeepseekV2DecoderLayer(config, prefix, model_config, + cache_config, quant_config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + + # print(f"draft {attn_metadata=}") + # print(f"draft {input_ids=}") + # print(f"draft {positions=}") + # print(f"draft {previous_hidden_states.shape=}") + # print(f"draft {kv_cache.shape=}") + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds[positions == 0] = 0 + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, _ = self.block(positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + residual=None) + return self.shared_head(hidden_states) + + +class DeepSeekMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + DeepSeekMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + + self.logits_processor = LogitsProcessor(config.vocab_size) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]( + input_ids, + positions, + kv_caches[spec_step_idx], + attn_metadata, + previous_hidden_states, + inputs_embeds, + spec_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> torch.Tensor: + mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)] + logits = self.logits_processor(mtp_layer.shared_head.head, + hidden_states, sampling_metadata) + return logits + + +class DeepSeekMTP(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + + self.sampler = get_sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, previous_hidden_states, + inputs_embeds, spec_step_idx) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.model.compute_logits(hidden_states, sampling_metadata, + spec_step_idx) + + def generate_proposals( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> List[SamplerOutput]: + hidden_states = previous_hidden_states + cur_input_ids = input_ids + outputs = [] + for i in range(self.model.num_mtp_layers): + hidden_states = self.forward(cur_input_ids, positions, kv_caches, + attn_metadata, hidden_states, + spec_step_idx=i) + logits = self.compute_logits( + hidden_states=hidden_states, + sampling_metadata=sampling_metadata, + spec_step_idx=i + ) + output = self.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + outputs.append(output) + cur_input_ids = self.get_next_layer_input( + input_ids, attn_metadata, output) + return outputs + + def get_next_layer_input( + self, + input_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + outputs: SamplerOutput + ) -> Tuple[torch.Tensor, SamplerOutput]: + assert outputs.sampled_token_ids is not None + assert attn_metadata.query_start_loc is not None + input_ids = input_ids.roll(shifts=-1, dims=0) + query_end_loc = attn_metadata.query_start_loc[1:] - 1 + input_ids[query_end_loc] = outputs.sampled_token_ids[:, 0] + return input_ids + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def get_last_sample_output( + self, + output: SamplerOutput, + attn_metadata: AttentionMetadata, + ) -> SamplerOutput: + query_end_loc = attn_metadata.query_start_loc[1:] - 1 + output.sampled_token_ids = output.sampled_token_ids[query_end_loc] + if output.sampled_token_probs is not None: + output.sampled_token_probs = output.sampled_token_probs[query_end_loc] + return output + + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is None: + continue + name = self._rewrite_spec_layer_name(spec_layer, name) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + Add .block for modules in transformer layer block for spec layer + """ + spec_layer_weight_names = [ + "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + ] + spec_layer_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + break + if not spec_layer_weight: + # treat rest weights as weights for transformer layer block + name = name.replace(f"model.layers.{spec_layer}.", + f"model.layers.{spec_layer}.block.") + return name diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index fd0e58fa1458..5e5ae11b2b45 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -640,8 +640,7 @@ def forward( "residual": residual }) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + return hidden_states + residual class DeepseekV2ForCausalLM(nn.Module, SupportsPP): @@ -684,7 +683,7 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, + logits = self.logits_processor(self.lm_head, self.model.norm(hidden_states), sampling_metadata) return logits @@ -732,13 +731,9 @@ def load_weights(self, weights: Iterable[Tuple[str, if "rotary_emb.inv_freq" in name: continue - # TODO(simon): support nextn predict layers - if hasattr(self.config, "num_nextn_predict_layers" - ) and self.config.num_nextn_predict_layers > 0: - assert self.config.num_nextn_predict_layers == 1 - layer_idx = self.config.num_hidden_layers - if name.startswith(f"model.layers.{layer_idx}"): - continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -805,3 +800,15 @@ def load_weights(self, weights: Iterable[Tuple[str, class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass + + +def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, + weight_name: str) -> Optional[int]: + if hasattr(config, + "num_nextn_predict_layers") and (config.num_nextn_predict_layers + > 0): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{layer_idx+i}."): + return layer_idx + i + return None diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 198b6d134718..74c0e3c05c18 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -187,6 +187,7 @@ _SPECULATIVE_DECODING_MODELS = { "EAGLEModel": ("eagle", "EAGLE"), + "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), } diff --git a/vllm/sequence.py b/vllm/sequence.py index 534b9e60610a..4ea53e94f75d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -365,6 +365,9 @@ def get_delta_and_reset(self) -> SequenceDataDelta: self._new_appended_tokens = [] return delta + def reset_new_appended_tokens(self) -> None: + self._new_appended_tokens = [] + def apply_delta(self, delta: SequenceDataDelta): self._num_computed_tokens = delta.new_num_computed_tokens self._cumulative_logprob = delta.new_cumulative_logprob @@ -1209,12 +1212,14 @@ class HiddenStates(msgspec.Struct, array_like=True, # last proposed token is accepted (i.e., in case of bonus tokens). For the # case of no bonus tokens, these are ignored. second_last_token_hidden_states: Optional[torch.Tensor] = None - + # for varseq + hidden_states_seq_indices: Optional[torch.Tensor] = None _seq_ids: List[int] = msgspec.field(default_factory=list) def __post_init__(self): if self.seq_group_metadata_list is not None: - assert len(self.seq_group_metadata_list) == len(self.hidden_states) + # TODO: add assertion for the group metadata list with var seqs + # assert len(self.seq_group_metadata_list) == len(self.hidden_states) self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) @property @@ -1228,8 +1233,18 @@ def update(self, """Update hidden states from target model invocation. Only used for decode steps""" assert len(seq_group_metadata_list) == len(hidden_states) - self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) + last_seq_indice = len(self._seq_ids) + new_seq_ids = get_all_seq_ids(seq_group_metadata_list) + self._seq_ids.extend(new_seq_ids) self.hidden_states = torch.cat([self.hidden_states, hidden_states]) + if self.hidden_states_seq_indices is not None: + updated_indices = list(range(last_seq_indice, len(self._seq_ids))) + # assume new updated are hidden states from prefill which is always length of 1 + new_seq_indices = torch.tensor(updated_indices, device=self.hidden_states_seq_indices.device) + self.hidden_states_seq_indices = torch.concat([ + self.hidden_states_seq_indices, + new_seq_indices, + ]) if self.second_last_token_hidden_states is not None: # Adding dummy hidden_states to this to maintain same shape @@ -1252,10 +1267,15 @@ def prune(self, if seq_ids != self._seq_ids: # Batch contents changed - prune removed sequences. index = [self._seq_ids.index(seq_id) for seq_id in seq_ids] - self.hidden_states = self.hidden_states[index] - if self.second_last_token_hidden_states is not None: - self.second_last_token_hidden_states = self\ - .second_last_token_hidden_states[index] + if self.hidden_states_seq_indices is not None: + target_indices_tensor = torch.tensor(index, device=self.hidden_states_seq_indices.device) + index = (self.hidden_states_seq_indices[..., None] == target_indices_tensor).any(dim=-1) + self.hidden_states = self.hidden_states[index] + else: + self.hidden_states = self.hidden_states[index] + if self.second_last_token_hidden_states is not None: + self.second_last_token_hidden_states = self\ + .second_last_token_hidden_states[index] self._seq_ids = seq_ids def expand_with_bonus_tokens( @@ -1304,6 +1324,8 @@ class ExecuteModelRequest( previous_hidden_states: Optional[HiddenStates] = None # The number of forward steps to run. num_steps: int = 1 + # The step index for spec model input. + spec_step_idx: int = 0 # Finished request ids since last step. finished_requests_ids: List[str] = msgspec.field(default_factory=list) # The last sampled token ids for multi step decoding. diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 3948298db40c..590756344713 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional +from typing import List, Optional, Tuple import torch @@ -14,6 +14,11 @@ # vllm_flash_attn is not installed, try the ROCm FA metadata from vllm.attention.backends.rocm_flash_attn import ( ROCmFlashAttentionMetadata as FlashAttentionMetadata) + try: + from vllm.attention.backends.triton_mla import TritonMLAMetadata + except (ModuleNotFoundError, ImportError): + TritonMLAMetadata = FlashAttentionMetadata + except (ModuleNotFoundError, ImportError) as err: raise RuntimeError( "Draft model speculative decoding currently only supports" @@ -57,7 +62,7 @@ def __init__(self, model_runner: ModelRunnerBase): "return_hidden_states is not supported for TP1DraftModelRunner." ) super().__init__(model_runner) - + self.mtp = False self.indices_of_seq_with_bonus_tokens = None def _update_sampling_metadata(self, sampling_metadata, num_seqs, @@ -92,7 +97,8 @@ def _gpu_advance_step(self, model_input: ModelRunnerInputBase, # Update attn_metadata attn_metadata = model_input.attn_metadata - assert isinstance(attn_metadata, FlashAttentionMetadata) + assert isinstance(attn_metadata, FlashAttentionMetadata) \ + or isinstance(attn_metadata, TritonMLAMetadata) attn_metadata.advance_step(model_input, sampled_token_ids, self.block_size, num_seqs, num_queries) @@ -153,7 +159,7 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): return False # TODO: Add support for other attn backends - if self.attn_backend.get_name() != "FLASH_ATTN": + if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"): return False # TODO: Add support for LORA @@ -175,6 +181,7 @@ def execute_model( previous_hidden_states: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs, ) -> Optional[List[SamplerOutput]]: """Executes num_steps forward passes with advacement of input tensors on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. @@ -192,6 +199,7 @@ def execute_model( # iteration invokes this function only once # (Look at multi-step-worker code) is_fallback = num_steps == 1 + self.mtp = self.model.config.model_type == "deepseek_mtp" if not is_fallback: # Since we do not broadcast data inside execute_model anymore, # we need to figure out the best way to support TP > 1 in this @@ -268,38 +276,69 @@ def execute_model( hidden_states = previous_hidden_states outputs: List[SamplerOutput] = [] + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata for step in range(num_steps): multi_modal_kwargs = model_input.multi_modal_kwargs or {} - kwargs = {"previous_hidden_states": hidden_states} \ + model_execute_kwargs = {"previous_hidden_states": hidden_states} \ if previous_hidden_states is not None else {} + compute_logits_kwargs = {} # Run model - with set_forward_context(model_input.attn_metadata, + spec_step_idx = kwargs.get("spec_step_idx", 0) + if hasattr(self.model.config, "num_nextn_predict_layers"): + # for DeepSeek MTP only to use the corresponding layer for + # each step + if spec_step_idx >= 0: + model_execute_kwargs["spec_step_idx"] = spec_step_idx + compute_logits_kwargs["spec_step_idx"] = spec_step_idx + else: + # for single step prefill + with set_forward_context(attn_metadata, + self.vllm_config): + return model_executable.generate_proposals( + input_ids=input_tokens, + positions=input_positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + sampling_metadata=model_input.sampling_metadata, + **model_execute_kwargs, + ) + # model_execute_kwargs["spec_step_idx"] = spec_step_idx + with set_forward_context(attn_metadata, self.vllm_config): hidden_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, + input_ids=input_tokens, + positions=input_positions, kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, + attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), - **kwargs, + **model_execute_kwargs, ) # Compute the logits. - logits = self.model.compute_logits(hidden_states, - model_input.sampling_metadata) - + logits = self.model.compute_logits(hidden_states, # do not sample for the previous tokens + model_input.sampling_metadata, + **compute_logits_kwargs) # Sample the next token. output = self.model.sample( logits=logits, sampling_metadata=model_input.sampling_metadata, ) + # TODO: do sampling/compute logits for the last token only + if self.mtp: + # return last token only for each step for MTP + output = self.model.get_last_sample_output(output, attn_metadata) + input_tokens = self.model.get_next_layer_input( + input_tokens, + attn_metadata, output) outputs.append(output) - if model_input.attn_metadata.num_prefills == 0 \ + if not self.mtp and model_input.attn_metadata.num_prefills == 0 \ and self.indices_of_seq_with_bonus_tokens is not None: assert output.sampled_token_ids is not None # output.sampled_token_ids should be of shape (num_seqs, 1) @@ -319,7 +358,7 @@ def execute_model( count += 1 # Prepare inputs for the next step - if step != num_steps - 1: + if step != num_steps - 1 and not self.mtp: model_input = self._gpu_advance_step(model_input, outputs[-1]) return outputs diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 5474917a6fab..b5c24ef424e6 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -2,7 +2,7 @@ import copy import weakref -from typing import Dict, List, Set, Tuple +from typing import Dict, List, Set, Tuple, Optional import torch @@ -55,6 +55,10 @@ def set_should_modify_greedy_probs_inplace(self) -> None: self.model_runner.model.sampler.should_modify_greedy_probs_inplace = ( True) + @property + def has_mtp_runner(self) -> bool: + return getattr(self.model_runner, "mtp", False) + @torch.inference_mode() def sampler_output( self, @@ -72,11 +76,14 @@ def sampler_output( self._raise_if_unsupported(execute_model_req) # Expand the batch for sequences with a bonus token. # Perform a forward pass on the expanded batch and filter the - # response to retain only the original sequences' responses. - expanded_request, indices_of_seq_with_bonus_tokens =\ - self._expand_execute_model_request( - execute_model_req, seq_ids_with_bonus_token_in_last_step) - + # response to retain only the original sequences' responses. + if self.has_mtp_runner: + expanded_request, indices_of_seq_with_bonus_tokens =\ + execute_model_req, [] + else: + expanded_request, indices_of_seq_with_bonus_tokens =\ + self._expand_execute_model_request( + execute_model_req, seq_ids_with_bonus_token_in_last_step) # Run model sample_len times. model_outputs: List[SamplerOutput] = [] if current_platform.is_cuda_alike() and isinstance( @@ -95,9 +102,10 @@ def sampler_output( # TODO: Remove this branch once DraftModelRunner supports TP>1 # and other restrictions that are part of DraftModelRunner's # supports_gpu_multi_step(..) - for _ in range(sample_len): + for i in range(sample_len): model_output: List[SamplerOutput] = self.worker.execute_model( execute_model_req=expanded_request) + expanded_request.spec_step_idx += 1 assert (len(model_output) == 1 ), "composing multistep workers not supported" model_output = model_output[0] @@ -106,12 +114,17 @@ def sampler_output( model_output, expanded_request.seq_group_metadata_list, indices_of_seq_with_bonus_tokens) model_outputs.append(model_output) + expanded_request.spec_step_idx = 0 # move indices to device to avoid stream sync - indices_of_seq_with_bonus_tokens = torch.tensor( - indices_of_seq_with_bonus_tokens, device=self.device) - filtered_model_outputs = self._filter_model_output( - model_outputs, indices_of_seq_with_bonus_tokens) + if self.has_mtp_runner: + filtered_model_outputs = model_outputs + else: + indices_of_seq_with_bonus_tokens = torch.tensor( + indices_of_seq_with_bonus_tokens, device=self.device) + filtered_model_outputs = self._filter_model_output( + model_outputs, indices_of_seq_with_bonus_tokens) + return filtered_model_outputs, True @staticmethod diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 8653bece8b5a..68b6905fbb55 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -107,6 +107,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": typical_acceptance_sampler_posterior_alpha, disable_logprobs=speculative_config.disable_logprobs, disable_log_stats=speculative_config.disable_log_stats, + num_speculative_tokens=speculative_config.num_speculative_tokens, ) return spec_decode_worker @@ -152,9 +153,11 @@ def create_worker( typical_acceptance_sampler_posterior_alpha: float, disable_logprobs: bool, disable_log_stats: bool, + num_speculative_tokens: int, ) -> "SpecDecodeWorker": allow_zero_draft_token_step = True + next_n_prediction_steps = -1 ngram_prompt_lookup_max = ( draft_worker_kwargs.pop("ngram_prompt_lookup_max")) ngram_prompt_lookup_min = ( @@ -182,11 +185,15 @@ def create_worker( draft_worker_kwargs[ "model_runner_cls"] = TP1DraftModelRunner else: - if draft_model_config.hf_config.model_type == "eagle": + if draft_model_config.hf_config.model_type in ( + "eagle", "deepseek_mtp"): raise NotImplementedError( - "EAGLE does not support TP > 1 yet") + f"{draft_model_config.hf_config.model_type} " + "does not support TP > 1 yet") allow_zero_draft_token_step = False + if draft_model_config.hf_config.model_type == "deepseek_mtp": + next_n_prediction_steps = num_speculative_tokens proposer_worker = MultiStepWorker(**draft_worker_kwargs) proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( @@ -239,7 +246,9 @@ def create_worker( disable_log_stats=disable_log_stats, disable_by_batch_size=disable_by_batch_size, spec_decode_sampler=spec_decode_sampler, - allow_zero_draft_token_step=allow_zero_draft_token_step) + allow_zero_draft_token_step=allow_zero_draft_token_step, + next_n_prediction_steps=next_n_prediction_steps, + ) def __init__( self, @@ -252,6 +261,7 @@ def __init__( metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, allow_zero_draft_token_step: Optional[bool] = True, + next_n_prediction_steps: int = -1, ): """ Create a SpecDecodeWorker. @@ -282,6 +292,10 @@ def __init__( allow_zero_draft_token_step: whether to allow a step where the draft model generates no draft token; should disallow when the tp of draft model is larger than 1 (TODO: #5814) + next_n_prediction_steps: number of speculative prefill steps to run + before the speculative decoding starts. This is only used when + the draft model is a deepseek_mtp model that requires prefill + kv cache separately for each step layer. """ self.proposer_worker = proposer_worker self.scorer_worker = scorer_worker @@ -314,6 +328,7 @@ def __init__( self.previous_hidden_states: Optional[HiddenStates] = None self._disable_logprobs = disable_logprobs self._disable_log_stats = disable_log_stats + self._next_n_prediction_steps = next_n_prediction_steps def init_device(self) -> None: """Initialize both scorer and proposer models. @@ -498,8 +513,10 @@ def execute_model( if no_spec: return self._run_no_spec(execute_model_req, skip_proposer=disable_all_speculation) - return self._run_speculative_decoding_step(execute_model_req, - num_lookahead_slots) + results = self._run_speculative_decoding_step( + execute_model_req, + num_lookahead_slots) + return results @torch.inference_mode() def start_worker_execution_loop(self) -> None: @@ -628,7 +645,6 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, not called, meaning that the kv-cache in proposer for requests is not updated, so they cannot enable spec decode in the rest decoding. """ - sampler_output = self.scorer_worker.execute_model(execute_model_req) assert len(sampler_output) == 1 sampler_output = sampler_output[0] @@ -649,7 +665,7 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, if self.previous_hidden_states is None and len( seq_group_meta_with_hidden): self.previous_hidden_states = HiddenStates( - hidden_states, seq_group_meta_with_hidden) + hidden_states, seq_group_meta_with_hidden) # hidden states for T, (T+1 token) elif self.previous_hidden_states and len( seq_group_meta_with_hidden): self.previous_hidden_states.update(hidden_states, @@ -659,12 +675,11 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, # We prepare the prefill hidden states here so that there no # additional complexity in worker for spec_decode vs non_spec_decode # flow and execute_model doesn't need additional modifications. + execute_model_req.spec_step_idx = -1 execute_model_req.previous_hidden_states = \ prepare_prefill_hidden_states( sampler_output.prefill_hidden_states) - self.proposer_worker.execute_model(execute_model_req) - sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( execute_model_req=execute_model_req, sampler_output=sampler_output) if self._disable_logprobs else @@ -870,22 +885,35 @@ def _verify_tokens( # Contract hidden states based on accepted tokens hs_size = hidden_states.shape[-1] accepted_index = accepted_token_ids + 1 # Convert -1 to 0 - accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b - # Drop non-terminal prefill chunks hidden states. - hidden_states = hidden_states[accepted_index != - VLLM_INVALID_TOKEN_ID] - accepted_index = accepted_index[accepted_index != + accepted_index = accepted_index.count_nonzero(dim=1) + if self._next_n_prediction_steps > 0: + hidden_states = hidden_states.reshape(-1, hs_size)[ + accepted_token_ids.reshape(-1) != VLLM_INVALID_TOKEN_ID] + seq_indices = torch.repeat_interleave( + torch.arange(0, accepted_index.shape[0]).to(hidden_states.device), + accepted_index + ) # seq indices for each hidden state + self.previous_hidden_states = HiddenStates( + hidden_states, terminal_metadata, + hidden_states_seq_indices=seq_indices, + ) + else: + # Drop non-terminal prefill chunks hidden states. + hidden_states = hidden_states[accepted_index != VLLM_INVALID_TOKEN_ID] - assert len(accepted_index) == hidden_states.shape[0] == len( - terminal_metadata) - index = accepted_index[:, None, None].expand(-1, 1, - hs_size) # b x 1 x d - second_last_token_hidden_states = hidden_states[:, -2] # b x d - hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d - # Store hidden states from target model for subsequent decode step - self.previous_hidden_states = HiddenStates( - hidden_states, terminal_metadata, - second_last_token_hidden_states) + accepted_index.add_(-1) + accepted_index = accepted_index[accepted_index != + VLLM_INVALID_TOKEN_ID] + assert len(accepted_index) == hidden_states.shape[0] == len( + terminal_metadata) + index = accepted_index[:, None, None].expand(-1, 1, + hs_size) # b x 1 x d + second_last_token_hidden_states = hidden_states[:, -2] # b x d + hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d + # Store hidden states from target model for subsequent decode step + self.previous_hidden_states = HiddenStates( + hidden_states, terminal_metadata, + second_last_token_hidden_states) return accepted_token_ids, logprobs def _create_output_sampler_list( diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index b538923c03e7..06bc3a23bd56 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -3,9 +3,8 @@ from typing import List, Optional, Set, Tuple import torch - from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata, SequenceData from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase @@ -52,15 +51,37 @@ def get_spec_proposals( speculation. """ proposal_len = execute_model_req.num_lookahead_slots - seq_group_metadata_list = execute_model_req.seq_group_metadata_list - + if hasattr(self._worker, "model_runner") and getattr(self._worker.model_runner,"mtp", False): + seq_group_metadata_list_for_proposal = [] + for metadata in execute_model_req.seq_group_metadata_list: + mtp_seq_data = {} + for key, seq_data in metadata.seq_data.items(): + mtp_seq_data[key] = SequenceData.from_seqs( + seq_data.prompt_token_ids, + output_token_ids=seq_data.output_token_ids, + ) + mtp_seq_data[key].update_num_computed_tokens( + len(seq_data.prompt_token_ids) + + len(seq_data.output_token_ids) - + len(seq_data._new_appended_tokens) + ) + new_metadata = SequenceGroupMetadata( + request_id=metadata.request_id, + is_prompt=False, + seq_data=mtp_seq_data, + sampling_params=metadata.sampling_params, + block_tables=metadata.block_tables, + lora_request=metadata.lora_request, + ) + seq_group_metadata_list_for_proposal.append(new_metadata) + else: + seq_group_metadata_list_for_proposal = execute_model_req.seq_group_metadata_list # Split speculative- and non-speculative- sequences. ( proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices, - ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len) - + ) = self._split_by_proposal_len(seq_group_metadata_list_for_proposal, proposal_len) if nonzero_proposal_len_seqs: # Speculate tokens using the draft worker for the speculative # sequences. @@ -98,7 +119,7 @@ def get_spec_proposals( # Combine speculative- and non-speculative sequences into the same # representation. proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs( - batch_size=len(seq_group_metadata_list), + batch_size=len(seq_group_metadata_list_for_proposal), proposal_len=proposal_len, maybe_sampler_output=maybe_sampler_output, proposal_lens=proposal_lens, @@ -246,7 +267,6 @@ def _merge_outputs( sampler_output = maybe_sampler_output proposal_tokens, proposal_probs, *_ = sampler_output_to_torch( sampler_output, sampler_transposed) - # Now, reformat the output GPU tensors such that each sequence has # a proposal. the proposal can be empty, e.g. [-1, -1, -1] diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 4b76509e4541..eda5a8143d12 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -29,9 +29,9 @@ # yapf conflicts with isort for this block # yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, - DbrxConfig, DeepseekVLV2Config, - EAGLEConfig, ExaoneConfig, - H2OVLChatConfig, + DbrxConfig, + DeepseekVLV2Config, EAGLEConfig, + ExaoneConfig, H2OVLChatConfig, InternVLChatConfig, JAISConfig, MedusaConfig, MllamaConfig, MLPSpeculatorConfig, MPTConfig, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c7814f17375b..eb4234432a82 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1649,6 +1649,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError("num_steps > 1 is not supported in ModelRunner") diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 582aa460eb4f..ff38e3bfc207 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -68,10 +68,10 @@ def __init__( speculative_config = self.speculative_config model_config = self.model_config speculative_args = {} if speculative_config is None \ - or (speculative_config.draft_model_config.model == - model_config.model) \ + or (speculative_config.draft_model_config.hf_config.model_type == + model_config.hf_config.model_type) \ or (speculative_config.draft_model_config.hf_config.model_type - not in ["medusa", "mlp_speculator", "eagle"]) \ + not in ("medusa", "mlp_speculator", "eagle", "deepseek_mtp")) \ else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 819b81fbfdbb..89b4652f201f 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -392,6 +392,8 @@ def execute_model( model_input, worker_input, kwargs = inputs num_steps = worker_input.num_steps + if execute_model_req is not None: + kwargs["spec_step_idx"] = execute_model_req.spec_step_idx self.execute_worker(worker_input)