Skip to content

[Triton] DSV4 fusions phase 1#704

Open
k50112113 wants to merge 6 commits intomainfrom
shaoclee/dsv4-fusions
Open

[Triton] DSV4 fusions phase 1#704
k50112113 wants to merge 6 commits intomainfrom
shaoclee/dsv4-fusions

Conversation

@k50112113
Copy link
Copy Markdown
Contributor

@k50112113 k50112113 commented May 7, 2026

This PR includes:

  1. q_per_head_norm + qk_rope + swa_write -> fused_reduce_q_norm_qk_rope_swa_write (ATOM_V4_USE_TRITON_FUSION)
  2. clamp + silu + mul + quant -> fused_clamp_act_mul_fp8_group_quant (ATOM_V4_USE_TRITON_FUSION)
  3. enabled ATOM_ENABLE_DS_QKNORM_QUANT_FUSION in DSV4 (the API is imported from deepseek_v2.py)
  4. updates: fuse the upcast of combined = self.wkv_gate(x, otype=torch.bfloat16) into _fused_compress_attn_kernel
  5. neg1 = positions.new_full((), -1) at atom/model_ops/attentions/deepseek_v4_attn.py that prevent host-side sync, 5% improvement end-to-end

This PR requires ROCm/aiter#3057

baseline lm_eval

local-completions ({'model': '/data/deepseek-ai/DeepSeek-V4-Flash', 'base_url': 'http://localhost:8000/v1/completions', 'num_concurrent': 16, 'max_retries': 2, 'tokenized_requests': False}), gen_kwargs: ({}), limit: 100.0, num_fewshot: 3, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     3|exact_match|↑  | 0.97|±  |0.0171|
|     |       |strict-match    |     3|exact_match|↑  | 0.97|±  |0.0171|

fused lm_eval

local-completions ({'model': '/data/deepseek-ai/DeepSeek-V4-Flash', 'base_url': 'http://localhost:8000/v1/completions', 'num_concurrent': 16, 'max_retries': 2, 'tokenized_requests': False}), gen_kwargs: ({}), limit: 100.0, num_fewshot: 3, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     3|exact_match|↑  | 0.97|±  |0.0171|
|     |       |strict-match    |     3|exact_match|↑  | 0.97|±  |0.0171|

so far I am seeing ~10% runtime reduction in the trace in decode and not runtime difference in prefill

baseline decode
image

fused decode
image

in end-to-end performance, I saw 8% performance uplift, but ~5% comes from item 5), all other optimization, 1) ~ 4), should be around ~10% of improvement yet we only saw 3%, there may be other things other than 5) that is blocking the performance.

Baseline

Maximum request concurrency: 64
============ Serving Benchmark Result ============
Successful requests:                     512       
Benchmark duration (s):                  210.15    
Total input tokens:                      524288    
Total generated tokens:                  524288    
Request throughput (req/s):              2.44      
Output token throughput (tok/s):         2494.83   
Total Token throughput (tok/s):          4989.66   
---------------Time to First Token----------------
Mean TTFT (ms):                          1205.41   
Median TTFT (ms):                        1251.45   
P99 TTFT (ms):                           2057.02   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          24.49     
Median TPOT (ms):                        24.49     
P99 TPOT (ms):                           25.49     
---------------Inter-token Latency----------------
Mean ITL (ms):                           24.47     
Median ITL (ms):                         23.97     
P99 ITL (ms):                            26.08     
==================================================

With all the optimization in this PR

Maximum request concurrency: 64
============ Serving Benchmark Result ============
Successful requests:                     512       
Benchmark duration (s):                  193.79    
Total input tokens:                      524288    
Total generated tokens:                  524288    
Request throughput (req/s):              2.64      
Output token throughput (tok/s):         2705.39   
Total Token throughput (tok/s):          5410.78   
---------------Time to First Token----------------
Mean TTFT (ms):                          1060.13   
Median TTFT (ms):                        984.96    
P99 TTFT (ms):                           1554.14   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          22.64     
Median TPOT (ms):                        22.55     
P99 TPOT (ms):                           23.45     
---------------Inter-token Latency----------------
Mean ITL (ms):                           22.62     
Median ITL (ms):                         22.10     
P99 ITL (ms):                            24.22     
==================================================

@k50112113
Copy link
Copy Markdown
Contributor Author

updates: fuse the upcast of combined = self.wkv_gate(x, otype=torch.bfloat16) into _fused_compress_attn_kernel

@k50112113
Copy link
Copy Markdown
Contributor Author

updates: neg1 = positions.new_full((), -1) that prevent host-side sync, 5% improvement end-to-end.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant