diff --git a/dflash/model.py b/dflash/model.py index 5a1d33a..619c78f 100644 --- a/dflash/model.py +++ b/dflash/model.py @@ -173,7 +173,13 @@ def dflash_generate( # DFlash model # --------------------------------------------------------------------------- -def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + unsqueeze_dim: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_len = q.size(-2) @@ -300,6 +306,20 @@ def forward( class DFlashDraftModel(Qwen3PreTrainedModel): + """DFlash draft model for speculative decoding with block diffusion. + + This model implements a lightweight block diffusion architecture that + generates multiple draft tokens in parallel for speculative decoding. + It uses a Qwen3-based backbone with cross-attention to target model + hidden states. + + Args: + config: Qwen3Config with additional dflash_config parameters including: + - target_layer_ids: Layer IDs to extract from target model + - mask_token_id: Token ID used for masking in diffusion + - block_size: Number of tokens to generate per block + - num_target_layers: Number of layers in target model + """ config_class = Qwen3Config _no_split_modules = ["Qwen3DFlashDecoderLayer"] @@ -320,6 +340,23 @@ def __init__(self, config) -> None: self.mask_token_id = self.config.dflash_config.get("mask_token_id", None) self.post_init() + def _check_weights_for_nan(self) -> None: + """Check for NaN in model weights and warn if found. + + This helps diagnose issues like #115 where layer norm weights + become NaN during inference. + """ + for name, param in self.named_parameters(): + if torch.isnan(param).any(): + import warnings + warnings.warn( + f"NaN detected in {name}. This may cause incorrect " + f"outputs or further NaN propagation. Consider " + f"re-downloading the model checkpoint.", + stacklevel=2, + ) + break # Only warn once per forward pass + def forward( self, position_ids: torch.LongTensor, @@ -330,7 +367,13 @@ def forward( use_cache: bool = False, **kwargs, ) -> CausalLMOutputWithPast: + if noise_embedding is None: + raise ValueError("noise_embedding is required for DFlashDraftModel.forward()") + if target_hidden is None: + raise ValueError("target_hidden is required for DFlashDraftModel.forward()") + hidden_states = noise_embedding + self._check_weights_for_nan() target_hidden = self.hidden_norm(self.fc(target_hidden)) position_embeddings = self.rotary_emb(hidden_states, position_ids) for layer in self.layers: