diff --git a/benchmark/scripts/benchmark_model_configs.py b/benchmark/scripts/benchmark_model_configs.py index da6f0fa0b..0be4d6a00 100644 --- a/benchmark/scripts/benchmark_model_configs.py +++ b/benchmark/scripts/benchmark_model_configs.py @@ -291,6 +291,68 @@ def estimate_kernel_peak_memory(probe_fn: Callable[[], torch.Tensor]) -> int: return max(1, peak_bytes) +def compute_seq_len_sweep_config_with_probe( + model_cfg: ModelConfig, + probe_fn: Callable[[], torch.Tensor], + probe_seq_len: int, + probe_batch_size: int = 1, + scaling_method: Literal["linear", "quadratic"] = "linear", + memory_utilization: float = 0.4, + max_seq_len: Optional[int] = None, + max_batch_size: int = 32, +) -> SeqLenSweepConfig: + """Compute safe sweep config from a probe, supporting non-linear seq-len scaling. + + Wraps :func:`estimate_kernel_peak_memory` and inverts the memory model + according to *scaling_method*. Linear-scaling kernels can keep using + :func:`compute_seq_len_sweep_config` directly; this helper exists for + kernels whose memory grows non-linearly with seq_len (e.g. attention + kernels with O(L^2) scratch). + + Args: + model_cfg: Model architecture config. + probe_fn: Callable that performs setup, runs a forward pass, and + returns an output tensor suitable for ``.backward()``. Same + contract as :func:`estimate_kernel_peak_memory`'s *probe_fn*. + probe_seq_len: Sequence length used inside *probe_fn*. Required so + the inversion can isolate the seq-len-dependent term. + probe_batch_size: Batch size used inside *probe_fn*. Defaults to 1. + scaling_method: How peak memory scales with seq_len, holding batch + size fixed. + + - ``"linear"``: peak ~ B * L. Inversion: ``L_max = usable / (B * c_per_BL)``. + - ``"quadratic"``: peak ~ B * L^2. Inversion: ``L_max = sqrt(usable / (B * c_per_BL2))``. + memory_utilization: Fraction of total device memory to target (0 to 1). + max_seq_len: Hard upper bound for sequence length. Defaults to + ``model_cfg.max_position_embeddings``. + max_batch_size: Hard upper bound for batch size. + """ + if scaling_method not in ("linear", "quadratic"): + raise ValueError(f"scaling_method must be 'linear' or 'quadratic', got {scaling_method!r}") + + peak_bytes = estimate_kernel_peak_memory(probe_fn=probe_fn) + + total_memory_gb = get_total_gpu_memory() + usable_bytes = total_memory_gb * (1024**3) * memory_utilization + + if max_seq_len is None: + max_seq_len = model_cfg.max_position_embeddings + + batch_size = max(1, min(max_batch_size, probe_batch_size)) + + if scaling_method == "linear": + c_per_BL = max(1.0, peak_bytes / (probe_batch_size * probe_seq_len)) + max_seq_len_from_mem = max(1, int(usable_bytes / (batch_size * c_per_BL))) + else: + c_per_BL2 = max(1.0, peak_bytes / (probe_batch_size * probe_seq_len * probe_seq_len)) + max_seq_len_from_mem = max(1, int(math.sqrt(usable_bytes / (batch_size * c_per_BL2)))) + + seq_len = min(max_seq_len, max_seq_len_from_mem) + seq_len = 2 ** int(math.log2(seq_len)) if seq_len >= 1024 else 1024 + + return SeqLenSweepConfig(batch_size=batch_size, seq_len=seq_len) + + def compute_seq_len_sweep_config( model_cfg: ModelConfig, kernel_bytes_per_token: Optional[int] = None, diff --git a/benchmark/scripts/benchmark_sparse_multi_token_attention.py b/benchmark/scripts/benchmark_sparse_multi_token_attention.py index e35fdbba3..646c229f0 100644 --- a/benchmark/scripts/benchmark_sparse_multi_token_attention.py +++ b/benchmark/scripts/benchmark_sparse_multi_token_attention.py @@ -7,8 +7,7 @@ from benchmark_model_configs import MODEL_REGISTRY from benchmark_model_configs import compute_model_config_sweep_config -from benchmark_model_configs import compute_seq_len_sweep_config -from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import compute_seq_len_sweep_config_with_probe from benchmark_model_configs import get_benchmark_model_config from utils import QUANTILES from utils import SingleBenchmarkRunInput @@ -329,10 +328,13 @@ def _probe(): _, _, fwd_fn = _setup_sparse_multi_token_attention(probe_input) return fwd_fn() - peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) - # Memory scales as O(L^2), so compute bytes per L^2 - kernel_bpt = peak_bytes // (probe_L * probe_L) - config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + config = compute_seq_len_sweep_config_with_probe( + model_cfg=model, + probe_fn=_probe, + probe_seq_len=probe_L, + probe_batch_size=B, + scaling_method="quadratic", + ) common_configs = { "kernel_name": "sparse_multi_token_attention",