Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions benchmark/scripts/benchmark_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,68 @@ def estimate_kernel_peak_memory(probe_fn: Callable[[], torch.Tensor]) -> int:
return max(1, peak_bytes)


def compute_seq_len_sweep_config_with_probe(
model_cfg: ModelConfig,
probe_fn: Callable[[], torch.Tensor],
probe_seq_len: int,
probe_batch_size: int = 1,
scaling_method: Literal["linear", "quadratic"] = "linear",
memory_utilization: float = 0.4,
max_seq_len: Optional[int] = None,
max_batch_size: int = 32,
) -> SeqLenSweepConfig:
"""Compute safe sweep config from a probe, supporting non-linear seq-len scaling.

Wraps :func:`estimate_kernel_peak_memory` and inverts the memory model
according to *scaling_method*. Linear-scaling kernels can keep using
:func:`compute_seq_len_sweep_config` directly; this helper exists for
kernels whose memory grows non-linearly with seq_len (e.g. attention
kernels with O(L^2) scratch).

Args:
model_cfg: Model architecture config.
probe_fn: Callable that performs setup, runs a forward pass, and
returns an output tensor suitable for ``.backward()``. Same
contract as :func:`estimate_kernel_peak_memory`'s *probe_fn*.
probe_seq_len: Sequence length used inside *probe_fn*. Required so
the inversion can isolate the seq-len-dependent term.
probe_batch_size: Batch size used inside *probe_fn*. Defaults to 1.
scaling_method: How peak memory scales with seq_len, holding batch
size fixed.

- ``"linear"``: peak ~ B * L. Inversion: ``L_max = usable / (B * c_per_BL)``.
- ``"quadratic"``: peak ~ B * L^2. Inversion: ``L_max = sqrt(usable / (B * c_per_BL2))``.
memory_utilization: Fraction of total device memory to target (0 to 1).
max_seq_len: Hard upper bound for sequence length. Defaults to
``model_cfg.max_position_embeddings``.
max_batch_size: Hard upper bound for batch size.
"""
if scaling_method not in ("linear", "quadratic"):
raise ValueError(f"scaling_method must be 'linear' or 'quadratic', got {scaling_method!r}")

peak_bytes = estimate_kernel_peak_memory(probe_fn=probe_fn)

total_memory_gb = get_total_gpu_memory()
usable_bytes = total_memory_gb * (1024**3) * memory_utilization

if max_seq_len is None:
max_seq_len = model_cfg.max_position_embeddings

batch_size = max(1, min(max_batch_size, probe_batch_size))

if scaling_method == "linear":
c_per_BL = max(1.0, peak_bytes / (probe_batch_size * probe_seq_len))
max_seq_len_from_mem = max(1, int(usable_bytes / (batch_size * c_per_BL)))
else:
c_per_BL2 = max(1.0, peak_bytes / (probe_batch_size * probe_seq_len * probe_seq_len))
max_seq_len_from_mem = max(1, int(math.sqrt(usable_bytes / (batch_size * c_per_BL2))))

seq_len = min(max_seq_len, max_seq_len_from_mem)
seq_len = 2 ** int(math.log2(seq_len)) if seq_len >= 1024 else 1024

return SeqLenSweepConfig(batch_size=batch_size, seq_len=seq_len)


def compute_seq_len_sweep_config(
model_cfg: ModelConfig,
kernel_bytes_per_token: Optional[int] = None,
Expand Down
14 changes: 8 additions & 6 deletions benchmark/scripts/benchmark_sparse_multi_token_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

from benchmark_model_configs import MODEL_REGISTRY
from benchmark_model_configs import compute_model_config_sweep_config
from benchmark_model_configs import compute_seq_len_sweep_config
from benchmark_model_configs import estimate_kernel_peak_memory
from benchmark_model_configs import compute_seq_len_sweep_config_with_probe
from benchmark_model_configs import get_benchmark_model_config
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
Expand Down Expand Up @@ -329,10 +328,13 @@ def _probe():
_, _, fwd_fn = _setup_sparse_multi_token_attention(probe_input)
return fwd_fn()

peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)
# Memory scales as O(L^2), so compute bytes per L^2
kernel_bpt = peak_bytes // (probe_L * probe_L)
config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt)
config = compute_seq_len_sweep_config_with_probe(
model_cfg=model,
probe_fn=_probe,
probe_seq_len=probe_L,
probe_batch_size=B,
scaling_method="quadratic",
)

common_configs = {
"kernel_name": "sparse_multi_token_attention",
Expand Down