From 9237456f8ff9d58a2b3ed816e8befae41481820c Mon Sep 17 00:00:00 2001 From: HaonanWang98 Date: Fri, 8 May 2026 09:25:23 +0000 Subject: [PATCH 1/2] add triton fallback for ds & gptoss --- atom/model_loader/loader.py | 73 ++++---- atom/model_ops/attention_mha.py | 186 ++++++++++++------- atom/model_ops/attention_mla.py | 134 +++++++------ atom/model_ops/attentions/aiter_attention.py | 55 ++++-- atom/model_ops/attentions/triton_mla.py | 101 ++++++++++ atom/model_ops/fused_moe_triton.py | 70 ++++--- atom/model_ops/moe.py | 19 +- atom/utils/envs.py | 10 + atom/utils/selector.py | 11 +- 9 files changed, 453 insertions(+), 206 deletions(-) create mode 100644 atom/model_ops/attentions/triton_mla.py diff --git a/atom/model_loader/loader.py b/atom/model_loader/loader.py index 987b05e18..df2fbf02a 100644 --- a/atom/model_loader/loader.py +++ b/atom/model_loader/loader.py @@ -318,6 +318,20 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None: with concurrent.futures.ThreadPoolExecutor() as executor: futures = [] + use_threadpool = envs.ATOM_LOADER_USE_THREADPOOL + if use_threadpool: + executor = concurrent.futures.ThreadPoolExecutor() + else: + executor = None + futures = [] + + def _submit(fn, *args): + if executor is not None: + futures.append(executor.submit(fn, *args)) + else: + fn(*args) + + try: disable_mmap = envs.ATOM_DISABLE_MMAP for name, weight_tensor in safetensors_weights_iterator( model_name_or_path, disable_mmap=disable_mmap @@ -389,11 +403,7 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None: ) continue weight_loader = getattr(param, "weight_loader") - futures.append( - executor.submit( - weight_loader, param, weight_tensor, shard_idx - ) - ) + _submit(weight_loader, param, weight_tensor, shard_idx) loaded_weights_record.add(prefix + param_name) else: # Checkpoint has separate weights, load into fused param @@ -407,12 +417,7 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None: dropped_ckpt_keys.append((_orig_ckpt_name, param_name)) break weight_loader = getattr(param, "weight_loader") - # weight_loader(param, weight_tensor, shard_id) - futures.append( - executor.submit( - weight_loader, param, weight_tensor, shard_id - ) - ) + _submit(weight_loader, param, weight_tensor, shard_id) loaded_weights_record.add(prefix + param_name) break else: @@ -482,15 +487,13 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None: matched = True break weight_loader = getattr(param, "weight_loader") - futures.append( - executor.submit( - weight_loader, - param, - weight_tensor, - name, - shard_id, - expert_id, - ) + _submit( + weight_loader, + param, + weight_tensor, + name, + shard_id, + expert_id, ) loaded_weights_record.add(prefix + name) matched = True @@ -508,15 +511,13 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None: weight_loader = getattr( param, "weight_loader", default_weight_loader ) - futures.append( - executor.submit( - weight_loader, - param, - weight_tensor, - "", # use merged moe loader - "", - expert_id, - ) + _submit( + weight_loader, + param, + weight_tensor, + "", # use merged moe loader + "", + expert_id, ) loaded_weights_record.add(prefix + name) try: @@ -527,9 +528,7 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None: weight_loader = getattr( param, "weight_loader", default_weight_loader ) - futures.append( - executor.submit(weight_loader, param, weight_tensor) - ) + _submit(weight_loader, param, weight_tensor) loaded_weights_record.add(prefix + name) else: # Model doesn't have expert mapping, use generic loading @@ -541,12 +540,12 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None: weight_loader = getattr( param, "weight_loader", default_weight_loader ) - # weight_loader(param, weight_tensor) - futures.append(executor.submit(weight_loader, param, weight_tensor)) + _submit(weight_loader, param, weight_tensor) loaded_weights_record.add(prefix + name) - # Wait for all tasks to complete and raise any exceptions. - for future in concurrent.futures.as_completed(futures): - future.result() + finally: + if executor is not None: + concurrent.futures.wait(futures) + executor.shutdown(wait=True) # Verify every model parameter actually got loaded from the checkpoint. # Without this check, weights_mapping bugs (e.g. a substring rule diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index c3ebd525a..aa91f2b27 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -10,6 +10,7 @@ from aiter.ops.triton.gluon.pa_decode_gluon import get_recommended_splits from aiter.ops.triton.unified_attention import unified_attention from atom.config import get_current_atom_config +from atom.utils import envs from atom.utils.forward_context import ForwardContext, get_forward_context from torch import nn @@ -229,7 +230,7 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): k_scale, v_scale, self.rotary_emb.is_neox_style, - flash_layout=False, + flash_layout=envs.ATOM_USE_UNIFIED_ATTN, apply_scale=self.kv_cache_dtype.startswith("fp8"), offs=None, q_out=q, @@ -372,71 +373,101 @@ def paged_attention_triton( o = torch.empty_like(q) num_seqs = attn_metadata.context_lens.shape[0] - _, num_q_heads_total, head_size = q.shape - num_blocks, num_kv_heads, _, block_size, _ = k_cache.shape - # assume all query have same length - query_group_size = attn_metadata.max_seqlen_q * ( - num_q_heads_total // num_kv_heads - ) - assert num_q_heads_total % num_kv_heads == 0 - max_context_partition_num = get_recommended_splits(num_seqs, num_kv_heads) + if envs.ATOM_USE_UNIFIED_ATTN: + sliding_window = ( + (self.sliding_window - 1, 0) if self.sliding_window > 0 else (-1, -1) + ) - context_partition_size = 256 - if self.sliding_window > 0: - max_context_partition_num = 1 - context_partition_size = 128 + # KV cache is already in flash layout (4D) when + # ATOM_USE_UNIFIED_ATTN is set, allocated by model_runner. + nkv = k_cache.shape[2] + descale_shape = (num_seqs, nkv) - # Output buffers (same as Triton) - intermediate_shape = ( - num_seqs, - num_kv_heads, - max_context_partition_num, - query_group_size, - ) - exp_sums = torch.empty(intermediate_shape, dtype=torch.float32, device=q.device) - max_logits = torch.empty( - intermediate_shape, dtype=torch.float32, device=q.device - ) - temporary_output = torch.empty( - *intermediate_shape, - head_size, - dtype=q.dtype, - device=q.device, - ) + unified_attention( + q, + k_cache, + v_cache, + o, + cu_seqlens_q=attn_metadata.cu_seqlens_q, + seqused_k=attn_metadata.context_lens, + max_seqlen_q=attn_metadata.max_seqlen_q, + max_seqlen_k=attn_metadata.max_seqlen_k, + softmax_scale=self.scale, + causal=True, + alibi_slopes=None, + window_size=sliding_window, + block_table=attn_metadata.block_tables, + softcap=0, + q_descale=None, + k_descale=self.kv_scale.expand(descale_shape), + v_descale=self.kv_scale.expand(descale_shape), + sinks=self.sinks, + ) + else: + _, num_q_heads_total, head_size = q.shape + num_blocks, num_kv_heads, _, block_size, _ = k_cache.shape + query_group_size = attn_metadata.max_seqlen_q * ( + num_q_heads_total // num_kv_heads + ) + assert num_q_heads_total % num_kv_heads == 0 - if k_scale is not None and k_scale.numel() > 1: - k_scale = k_scale.unsqueeze(-1) - v_scale = v_scale.unsqueeze(-1) + max_context_partition_num = get_recommended_splits(num_seqs, num_kv_heads) - compute_type = ( - torch.bfloat16 - if self.kv_cache_dtype == "bf16" # or per_tensor - else aiter.dtypes.fp8 - ) - torch.ops.aiter.pa_decode_gluon( - o, - q, - k_cache, - v_cache, - attn_metadata.context_lens, - attn_metadata.block_tables, - self.scale, - attn_metadata.max_seqlen_q, - max_context_partition_num, - context_partition_size, - compute_type, - None, # q_scale - None if self.kv_cache_dtype == "bf16" else k_scale, - None if self.kv_cache_dtype == "bf16" else v_scale, - exp_sums=exp_sums, - max_logits=max_logits, - temporary_output=temporary_output, - alibi_slopes=None, - sinks=self.sinks, - sliding_window=self.sliding_window, - ps=True, - ) + context_partition_size = 256 + if self.sliding_window > 0: + max_context_partition_num = 1 + context_partition_size = 128 + + intermediate_shape = ( + num_seqs, + num_kv_heads, + max_context_partition_num, + query_group_size, + ) + exp_sums = torch.empty( + intermediate_shape, dtype=torch.float32, device=q.device + ) + max_logits = torch.empty( + intermediate_shape, dtype=torch.float32, device=q.device + ) + temporary_output = torch.empty( + *intermediate_shape, + head_size, + dtype=q.dtype, + device=q.device, + ) + + if k_scale is not None and k_scale.numel() > 1: + k_scale = k_scale.unsqueeze(-1) + v_scale = v_scale.unsqueeze(-1) + + compute_type = ( + torch.bfloat16 if self.kv_cache_dtype == "bf16" else aiter.dtypes.fp8 + ) + torch.ops.aiter.pa_decode_gluon( + o, + q, + k_cache, + v_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + self.scale, + attn_metadata.max_seqlen_q, + max_context_partition_num, + context_partition_size, + compute_type, + None, # q_scale + None if self.kv_cache_dtype == "bf16" else k_scale, + None if self.kv_cache_dtype == "bf16" else v_scale, + exp_sums=exp_sums, + max_logits=max_logits, + temporary_output=temporary_output, + alibi_slopes=None, + sinks=self.sinks, + sliding_window=self.sliding_window, + ps=True, + ) return o @@ -543,19 +574,42 @@ def prefill_attention_triton( block_tables = attn_metadata.block_tables o = torch.empty_like(q) - descale_shape = (attn_metadata.cu_seqlens_q.shape[0] - 1, k.shape[1]) + num_seqs = attn_metadata.cu_seqlens_q.shape[0] - 1 + descale_shape = (num_seqs, k.shape[1]) sliding_window = ( (self.sliding_window - 1, 0) if self.sliding_window is not None else (-1, -1) ) + + if block_tables is None: + # Prefill has no block_table. Use k/v directly as kv_cache with + # block_size=1 and a fake block_table (see comments above). + # k: [total_tokens, num_kv_heads, head_size] + # -> [total_tokens, 1, num_kv_heads, head_size] + k_for_attn = k.unsqueeze(1) + v_for_attn = v.unsqueeze(1) + # Build per-seq block tables: seq i maps to token indices + # [cu_seqlens_k[i], cu_seqlens_k[i]+1, ..., cu_seqlens_k[i+1]-1] + max_seqlen_k = attn_metadata.max_seqlen_k + cu_k = attn_metadata.cu_seqlens_k + offsets = cu_k[:num_seqs] # [num_seqs] + block_tables = offsets.unsqueeze(1) + torch.arange( + max_seqlen_k, dtype=torch.int32, device=q.device + ) + seqused_k = cu_k[1:] - cu_k[:num_seqs] + else: + k_for_attn = k_cache + v_for_attn = v_cache + seqused_k = attn_metadata.context_lens + unified_attention( q, - k_cache, - v_cache, + k_for_attn, + v_for_attn, o, cu_seqlens_q=attn_metadata.cu_seqlens_q, - seqused_k=attn_metadata.context_lens, + seqused_k=seqused_k, max_seqlen_q=attn_metadata.max_seqlen_q, max_seqlen_k=attn_metadata.max_seqlen_k, softmax_scale=self.scale, @@ -577,6 +631,8 @@ def dispatch_backend(self, fwd_ctx: ForwardContext): ctx = fwd_ctx.context if ctx.is_prefill: + if envs.ATOM_USE_UNIFIED_ATTN and self.use_triton_attn: + return self.prefill_attention_triton return self.prefill_attention else: if self.use_triton_attn: diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index b2a8aa67d..6ec06f776 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -546,64 +546,88 @@ def _forward_decode( device=q.device, ) - kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) - paged_kv_indptr = attn_metadata.kv_indptr - paged_kv_indices = attn_metadata.kv_indices - if self.topk_indices_buffer is not None: - paged_kv_indptr = attn_metadata.sparse_kv_indptr - paged_kv_indices = triton_convert_req_index_to_global_index( - attn_metadata.cu_seqlens_q, - attn_metadata.kv_indptr, - paged_kv_indptr, - attn_metadata.kv_indices, - self.topk_indices_buffer[:B], - NUM_TOPK_TOKENS=self.topk_indices_buffer.shape[1], + if hasattr(attn_metadata, "triton_block_table"): + from aiter.ops.triton.attention.mla_decode import decode_attention_fwd + + k_buffer = kv_c_and_k_pe_cache.unsqueeze(2) + v_buffer = k_buffer[..., : self.kv_lora_rank] + page_size = k_buffer.shape[1] + + q_for_triton = ( + q.to(torch.bfloat16) + if q.dtype.is_floating_point and q.element_size() == 1 + else q ) - # q_scale = kv_scale = None - # if self.kv_cache_dtype.startswith("fp8"): - # q = q.to(dtypes.fp8) - # q_scale = kv_scale = self.one_scale - - dp_size = get_dp_group().world_size - use_persistent_mode = not (dp_size > 1) - - if not use_persistent_mode: - # DP : disable persistent mode to avoid overflow - work_meta_data = None - work_indptr = None - work_info_set = None - reduce_indptr = None - reduce_final_map = None - reduce_partial_map = None + # Use pre-built dense block_table from prepare_decode() + decode_attention_fwd( + q_for_triton, + k_buffer, + v_buffer, + o, + attn_metadata.triton_lse, + attn_metadata.triton_block_table, + attn_metadata.context_lens, + attn_metadata.triton_attn_logits, + 4, # num_kv_splits + self.scale, + page_size, + k_scale=self._k_scale, + v_scale=self._k_scale, + ) else: - work_meta_data = attn_metadata.work_meta_data - work_indptr = attn_metadata.work_indptr - work_info_set = attn_metadata.work_info_set - reduce_indptr = attn_metadata.reduce_indptr - reduce_final_map = attn_metadata.reduce_final_map - reduce_partial_map = attn_metadata.reduce_partial_map - - mla_decode_fwd( - q, - kv_buffer.view(-1, 1, 1, q.shape[-1]), - o, - attn_metadata.cu_seqlens_q, - paged_kv_indptr, - paged_kv_indices, - attn_metadata.kv_last_page_lens, - attn_metadata.max_seqlen_q, - num_kv_splits=16, - sm_scale=self.scale, - work_meta_data=work_meta_data, - work_indptr=work_indptr, - work_info_set=work_info_set, - reduce_indptr=reduce_indptr, - reduce_final_map=reduce_final_map, - reduce_partial_map=reduce_partial_map, - q_scale=self._q_scale, - kv_scale=self._k_scale, - ) + kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) + paged_kv_indptr = attn_metadata.kv_indptr + paged_kv_indices = attn_metadata.kv_indices + if self.topk_indices_buffer is not None: + paged_kv_indptr = attn_metadata.sparse_kv_indptr + paged_kv_indices = triton_convert_req_index_to_global_index( + attn_metadata.cu_seqlens_q, + attn_metadata.kv_indptr, + paged_kv_indptr, + attn_metadata.kv_indices, + self.topk_indices_buffer[:B], + NUM_TOPK_TOKENS=self.topk_indices_buffer.shape[1], + ) + + dp_size = get_dp_group().world_size + use_persistent_mode = not (dp_size > 1) + + if not use_persistent_mode: + work_meta_data = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + else: + work_meta_data = attn_metadata.work_meta_data + work_indptr = attn_metadata.work_indptr + work_info_set = attn_metadata.work_info_set + reduce_indptr = attn_metadata.reduce_indptr + reduce_final_map = attn_metadata.reduce_final_map + reduce_partial_map = attn_metadata.reduce_partial_map + + mla_decode_fwd( + q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + attn_metadata.cu_seqlens_q, + paged_kv_indptr, + paged_kv_indices, + attn_metadata.kv_last_page_lens, + attn_metadata.max_seqlen_q, + num_kv_splits=16, + sm_scale=self.scale, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=self._q_scale, + kv_scale=self._k_scale, + ) if self.head_repeat_factor > 1: o = o[:, :: self.head_repeat_factor, :].contiguous() diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index c6defd2fd..7e02dedbf 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -8,7 +8,7 @@ import torch from aiter.dist.parallel_state import get_tp_group from atom.model_engine.scheduler import ScheduledBatch -from atom.utils import CpuGpuBuffer +from atom.utils import CpuGpuBuffer, envs from atom.utils.block_convert import ( block_table_convert_triton, kv_indices_generate_triton, @@ -498,19 +498,50 @@ def build_kv_cache_tensor(self, layer_id: int, module): ) else: x = 16 // runner.kv_cache.element_size() - k_cache = runner.kv_cache[0, attn_idx].view( - runner.num_physical_kvcache_blocks, - runner.num_kv_heads, - hf_config.head_dim // x, - runner.physical_block_size, - x, + # unified_attention consumes flash layout directly; the ASM path + # keeps the legacy shuffled K layout. + impl = getattr(module, "impl", None) + use_triton_attn = impl is not None and ( + getattr(impl, "sliding_window", -1) != -1 + or getattr(impl, "head_dim", 128) != 128 ) - v_cache = runner.kv_cache[1, attn_idx].view( - runner.num_physical_kvcache_blocks, - runner.num_kv_heads, - hf_config.head_dim, - runner.physical_block_size, + fused_shuffle_path = impl is not None and ( + getattr(impl, "rotary_emb", None) is not None + and getattr(impl, "q_norm", None) is not None + and getattr(impl, "k_norm", None) is not None + ) + use_flash_layout = ( + use_triton_attn + and not fused_shuffle_path + and envs.ATOM_USE_UNIFIED_ATTN ) + if use_flash_layout: + k_cache = runner.kv_cache[0, attn_idx].view( + runner.num_physical_kvcache_blocks, + runner.physical_block_size, + runner.num_kv_heads, + hf_config.head_dim, + ) + v_cache = runner.kv_cache[1, attn_idx].view( + runner.num_physical_kvcache_blocks, + runner.physical_block_size, + runner.num_kv_heads, + hf_config.head_dim, + ) + else: + k_cache = runner.kv_cache[0, attn_idx].view( + runner.num_physical_kvcache_blocks, + runner.num_kv_heads, + hf_config.head_dim // x, + runner.physical_block_size, + x, + ) + v_cache = runner.kv_cache[1, attn_idx].view( + runner.num_physical_kvcache_blocks, + runner.num_kv_heads, + hf_config.head_dim, + runner.physical_block_size, + ) if config.kv_cache_dtype == "fp8": module.k_scale = runner.kv_scale[0, attn_idx] module.v_scale = runner.kv_scale[1, attn_idx] diff --git a/atom/model_ops/attentions/triton_mla.py b/atom/model_ops/attentions/triton_mla.py new file mode 100644 index 000000000..3a52dfd4f --- /dev/null +++ b/atom/model_ops/attentions/triton_mla.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import logging +from typing import Type + +import torch +from aiter.ops.triton.attention.mla_decode import csr_to_dense_block_table +from atom.model_engine.scheduler import ScheduledBatch +from atom.model_ops.attention_mla import MLAAttention +from atom.utils.forward_context import AttentionMetaData + +from .aiter_mla import AiterMLAMetadataBuilder +from .backends import AttentionBackend + +logger = logging.getLogger("atom") + + +class TritonMLABackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "ROCM_TRITON_MLA" + + @staticmethod + def get_builder_cls() -> Type["TritonMLAMetadataBuilder"]: + return TritonMLAMetadataBuilder + + @staticmethod + def get_impl_cls() -> Type["MLAAttention"]: + return MLAAttention + + +class TritonMLAMetadataBuilder(AiterMLAMetadataBuilder): + + def __init__(self, model_runner): + super().__init__(model_runner) + + hf = model_runner.config.hf_config + kv_lora_rank = hf.kv_lora_rank + num_kv_splits = 4 + triton_mla_buffers = { + "triton_block_table": torch.zeros( + self.max_bs, + self.max_num_blocks_per_seq, + dtype=torch.int32, + device=self.device, + ), + "triton_attn_logits": torch.empty( + self.max_bs, + self.padded_num_attention_heads, + num_kv_splits, + kv_lora_rank + 1, + dtype=torch.float32, + device=self.device, + ), + "triton_lse": torch.empty( + self.max_bs, + self.padded_num_attention_heads, + dtype=torch.float32, + device=self.device, + ), + } + self.model_runner.forward_vars.update(triton_mla_buffers) + + def set_mla_persistent_worker_buffers( + self, bs, max_q_len, only_update=False, num_reject_tokens=None + ): + # Triton MLA does not use aiter persistent worker buffers + return {} + + def prepare_decode(self, batch: ScheduledBatch, bs: int): + attn_metadata, positions = super().prepare_decode(batch, bs) + + scheduled_bs = batch.total_seqs_num_decode + max_seqlen_k = attn_metadata.max_seqlen_k + var = self.model_runner.forward_vars + + triton_bt = var["triton_block_table"][:scheduled_bs, :max_seqlen_k] + triton_bt.zero_() + csr_to_dense_block_table( + attn_metadata.kv_indices, + attn_metadata.kv_indptr, + triton_bt, + max_seqlen_k, + scheduled_bs, + ) + attn_metadata.triton_block_table = triton_bt + attn_metadata.triton_attn_logits = var["triton_attn_logits"][:scheduled_bs] + attn_metadata.triton_lse = var["triton_lse"][:scheduled_bs] + + return attn_metadata, positions + + def build_for_cudagraph_capture(self, bs: int) -> AttentionMetaData: + attn_metadata, context = super().build_for_cudagraph_capture(bs) + + var = self.model_runner.forward_vars + attn_metadata.triton_block_table = var["triton_block_table"][:bs] + attn_metadata.triton_attn_logits = var["triton_attn_logits"][:bs] + attn_metadata.triton_lse = var["triton_lse"][:bs] + + return attn_metadata, context diff --git a/atom/model_ops/fused_moe_triton.py b/atom/model_ops/fused_moe_triton.py index 96aafcfe9..e604b5351 100644 --- a/atom/model_ops/fused_moe_triton.py +++ b/atom/model_ops/fused_moe_triton.py @@ -24,6 +24,7 @@ from typing import Any import logging from math import prod +from aiter import ActivationType from aiter.jit.utils.chip_info import get_gfx from atom.model_ops.utils import has_triton_kernels @@ -32,9 +33,18 @@ if has_triton_kernels(): try: - from triton_kernels.matmul_ogs import matmul_ogs + import triton_kernels.swiglu + from triton_kernels.matmul_ogs import ( + FnSpecs, + FusedActivation, + PrecisionConfig, + matmul_ogs, + ) + from triton_kernels.matmul_ogs_details.opt_flags import ( + update_opt_flags_constraints, + reset_opt_flags_constraints, + ) from triton_kernels.routing import routing - from triton_kernels.matmul_ogs import PrecisionConfig except (AttributeError, ImportError) as e: logger.error( "Failed to import Triton kernels. Please make sure your triton " @@ -283,28 +293,42 @@ def triton_kernel_fused_experts( ) with _amd_smem_safe_tile(): - matmul_ogs( - hidden_states, - w1, - w1_bias, - routing_data, - gather_indx=gather_indx, - precision_config=w13_precision_config, - gammas=gammas if apply_router_weight_on_input else None, - y=raw_intermediate, - ) + if activation == ActivationType.Swiglu: + # SwiGLU (GPT OSS): fused activation with interleaved [gate, up] layout + act = FusedActivation( + FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), + (swiglu_alpha, swiglu_limit), + 2, + ) + matmul_ogs( + hidden_states, + w1, + w1_bias, + routing_data, + gather_indx=gather_indx, + precision_config=w13_precision_config, + gammas=gammas if apply_router_weight_on_input else None, + fused_activation=act, + y=intermediate_cache, + ) + else: + # SiLU (DeepSeek): concatenated [gate | up] layout, manual activation + raw_intermediate = matmul_ogs( + hidden_states, + w1, + w1_bias, + routing_data, + gather_indx=gather_indx, + precision_config=w13_precision_config, + gammas=gammas if apply_router_weight_on_input else None, + ) + raw_2d = raw_intermediate.view(M * topk, N) + gate = raw_2d[:, :half_N] + up = raw_2d[:, half_N:] + intermediate_cache = intermediate_cache.view(M * topk, half_N) + intermediate_cache.copy_(torch.nn.functional.silu(gate) * up) + intermediate_cache = intermediate_cache.view(batch_dim, M * topk, half_N) - # Standard SiLU/SwiGLU activation: silu(gate) * up - # With optional swiglu_limit clamping (V4: limit=10.0) - raw_2d = raw_intermediate.view(M * topk, N) - gate = raw_2d[:, :half_N] - up = raw_2d[:, half_N:] - if swiglu_limit > 0: - gate = gate.clamp(max=swiglu_limit) - up = up.clamp(-swiglu_limit, swiglu_limit) - intermediate_cache[0] = torch.nn.functional.silu(gate) * up - - with _amd_smem_safe_tile(): matmul_ogs( intermediate_cache.view(M * topk, half_N), w2, diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 9f6e0a0f4..f241ab219 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -686,11 +686,14 @@ def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig): or self.quant_type == QuantType.per_1x32 ) gfx = get_gfx() - self.use_triton = ( - gfx.startswith("gfx94") - or (gfx.startswith("gfx95") and envs.ATOM_USE_TRITON_GEMM) - or os.environ.get("ATOM_USE_TRITON_MOE") == "1" - ) + if envs.is_set("ATOM_USE_TRITON_MOE"): + self.use_triton = envs.ATOM_USE_TRITON_MOE + else: + self.use_triton = ( + gfx.startswith("gfx94") + or gfx.startswith("gfx12") + or (gfx.startswith("gfx95") and envs.ATOM_USE_TRITON_GEMM) + ) if self.use_triton: from atom.model_ops.utils import has_triton_kernels @@ -975,11 +978,13 @@ def apply( num_fused_shared_experts=layer.num_fused_shared_experts, routed_scaling_factor=layer.routed_scaling_factor, ) + n_expts_act = topk_weights.shape[1] # Convert to triton routing data structures n_expts_tot = router_logits.shape[-1] if global_num_experts > 0: n_expts_tot = global_num_experts + n_expts_tot = n_expts_tot + layer.num_fused_shared_experts routing_data, gather_idx, scatter_idx = routing_from_topk( topk_weights, topk_ids, n_expts_tot @@ -994,7 +999,7 @@ def apply( routing_data, gather_idx, scatter_idx, - topk=top_k, + topk=n_expts_act, activation=activation, w13_precision_config=self.w13_precision_config, w2_precision_config=self.w2_precision_config, @@ -1002,7 +1007,7 @@ def apply( w2_bias=layer.w2_bias, swiglu_limit=getattr(layer, "swiglu_limit", 0.0), apply_router_weight_on_input=layer.apply_router_weight_on_input, - global_num_experts=global_num_experts, + global_num_experts=n_expts_tot, expert_map=expert_map, ) return _moe_result diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 269fa0c20..babd6cc2c 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -33,6 +33,8 @@ "ATOM_USE_TRITON_MXFP4_BMM": lambda: ( os.getenv("ATOM_USE_TRITON_MXFP4_BMM", "0") == "1" ), + "ATOM_USE_TRITON_MLA": lambda: os.getenv("ATOM_USE_TRITON_MLA", "0") == "1", + "ATOM_USE_TRITON_MOE": lambda: os.getenv("ATOM_USE_TRITON_MOE", "0") == "1", # --- Kernel Fusion Toggles --- # QK-norm-rope-cache-quant fusion for Qwen3-MoE; disabled by default. # Enable for Qwen3-MoE to get better performance. @@ -69,6 +71,14 @@ "ATOM_DISABLE_MMAP": lambda: ( os.getenv("ATOM_DISABLE_MMAP", "false").lower() == "true" ), + # Use a thread pool for weight loading instead of main-process sequential I/O. + # Set to 0 to disable if the thread pool causes hangs (e.g. on gfx1250). + "ATOM_LOADER_USE_THREADPOOL": lambda: os.getenv("ATOM_LOADER_USE_THREADPOOL", "1") + == "1", + # --- Attention Backend --- + # Use unified_attention (flash-style) for MHA paged/prefill attention instead + # of pa_decode_gluon. Set to 1 to enable the unified_attention path. + "ATOM_USE_UNIFIED_ATTN": lambda: os.getenv("ATOM_USE_UNIFIED_ATTN", "0") == "1", # --- Plugin Mode --- "ATOM_DISABLE_VLLM_PLUGIN": lambda: ( os.getenv("ATOM_DISABLE_VLLM_PLUGIN", "0").lower() == "1" diff --git a/atom/utils/selector.py b/atom/utils/selector.py index e87b1f819..d76708c3e 100644 --- a/atom/utils/selector.py +++ b/atom/utils/selector.py @@ -7,6 +7,7 @@ from atom.model_ops.attentions.backends import AttentionBackend from atom.utils import resolve_obj_by_qualname from atom.plugin.prepare import is_sglang, is_vllm +from atom.utils import envs def get_attn_backend( @@ -51,13 +52,9 @@ def get_attn_backend_cls( if use_v4: return "atom.model_ops.attentions.deepseek_v4_attn.DeepseekV4Backend" if use_mla: - # if block_size == 1: - return "atom.model_ops.attentions.aiter_mla.AiterMLABackend" # noqa: E501 - # else: - # raise ValueError( - # f" The selected backend" - # f"does not support block size {block_size}." - # "(currently only supports block size 1)") + if envs.ATOM_USE_TRITON_MLA: + return "atom.model_ops.attentions.triton_mla.TritonMLABackend" + return "atom.model_ops.attentions.aiter_mla.AiterMLABackend" if use_gdn: if use_vllm: return "atom.plugin.vllm.attention_backend.gdn_attn.GDNAttentionBackend" From 06d85ba205afa4076b0510d02c613de1ceec00b6 Mon Sep 17 00:00:00 2001 From: HaonanWang98 Date: Sat, 9 May 2026 02:35:23 +0000 Subject: [PATCH 2/2] fix format --- atom/model_ops/fused_moe_triton.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/atom/model_ops/fused_moe_triton.py b/atom/model_ops/fused_moe_triton.py index e604b5351..6fc4c6e7b 100644 --- a/atom/model_ops/fused_moe_triton.py +++ b/atom/model_ops/fused_moe_triton.py @@ -40,10 +40,6 @@ PrecisionConfig, matmul_ogs, ) - from triton_kernels.matmul_ogs_details.opt_flags import ( - update_opt_flags_constraints, - reset_opt_flags_constraints, - ) from triton_kernels.routing import routing except (AttributeError, ImportError) as e: logger.error(