Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion dflash/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,13 @@ def dflash_generate(
# DFlash model
# ---------------------------------------------------------------------------

def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
unsqueeze_dim: int = 1,
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_len = q.size(-2)
Expand Down Expand Up @@ -300,6 +306,20 @@ def forward(


class DFlashDraftModel(Qwen3PreTrainedModel):
"""DFlash draft model for speculative decoding with block diffusion.

This model implements a lightweight block diffusion architecture that
generates multiple draft tokens in parallel for speculative decoding.
It uses a Qwen3-based backbone with cross-attention to target model
hidden states.

Args:
config: Qwen3Config with additional dflash_config parameters including:
- target_layer_ids: Layer IDs to extract from target model
- mask_token_id: Token ID used for masking in diffusion
- block_size: Number of tokens to generate per block
- num_target_layers: Number of layers in target model
"""
config_class = Qwen3Config
_no_split_modules = ["Qwen3DFlashDecoderLayer"]

Expand All @@ -320,6 +340,23 @@ def __init__(self, config) -> None:
self.mask_token_id = self.config.dflash_config.get("mask_token_id", None)
self.post_init()

def _check_weights_for_nan(self) -> None:
"""Check for NaN in model weights and warn if found.

This helps diagnose issues like #115 where layer norm weights
become NaN during inference.
"""
for name, param in self.named_parameters():
if torch.isnan(param).any():
import warnings
warnings.warn(
f"NaN detected in {name}. This may cause incorrect "
f"outputs or further NaN propagation. Consider "
f"re-downloading the model checkpoint.",
stacklevel=2,
)
break # Only warn once per forward pass

def forward(
self,
position_ids: torch.LongTensor,
Expand All @@ -330,7 +367,13 @@ def forward(
use_cache: bool = False,
**kwargs,
) -> CausalLMOutputWithPast:
if noise_embedding is None:
raise ValueError("noise_embedding is required for DFlashDraftModel.forward()")
if target_hidden is None:
raise ValueError("target_hidden is required for DFlashDraftModel.forward()")

hidden_states = noise_embedding
self._check_weights_for_nan()
target_hidden = self.hidden_norm(self.fc(target_hidden))
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for layer in self.layers:
Expand Down