diff --git a/aphrodite/config/speculative.py b/aphrodite/config/speculative.py index 9a4d513cba..38b20b85a5 100644 --- a/aphrodite/config/speculative.py +++ b/aphrodite/config/speculative.py @@ -936,6 +936,10 @@ def max_num_new_slots_for_drafting(self) -> int: def use_eagle(self) -> bool: return self.method in ("eagle", "eagle3", "mtp", "dflash") + def requires_eagle_cache_drop(self) -> bool: + """Whether prefix cache hits must drop one block for hidden states.""" + return self.use_eagle() and not self.use_dflash() + def use_dflash(self) -> bool: return self.method == "dflash" diff --git a/aphrodite/model_executor/models/qwen3_dflash.py b/aphrodite/model_executor/models/qwen3_dflash.py index 700c042b54..4c8fac33c2 100644 --- a/aphrodite/model_executor/models/qwen3_dflash.py +++ b/aphrodite/model_executor/models/qwen3_dflash.py @@ -34,6 +34,12 @@ from aphrodite.multimodal.inputs import NestedTensors from aphrodite.transformers_utils.config import set_default_rope_theta from aphrodite.v1.attention.backend import AttentionType +from aphrodite.v1.attention.selector import get_attn_backend +from aphrodite.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheSpec, + SlidingWindowSpec, +) from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen3 import Qwen3ForCausalLM @@ -47,6 +53,53 @@ logger = init_logger(__name__) +_DFLASH_VALID_LAYER_TYPES = frozenset({"full_attention", "sliding_attention"}) + + +def _get_dflash_layer_types(config: Qwen3Config) -> tuple[str, ...]: + layer_types = getattr(config, "layer_types", None) + if layer_types is None: + return ("full_attention",) * config.num_hidden_layers + if len(layer_types) != config.num_hidden_layers: + raise ValueError( + f"DFlash layer_types length {len(layer_types)} does not match " + f"num_hidden_layers {config.num_hidden_layers}." + ) + invalid = set(layer_types) - _DFLASH_VALID_LAYER_TYPES + if invalid: + raise ValueError(f"Invalid DFlash layer_type(s): {sorted(invalid)}.") + if "sliding_attention" in layer_types and not getattr( + config, "sliding_window", None + ): + raise ValueError( + "DFlash sliding_attention layers require `sliding_window` in config." + ) + return tuple(layer_types) + + +class DFlashAttention(Attention): + """Attention with DFlash-specific KV allocation semantics. + + The compute path keeps the layer's configured sliding window. The KV cache + spec is widened to full attention because DFlash writes every context KV + before drafting and cannot evict old context blocks from draft layers. + """ + + def get_kv_cache_spec(self, aphrodite_config: AphroditeConfig) -> KVCacheSpec | None: + spec = super().get_kv_cache_spec(aphrodite_config) + if isinstance(spec, SlidingWindowSpec): + return FullAttentionSpec( + block_size=spec.block_size, + num_kv_heads=spec.num_kv_heads, + head_size=spec.head_size, + head_size_v=getattr(spec, "head_size_v", spec.head_size), + dtype=spec.dtype, + kv_quant_mode=spec.kv_quant_mode, + page_size_padded=spec.page_size_padded, + ) + return spec + + class DFlashQwen3Attention(nn.Module): """Attention for DFlash speculative decoding. @@ -66,6 +119,7 @@ def __init__( attention_bias: bool = False, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, + sliding_window: int | None = None, prefix: str = "", attn_type: str = AttentionType.DECODER, ) -> None: @@ -109,15 +163,24 @@ def __init__( max_position=max_position, rope_parameters=rope_parameters, ) - self.attn = Attention( + draft_attn_backend = get_attn_backend( + self.head_dim, + torch.get_default_dtype(), + cache_config.cache_dtype if cache_config is not None else "auto", + use_mm_prefix=False, + attn_type=attn_type, + ) + self.attn = DFlashAttention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, + per_layer_sliding_window=sliding_window, prefix=f"{prefix}.attn", attn_type=attn_type, + attn_backend=draft_attn_backend, ) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) @@ -154,12 +217,17 @@ def __init__( config: Qwen3Config, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, + layer_type: str = "full_attention", prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size + self.layer_type = layer_type set_default_rope_theta(config, default_theta=1000000) attn_type = AttentionType.DECODER + sliding_window = ( + config.sliding_window if layer_type == "sliding_attention" else None + ) self.self_attn = DFlashQwen3Attention( hidden_size=self.hidden_size, @@ -171,6 +239,7 @@ def __init__( head_dim=getattr(config, "head_dim", None), cache_config=cache_config, quant_config=quant_config, + sliding_window=sliding_window, rope_parameters=config.rope_parameters, prefix=f"{prefix}.self_attn", attn_type=attn_type, @@ -236,17 +305,30 @@ def __init__( self.config.hidden_size, prefix=maybe_prefix(prefix, "embed_tokens"), ) - + target_config = aphrodite_config.model_config.hf_text_config + self.embed_normalizer: float | None = None + if str(getattr(target_config, "model_type", "")).startswith("gemma4"): + # Gemma4 scales token embeddings by sqrt(hidden_size). DFlash + # shares the target embeddings, so the draft path must match. + self.embed_normalizer = target_config.hidden_size**0.5 + + self.layer_types = _get_dflash_layer_types(self.config) self.layers = nn.ModuleList( [ DFlashQwen3DecoderLayer( current_aphrodite_config, prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"), config=self.config, + layer_type=self.layer_types[layer_idx], ) for layer_idx in range(self.config.num_hidden_layers) ] ) + self.sliding_attention_layer_names = { + layer.self_attn.attn.layer_name + for layer in self.layers + if layer.layer_type == "sliding_attention" + } if self.use_aux_hidden_state: num_features_to_use = self.config.num_hidden_layers if "target_layer_ids" in drafter_config: @@ -276,7 +358,8 @@ def __init__( ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) + embeds = self.embed_tokens(input_ids) + return embeds * self.embed_normalizer if self.embed_normalizer else embeds def _build_fused_kv_buffers(self) -> None: """Build fused weight buffers for precompute_and_store_context_kv. @@ -504,7 +587,11 @@ def __init__(self, *, aphrodite_config: AphroditeConfig, prefix: str = ""): self.config.hidden_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor(self.config.draft_vocab_size, scale=logit_scale) + self.logits_processor = LogitsProcessor( + self.config.draft_vocab_size, + scale=logit_scale, + soft_cap=getattr(self.config, "final_logit_softcapping", None), + ) target_vocab_size = aphrodite_config.model_config.get_vocab_size() if self.config.draft_vocab_size != target_vocab_size: self.draft_id_to_target_id = nn.Parameter( @@ -556,6 +643,10 @@ def precompute_and_store_context_kv( """Precompute projected + RoPE'd K/V and write to cache.""" self.model.precompute_and_store_context_kv(context_states, context_positions, context_slot_mapping) + @property + def sliding_attention_layer_names(self) -> set[str]: + return self.model.sliding_attention_layer_names + def combine_hidden_states( self, hidden_states: torch.Tensor, diff --git a/aphrodite/transformers_utils/configs/speculators/algos.py b/aphrodite/transformers_utils/configs/speculators/algos.py index a7a9b1877c..3cc800e8ca 100644 --- a/aphrodite/transformers_utils/configs/speculators/algos.py +++ b/aphrodite/transformers_utils/configs/speculators/algos.py @@ -60,7 +60,16 @@ def update_dflash(config_dict: dict, pre_trained_config: dict) -> None: pre_trained_config["draft_vocab_size"] = config_dict.get("draft_vocab_size") if config_dict.get("target_hidden_size") is not None: pre_trained_config["target_hidden_size"] = config_dict["target_hidden_size"] + for key in ( + "layer_types", + "use_sliding_window", + "sliding_window", + "max_window_layers", + ): + if key in config_dict: + pre_trained_config[key] = config_dict[key] + # TODO: does this need to be shifted by 1 like in gpu_model_runner? aux_layer_ids = config_dict["aux_hidden_state_layer_ids"] pre_trained_config["eagle_aux_hidden_state_layer_ids"] = aux_layer_ids diff --git a/aphrodite/v1/attention/backends/triton_attn.py b/aphrodite/v1/attention/backends/triton_attn.py index c358e0c7e4..a497ad63ed 100644 --- a/aphrodite/v1/attention/backends/triton_attn.py +++ b/aphrodite/v1/attention/backends/triton_attn.py @@ -135,8 +135,11 @@ def __init__( model_config = aphrodite_config.model_config self.num_heads_q = model_config.get_num_attention_heads(aphrodite_config.parallel_config) - self.num_heads_kv = model_config.get_num_kv_heads(aphrodite_config.parallel_config) - self.headdim = model_config.get_head_size() + # Some models (e.g. Gemma4) use different KV/head geometry for + # different attention layer groups, so size decode metadata from the + # actual KV cache spec instead of the model-wide defaults. + self.num_heads_kv = kv_cache_spec.num_kv_heads + self.headdim = kv_cache_spec.head_size # Check if CUDA Graphs are enabled for decode self.decode_cudagraph_enabled = self.aphrodite_config.compilation_config.cudagraph_mode in ( diff --git a/aphrodite/v1/core/kv_cache_utils.py b/aphrodite/v1/core/kv_cache_utils.py index 5989cb6cab..c92d3d3018 100644 --- a/aphrodite/v1/core/kv_cache_utils.py +++ b/aphrodite/v1/core/kv_cache_utils.py @@ -6,6 +6,7 @@ import hashlib import math import os +import re from collections import defaultdict from collections.abc import Callable, Iterable, Iterator, Sequence from dataclasses import dataclass, replace @@ -78,6 +79,8 @@ def maybe_convert_block_hash(hash_bytes: BlockHash) -> ExternalBlockHash: logger = init_logger(__name__) +_LAYER_INDEX_RE = re.compile(r"(?:^|\.)layers\.(\d+)(?:\.|$)") + # The hash seed for the first block of any prefix block sequence. # # We use a random value to avoid hash collisions or PYTHONHASHSEED environment @@ -846,7 +849,10 @@ def may_override_num_blocks(aphrodite_config: AphroditeConfig, num_blocks: int) return num_blocks -def _pool_bytes_per_block(kv_cache_groups: list[KVCacheGroupSpec]) -> int: +def _pool_bytes_per_block( + kv_cache_groups: list[KVCacheGroupSpec], + aphrodite_config: AphroditeConfig | None = None, +) -> int: """ Bytes consumed by one block in the worker's shared KV cache pool, mirroring the divisor used by `get_kv_cache_config_from_groups` to convert @@ -863,7 +869,22 @@ def _pool_bytes_per_block(kv_cache_groups: list[KVCacheGroupSpec]) -> int: cast(UniformTypeKVCacheSpecs, g.kv_cache_spec).get_num_layer_tuples() for g in kv_cache_groups ) return layer_tuple_page_bytes * num_layer_tuples - group_size = max(len(g.layer_names) for g in kv_cache_groups) + if aphrodite_config is not None: + isolated_group_ids = _get_dflash_isolated_group_ids( + aphrodite_config, kv_cache_groups + ) + shared_group_size = max( + ( + len(group.layer_names) + for group_id, group in enumerate(kv_cache_groups) + if group_id not in isolated_group_ids + ), + default=0, + ) + isolated_layers = sum(len(kv_cache_groups[group_id].layer_names) for group_id in isolated_group_ids) + group_size = shared_group_size + isolated_layers + else: + group_size = max(len(g.layer_names) for g in kv_cache_groups) page_size = get_uniform_page_size([g.kv_cache_spec for g in kv_cache_groups]) return page_size * group_size @@ -897,6 +918,35 @@ def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int: return page_sizes.pop() +def _get_dflash_isolated_group_ids( + aphrodite_config: AphroditeConfig, + kv_cache_groups: list[KVCacheGroupSpec], +) -> set[int]: + spec_config = aphrodite_config.speculative_config + if spec_config is None or spec_config.method != "dflash": + return set() + + try: + target_num_layers = aphrodite_config.model_config.get_num_layers( + aphrodite_config.parallel_config + ) + except Exception: + return set() + + group_ids: set[int] = set() + for group_id, group in enumerate(kv_cache_groups): + layer_indices: list[int] = [] + for layer_name in group.layer_names: + match = _LAYER_INDEX_RE.search(layer_name) + if match is None: + layer_indices = [] + break + layer_indices.append(int(match.group(1))) + if layer_indices and all(idx >= target_num_layers for idx in layer_indices): + group_ids.add(group_id) + return group_ids + + def _get_kv_cache_groups_uniform_spec( kv_cache_specs: dict[str, KVCacheSpec], ) -> list[KVCacheGroupSpec]: @@ -1222,18 +1272,41 @@ def get_kv_cache_config_from_groups( # (sw.1, padding) will be: (group_size = 2) # full.0, sw.0, sw.1: share a Tensor with size=available_memory//2 # full.1, sw.2: share another Tensor with size=available_memory//2 - group_size = max(len(group.layer_names) for group in kv_cache_groups) + # DFlash writes draft context KVs directly into cache using the draft + # block table. Do not row-share those tensors with target KV groups, or + # overlapping physical block ids can overwrite target KVs under batching. + isolated_group_ids = _get_dflash_isolated_group_ids( + aphrodite_config, kv_cache_groups + ) + shared_groups = [ + group + for group_id, group in enumerate(kv_cache_groups) + if group_id not in isolated_group_ids + ] + isolated_layer_names = [ + layer_name + for group_id in sorted(isolated_group_ids) + for layer_name in kv_cache_groups[group_id].layer_names + ] + shared_group_size = ( + max(len(group.layer_names) for group in shared_groups) + if shared_groups + else 0 + ) + group_size = shared_group_size + len(isolated_layer_names) page_size = get_uniform_page_size([group.kv_cache_spec for group in kv_cache_groups]) assert group_size > 0, "group_size must be greater than 0" num_blocks = get_num_blocks(aphrodite_config, group_size, available_memory, page_size) kv_cache_tensors = [] - for i in range(group_size): + for i in range(shared_group_size): shared_by = [] - for j in range(len(kv_cache_groups)): - if i < len(kv_cache_groups[j].layer_names): - shared_by.append(kv_cache_groups[j].layer_names[i]) + for group in shared_groups: + if i < len(group.layer_names): + shared_by.append(group.layer_names[i]) kv_cache_tensors.append(KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by)) + for layer_name in isolated_layer_names: + kv_cache_tensors.append(KVCacheTensor(size=page_size * num_blocks, shared_by=[layer_name])) return KVCacheConfig( num_blocks=num_blocks, @@ -1839,7 +1912,7 @@ def get_kv_cache_configs( if not groups: adjusted_memory.append(avail_mem) continue - bytes_per_block = _pool_bytes_per_block(groups) + bytes_per_block = _pool_bytes_per_block(groups, aphrodite_config) logger.info( "Overriding num_gpu_blocks=%d with num_gpu_blocks_override=%d", avail_mem // bytes_per_block, diff --git a/aphrodite/v1/core/sched/scheduler.py b/aphrodite/v1/core/sched/scheduler.py index 4f29915f1d..e51ae8a876 100644 --- a/aphrodite/v1/core/sched/scheduler.py +++ b/aphrodite/v1/core/sched/scheduler.py @@ -195,11 +195,15 @@ def __init__( speculative_config = aphrodite_config.speculative_config self.use_eagle = False + self.requires_eagle_cache_drop = False self.num_spec_tokens = self.num_lookahead_tokens = 0 if speculative_config: self.num_spec_tokens = speculative_config.num_speculative_tokens if speculative_config.use_eagle(): self.use_eagle = True + self.requires_eagle_cache_drop = ( + speculative_config.requires_eagle_cache_drop() + ) self.num_lookahead_tokens = self.num_spec_tokens if speculative_config.uses_draft_model(): self.num_lookahead_tokens = self.num_spec_tokens @@ -212,7 +216,7 @@ def __init__( max_model_len=self.max_model_len, max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens, enable_caching=self.cache_config.enable_prefix_caching, - use_eagle=self.use_eagle, + use_eagle=self.requires_eagle_cache_drop, log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, dcp_world_size=self.dcp_world_size, @@ -288,13 +292,13 @@ def _mamba_block_aligned_split( # must be a multiple of `block_size`. # As an exception, if `num_new_tokens` is less than `block_size`, the # state is simply not cached, requiring no special handling. - # Additionally, when Eagle mode is enabled, FullAttn prunes the last - # matching block. To prevent this from causing a Mamba cache miss, the - # last chunk must be not smaller than `block_size`. + # Additionally, when Eagle-style cache drop is enabled, FullAttn + # prunes the last matching block. To prevent this from causing a + # Mamba cache miss, the last chunk must be not smaller than + # `block_size`. block_size = self.cache_config.block_size last_cache_position = request.num_tokens - request.num_tokens % block_size - # eagle prune - if self.use_eagle: + if self.requires_eagle_cache_drop: last_cache_position = max(last_cache_position - block_size, 0) num_computed_tokens_after_sched = num_computed_tokens + num_new_tokens if num_computed_tokens_after_sched < last_cache_position: @@ -390,7 +394,9 @@ def schedule(self) -> SchedulerOutput: request.num_computed_tokens, num_new_tokens, encoder_compute_budget, - shift_computed_tokens=1 if self.use_eagle else 0, + shift_computed_tokens=1 + if self.requires_eagle_cache_drop + else 0, ) if self.need_mamba_block_aligned_split: @@ -642,7 +648,9 @@ def schedule(self) -> SchedulerOutput: num_computed_tokens, num_new_tokens, encoder_compute_budget, - shift_computed_tokens=1 if self.use_eagle else 0, + shift_computed_tokens=1 + if self.requires_eagle_cache_drop + else 0, ) if num_new_tokens == 0: # The request cannot be scheduled. diff --git a/aphrodite/v1/spec_decode/dflash.py b/aphrodite/v1/spec_decode/dflash.py index 624c246713..89f23120d6 100644 --- a/aphrodite/v1/spec_decode/dflash.py +++ b/aphrodite/v1/spec_decode/dflash.py @@ -268,7 +268,33 @@ def build_per_group_and_layer_attn_metadata( self, cad: CommonAttentionMetadata, draft_index: int = 0 ) -> tuple[list[object], dict[str, object]]: per_group, per_layer = super().build_per_group_and_layer_attn_metadata(cad, draft_index) + sliding_layer_names: set[str] = getattr( + self.model, "sliding_attention_layer_names", set() + ) + if sliding_layer_names: + # DFlash layers consume attention metadata through the per-layer + # forward context. Keep the base non-causal group metadata for + # group-level spec decode checks, and specialize only the SWA + # layers that need a causal sliding-window mask. + causal_cad = cad.replace(causal=True) + for attn_group in self.draft_attn_groups: + causal_layers = sliding_layer_names & set(attn_group.layer_names) + if not causal_layers: + continue + attn_metadata = attn_group.get_metadata_builder().build_for_drafting( + common_attn_metadata=causal_cad, + draft_index=draft_index, + ) + for layer_name in causal_layers: + per_layer[layer_name] = attn_metadata + for layer_name, attn_metadata in per_layer.items(): + if layer_name in sliding_layer_names: + assert getattr(attn_metadata, "causal", None) is True, ( + f"Attention metadata for sliding layer {layer_name} does not have" + " causal support, which is required for DFlash SWA." + ) + continue assert getattr(attn_metadata, "causal", None) is False, ( f"Attention metadata for layer {layer_name} does not have" " non-causal support, which is required for DFlash." diff --git a/aphrodite/v1/spec_decode/utils.py b/aphrodite/v1/spec_decode/utils.py index 9c89f66b7a..73ed695dfa 100644 --- a/aphrodite/v1/spec_decode/utils.py +++ b/aphrodite/v1/spec_decode/utils.py @@ -482,20 +482,13 @@ def copy_and_expand_dflash_inputs_kernel( ctx_start = tl.load(query_start_loc_ptr + req_idx) ctx_end = tl.load(query_start_loc_ptr + req_idx + 1) num_ctx = ctx_end - ctx_start - total_tokens = num_ctx + num_query_per_req j = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - in_bounds = j < total_tokens + in_bounds = j < (num_ctx + num_query_per_req) is_ctx = j < num_ctx is_query = (~is_ctx) & in_bounds query_off = j - num_ctx # offset within query portion (0-indexed) - # --- Positions --- - # Context: load from target_positions - ctx_pos_idx = tl.minimum(ctx_start + j, total_input_tokens - 1) - ctx_pos = tl.load(target_positions_ptr + ctx_pos_idx, mask=is_ctx, other=0) - - # Query: last_valid_pos + 1 + query_off # In padded mode, ctx_end includes rejected tokens; use valid_ctx_end # to find the last accepted context position. if HAS_NUM_REJECTED: @@ -503,14 +496,37 @@ def copy_and_expand_dflash_inputs_kernel( valid_ctx_end = ctx_end - num_rejected else: valid_ctx_end = ctx_end - last_pos = tl.load(target_positions_ptr + valid_ctx_end - 1) + + valid_num_ctx = valid_ctx_end - ctx_start + is_valid_ctx = j < valid_num_ctx + + # --- Positions --- + # Context: load from target_positions. Rejected context positions are + # don't-care because their slot mappings are masked out below. + ctx_pos_idx = tl.minimum(ctx_start + j, total_input_tokens - 1) + ctx_pos = tl.load(target_positions_ptr + ctx_pos_idx, mask=is_ctx, other=0) + + # Query: last_valid_pos + 1 + query_off. If the previous sampled output + # was discarded, there may be no valid context token from that target pass; + # start from the first scheduled position in that case. + fallback_last_pos = tl.load(target_positions_ptr + ctx_start) - 1 + last_valid_pos_idx = tl.maximum(valid_ctx_end - 1, ctx_start) + last_pos = tl.load( + target_positions_ptr + last_valid_pos_idx, + mask=valid_num_ctx > 0, + other=fallback_last_pos, + ) query_pos = last_pos + 1 + query_off positions = tl.where(is_ctx, ctx_pos, query_pos) # Context and query positions go to separate buffers. ctx_pos_out = ctx_start + j - tl.store(out_context_positions_ptr + ctx_pos_out, ctx_pos, mask=is_ctx) + tl.store( + out_context_positions_ptr + ctx_pos_out, + tl.where(is_valid_ctx, ctx_pos, 0), + mask=is_ctx, + ) query_out = req_idx * num_query_per_req + query_off tl.store(out_query_positions_ptr + query_out, query_pos, mask=is_query) @@ -524,7 +540,11 @@ def copy_and_expand_dflash_inputs_kernel( other=0, ).to(tl.int64) slot = block_id * block_size + (positions % block_size) - tl.store(out_context_slot_mapping_ptr + ctx_pos_out, slot, mask=is_ctx) + tl.store( + out_context_slot_mapping_ptr + ctx_pos_out, + tl.where(is_valid_ctx, slot, -1), + mask=is_ctx, + ) tl.store(out_query_slot_mapping_ptr + query_out, slot, mask=is_query) # --- Input IDs (query tokens only) --- diff --git a/aphrodite/v1/worker/gpu_model_runner.py b/aphrodite/v1/worker/gpu_model_runner.py index 7e4d8530bd..f031080056 100644 --- a/aphrodite/v1/worker/gpu_model_runner.py +++ b/aphrodite/v1/worker/gpu_model_runner.py @@ -4607,7 +4607,8 @@ def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: if not layer_ids: dflash_config = getattr(hf_config, "dflash_config", None) if dflash_config and isinstance(dflash_config, dict): - layer_ids = dflash_config.get("target_layer_ids") + # Add 1 to convert DFlash's aux layer id semantics + layer_ids = [i + 1 for i in dflash_config.get("target_layer_ids", [])] if layer_ids and isinstance(layer_ids, (list, tuple)): return tuple(layer_ids)