Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
# Configuration
# ---------------------------------------------------------------------------

CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "autoresearch")
CACHE_DIR = os.environ.get(
"AUTORESEARCH_CACHE_DIR",
os.path.join(os.path.expanduser("~"), ".cache", "autoresearch"),
)
DATA_DIR = os.path.join(CACHE_DIR, "data")
TOKENIZER_DIR = os.path.join(CACHE_DIR, "tokenizer")
BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main"
Expand Down
36 changes: 30 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel

# Detect AMD ROCm vs NVIDIA CUDA
IS_ROCM = hasattr(torch.version, 'hip') and torch.version.hip is not None
Expand All @@ -29,6 +30,18 @@
else:
fa3 = None # Will use PyTorch SDPA on ROCm (dispatches to AOTriton)


def _run_rocm_sdpa(q, k, v, causal):
last_error = None
for backend in (SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.MATH):
try:
with sdpa_kernel(backends=[backend]):
return F.scaled_dot_product_attention(q, k, v, is_causal=causal)
except RuntimeError as exc:
last_error = exc
raise RuntimeError("ROCm SDPA failed for efficient, flash, and math backends") from last_error


from prepare import MAX_SEQ_LEN, TIME_BUDGET, Tokenizer, make_dataloader, evaluate_bpb

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -97,13 +110,12 @@ def forward(self, x, ve, cos_sin, window_size):
q, k = norm(q), norm(k)

if IS_ROCM:
# PyTorch SDPA on ROCm dispatches to AOTriton
# Note: SDPA doesn't support window_size, so SSSL pattern degrades to
# full causal attention on all layers
# Prefer the ROCm efficient-attention path but keep flash and math as
# fallbacks so training still runs across PyTorch/ROCm variants.
q = q.transpose(1, 2) # (B, T, H, D) -> (B, H, T, D)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
y = _run_rocm_sdpa(q, k, v, causal=True)
y = y.transpose(1, 2).contiguous().view(B, T, -1)
else:
y = fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size)
Expand Down Expand Up @@ -297,7 +309,6 @@ def forward(self, idx, targets=None, reduction='mean'):

softcap = 15
logits = self.lm_head(x)
logits = logits.float()
logits = softcap * torch.tanh(logits / softcap)

if targets is not None:
Expand All @@ -320,6 +331,17 @@ def forward(self, idx, targets=None, reduction='mean'):

_maybe_compile = torch.compile(dynamic=False, fullgraph=True) if not IS_ROCM else lambda fn: fn


def _resolve_window_pattern(pattern):
pattern = pattern.upper()
if IS_ROCM and any(c != "L" for c in pattern):
print(
f"ROCm detected: overriding WINDOW_PATTERN={pattern!r} -> 'L' because "
"the ROCm SDPA path is full causal attention only."
)
return "L"
return pattern

@_maybe_compile
def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
p.mul_(1 - lr_t * wd_t)
Expand Down Expand Up @@ -486,6 +508,8 @@ def step(self):
"A100": 312.0e12,
"B200": 2250.0e12,
# AMD Instinct
"MI350X": 2300.0e12,
"MI355X": 2500.0e12,
"MI300X": 1307.4e12,
"MI308X": 1307.4e12,
"MI325X": 1307.4e12,
Expand Down Expand Up @@ -514,7 +538,7 @@ def build_model_config(depth):
return GPTConfig(
sequence_len=MAX_SEQ_LEN, vocab_size=vocab_size,
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
window_pattern=WINDOW_PATTERN,
window_pattern=_resolve_window_pattern(WINDOW_PATTERN),
)

config = build_model_config(DEPTH)
Expand Down