diff --git a/cookbook/transformers/deepseek_v4_patch/README.md b/cookbook/transformers/deepseek_v4_patch/README.md new file mode 100644 index 00000000..ace2f6e0 --- /dev/null +++ b/cookbook/transformers/deepseek_v4_patch/README.md @@ -0,0 +1,112 @@ +# DeepSeek-V4 NPU Sparse Attention (SAS) / Lightning Indexer (LI) + +Twinkle 提供的 DeepSeek-V4 NPU 加速 patch,通过 monkey-patch 方式替换 transformers 中的注意力计算和索引器实现,无需修改 transformers 源码。 + +## 功能说明 + +### SAS (Sparse Attention Shared-KV) + +替换 `DeepseekV4Attention.forward` 中的标准注意力计算,使用 mindspeed 提供的融合稀疏注意力核 `SparseAttnSharedKV`,支持三种注意力层类型: + +- **Sliding Attention**: 纯滑动窗口注意力 +- **CSA (Compressed Sparse Attention)**: 压缩稀疏注意力,使用 Lightning Indexer 选择 top-k 压缩条目 +- **HCA (Heavily Compressed Attention)**: 高度压缩注意力,所有压缩条目可见 + +### LI (Lightning Indexer) + +替换 `DeepseekV4Indexer.forward` 中的 torch 实现,使用 mindspeed 提供的 `npu_lightning_indexer` 加速 top-k 索引选择。 + +**注意**: 当前版本 SAS 和 LI 不能同时启用。 + +## 依赖 + +- **[ops-transformer](https://gitcode.com/cann/ops-transformer)**: 提供 NPU 算子实现,需要编译安装 +- **[mindspeed](https://gitcode.com/Ascend/MindSpeed)**: 提供 NPU 算子调用实现,需要使用git clone下载mindspeed并切换到master分支进行手动安装 + - `mindspeed.ops.npu_sparse_attn_shared_kv.SparseAttnSharedKV` (SAS) + - `mindspeed.ops.npu_lightning_indexer` (LI) +- **transformers**: 需包含 DeepSeek-V4 模型支持 +- **torch_npu**: Ascend NPU 运行时 + +## 环境变量 + +| 变量 | 默认值 | 说明 | +|------|--------|------| +| `TWINKLE_NPU_DSV4_SAS` | `0` | 启用 SAS patch | +| `TWINKLE_NPU_DSV4_LI` | `0` | 启用 LI patch | + +**约束**: `TWINKLE_NPU_DSV4_SAS` 和 `TWINKLE_NPU_DSV4_LI` 不能同时设置为 `1`。 + +## 使用示例 + +### 镜像(可选) +```shell +#A3 +docker pull swr.cn-southwest-2.myhuaweicloud.com/ascend-sact/twinkle-npu:v4 +``` + +### 启用 SAS + +```bash +export TWINKLE_NPU_DSV4_SAS=1 +torchrun --standalone --nnodes=1 --nproc_per_node=8 ep_fsdp2_lora_deepseek_v4_npu.py +``` + +### 启用 LI + +```bash +export TWINKLE_NPU_DSV4_LI=1 + +torchrun --standalone --nnodes=1 --nproc_per_node=8 ep_fsdp2_lora_deepseek_v4_npu.py +``` + +### 完整示例脚本 (ds16_sas.sh) + +```bash +#!/bin/bash +export GLOO_SOCKET_IFNAME="enp162s0f0" +export HCCL_SOCKET_IFNAME="enp162s0f0" +export HCCL_CONNECT_TIMEOUT=7200 +export HCCL_EXEC_TIMEOUT=7200 +export ACL_DEVICE_SYNC_TIMEOUT=7200 +export HCCL_IF_BASE_PORT=30000 +export BATCH_SIZE=8 +export MAX_STEPS=10 +export GRADIENT_CHECKPOINTING=1 +export USE_EP=1 + +# 启用 twinkle SAS patch +export TWINKLE_NPU_DSV4_SAS=1 + +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +source /usr/local/Ascend/cann/opp/vendors/custom_transformer/bin/set_env.bash +torchrun --standalone --nnodes=1 --nproc_per_node=8 ep_fsdp2_lora_deepseek_v4_npu.py +``` + +## 实现原理 + +Patch 在 `apply_npu_patch()` 阶段自动应用(位于 EP sharding 之后、FSDP wrap 之前),通过以下方式替换原始实现: + +1. **Compressor patch**: 包装 `DeepseekV4HCACompressor` 和 `DeepseekV4CSACompressor` 的 `forward` 方法,确保返回 3-tuple `(compressed_kv, block_bias, top_k_indices)` +2. **Attention patch**: 替换 `DeepseekV4Attention.forward`,调用 `SparseAttnSharedKV.apply()` 替代标准注意力 dispatch +3. **Indexer patch**: 替换 `DeepseekV4Indexer.forward`,调用 `mindspeed.ops.npu_lightning_indexer` 替代 torch 实现 + +所有 patch 均包含 `ImportError` fallback,当 mindspeed 不可用时自动回退到原始实现。 + +## 验证 + +运行测试后,检查日志中是否出现: + +``` +[NPU] [DSV4-SAS] Twinkle sparse attention active (layer_type=..., cmp_ratio=..., topk=...) +``` + +或 + +``` +[NPU] [DSV4-LI] Twinkle lightning indexer active (sparse_count=..., cmp_ratio=...) +``` + +## 相关文件 + +- `src/twinkle/kernel/deepseek_v4_npu.py`: Patch 核心实现 +- `src/twinkle/kernel/monkey_patch_npu.py`: Patch 注册和环境变量控制 diff --git a/cookbook/transformers/deepseek_v4_patch/ep_fsdp2_lora_deepseek_v4_npu.py b/cookbook/transformers/deepseek_v4_patch/ep_fsdp2_lora_deepseek_v4_npu.py new file mode 100644 index 00000000..da39157f --- /dev/null +++ b/cookbook/transformers/deepseek_v4_patch/ep_fsdp2_lora_deepseek_v4_npu.py @@ -0,0 +1,123 @@ +import os +import twinkle +from peft import LoraConfig +from transformers import AutoConfig +from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.kernel import apply_npu_patch +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor + + +logger = get_logger() +MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-Flash') +DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') +TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template') +OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output') + +MAX_LENGTH = int(os.environ.get('MAX_LENGTH', '4096')) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '32')) +GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '2')) +LOG_INTERVAL = GRAD_ACCUM_STEPS +LR = float(os.environ.get('LR', '1e-5')) +MAX_STEPS = int(os.environ.get('MAX_STEPS', '0')) +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', '50')) +USE_LORA = os.environ.get('USE_LORA', '1') == '1' +MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) +IGNORE_MISMATCHED_SIZES = os.environ.get('IGNORE_MISMATCHED_SIZES', '1') == '1' +GRADIENT_CHECKPOINTING = os.environ.get('GRADIENT_CHECKPOINTING', '1') == '1' +RESHARD_AFTER_FORWARD = os.environ.get('RESHARD_AFTER_FORWARD', '1') == '1' +LORA_TARGET_MODULES = os.environ.get( + 'LORA_TARGET_MODULES', + 'wq_a,wq_b,wkv,wgate,gate_proj,up_proj,down_proj', +) +USE_EP = os.environ.get('USE_EP', '0') == '1' +ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default') +EP_SIZE = BATCH_SIZE if USE_EP else 1 +device_mesh = DeviceMesh.from_sizes( + fsdp_size=BATCH_SIZE, + dp_size=1, + ep_size=EP_SIZE, + device_type=Platform.get_platform().device_prefix(), +) + +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + + +def create_dataset(data_slice=None): + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID)) + dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope')) + return dataset + +def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): + return model.save( + name=checkpoint_name, + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) + +def train(): + dataset = create_dataset() + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + + config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) + if hasattr(config, 'use_cache'): + config.use_cache = False + model = TransformersModel( + model_id=MODEL_ID, + config=config, + device_mesh=device_mesh, + strategy="native_fsdp", + memory_efficient_init=True, + ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES, + fsdp_config={ + 'reshard_after_forward': RESHARD_AFTER_FORWARD, + 'expert_parallel': { + 'enabled': USE_EP, + 'router_dtype': 'fp32', + 'keep_router_logits': False, + } + }, + ) + + apply_npu_patch(model) + + if USE_LORA: + lora_target_modules = [name.strip() for name in LORA_TARGET_MODULES.split(',') if name.strip()] + lora_config = LoraConfig(r=8, lora_alpha=32, target_modules=lora_target_modules) + model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + + if not GRADIENT_CHECKPOINTING: + model.model.gradient_checkpointing_disable() + + model.set_template(TEMPLATE_ID, model_id=MODEL_ID, adapter_name=ADAPTER_NAME) + model.set_optimizer('AdamW', lr=LR, foreach=False, adapter_name=ADAPTER_NAME) + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', + num_warmup_steps=1, + num_training_steps=len(dataloader), + adapter_name=ADAPTER_NAME, + ) + optimizer_group = model.optimizer_group[ADAPTER_NAME] + for batch in dataloader: + if callable(batch): + batch = batch() + model.forward_backward(inputs=batch) + model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + cur_step = optimizer_group.cur_step + if cur_step > 0 and cur_step % LOG_INTERVAL == 0: + metric = model.calculate_metric(is_training=True) + if callable(metric): + metric = metric() + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') + + final_checkpoint = save_checkpoint(model, 'checkpoint-final', dataloader) + logger.info(f'Saved final adapter to {final_checkpoint}') + +if __name__ == '__main__': + train() + diff --git a/src/twinkle/kernel/deepseek_v4_npu.py b/src/twinkle/kernel/deepseek_v4_npu.py new file mode 100644 index 00000000..49481882 --- /dev/null +++ b/src/twinkle/kernel/deepseek_v4_npu.py @@ -0,0 +1,307 @@ +import torch +import torch.nn.functional as F + +from twinkle import get_logger + +logger = get_logger() + +_sas_logged = False +_li_logged = False + + +def _npu_sparse_attn_shared_kv(query, ori_kv, cmp_kv, cmp_sparse_indices, sinks, softmax_scale, cmp_ratio, + ori_mask_mode=4, cmp_mask_mode=3, ori_win_left=127, ori_win_right=0): + cu_seq_lens_q = cu_seq_lens_ori_kv = cu_seq_lens_cmp_kv = None + ori_sparse_indices = None + batch_size, max_seq_len_q, num_heads_q, head_dim = query.size() + num_heads_kv = 1 + max_seq_len_kv = ori_kv.size(1) + topk = 0 if (cmp_ratio != 4 or cmp_sparse_indices is None) else cmp_sparse_indices.size(-1) + layout_q = layout_kv = 'BSND' + query = query.contiguous() + ori_kv = ori_kv.unsqueeze(2).contiguous() + cmp_kv = cmp_kv if cmp_kv is None else cmp_kv.unsqueeze(2).contiguous() + cmp_sparse_indices = None if (cmp_ratio != 4 or cmp_sparse_indices is None) else cmp_sparse_indices.unsqueeze(2).contiguous() + + from mindspeed.ops.npu_sparse_attn_shared_kv import SparseAttnSharedKV + + output = SparseAttnSharedKV.apply( + query, ori_kv, cmp_kv, + cu_seq_lens_q, cu_seq_lens_ori_kv, cu_seq_lens_cmp_kv, + ori_sparse_indices, cmp_sparse_indices, + sinks, softmax_scale, cmp_ratio, + ori_mask_mode, cmp_mask_mode, ori_win_left, ori_win_right, + num_heads_q, num_heads_kv, head_dim, + batch_size, max_seq_len_q, max_seq_len_kv, topk, + layout_q, layout_kv, + ) + return output.contiguous() + + +def _patched_attention_forward( + self, + hidden_states, + position_embeddings, + position_ids, + attention_mask, + past_key_values=None, + **kwargs, +): + from transformers.models.deepseek_v4.modeling_deepseek_v4 import ( + ALL_ATTENTION_FUNCTIONS, + apply_rotary_pos_emb, + eager_attention_forward, + ) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + cos, sin = position_embeddings[self.rope_layer_type] + + q_residual = self.q_a_norm(self.q_a_proj(hidden_states)) + q = self.q_b_proj(q_residual).view(*hidden_shape).transpose(1, 2) + q = self.q_b_norm(q) + q = apply_rotary_pos_emb(q, cos, sin) + + kv = self.kv_norm(self.kv_proj(hidden_states)).view(*hidden_shape).transpose(1, 2) + kv = apply_rotary_pos_emb(kv, cos, sin) + + if past_key_values is not None: + kv = past_key_values.update(kv, kv, self.layer_idx)[0] + + ori_kv = kv + compressed_kv = None + block_bias = None + top_k_indices = None + if self.compressor is not None: + compressor_out = self.compressor(hidden_states, q_residual, position_ids, past_key_values, self.layer_idx) + if len(compressor_out) == 3: + compressed_kv, block_bias, top_k_indices = compressor_out + else: + compressed_kv, block_bias = compressor_out + + use_sas = True + if self.layer_type == 'sliding_attention': + cmp_ratio = 1 + cmp_kv_arg = None + cmp_sparse_indices = None + elif self.layer_type == 'compressed_sparse_attention': + cmp_ratio = self.config.compress_rates['compressed_sparse_attention'] + # Check if compressed_kv is empty (no compressed entries) + if compressed_kv is not None and compressed_kv.shape[2] > 0: + cmp_kv_arg = compressed_kv.squeeze(1).contiguous() + cmp_sparse_indices = top_k_indices.to(torch.int32) if top_k_indices is not None else None + else: + # No compressed entries, fall back to standard attention + use_sas = False + cmp_kv_arg = None + cmp_sparse_indices = None + else: + cmp_ratio = self.config.compress_rates['heavily_compressed_attention'] + # Check if compressed_kv is empty (no compressed entries) + if compressed_kv is not None and compressed_kv.shape[2] > 0: + cmp_kv_arg = compressed_kv.squeeze(1).contiguous() + else: + # No compressed entries, fall back to standard attention + use_sas = False + cmp_kv_arg = None + cmp_sparse_indices = None + + try: + attn_output = _npu_sparse_attn_shared_kv( + query=q.transpose(1, 2).contiguous(), + ori_kv=ori_kv.squeeze(1).contiguous(), + cmp_kv=cmp_kv_arg, + cmp_sparse_indices=cmp_sparse_indices, + sinks=self.sinks.float(), + softmax_scale=self.scaling, + cmp_ratio=cmp_ratio, + ori_win_left=self.sliding_window - 1, + ) + global _sas_logged + if not _sas_logged: + logger.info( + '[NPU] [DSV4-SAS] Twinkle sparse attention active ' + '(layer_type=%s, cmp_ratio=%s, topk=%s)', + self.layer_type, cmp_ratio, + 0 if cmp_sparse_indices is None else cmp_sparse_indices.shape[-1], + ) + _sas_logged = True + attn_weights = None + except Exception as e: + logger.warning('[NPU] [DSV4-SAS] Failed to run sparse attention, falling back to standard attention. Error: %s', e) + use_sas = False + + if not use_sas: + if compressed_kv is not None: + kv = torch.cat([kv, compressed_kv], dim=2) + if isinstance(attention_mask, torch.Tensor) and kv.shape[2] > attention_mask.shape[-1]: + if block_bias is not None: + attention_mask = torch.cat([attention_mask, block_bias.to(attention_mask.dtype)], dim=-1) + else: + attention_mask = F.pad(attention_mask, (0, kv.shape[2] - attention_mask.shape[-1]), value=0.0) + + attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + attn_output, attn_weights = attention_interface( + self, q, kv, kv, attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, sliding_window=self.sliding_window, + s_aux=self.sinks, **kwargs, + ) + + attn_output = apply_rotary_pos_emb(attn_output.transpose(1, 2), cos, -sin).transpose(1, 2) + grouped = attn_output.reshape(*input_shape, self.config.o_groups, -1) + grouped = self.o_a_proj(grouped).flatten(2) + output = self.o_b_proj(grouped) + return output, attn_weights + + +def _patched_indexer_forward( + self, + hidden_states, + q_residual, + position_ids, + past_key_values, + layer_idx, +): + from transformers.models.deepseek_v4.modeling_deepseek_v4 import apply_rotary_pos_emb + + batch, seq_len, _ = hidden_states.shape + cache_layer = past_key_values.layers[layer_idx] if past_key_values is not None else None + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) + + if cache_layer is None: + usable = (kv.shape[1] // self.compress_rate) * self.compress_rate + chunk_kv, chunk_gate, first_window_position = kv[:, :usable], gate[:, :usable], 0 + else: + chunk_kv, chunk_gate, first_window_position = cache_layer.store_compression_weights('indexer', kv, gate) + + if chunk_kv.shape[1] > 0: + n_windows = chunk_kv.shape[1] // self.compress_rate + ratio = self.compress_rate + chunk_kv = chunk_kv.view(batch, n_windows, ratio, -1) + chunk_gate = chunk_gate.view(batch, n_windows, ratio, -1) + self.position_bias.to(chunk_gate.dtype) + + new_kv = chunk_kv.new_zeros((batch, n_windows, 2 * ratio, self.head_dim)) + new_gate = chunk_gate.new_full((batch, n_windows, 2 * ratio, self.head_dim), float('-inf')) + new_kv[:, :, ratio:] = chunk_kv[..., self.head_dim:] + new_gate[:, :, ratio:] = chunk_gate[..., self.head_dim:] + if n_windows > 1: + new_kv[:, 1:, :ratio] = chunk_kv[:, :-1, :, :self.head_dim] + new_gate[:, 1:, :ratio] = chunk_gate[:, :-1, :, :self.head_dim] + if cache_layer is not None: + prior_kv, prior_gate = cache_layer.update_overlap_state('indexer', chunk_kv, chunk_gate, self.head_dim) + if prior_kv is not None: + new_kv[:, 0, :ratio] = prior_kv.to(new_kv.dtype) + new_gate[:, 0, :ratio] = prior_gate.to(new_gate.dtype) + + compressed = self.kv_norm( + (new_kv * new_gate.softmax(dim=2, dtype=torch.float32).to(new_kv.dtype)).sum(dim=2) + ) + positions = torch.arange(n_windows, device=compressed.device) + positions = positions * self.compress_rate + first_window_position + positions = positions.unsqueeze(0).expand(batch, -1) + cos, sin = self.rotary_emb(compressed, position_ids=positions, layer_type=self.rope_layer_type) + compressed = apply_rotary_pos_emb(compressed.unsqueeze(1), cos, sin).squeeze(1) + else: + compressed = chunk_kv.new_zeros((batch, 0, self.head_dim)) + + compressed_kv = ( + compressed if cache_layer is None else cache_layer.update_compressor_states('indexer', compressed) + ) + + cos_q, sin_q = self.rotary_emb(hidden_states, position_ids=position_ids, layer_type=self.rope_layer_type) + q = self.q_b_proj(q_residual).view(batch, seq_len, -1, self.head_dim).transpose(1, 2) + q = apply_rotary_pos_emb(q, cos_q, sin_q).transpose(1, 2) + + def torch_indexer_top_k_indices(): + scores = torch.matmul(q.float(), compressed_kv.transpose(-1, -2).float().unsqueeze(1)) + scores = F.relu(scores) * self.softmax_scale + weights = self.weights_proj(hidden_states).float() * self.weights_scaling + index_scores = (scores * weights.unsqueeze(-1)).sum(dim=2) + compressed_len = compressed_kv.shape[1] + top_k = min(self.index_topk, compressed_len) + if compressed_len > 0: + causal_threshold = (position_ids + 1) // self.compress_rate + entry_indices = torch.arange(compressed_len, device=index_scores.device) + future_mask = entry_indices.view(1, 1, -1) >= causal_threshold.unsqueeze(-1) + index_scores = index_scores.masked_fill(future_mask, float('-inf')) + top_k_indices = index_scores.topk(top_k, dim=-1).indices + invalid = top_k_indices >= causal_threshold.unsqueeze(-1) + top_k_indices = torch.where(invalid, torch.full_like(top_k_indices, -1), top_k_indices) + if top_k < self.index_topk: + padding = top_k_indices.new_full((batch, seq_len, self.index_topk - top_k), -1) + top_k_indices = torch.cat([top_k_indices, padding], dim=-1) + return top_k_indices + return index_scores.new_full((batch, seq_len, self.index_topk), -1, dtype=torch.long) + + if compressed_kv.shape[1] > 0: + try: + import mindspeed.ops.npu_lightning_indexer as mindspeed_li + + weights = self.weights_proj(hidden_states).to(torch.bfloat16) * self.weights_scaling + q_indexer = q.to(torch.bfloat16) + k_indexer = compressed_kv.to(torch.bfloat16).unsqueeze(2) + top_k_indices, _ = mindspeed_li.npu_lightning_indexer( + q_indexer, k_indexer, weights, + sparse_count=self.index_topk, + sparse_mode=3, + cmp_ratio=self.compress_rate, + ) + top_k_indices = top_k_indices.squeeze(2) + global _li_logged + if not _li_logged: + logger.info( + '[NPU] [DSV4-LI] Twinkle lightning indexer active ' + '(sparse_count=%s, cmp_ratio=%s)', + self.index_topk, self.compress_rate, + ) + _li_logged = True + return top_k_indices + except Exception as e: + logger.warning('[NPU] [DSV4-LI] Failed to run lightning indexer, falling back to torch indexer. Error: %s', e) + + return torch_indexer_top_k_indices() + + +def _make_compressor_wrapper(orig_forward, has_top_k): + def wrapper(self, hidden_states, q_residual, position_ids, past_key_values, layer_idx): + result = orig_forward(self, hidden_states, q_residual, position_ids, past_key_values, layer_idx) + if len(result) == 3: + return result + compressed_kv, block_bias = result + if has_top_k: + top_k_indices = self.indexer(hidden_states, q_residual, position_ids, past_key_values, layer_idx) + return compressed_kv, block_bias, top_k_indices + return compressed_kv, block_bias, None + return wrapper + + +def apply_deepseek_v4_npu_patch(model, sas_enabled=False, li_enabled=False): + try: + from transformers.models.deepseek_v4.modeling_deepseek_v4 import ( + DeepseekV4Attention, + DeepseekV4CSACompressor, + DeepseekV4HCACompressor, + DeepseekV4Indexer, + ) + except ImportError: + return + + if sas_enabled and not getattr(DeepseekV4Attention, '_twinkle_dsv4_sas_patched', False): + DeepseekV4Attention.forward = _patched_attention_forward + DeepseekV4Attention._twinkle_dsv4_sas_patched = True + + if li_enabled and not getattr(DeepseekV4Indexer, '_twinkle_dsv4_li_patched', False): + DeepseekV4Indexer.forward = _patched_indexer_forward + DeepseekV4Indexer._twinkle_dsv4_li_patched = True + + if (sas_enabled or li_enabled) and not getattr(DeepseekV4HCACompressor, '_twinkle_dsv4_compressor_patched', False): + orig_hca = DeepseekV4HCACompressor.forward + orig_csa = DeepseekV4CSACompressor.forward + DeepseekV4HCACompressor.forward = _make_compressor_wrapper(orig_hca, has_top_k=False) + DeepseekV4CSACompressor.forward = _make_compressor_wrapper(orig_csa, has_top_k=True) + DeepseekV4HCACompressor._twinkle_dsv4_compressor_patched = True + DeepseekV4CSACompressor._twinkle_dsv4_compressor_patched = True diff --git a/src/twinkle/kernel/monkey_patch_npu.py b/src/twinkle/kernel/monkey_patch_npu.py index c4ea67b4..0d61cb16 100644 --- a/src/twinkle/kernel/monkey_patch_npu.py +++ b/src/twinkle/kernel/monkey_patch_npu.py @@ -901,6 +901,39 @@ def _discover_and_patch_unknown_models() -> int: return patched +# ============================================================================= +# Section 5b: DeepSeek-V4 NPU Sparse Attention / Lightning Indexer +# ============================================================================= + + +def _apply_deepseek_v4_npu_patch(model=None) -> None: + sas_enabled = _is_env_enabled('TWINKLE_NPU_DSV4_SAS', default=False) + li_enabled = _is_env_enabled('TWINKLE_NPU_DSV4_LI', default=False) + + if not sas_enabled and not li_enabled: + return + + if sas_enabled and li_enabled: + raise ValueError( + '[NPU] [DSV4] TWINKLE_NPU_DSV4_SAS and TWINKLE_NPU_DSV4_LI cannot be enabled simultaneously. ' + 'Please enable only one.' + ) + + if model is not None: + config = getattr(model, 'hf_config', getattr(model, 'config', None)) + archs = getattr(config, 'architectures', None) if config else None + if archs and 'DeepseekV4ForCausalLM' not in archs: + return + + from .deepseek_v4_npu import apply_deepseek_v4_npu_patch + apply_deepseek_v4_npu_patch(model, sas_enabled=sas_enabled, li_enabled=li_enabled) + + if sas_enabled: + logger.info('[NPU] [DSV4] SAS patch applied') + if li_enabled: + logger.info('[NPU] [DSV4] Lightning Indexer patch applied') + + # ============================================================================= # Section 6: Public API # ============================================================================= @@ -935,6 +968,8 @@ def apply_npu_patch(model=None) -> None: When ``0``: disable the patch regardless. - ``TWINKLE_NPU_FLA``: FLA switch (``1``/``0``) - ``TWINKLE_NPU_GATED_RMSNorm_FP32``: force FP32 in Gated RMSNorm (``1``/``0``) + - ``TWINKLE_NPU_DSV4_SAS``: DeepSeek-V4 NPU Sparse Attention (``1``/``0``) + - ``TWINKLE_NPU_DSV4_LI``: DeepSeek-V4 NPU Lightning Indexer (``1``/``0``). Args: model: Optional model instance. If not provided, GMM patch is skipped. @@ -959,6 +994,8 @@ def apply_npu_patch(model=None) -> None: _apply_all_fused_ops(model) + _apply_deepseek_v4_npu_patch(model) + _NPU_PATCH_APPLIED = True logger.info('[NPU] All patches applied successfully')