diff --git a/atom/model_loader/loader.py b/atom/model_loader/loader.py index 987b05e18..df2fbf02a 100644 --- a/atom/model_loader/loader.py +++ b/atom/model_loader/loader.py @@ -318,6 +318,20 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None: with concurrent.futures.ThreadPoolExecutor() as executor: futures = [] + use_threadpool = envs.ATOM_LOADER_USE_THREADPOOL + if use_threadpool: + executor = concurrent.futures.ThreadPoolExecutor() + else: + executor = None + futures = [] + + def _submit(fn, *args): + if executor is not None: + futures.append(executor.submit(fn, *args)) + else: + fn(*args) + + try: disable_mmap = envs.ATOM_DISABLE_MMAP for name, weight_tensor in safetensors_weights_iterator( model_name_or_path, disable_mmap=disable_mmap @@ -389,11 +403,7 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None: ) continue weight_loader = getattr(param, "weight_loader") - futures.append( - executor.submit( - weight_loader, param, weight_tensor, shard_idx - ) - ) + _submit(weight_loader, param, weight_tensor, shard_idx) loaded_weights_record.add(prefix + param_name) else: # Checkpoint has separate weights, load into fused param @@ -407,12 +417,7 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None: dropped_ckpt_keys.append((_orig_ckpt_name, param_name)) break weight_loader = getattr(param, "weight_loader") - # weight_loader(param, weight_tensor, shard_id) - futures.append( - executor.submit( - weight_loader, param, weight_tensor, shard_id - ) - ) + _submit(weight_loader, param, weight_tensor, shard_id) loaded_weights_record.add(prefix + param_name) break else: @@ -482,15 +487,13 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None: matched = True break weight_loader = getattr(param, "weight_loader") - futures.append( - executor.submit( - weight_loader, - param, - weight_tensor, - name, - shard_id, - expert_id, - ) + _submit( + weight_loader, + param, + weight_tensor, + name, + shard_id, + expert_id, ) loaded_weights_record.add(prefix + name) matched = True @@ -508,15 +511,13 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None: weight_loader = getattr( param, "weight_loader", default_weight_loader ) - futures.append( - executor.submit( - weight_loader, - param, - weight_tensor, - "", # use merged moe loader - "", - expert_id, - ) + _submit( + weight_loader, + param, + weight_tensor, + "", # use merged moe loader + "", + expert_id, ) loaded_weights_record.add(prefix + name) try: @@ -527,9 +528,7 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None: weight_loader = getattr( param, "weight_loader", default_weight_loader ) - futures.append( - executor.submit(weight_loader, param, weight_tensor) - ) + _submit(weight_loader, param, weight_tensor) loaded_weights_record.add(prefix + name) else: # Model doesn't have expert mapping, use generic loading @@ -541,12 +540,12 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None: weight_loader = getattr( param, "weight_loader", default_weight_loader ) - # weight_loader(param, weight_tensor) - futures.append(executor.submit(weight_loader, param, weight_tensor)) + _submit(weight_loader, param, weight_tensor) loaded_weights_record.add(prefix + name) - # Wait for all tasks to complete and raise any exceptions. - for future in concurrent.futures.as_completed(futures): - future.result() + finally: + if executor is not None: + concurrent.futures.wait(futures) + executor.shutdown(wait=True) # Verify every model parameter actually got loaded from the checkpoint. # Without this check, weights_mapping bugs (e.g. a substring rule diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index c3ebd525a..05d389fbf 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -77,6 +77,10 @@ def __init__( self.rotary_emb = rotary_emb self.q_norm = q_norm self.k_norm = k_norm + # Set by the attention backend's build_kv_cache_tensor when KV cache is + # allocated in flash layout [num_blocks, block_size, num_kv_heads, head_dim] + # for aiter triton unified_attention. AiterBackend keeps this False. + self.use_flash_layout = False # for plugin mode(vllm), the query quant is disabled for now if is_vllm(): @@ -229,7 +233,7 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): k_scale, v_scale, self.rotary_emb.is_neox_style, - flash_layout=False, + flash_layout=self.use_flash_layout, apply_scale=self.kv_cache_dtype.startswith("fp8"), offs=None, q_out=q, @@ -372,71 +376,101 @@ def paged_attention_triton( o = torch.empty_like(q) num_seqs = attn_metadata.context_lens.shape[0] - _, num_q_heads_total, head_size = q.shape - num_blocks, num_kv_heads, _, block_size, _ = k_cache.shape - # assume all query have same length - query_group_size = attn_metadata.max_seqlen_q * ( - num_q_heads_total // num_kv_heads - ) - assert num_q_heads_total % num_kv_heads == 0 - max_context_partition_num = get_recommended_splits(num_seqs, num_kv_heads) + if self.use_flash_layout: + sliding_window = ( + (self.sliding_window - 1, 0) if self.sliding_window > 0 else (-1, -1) + ) - context_partition_size = 256 - if self.sliding_window > 0: - max_context_partition_num = 1 - context_partition_size = 128 + # KV cache is already in flash layout (4D), allocated by + # TritonMHAMetadataBuilder.build_kv_cache_tensor. + nkv = k_cache.shape[2] + descale_shape = (num_seqs, nkv) - # Output buffers (same as Triton) - intermediate_shape = ( - num_seqs, - num_kv_heads, - max_context_partition_num, - query_group_size, - ) - exp_sums = torch.empty(intermediate_shape, dtype=torch.float32, device=q.device) - max_logits = torch.empty( - intermediate_shape, dtype=torch.float32, device=q.device - ) - temporary_output = torch.empty( - *intermediate_shape, - head_size, - dtype=q.dtype, - device=q.device, - ) + unified_attention( + q, + k_cache, + v_cache, + o, + cu_seqlens_q=attn_metadata.cu_seqlens_q, + seqused_k=attn_metadata.context_lens, + max_seqlen_q=attn_metadata.max_seqlen_q, + max_seqlen_k=attn_metadata.max_seqlen_k, + softmax_scale=self.scale, + causal=True, + alibi_slopes=None, + window_size=sliding_window, + block_table=attn_metadata.block_tables, + softcap=0, + q_descale=None, + k_descale=self.kv_scale.expand(descale_shape), + v_descale=self.kv_scale.expand(descale_shape), + sinks=self.sinks, + ) + else: + _, num_q_heads_total, head_size = q.shape + num_blocks, num_kv_heads, _, block_size, _ = k_cache.shape + query_group_size = attn_metadata.max_seqlen_q * ( + num_q_heads_total // num_kv_heads + ) + assert num_q_heads_total % num_kv_heads == 0 - if k_scale is not None and k_scale.numel() > 1: - k_scale = k_scale.unsqueeze(-1) - v_scale = v_scale.unsqueeze(-1) + max_context_partition_num = get_recommended_splits(num_seqs, num_kv_heads) - compute_type = ( - torch.bfloat16 - if self.kv_cache_dtype == "bf16" # or per_tensor - else aiter.dtypes.fp8 - ) - torch.ops.aiter.pa_decode_gluon( - o, - q, - k_cache, - v_cache, - attn_metadata.context_lens, - attn_metadata.block_tables, - self.scale, - attn_metadata.max_seqlen_q, - max_context_partition_num, - context_partition_size, - compute_type, - None, # q_scale - None if self.kv_cache_dtype == "bf16" else k_scale, - None if self.kv_cache_dtype == "bf16" else v_scale, - exp_sums=exp_sums, - max_logits=max_logits, - temporary_output=temporary_output, - alibi_slopes=None, - sinks=self.sinks, - sliding_window=self.sliding_window, - ps=True, - ) + context_partition_size = 256 + if self.sliding_window > 0: + max_context_partition_num = 1 + context_partition_size = 128 + + intermediate_shape = ( + num_seqs, + num_kv_heads, + max_context_partition_num, + query_group_size, + ) + exp_sums = torch.empty( + intermediate_shape, dtype=torch.float32, device=q.device + ) + max_logits = torch.empty( + intermediate_shape, dtype=torch.float32, device=q.device + ) + temporary_output = torch.empty( + *intermediate_shape, + head_size, + dtype=q.dtype, + device=q.device, + ) + + if k_scale is not None and k_scale.numel() > 1: + k_scale = k_scale.unsqueeze(-1) + v_scale = v_scale.unsqueeze(-1) + + compute_type = ( + torch.bfloat16 if self.kv_cache_dtype == "bf16" else aiter.dtypes.fp8 + ) + torch.ops.aiter.pa_decode_gluon( + o, + q, + k_cache, + v_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + self.scale, + attn_metadata.max_seqlen_q, + max_context_partition_num, + context_partition_size, + compute_type, + None, # q_scale + None if self.kv_cache_dtype == "bf16" else k_scale, + None if self.kv_cache_dtype == "bf16" else v_scale, + exp_sums=exp_sums, + max_logits=max_logits, + temporary_output=temporary_output, + alibi_slopes=None, + sinks=self.sinks, + sliding_window=self.sliding_window, + ps=True, + ) return o @@ -543,19 +577,42 @@ def prefill_attention_triton( block_tables = attn_metadata.block_tables o = torch.empty_like(q) - descale_shape = (attn_metadata.cu_seqlens_q.shape[0] - 1, k.shape[1]) + num_seqs = attn_metadata.cu_seqlens_q.shape[0] - 1 + descale_shape = (num_seqs, k.shape[1]) sliding_window = ( (self.sliding_window - 1, 0) if self.sliding_window is not None else (-1, -1) ) + + if block_tables is None: + # Prefill has no block_table. Use k/v directly as kv_cache with + # block_size=1 and a fake block_table (see comments above). + # k: [total_tokens, num_kv_heads, head_size] + # -> [total_tokens, 1, num_kv_heads, head_size] + k_for_attn = k.unsqueeze(1) + v_for_attn = v.unsqueeze(1) + # Build per-seq block tables: seq i maps to token indices + # [cu_seqlens_k[i], cu_seqlens_k[i]+1, ..., cu_seqlens_k[i+1]-1] + max_seqlen_k = attn_metadata.max_seqlen_k + cu_k = attn_metadata.cu_seqlens_k + offsets = cu_k[:num_seqs] # [num_seqs] + block_tables = offsets.unsqueeze(1) + torch.arange( + max_seqlen_k, dtype=torch.int32, device=q.device + ) + seqused_k = cu_k[1:] - cu_k[:num_seqs] + else: + k_for_attn = k_cache + v_for_attn = v_cache + seqused_k = attn_metadata.context_lens + unified_attention( q, - k_cache, - v_cache, + k_for_attn, + v_for_attn, o, cu_seqlens_q=attn_metadata.cu_seqlens_q, - seqused_k=attn_metadata.context_lens, + seqused_k=seqused_k, max_seqlen_q=attn_metadata.max_seqlen_q, max_seqlen_k=attn_metadata.max_seqlen_k, softmax_scale=self.scale, @@ -577,9 +634,11 @@ def dispatch_backend(self, fwd_ctx: ForwardContext): ctx = fwd_ctx.context if ctx.is_prefill: + if self.use_flash_layout: + return self.prefill_attention_triton return self.prefill_attention else: - if self.use_triton_attn: + if self.use_triton_attn or self.use_flash_layout: return self.paged_attention_triton else: # Only use pa persistent when block_size == 1024 diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 556ec8cf0..393953221 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -547,94 +547,124 @@ def _forward_decode( device=q.device, ) - kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) - paged_cu_seqlens_q = attn_metadata.cu_seqlens_q - paged_kv_indptr = attn_metadata.kv_indptr - paged_kv_indices = attn_metadata.kv_indices - paged_kv_last_page_lens = attn_metadata.kv_last_page_lens - max_q_len = attn_metadata.max_seqlen_q - if self.topk_indices_buffer is not None: - if attn_metadata.max_seqlen_q > 1: - # MTP verify: per-token layout with max_q_len=1. - # Persistent metadata is per-token (from _set_mla_persistent_worker_buffers_sparse_mtp). - paged_cu_seqlens_q = attn_metadata.sparse_cu_seqlens_q - paged_kv_indptr = attn_metadata.sparse_kv_indptr - paged_kv_last_page_lens = attn_metadata.sparse_kv_last_page_lens - # Gather physical page indices from kv_indices using topk positions. - # block_tables contains large-block IDs (block_ratio > 1) that - # need expansion; kv_indices already has per-token page indices. - paged_kv_indices = triton_gather_kv_indices_sparse( - paged_kv_indptr, - attn_metadata.token_to_seq_idxs, - self.topk_indices_buffer[:B], - attn_metadata.kv_indices, - attn_metadata.kv_indptr, - NUM_TOPK_TOKENS=self.topk_indices_buffer.shape[1], - ) - max_q_len = 1 - else: - paged_kv_indptr = attn_metadata.sparse_kv_indptr - paged_kv_indices = triton_convert_req_index_to_global_index( - attn_metadata.cu_seqlens_q, - attn_metadata.kv_indptr, - paged_kv_indptr, - attn_metadata.kv_indices, - self.topk_indices_buffer[:B], - NUM_TOPK_TOKENS=self.topk_indices_buffer.shape[1], - ) + if hasattr(attn_metadata, "triton_block_table"): + from aiter.ops.triton.attention.mla_decode import decode_attention_fwd - dp_size = get_dp_group().world_size - use_persistent_mode = not (dp_size > 1) + k_buffer = kv_c_and_k_pe_cache.unsqueeze(2) + v_buffer = k_buffer[..., : self.kv_lora_rank] + page_size = k_buffer.shape[1] - # Sparse layers in MTP verify use separate persistent metadata - # (per-token, max_seqlen_qo=1) while dense layers use normal metadata - # (max_seqlen_qo=2). - is_sparse_mtp = ( - self.topk_indices_buffer is not None and attn_metadata.max_seqlen_q > 1 - ) + q_for_triton = ( + q.to(torch.bfloat16) + if q.dtype.is_floating_point and q.element_size() == 1 + else q + ) - if not use_persistent_mode: - work_meta_data = None - work_indptr = None - work_info_set = None - reduce_indptr = None - reduce_final_map = None - reduce_partial_map = None - elif is_sparse_mtp: - work_meta_data = attn_metadata.sparse_mtp_work_meta_data - work_indptr = attn_metadata.sparse_mtp_work_indptr - work_info_set = attn_metadata.sparse_mtp_work_info_set - reduce_indptr = attn_metadata.sparse_mtp_reduce_indptr - reduce_final_map = attn_metadata.sparse_mtp_reduce_final_map - reduce_partial_map = attn_metadata.sparse_mtp_reduce_partial_map + # Use pre-built dense block_table from prepare_decode() + decode_attention_fwd( + q_for_triton, + k_buffer, + v_buffer, + o, + attn_metadata.triton_lse, + attn_metadata.triton_block_table, + attn_metadata.context_lens, + attn_metadata.triton_attn_logits, + 4, # num_kv_splits + self.scale, + page_size, + k_scale=self._k_scale, + v_scale=self._k_scale, + ) else: - work_meta_data = attn_metadata.work_meta_data - work_indptr = attn_metadata.work_indptr - work_info_set = attn_metadata.work_info_set - reduce_indptr = attn_metadata.reduce_indptr - reduce_final_map = attn_metadata.reduce_final_map - reduce_partial_map = attn_metadata.reduce_partial_map - - mla_decode_fwd( - q, - kv_buffer.view(-1, 1, 1, q.shape[-1]), - o, - paged_cu_seqlens_q, - paged_kv_indptr, - paged_kv_indices, - paged_kv_last_page_lens, - max_q_len, - num_kv_splits=16, - sm_scale=self.scale, - work_meta_data=work_meta_data, - work_indptr=work_indptr, - work_info_set=work_info_set, - reduce_indptr=reduce_indptr, - reduce_final_map=reduce_final_map, - reduce_partial_map=reduce_partial_map, - q_scale=self._q_scale, - kv_scale=self._k_scale, - ) + kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) + paged_cu_seqlens_q = attn_metadata.cu_seqlens_q + paged_kv_indptr = attn_metadata.kv_indptr + paged_kv_indices = attn_metadata.kv_indices + paged_kv_last_page_lens = attn_metadata.kv_last_page_lens + max_q_len = attn_metadata.max_seqlen_q + if self.topk_indices_buffer is not None: + if attn_metadata.max_seqlen_q > 1: + # MTP verify: per-token layout with max_q_len=1. + # Persistent metadata is per-token (from _set_mla_persistent_worker_buffers_sparse_mtp). + paged_cu_seqlens_q = attn_metadata.sparse_cu_seqlens_q + paged_kv_indptr = attn_metadata.sparse_kv_indptr + paged_kv_last_page_lens = attn_metadata.sparse_kv_last_page_lens + # Gather physical page indices from kv_indices using topk positions. + # block_tables contains large-block IDs (block_ratio > 1) that + # need expansion; kv_indices already has per-token page indices. + paged_kv_indices = triton_gather_kv_indices_sparse( + paged_kv_indptr, + attn_metadata.token_to_seq_idxs, + self.topk_indices_buffer[:B], + attn_metadata.kv_indices, + attn_metadata.kv_indptr, + NUM_TOPK_TOKENS=self.topk_indices_buffer.shape[1], + ) + max_q_len = 1 + else: + paged_kv_indptr = attn_metadata.sparse_kv_indptr + paged_kv_indices = triton_convert_req_index_to_global_index( + attn_metadata.cu_seqlens_q, + attn_metadata.kv_indptr, + paged_kv_indptr, + attn_metadata.kv_indices, + self.topk_indices_buffer[:B], + NUM_TOPK_TOKENS=self.topk_indices_buffer.shape[1], + ) + + dp_size = get_dp_group().world_size + use_persistent_mode = not (dp_size > 1) + + # Sparse layers in MTP verify use separate persistent metadata + # (per-token, max_seqlen_qo=1) while dense layers use normal metadata + # (max_seqlen_qo=2). + is_sparse_mtp = ( + self.topk_indices_buffer is not None and attn_metadata.max_seqlen_q > 1 + ) + + if not use_persistent_mode: + work_meta_data = None + work_indptr = None + work_info_set = None + reduce_indptr = None + reduce_final_map = None + reduce_partial_map = None + elif is_sparse_mtp: + work_meta_data = attn_metadata.sparse_mtp_work_meta_data + work_indptr = attn_metadata.sparse_mtp_work_indptr + work_info_set = attn_metadata.sparse_mtp_work_info_set + reduce_indptr = attn_metadata.sparse_mtp_reduce_indptr + reduce_final_map = attn_metadata.sparse_mtp_reduce_final_map + reduce_partial_map = attn_metadata.sparse_mtp_reduce_partial_map + else: + work_meta_data = attn_metadata.work_meta_data + work_indptr = attn_metadata.work_indptr + work_info_set = attn_metadata.work_info_set + reduce_indptr = attn_metadata.reduce_indptr + reduce_final_map = attn_metadata.reduce_final_map + reduce_partial_map = attn_metadata.reduce_partial_map + + mla_decode_fwd( + q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + paged_cu_seqlens_q, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_lens, + max_q_len, + num_kv_splits=16, + sm_scale=self.scale, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + q_scale=self._q_scale, + kv_scale=self._k_scale, + ) if self.head_repeat_factor > 1: o = o[:, :: self.head_repeat_factor, :].contiguous() diff --git a/atom/model_ops/attentions/triton_mha.py b/atom/model_ops/attentions/triton_mha.py new file mode 100644 index 000000000..f6404ab75 --- /dev/null +++ b/atom/model_ops/attentions/triton_mha.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import logging +from typing import Type + +import atom.model_ops as ops +from atom.config import KVCacheTensor +from atom.model_ops.attention_mha import PagedAttentionImpl +from atom.model_ops.paged_attention import PagedAttention + +from .aiter_attention import AiterAttentionMetadataBuilder +from .backends import AttentionBackend + +logger = logging.getLogger("atom") + + +class TritonMHABackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "ROCM_TRITON_MHA" + + @staticmethod + def get_builder_cls() -> Type["TritonMHAMetadataBuilder"]: + return TritonMHAMetadataBuilder + + @staticmethod + def get_impl_cls(): + attn_cls = ops.Attention + if attn_cls == PagedAttention: + return PagedAttentionImpl + raise NotImplementedError( + f"TritonMHABackend does not support attention class {attn_cls!r}" + ) + + +class TritonMHAMetadataBuilder(AiterAttentionMetadataBuilder): + """MHA metadata builder that allocates KV cache in flash layout. + + Flash layout: K/V both [num_blocks, block_size, num_kv_heads, head_dim]. + Consumed directly by aiter triton `unified_attention` for prefill+decode. + """ + + def build_kv_cache_tensor(self, layer_id: int, module): + if not ( + hasattr(module, "base_attention") + and hasattr(module, "use_mla") + and not module.use_mla + ): + return None + + runner = self.model_runner + config = runner.config + hf_config = config.hf_config + + if runner.is_mimo_v2(): + raise NotImplementedError( + "TritonMHABackend does not support MiMo-V2 (per-layer alloc path)" + ) + + impl = getattr(module, "impl", None) + if impl is not None and ( + getattr(impl, "rotary_emb", None) is not None + and getattr(impl, "q_norm", None) is not None + and getattr(impl, "k_norm", None) is not None + ): + raise NotImplementedError( + "TritonMHABackend is incompatible with the fused qk_norm+rope+shuffle " + "cache path; use AiterBackend for this model." + ) + + if runner.is_qwen_next(): + mtp_start = runner.mtp_start_layer_idx + if layer_id < mtp_start: + attn_idx = layer_id // runner.full_attention_interval + else: + attn_idx = runner.num_full_attn + (layer_id - mtp_start) + else: + attn_idx = layer_id + + k_cache = runner.kv_cache[0, attn_idx].view( + runner.num_physical_kvcache_blocks, + runner.physical_block_size, + runner.num_kv_heads, + hf_config.head_dim, + ) + v_cache = runner.kv_cache[1, attn_idx].view( + runner.num_physical_kvcache_blocks, + runner.physical_block_size, + runner.num_kv_heads, + hf_config.head_dim, + ) + if config.kv_cache_dtype == "fp8": + module.k_scale = runner.kv_scale[0, attn_idx] + module.v_scale = runner.kv_scale[1, attn_idx] + + module.max_model_len = config.max_model_len + module.k_cache = k_cache + module.v_cache = v_cache + if impl is not None: + impl.use_flash_layout = True + + return KVCacheTensor( + layer_num=layer_id, + k_cache=k_cache, + v_cache=v_cache, + k_scale=module.k_scale, + v_scale=module.v_scale, + ) diff --git a/atom/model_ops/attentions/triton_mla.py b/atom/model_ops/attentions/triton_mla.py new file mode 100644 index 000000000..3a52dfd4f --- /dev/null +++ b/atom/model_ops/attentions/triton_mla.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import logging +from typing import Type + +import torch +from aiter.ops.triton.attention.mla_decode import csr_to_dense_block_table +from atom.model_engine.scheduler import ScheduledBatch +from atom.model_ops.attention_mla import MLAAttention +from atom.utils.forward_context import AttentionMetaData + +from .aiter_mla import AiterMLAMetadataBuilder +from .backends import AttentionBackend + +logger = logging.getLogger("atom") + + +class TritonMLABackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "ROCM_TRITON_MLA" + + @staticmethod + def get_builder_cls() -> Type["TritonMLAMetadataBuilder"]: + return TritonMLAMetadataBuilder + + @staticmethod + def get_impl_cls() -> Type["MLAAttention"]: + return MLAAttention + + +class TritonMLAMetadataBuilder(AiterMLAMetadataBuilder): + + def __init__(self, model_runner): + super().__init__(model_runner) + + hf = model_runner.config.hf_config + kv_lora_rank = hf.kv_lora_rank + num_kv_splits = 4 + triton_mla_buffers = { + "triton_block_table": torch.zeros( + self.max_bs, + self.max_num_blocks_per_seq, + dtype=torch.int32, + device=self.device, + ), + "triton_attn_logits": torch.empty( + self.max_bs, + self.padded_num_attention_heads, + num_kv_splits, + kv_lora_rank + 1, + dtype=torch.float32, + device=self.device, + ), + "triton_lse": torch.empty( + self.max_bs, + self.padded_num_attention_heads, + dtype=torch.float32, + device=self.device, + ), + } + self.model_runner.forward_vars.update(triton_mla_buffers) + + def set_mla_persistent_worker_buffers( + self, bs, max_q_len, only_update=False, num_reject_tokens=None + ): + # Triton MLA does not use aiter persistent worker buffers + return {} + + def prepare_decode(self, batch: ScheduledBatch, bs: int): + attn_metadata, positions = super().prepare_decode(batch, bs) + + scheduled_bs = batch.total_seqs_num_decode + max_seqlen_k = attn_metadata.max_seqlen_k + var = self.model_runner.forward_vars + + triton_bt = var["triton_block_table"][:scheduled_bs, :max_seqlen_k] + triton_bt.zero_() + csr_to_dense_block_table( + attn_metadata.kv_indices, + attn_metadata.kv_indptr, + triton_bt, + max_seqlen_k, + scheduled_bs, + ) + attn_metadata.triton_block_table = triton_bt + attn_metadata.triton_attn_logits = var["triton_attn_logits"][:scheduled_bs] + attn_metadata.triton_lse = var["triton_lse"][:scheduled_bs] + + return attn_metadata, positions + + def build_for_cudagraph_capture(self, bs: int) -> AttentionMetaData: + attn_metadata, context = super().build_for_cudagraph_capture(bs) + + var = self.model_runner.forward_vars + attn_metadata.triton_block_table = var["triton_block_table"][:bs] + attn_metadata.triton_attn_logits = var["triton_attn_logits"][:bs] + attn_metadata.triton_lse = var["triton_lse"][:bs] + + return attn_metadata, context diff --git a/atom/model_ops/fused_moe_triton.py b/atom/model_ops/fused_moe_triton.py index 851cb3dce..18c513e80 100644 --- a/atom/model_ops/fused_moe_triton.py +++ b/atom/model_ops/fused_moe_triton.py @@ -24,6 +24,7 @@ from typing import Any import logging from math import prod +from aiter import ActivationType from aiter.jit.utils.chip_info import get_gfx from atom.model_ops.utils import has_triton_kernels @@ -32,9 +33,14 @@ if has_triton_kernels(): try: - from triton_kernels.matmul_ogs import matmul_ogs + import triton_kernels.swiglu + from triton_kernels.matmul_ogs import ( + FnSpecs, + FusedActivation, + PrecisionConfig, + matmul_ogs, + ) from triton_kernels.routing import routing - from triton_kernels.matmul_ogs import PrecisionConfig except (AttributeError, ImportError) as e: logger.error( "Failed to import Triton kernels. Please make sure your triton " @@ -283,28 +289,45 @@ def triton_kernel_fused_experts( ) with _amd_smem_safe_tile(): - matmul_ogs( - hidden_states, - w1, - w1_bias, - routing_data, - gather_indx=gather_indx, - precision_config=w13_precision_config, - gammas=gammas if apply_router_weight_on_input else None, - y=raw_intermediate, - ) - - # Standard SiLU/SwiGLU activation: silu(gate) * up - # With optional swiglu_limit clamping (V4: limit=10.0) - raw_2d = raw_intermediate.view(M * topk, N) - gate = raw_2d[:, :half_N] - up = raw_2d[:, half_N:] - if swiglu_limit > 0: - gate = gate.clamp(max=swiglu_limit) - up = up.clamp(-swiglu_limit, swiglu_limit) - intermediate_cache[0] = torch.nn.functional.silu(gate) * up + if activation == ActivationType.Swiglu: + # SwiGLU (GPT OSS): fused activation with interleaved [gate, up] layout + act = FusedActivation( + FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), + (swiglu_alpha, swiglu_limit), + 2, + ) + matmul_ogs( + hidden_states, + w1, + w1_bias, + routing_data, + gather_indx=gather_indx, + precision_config=w13_precision_config, + gammas=gammas if apply_router_weight_on_input else None, + fused_activation=act, + y=intermediate_cache, + ) + else: + # SiLU (DeepSeek): concatenated [gate | up] layout, manual activation + raw_intermediate = matmul_ogs( + hidden_states, + w1, + w1_bias, + routing_data, + gather_indx=gather_indx, + precision_config=w13_precision_config, + gammas=gammas if apply_router_weight_on_input else None, + ) + raw_2d = raw_intermediate.view(M * topk, N) + gate = raw_2d[:, :half_N] + up = raw_2d[:, half_N:] + if swiglu_limit > 0: + gate = gate.clamp(max=swiglu_limit) + up = up.clamp(-swiglu_limit, swiglu_limit) + intermediate_cache = intermediate_cache.view(M * topk, half_N) + intermediate_cache.copy_(torch.nn.functional.silu(gate) * up) + intermediate_cache = intermediate_cache.view(batch_dim, M * topk, half_N) - with _amd_smem_safe_tile(): matmul_ogs( intermediate_cache.view(M * topk, half_N), w2, diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 5a42b12b8..dd100bd67 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -686,11 +686,14 @@ def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig): or self.quant_type == QuantType.per_1x32 ) gfx = get_gfx() - self.use_triton = ( - gfx.startswith("gfx94") - or (gfx.startswith("gfx95") and envs.ATOM_USE_TRITON_GEMM) - or os.environ.get("ATOM_USE_TRITON_MOE") == "1" - ) + if envs.is_set("ATOM_USE_TRITON_MOE"): + self.use_triton = envs.ATOM_USE_TRITON_MOE + else: + self.use_triton = ( + gfx.startswith("gfx94") + or gfx.startswith("gfx12") + or (gfx.startswith("gfx95") and envs.ATOM_USE_TRITON_GEMM) + ) if self.use_triton: from atom.model_ops.utils import has_triton_kernels @@ -975,11 +978,13 @@ def apply( num_fused_shared_experts=layer.num_fused_shared_experts, routed_scaling_factor=layer.routed_scaling_factor, ) + n_expts_act = topk_weights.shape[1] # Convert to triton routing data structures n_expts_tot = router_logits.shape[-1] if global_num_experts > 0: n_expts_tot = global_num_experts + n_expts_tot = n_expts_tot + layer.num_fused_shared_experts routing_data, gather_idx, scatter_idx = routing_from_topk( topk_weights, topk_ids, n_expts_tot @@ -994,7 +999,7 @@ def apply( routing_data, gather_idx, scatter_idx, - topk=top_k, + topk=n_expts_act, activation=activation, w13_precision_config=self.w13_precision_config, w2_precision_config=self.w2_precision_config, @@ -1002,7 +1007,7 @@ def apply( w2_bias=layer.w2_bias, swiglu_limit=getattr(layer, "swiglu_limit", 0.0), apply_router_weight_on_input=layer.apply_router_weight_on_input, - global_num_experts=global_num_experts, + global_num_experts=n_expts_tot, expert_map=expert_map, ) return _moe_result diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 51fee833a..7282248b8 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -33,6 +33,8 @@ "ATOM_USE_TRITON_MXFP4_BMM": lambda: ( os.getenv("ATOM_USE_TRITON_MXFP4_BMM", "0") == "1" ), + "ATOM_USE_TRITON_MLA": lambda: os.getenv("ATOM_USE_TRITON_MLA", "0") == "1", + "ATOM_USE_TRITON_MOE": lambda: os.getenv("ATOM_USE_TRITON_MOE", "0") == "1", # --- Kernel Fusion Toggles --- # QK-norm-rope-cache-quant fusion for Qwen3-MoE; disabled by default. # Enable for Qwen3-MoE to get better performance. @@ -69,6 +71,14 @@ "ATOM_DISABLE_MMAP": lambda: ( os.getenv("ATOM_DISABLE_MMAP", "false").lower() == "true" ), + # Use a thread pool for weight loading instead of main-process sequential I/O. + # Set to 0 to disable if the thread pool causes hangs (e.g. on gfx1250). + "ATOM_LOADER_USE_THREADPOOL": lambda: os.getenv("ATOM_LOADER_USE_THREADPOOL", "1") + == "1", + # --- Attention Backend --- + # Use unified_attention (flash-style) for MHA paged/prefill attention instead + # of pa_decode_gluon. Set to 1 to enable the unified_attention path. + "ATOM_USE_UNIFIED_ATTN": lambda: os.getenv("ATOM_USE_UNIFIED_ATTN", "0") == "1", # --- Plugin Mode --- "ATOM_DISABLE_VLLM_PLUGIN": lambda: ( os.getenv("ATOM_DISABLE_VLLM_PLUGIN", "0").lower() == "1" diff --git a/atom/utils/selector.py b/atom/utils/selector.py index e87b1f819..bb9b1c7d6 100644 --- a/atom/utils/selector.py +++ b/atom/utils/selector.py @@ -7,6 +7,7 @@ from atom.model_ops.attentions.backends import AttentionBackend from atom.utils import resolve_obj_by_qualname from atom.plugin.prepare import is_sglang, is_vllm +from atom.utils import envs def get_attn_backend( @@ -51,13 +52,9 @@ def get_attn_backend_cls( if use_v4: return "atom.model_ops.attentions.deepseek_v4_attn.DeepseekV4Backend" if use_mla: - # if block_size == 1: - return "atom.model_ops.attentions.aiter_mla.AiterMLABackend" # noqa: E501 - # else: - # raise ValueError( - # f" The selected backend" - # f"does not support block size {block_size}." - # "(currently only supports block size 1)") + if envs.ATOM_USE_TRITON_MLA: + return "atom.model_ops.attentions.triton_mla.TritonMLABackend" + return "atom.model_ops.attentions.aiter_mla.AiterMLABackend" if use_gdn: if use_vllm: return "atom.plugin.vllm.attention_backend.gdn_attn.GDNAttentionBackend" @@ -66,4 +63,6 @@ def get_attn_backend_cls( "atom.plugin.sglang.attention_backend.attention_gdn.GDNAttentionBackend" ) return "atom.model_ops.attentions.gdn_attn.GDNAttentionBackend" + if envs.ATOM_USE_UNIFIED_ATTN: + return "atom.model_ops.attentions.triton_mha.TritonMHABackend" return "atom.model_ops.attentions.aiter_attention.AiterBackend" # noqa: E501