Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 36 additions & 37 deletions atom/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,20 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None:

with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
use_threadpool = envs.ATOM_LOADER_USE_THREADPOOL
if use_threadpool:
executor = concurrent.futures.ThreadPoolExecutor()
else:
executor = None
futures = []

def _submit(fn, *args):
if executor is not None:
futures.append(executor.submit(fn, *args))
else:
fn(*args)

try:
disable_mmap = envs.ATOM_DISABLE_MMAP
for name, weight_tensor in safetensors_weights_iterator(
model_name_or_path, disable_mmap=disable_mmap
Expand Down Expand Up @@ -389,11 +403,7 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None:
)
continue
weight_loader = getattr(param, "weight_loader")
futures.append(
executor.submit(
weight_loader, param, weight_tensor, shard_idx
)
)
_submit(weight_loader, param, weight_tensor, shard_idx)
loaded_weights_record.add(prefix + param_name)
else:
# Checkpoint has separate weights, load into fused param
Expand All @@ -407,12 +417,7 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None:
dropped_ckpt_keys.append((_orig_ckpt_name, param_name))
break
weight_loader = getattr(param, "weight_loader")
# weight_loader(param, weight_tensor, shard_id)
futures.append(
executor.submit(
weight_loader, param, weight_tensor, shard_id
)
)
_submit(weight_loader, param, weight_tensor, shard_id)
loaded_weights_record.add(prefix + param_name)
break
else:
Expand Down Expand Up @@ -482,15 +487,13 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None:
matched = True
break
weight_loader = getattr(param, "weight_loader")
futures.append(
executor.submit(
weight_loader,
param,
weight_tensor,
name,
shard_id,
expert_id,
)
_submit(
weight_loader,
param,
weight_tensor,
name,
shard_id,
expert_id,
)
loaded_weights_record.add(prefix + name)
matched = True
Expand All @@ -508,15 +511,13 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None:
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
futures.append(
executor.submit(
weight_loader,
param,
weight_tensor,
"", # use merged moe loader
"",
expert_id,
)
_submit(
weight_loader,
param,
weight_tensor,
"", # use merged moe loader
"",
expert_id,
)
loaded_weights_record.add(prefix + name)
try:
Expand All @@ -527,9 +528,7 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None:
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
futures.append(
executor.submit(weight_loader, param, weight_tensor)
)
_submit(weight_loader, param, weight_tensor)
loaded_weights_record.add(prefix + name)
else:
# Model doesn't have expert mapping, use generic loading
Expand All @@ -541,12 +540,12 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None:
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
# weight_loader(param, weight_tensor)
futures.append(executor.submit(weight_loader, param, weight_tensor))
_submit(weight_loader, param, weight_tensor)
loaded_weights_record.add(prefix + name)
# Wait for all tasks to complete and raise any exceptions.
for future in concurrent.futures.as_completed(futures):
future.result()
finally:
if executor is not None:
concurrent.futures.wait(futures)
executor.shutdown(wait=True)

# Verify every model parameter actually got loaded from the checkpoint.
# Without this check, weights_mapping bugs (e.g. a substring rule
Expand Down
186 changes: 121 additions & 65 deletions atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from aiter.ops.triton.gluon.pa_decode_gluon import get_recommended_splits
from aiter.ops.triton.unified_attention import unified_attention
from atom.config import get_current_atom_config
from atom.utils import envs
from atom.utils.forward_context import ForwardContext, get_forward_context
from torch import nn

Expand Down Expand Up @@ -229,7 +230,7 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
k_scale,
v_scale,
self.rotary_emb.is_neox_style,
flash_layout=False,
flash_layout=envs.ATOM_USE_UNIFIED_ATTN,
apply_scale=self.kv_cache_dtype.startswith("fp8"),
offs=None,
q_out=q,
Expand Down Expand Up @@ -372,71 +373,101 @@ def paged_attention_triton(

o = torch.empty_like(q)
num_seqs = attn_metadata.context_lens.shape[0]
_, num_q_heads_total, head_size = q.shape
num_blocks, num_kv_heads, _, block_size, _ = k_cache.shape
# assume all query have same length
query_group_size = attn_metadata.max_seqlen_q * (
num_q_heads_total // num_kv_heads
)
assert num_q_heads_total % num_kv_heads == 0

max_context_partition_num = get_recommended_splits(num_seqs, num_kv_heads)
if envs.ATOM_USE_UNIFIED_ATTN:
sliding_window = (
(self.sliding_window - 1, 0) if self.sliding_window > 0 else (-1, -1)
)

context_partition_size = 256
if self.sliding_window > 0:
max_context_partition_num = 1
context_partition_size = 128
# KV cache is already in flash layout (4D) when
# ATOM_USE_UNIFIED_ATTN is set, allocated by model_runner.
nkv = k_cache.shape[2]
descale_shape = (num_seqs, nkv)

# Output buffers (same as Triton)
intermediate_shape = (
num_seqs,
num_kv_heads,
max_context_partition_num,
query_group_size,
)
exp_sums = torch.empty(intermediate_shape, dtype=torch.float32, device=q.device)
max_logits = torch.empty(
intermediate_shape, dtype=torch.float32, device=q.device
)
temporary_output = torch.empty(
*intermediate_shape,
head_size,
dtype=q.dtype,
device=q.device,
)
unified_attention(
q,
k_cache,
v_cache,
o,
cu_seqlens_q=attn_metadata.cu_seqlens_q,
seqused_k=attn_metadata.context_lens,
max_seqlen_q=attn_metadata.max_seqlen_q,
max_seqlen_k=attn_metadata.max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=None,
window_size=sliding_window,
block_table=attn_metadata.block_tables,
softcap=0,
q_descale=None,
k_descale=self.kv_scale.expand(descale_shape),
v_descale=self.kv_scale.expand(descale_shape),
sinks=self.sinks,
)
else:
_, num_q_heads_total, head_size = q.shape
num_blocks, num_kv_heads, _, block_size, _ = k_cache.shape
query_group_size = attn_metadata.max_seqlen_q * (
num_q_heads_total // num_kv_heads
)
assert num_q_heads_total % num_kv_heads == 0

if k_scale is not None and k_scale.numel() > 1:
k_scale = k_scale.unsqueeze(-1)
v_scale = v_scale.unsqueeze(-1)
max_context_partition_num = get_recommended_splits(num_seqs, num_kv_heads)

compute_type = (
torch.bfloat16
if self.kv_cache_dtype == "bf16" # or per_tensor
else aiter.dtypes.fp8
)
torch.ops.aiter.pa_decode_gluon(
o,
q,
k_cache,
v_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
self.scale,
attn_metadata.max_seqlen_q,
max_context_partition_num,
context_partition_size,
compute_type,
None, # q_scale
None if self.kv_cache_dtype == "bf16" else k_scale,
None if self.kv_cache_dtype == "bf16" else v_scale,
exp_sums=exp_sums,
max_logits=max_logits,
temporary_output=temporary_output,
alibi_slopes=None,
sinks=self.sinks,
sliding_window=self.sliding_window,
ps=True,
)
context_partition_size = 256
if self.sliding_window > 0:
max_context_partition_num = 1
context_partition_size = 128

intermediate_shape = (
num_seqs,
num_kv_heads,
max_context_partition_num,
query_group_size,
)
exp_sums = torch.empty(
intermediate_shape, dtype=torch.float32, device=q.device
)
max_logits = torch.empty(
intermediate_shape, dtype=torch.float32, device=q.device
)
temporary_output = torch.empty(
*intermediate_shape,
head_size,
dtype=q.dtype,
device=q.device,
)

if k_scale is not None and k_scale.numel() > 1:
k_scale = k_scale.unsqueeze(-1)
v_scale = v_scale.unsqueeze(-1)

compute_type = (
torch.bfloat16 if self.kv_cache_dtype == "bf16" else aiter.dtypes.fp8
)
torch.ops.aiter.pa_decode_gluon(
o,
q,
k_cache,
v_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
self.scale,
attn_metadata.max_seqlen_q,
max_context_partition_num,
context_partition_size,
compute_type,
None, # q_scale
None if self.kv_cache_dtype == "bf16" else k_scale,
None if self.kv_cache_dtype == "bf16" else v_scale,
exp_sums=exp_sums,
max_logits=max_logits,
temporary_output=temporary_output,
alibi_slopes=None,
sinks=self.sinks,
sliding_window=self.sliding_window,
ps=True,
)

return o

Expand Down Expand Up @@ -543,19 +574,42 @@ def prefill_attention_triton(
block_tables = attn_metadata.block_tables

o = torch.empty_like(q)
descale_shape = (attn_metadata.cu_seqlens_q.shape[0] - 1, k.shape[1])
num_seqs = attn_metadata.cu_seqlens_q.shape[0] - 1
descale_shape = (num_seqs, k.shape[1])
sliding_window = (
(self.sliding_window - 1, 0)
if self.sliding_window is not None
else (-1, -1)
)

if block_tables is None:
# Prefill has no block_table. Use k/v directly as kv_cache with
# block_size=1 and a fake block_table (see comments above).
# k: [total_tokens, num_kv_heads, head_size]
# -> [total_tokens, 1, num_kv_heads, head_size]
k_for_attn = k.unsqueeze(1)
v_for_attn = v.unsqueeze(1)
# Build per-seq block tables: seq i maps to token indices
# [cu_seqlens_k[i], cu_seqlens_k[i]+1, ..., cu_seqlens_k[i+1]-1]
max_seqlen_k = attn_metadata.max_seqlen_k
cu_k = attn_metadata.cu_seqlens_k
offsets = cu_k[:num_seqs] # [num_seqs]
block_tables = offsets.unsqueeze(1) + torch.arange(
max_seqlen_k, dtype=torch.int32, device=q.device
)
seqused_k = cu_k[1:] - cu_k[:num_seqs]
else:
k_for_attn = k_cache
v_for_attn = v_cache
seqused_k = attn_metadata.context_lens

unified_attention(
q,
k_cache,
v_cache,
k_for_attn,
v_for_attn,
o,
cu_seqlens_q=attn_metadata.cu_seqlens_q,
seqused_k=attn_metadata.context_lens,
seqused_k=seqused_k,
max_seqlen_q=attn_metadata.max_seqlen_q,
max_seqlen_k=attn_metadata.max_seqlen_k,
softmax_scale=self.scale,
Expand All @@ -577,6 +631,8 @@ def dispatch_backend(self, fwd_ctx: ForwardContext):
ctx = fwd_ctx.context

if ctx.is_prefill:
if envs.ATOM_USE_UNIFIED_ATTN and self.use_triton_attn:
return self.prefill_attention_triton
return self.prefill_attention
else:
if self.use_triton_attn:
Expand Down
Loading
Loading