[Benchmark] Add compute_seq_len_sweep_config_with_probe with linear/quadratic scaling support#1218
Open
shivam2199 wants to merge 2 commits intolinkedin:mainfrom
Open
Conversation
…r scaling (linkedin#1200) Adds a new helper alongside the existing compute_seq_len_sweep_config that internalizes both the probe and the seq-len inversion, with a scaling_method argument supporting "linear" (default) and "quadratic". For O(L^2) kernels, the inversion uses L_max = sqrt(usable / (B * c_per_BL2)) instead of the linear max_tokens / batch_size path. Migrates benchmark_sparse_multi_token_attention.py to the new helper and drops its manual `peak_bytes // (probe_L * probe_L)` workaround. The existing estimate_kernel_peak_memory and compute_seq_len_sweep_config are unchanged; linear-scaling benchmark callers don't need to migrate.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Refs #1200. Addresses non-linear memory scaling in benchmark sweep config inference.
The existing
compute_seq_len_sweep_configinverts memory viamax_tokens = usable_bytes / kernel_bytes_per_token, which only holds for linear-scaling kernels. For O(L²) kernels (e.g.benchmark_sparse_multi_token_attention.py), this overestimates capacity by orders of magnitude — the existing workaround there divides byprobe_L * probe_L, but the downstream sweep math still treats the result as linear bytes-per-token.Per discussion on the issue (#1200 (comment)), this PR adds a new helper rather than threading
scaling_methodthrough the existing function — 16+ benchmark scripts callestimate_kernel_peak_memorytoday, and a wider signature change would conflict with in-flight benchmark refactors (#1199, #1180). Linear-scaling callers are unchanged; only quadratic-scaling benchmarks opt in.What changed
benchmark/scripts/benchmark_model_configs.py— addscompute_seq_len_sweep_config_with_probe(model_cfg, probe_fn, probe_seq_len, probe_batch_size=1, scaling_method="linear" | "quadratic", ...). Internalizes the probe call + inversion; reusesestimate_kernel_peak_memoryfor the measurement.benchmark/scripts/benchmark_sparse_multi_token_attention.py— switches thetoken_lengthsweep mode to the new helper withscaling_method="quadratic", dropping the manualpeak_bytes // (probe_L * probe_L)workaround.estimate_kernel_peak_memoryandcompute_seq_len_sweep_configare untouched.Validation
Hardware: A10G 24GB (g5.xlarge).
Synthetic O(L²) probe (B=2, L=2048, allocates
B * L * Lfloats) usingLLAMA_3_8Bconfig andmax_seq_len=2**20to bypass the model cap so the raw inversion is visible:The 8× gap (≈17× before snap-to-power-of-2) demonstrates the inversion difference:
linearclaims a sweep at L=65536 fits, when in reality L² at that size would require multiple TBs.quadraticlands at a realistic L=8192. This matches the issue's premise — for non-linear-scaling kernels, the existing inversion overestimates capacity and would OOM at the predicted boundary.Testing Done
quadraticpredicts L=8192 vslinearpredicts L=65536 for the same probe (8× separation, scales as expected).benchmark_sparse_multi_token_attention.pyimports + helper resolution verified locally.cc @Tcc0403