diff --git a/prepare.py b/prepare.py index 06bea9165..c44a0948b 100644 --- a/prepare.py +++ b/prepare.py @@ -35,7 +35,10 @@ # Configuration # --------------------------------------------------------------------------- -CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "autoresearch") +CACHE_DIR = os.environ.get( + "AUTORESEARCH_CACHE_DIR", + os.path.join(os.path.expanduser("~"), ".cache", "autoresearch"), +) DATA_DIR = os.path.join(CACHE_DIR, "data") TOKENIZER_DIR = os.path.join(CACHE_DIR, "tokenizer") BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main" diff --git a/train.py b/train.py index d3e9700b7..b25251c28 100644 --- a/train.py +++ b/train.py @@ -16,6 +16,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel # Detect AMD ROCm vs NVIDIA CUDA IS_ROCM = hasattr(torch.version, 'hip') and torch.version.hip is not None @@ -29,6 +30,18 @@ else: fa3 = None # Will use PyTorch SDPA on ROCm (dispatches to AOTriton) + +def _run_rocm_sdpa(q, k, v, causal): + last_error = None + for backend in (SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.MATH): + try: + with sdpa_kernel(backends=[backend]): + return F.scaled_dot_product_attention(q, k, v, is_causal=causal) + except RuntimeError as exc: + last_error = exc + raise RuntimeError("ROCm SDPA failed for efficient, flash, and math backends") from last_error + + from prepare import MAX_SEQ_LEN, TIME_BUDGET, Tokenizer, make_dataloader, evaluate_bpb # --------------------------------------------------------------------------- @@ -97,13 +110,12 @@ def forward(self, x, ve, cos_sin, window_size): q, k = norm(q), norm(k) if IS_ROCM: - # PyTorch SDPA on ROCm dispatches to AOTriton - # Note: SDPA doesn't support window_size, so SSSL pattern degrades to - # full causal attention on all layers + # Prefer the ROCm efficient-attention path but keep flash and math as + # fallbacks so training still runs across PyTorch/ROCm variants. q = q.transpose(1, 2) # (B, T, H, D) -> (B, H, T, D) k = k.transpose(1, 2) v = v.transpose(1, 2) - y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + y = _run_rocm_sdpa(q, k, v, causal=True) y = y.transpose(1, 2).contiguous().view(B, T, -1) else: y = fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size) @@ -297,7 +309,6 @@ def forward(self, idx, targets=None, reduction='mean'): softcap = 15 logits = self.lm_head(x) - logits = logits.float() logits = softcap * torch.tanh(logits / softcap) if targets is not None: @@ -320,6 +331,17 @@ def forward(self, idx, targets=None, reduction='mean'): _maybe_compile = torch.compile(dynamic=False, fullgraph=True) if not IS_ROCM else lambda fn: fn + +def _resolve_window_pattern(pattern): + pattern = pattern.upper() + if IS_ROCM and any(c != "L" for c in pattern): + print( + f"ROCm detected: overriding WINDOW_PATTERN={pattern!r} -> 'L' because " + "the ROCm SDPA path is full causal attention only." + ) + return "L" + return pattern + @_maybe_compile def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t): p.mul_(1 - lr_t * wd_t) @@ -486,6 +508,8 @@ def step(self): "A100": 312.0e12, "B200": 2250.0e12, # AMD Instinct + "MI350X": 2300.0e12, + "MI355X": 2500.0e12, "MI300X": 1307.4e12, "MI308X": 1307.4e12, "MI325X": 1307.4e12, @@ -514,7 +538,7 @@ def build_model_config(depth): return GPTConfig( sequence_len=MAX_SEQ_LEN, vocab_size=vocab_size, n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, - window_pattern=WINDOW_PATTERN, + window_pattern=_resolve_window_pattern(WINDOW_PATTERN), ) config = build_model_config(DEPTH)