Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aphrodite/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,10 @@ def max_num_new_slots_for_drafting(self) -> int:
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "mtp", "dflash")

def requires_eagle_cache_drop(self) -> bool:
"""Whether prefix cache hits must drop one block for hidden states."""
return self.use_eagle() and not self.use_dflash()

def use_dflash(self) -> bool:
return self.method == "dflash"

Expand Down
99 changes: 95 additions & 4 deletions aphrodite/model_executor/models/qwen3_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
from aphrodite.multimodal.inputs import NestedTensors
from aphrodite.transformers_utils.config import set_default_rope_theta
from aphrodite.v1.attention.backend import AttentionType
from aphrodite.v1.attention.selector import get_attn_backend
from aphrodite.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheSpec,
SlidingWindowSpec,
)

from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen3 import Qwen3ForCausalLM
Expand All @@ -47,6 +53,53 @@
logger = init_logger(__name__)


_DFLASH_VALID_LAYER_TYPES = frozenset({"full_attention", "sliding_attention"})


def _get_dflash_layer_types(config: Qwen3Config) -> tuple[str, ...]:
layer_types = getattr(config, "layer_types", None)
if layer_types is None:
return ("full_attention",) * config.num_hidden_layers
if len(layer_types) != config.num_hidden_layers:
raise ValueError(
f"DFlash layer_types length {len(layer_types)} does not match "
f"num_hidden_layers {config.num_hidden_layers}."
)
invalid = set(layer_types) - _DFLASH_VALID_LAYER_TYPES
if invalid:
raise ValueError(f"Invalid DFlash layer_type(s): {sorted(invalid)}.")
if "sliding_attention" in layer_types and not getattr(
config, "sliding_window", None
):
raise ValueError(
"DFlash sliding_attention layers require `sliding_window` in config."
)
return tuple(layer_types)


class DFlashAttention(Attention):
"""Attention with DFlash-specific KV allocation semantics.

The compute path keeps the layer's configured sliding window. The KV cache
spec is widened to full attention because DFlash writes every context KV
before drafting and cannot evict old context blocks from draft layers.
"""

def get_kv_cache_spec(self, aphrodite_config: AphroditeConfig) -> KVCacheSpec | None:
spec = super().get_kv_cache_spec(aphrodite_config)
if isinstance(spec, SlidingWindowSpec):
return FullAttentionSpec(
block_size=spec.block_size,
num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size,
head_size_v=getattr(spec, "head_size_v", spec.head_size),
dtype=spec.dtype,
kv_quant_mode=spec.kv_quant_mode,
page_size_padded=spec.page_size_padded,
)
return spec


class DFlashQwen3Attention(nn.Module):
"""Attention for DFlash speculative decoding.

Expand All @@ -66,6 +119,7 @@ def __init__(
attention_bias: bool = False,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
sliding_window: int | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
Expand Down Expand Up @@ -109,15 +163,24 @@ def __init__(
max_position=max_position,
rope_parameters=rope_parameters,
)
self.attn = Attention(
draft_attn_backend = get_attn_backend(
self.head_dim,
torch.get_default_dtype(),
cache_config.cache_dtype if cache_config is not None else "auto",
use_mm_prefix=False,
attn_type=attn_type,
)
self.attn = DFlashAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn",
attn_type=attn_type,
attn_backend=draft_attn_backend,
)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
Expand Down Expand Up @@ -154,12 +217,17 @@ def __init__(
config: Qwen3Config,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
layer_type: str = "full_attention",
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.layer_type = layer_type
set_default_rope_theta(config, default_theta=1000000)
attn_type = AttentionType.DECODER
sliding_window = (
config.sliding_window if layer_type == "sliding_attention" else None
)

self.self_attn = DFlashQwen3Attention(
hidden_size=self.hidden_size,
Expand All @@ -171,6 +239,7 @@ def __init__(
head_dim=getattr(config, "head_dim", None),
cache_config=cache_config,
quant_config=quant_config,
sliding_window=sliding_window,
rope_parameters=config.rope_parameters,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
Expand Down Expand Up @@ -236,17 +305,30 @@ def __init__(
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)

target_config = aphrodite_config.model_config.hf_text_config
self.embed_normalizer: float | None = None
if str(getattr(target_config, "model_type", "")).startswith("gemma4"):
# Gemma4 scales token embeddings by sqrt(hidden_size). DFlash
# shares the target embeddings, so the draft path must match.
self.embed_normalizer = target_config.hidden_size**0.5

self.layer_types = _get_dflash_layer_types(self.config)
self.layers = nn.ModuleList(
[
DFlashQwen3DecoderLayer(
current_aphrodite_config,
prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"),
config=self.config,
layer_type=self.layer_types[layer_idx],
)
for layer_idx in range(self.config.num_hidden_layers)
]
)
self.sliding_attention_layer_names = {
layer.self_attn.attn.layer_name
for layer in self.layers
if layer.layer_type == "sliding_attention"
}
if self.use_aux_hidden_state:
num_features_to_use = self.config.num_hidden_layers
if "target_layer_ids" in drafter_config:
Expand Down Expand Up @@ -276,7 +358,8 @@ def __init__(
)

def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
embeds = self.embed_tokens(input_ids)
return embeds * self.embed_normalizer if self.embed_normalizer else embeds

def _build_fused_kv_buffers(self) -> None:
"""Build fused weight buffers for precompute_and_store_context_kv.
Expand Down Expand Up @@ -504,7 +587,11 @@ def __init__(self, *, aphrodite_config: AphroditeConfig, prefix: str = ""):
self.config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(self.config.draft_vocab_size, scale=logit_scale)
self.logits_processor = LogitsProcessor(
self.config.draft_vocab_size,
scale=logit_scale,
soft_cap=getattr(self.config, "final_logit_softcapping", None),
)
target_vocab_size = aphrodite_config.model_config.get_vocab_size()
if self.config.draft_vocab_size != target_vocab_size:
self.draft_id_to_target_id = nn.Parameter(
Expand Down Expand Up @@ -556,6 +643,10 @@ def precompute_and_store_context_kv(
"""Precompute projected + RoPE'd K/V and write to cache."""
self.model.precompute_and_store_context_kv(context_states, context_positions, context_slot_mapping)

@property
def sliding_attention_layer_names(self) -> set[str]:
return self.model.sliding_attention_layer_names

def combine_hidden_states(
self,
hidden_states: torch.Tensor,
Expand Down
9 changes: 9 additions & 0 deletions aphrodite/transformers_utils/configs/speculators/algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,16 @@ def update_dflash(config_dict: dict, pre_trained_config: dict) -> None:
pre_trained_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
if config_dict.get("target_hidden_size") is not None:
pre_trained_config["target_hidden_size"] = config_dict["target_hidden_size"]
for key in (
"layer_types",
"use_sliding_window",
"sliding_window",
"max_window_layers",
):
if key in config_dict:
pre_trained_config[key] = config_dict[key]

# TODO: does this need to be shifted by 1 like in gpu_model_runner?
aux_layer_ids = config_dict["aux_hidden_state_layer_ids"]
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids
Comment on lines +72 to 74
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TODO here should be addressed to ensure consistency with the changes in gpu_model_runner.py. Since gpu_model_runner.py now explicitly shifts DFlash target_layer_ids by 1 when converting them to Eagle-style auxiliary layer IDs, this function should perform the same transformation. Failing to do so will result in incorrect layer indices being used for hidden states when the config is updated via this path.

Suggested change
# TODO: does this need to be shifted by 1 like in gpu_model_runner?
aux_layer_ids = config_dict["aux_hidden_state_layer_ids"]
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids
# Shift by 1 to convert DFlash's aux layer id semantics to match Eagle
aux_layer_ids = [i + 1 for i in config_dict["aux_hidden_state_layer_ids"]]
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids


Expand Down
7 changes: 5 additions & 2 deletions aphrodite/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,11 @@ def __init__(

model_config = aphrodite_config.model_config
self.num_heads_q = model_config.get_num_attention_heads(aphrodite_config.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(aphrodite_config.parallel_config)
self.headdim = model_config.get_head_size()
# Some models (e.g. Gemma4) use different KV/head geometry for
# different attention layer groups, so size decode metadata from the
# actual KV cache spec instead of the model-wide defaults.
self.num_heads_kv = kv_cache_spec.num_kv_heads
self.headdim = kv_cache_spec.head_size

# Check if CUDA Graphs are enabled for decode
self.decode_cudagraph_enabled = self.aphrodite_config.compilation_config.cudagraph_mode in (
Expand Down
89 changes: 81 additions & 8 deletions aphrodite/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import hashlib
import math
import os
import re
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass, replace
Expand Down Expand Up @@ -78,6 +79,8 @@ def maybe_convert_block_hash(hash_bytes: BlockHash) -> ExternalBlockHash:

logger = init_logger(__name__)

_LAYER_INDEX_RE = re.compile(r"(?:^|\.)layers\.(\d+)(?:\.|$)")

# The hash seed for the first block of any prefix block sequence.
#
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
Expand Down Expand Up @@ -846,7 +849,10 @@ def may_override_num_blocks(aphrodite_config: AphroditeConfig, num_blocks: int)
return num_blocks


def _pool_bytes_per_block(kv_cache_groups: list[KVCacheGroupSpec]) -> int:
def _pool_bytes_per_block(
kv_cache_groups: list[KVCacheGroupSpec],
aphrodite_config: AphroditeConfig | None = None,
) -> int:
"""
Bytes consumed by one block in the worker's shared KV cache pool, mirroring
the divisor used by `get_kv_cache_config_from_groups` to convert
Expand All @@ -863,7 +869,22 @@ def _pool_bytes_per_block(kv_cache_groups: list[KVCacheGroupSpec]) -> int:
cast(UniformTypeKVCacheSpecs, g.kv_cache_spec).get_num_layer_tuples() for g in kv_cache_groups
)
return layer_tuple_page_bytes * num_layer_tuples
group_size = max(len(g.layer_names) for g in kv_cache_groups)
if aphrodite_config is not None:
isolated_group_ids = _get_dflash_isolated_group_ids(
aphrodite_config, kv_cache_groups
)
shared_group_size = max(
(
len(group.layer_names)
for group_id, group in enumerate(kv_cache_groups)
if group_id not in isolated_group_ids
),
default=0,
)
isolated_layers = sum(len(kv_cache_groups[group_id].layer_names) for group_id in isolated_group_ids)
group_size = shared_group_size + isolated_layers
else:
group_size = max(len(g.layer_names) for g in kv_cache_groups)
page_size = get_uniform_page_size([g.kv_cache_spec for g in kv_cache_groups])
return page_size * group_size

Expand Down Expand Up @@ -897,6 +918,35 @@ def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int:
return page_sizes.pop()


def _get_dflash_isolated_group_ids(
aphrodite_config: AphroditeConfig,
kv_cache_groups: list[KVCacheGroupSpec],
) -> set[int]:
spec_config = aphrodite_config.speculative_config
if spec_config is None or spec_config.method != "dflash":
return set()

try:
target_num_layers = aphrodite_config.model_config.get_num_layers(
aphrodite_config.parallel_config
)
except Exception:
return set()

group_ids: set[int] = set()
for group_id, group in enumerate(kv_cache_groups):
layer_indices: list[int] = []
for layer_name in group.layer_names:
match = _LAYER_INDEX_RE.search(layer_name)
if match is None:
layer_indices = []
break
layer_indices.append(int(match.group(1)))
if layer_indices and all(idx >= target_num_layers for idx in layer_indices):
group_ids.add(group_id)
return group_ids


def _get_kv_cache_groups_uniform_spec(
kv_cache_specs: dict[str, KVCacheSpec],
) -> list[KVCacheGroupSpec]:
Expand Down Expand Up @@ -1222,18 +1272,41 @@ def get_kv_cache_config_from_groups(
# (sw.1, padding) will be: (group_size = 2)
# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
# full.1, sw.2: share another Tensor with size=available_memory//2
group_size = max(len(group.layer_names) for group in kv_cache_groups)
# DFlash writes draft context KVs directly into cache using the draft
# block table. Do not row-share those tensors with target KV groups, or
# overlapping physical block ids can overwrite target KVs under batching.
isolated_group_ids = _get_dflash_isolated_group_ids(
aphrodite_config, kv_cache_groups
)
shared_groups = [
group
for group_id, group in enumerate(kv_cache_groups)
if group_id not in isolated_group_ids
]
isolated_layer_names = [
layer_name
for group_id in sorted(isolated_group_ids)
for layer_name in kv_cache_groups[group_id].layer_names
]
shared_group_size = (
max(len(group.layer_names) for group in shared_groups)
if shared_groups
else 0
)
group_size = shared_group_size + len(isolated_layer_names)

page_size = get_uniform_page_size([group.kv_cache_spec for group in kv_cache_groups])
assert group_size > 0, "group_size must be greater than 0"
num_blocks = get_num_blocks(aphrodite_config, group_size, available_memory, page_size)
kv_cache_tensors = []
for i in range(group_size):
for i in range(shared_group_size):
shared_by = []
for j in range(len(kv_cache_groups)):
if i < len(kv_cache_groups[j].layer_names):
shared_by.append(kv_cache_groups[j].layer_names[i])
for group in shared_groups:
if i < len(group.layer_names):
shared_by.append(group.layer_names[i])
kv_cache_tensors.append(KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by))
for layer_name in isolated_layer_names:
kv_cache_tensors.append(KVCacheTensor(size=page_size * num_blocks, shared_by=[layer_name]))

return KVCacheConfig(
num_blocks=num_blocks,
Expand Down Expand Up @@ -1839,7 +1912,7 @@ def get_kv_cache_configs(
if not groups:
adjusted_memory.append(avail_mem)
continue
bytes_per_block = _pool_bytes_per_block(groups)
bytes_per_block = _pool_bytes_per_block(groups, aphrodite_config)
logger.info(
"Overriding num_gpu_blocks=%d with num_gpu_blocks_override=%d",
avail_mem // bytes_per_block,
Expand Down
Loading
Loading