diff --git a/benchmark/README.md b/benchmark/README.md index aefba404b..9597d94e3 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -1,6 +1,97 @@ ## Benchmarking Liger Kernels -Follow these steps to benchmark and visualize kernel performance: +### Benchmark Framework Overview + +The benchmarking system is designed to provide a **consistent, low-boilerplate way** to evaluate kernel performance across: + +* Different **model configurations** (e.g., LLaMA and Qwen variants) +* Different **sequence lengths / Batch size * token length** +* Multiple **kernel providers** (e.g., `liger`, `huggingface`) + +#### Core Concepts + +1. `setup_fn` + + Defines how to **construct inputs and modules** for a single forward pass. + + * Input: `SingleBenchmarkRunInput` + * Output: tuple of tensors / modules + + + ```python + def _setup_fn(input: SingleBenchmarkRunInput) -> Tuple[Any, ...]: + x = ... + layer = ... + return x, layer + ``` + +2. Benchmark Function Builders + + Reusable helpers to generate benchmark functions handle: + + * forward / backward / full modes + * timing and memory measurement + + ```python + build_speed_bench_fn(setup_fn) + build_memory_bench_fn(setup_fn) + ``` + +3. Sweep Builders + + (a) `build_model_config_sweep` + + * Sweeps across **model configurations**(e.g. hidden size, dtype, vocab size) + * Keeps total tokens (`B * T`) approximately constant + * Automatically derives a suitable `(B, T)` that will not cause OOM under the given token budget + * `probe_dim` must align with how `input.x` is interpreted in `setup_fn` + + ```python + common_configs = build_model_config_sweep( + kernel_name=..., + all_model_configs=..., + setup_fn=..., + model_keys=[...], + probe_dim: Literal["T", "B", "BT"] = "T" + ) + ``` + + (b) `build_token_length_sweep` + + * Sweeps along a **chosen scaling dimension**: + + * `"T"` → sequence length + * `"B"` → batch size + * `"BT"` → total tokens + * Uses a **single fixed model configuration** + * Maintains a consistent memory model via bytes-per-token estimation + * `scale_dim` must align with how `input.x` is interpreted in `setup_fn` + + ```python + common_configs = build_token_length_sweep( + kernel_name=..., + probe_x=..., + model=..., + setup_fn=..., + model_keys=[...], + scale_dim: Literal["T", "B", "BT"] = "T", + ) + ``` + +4. `model_keys` and `extra_configs` + + * `model_keys`: attributes pulled from `ModelConfig` + + * e.g. `["hidden_size", "dtype"]` + + * `extra_configs`: static overrides + + * e.g. `{"eps": 1e-6}` + + These form `extra_benchmark_config`, passed into `setup_fn`. + + +### Benchmark workflow: 1. Create a benchmark script - Add your script under `benchmark/scripts/` @@ -12,7 +103,8 @@ Follow these steps to benchmark and visualize kernel performance: Example: Benchmarking KTO Loss ```bash cd benchmark - python scripts/benchmark_kto_loss.py + python scripts/benchmark_kto_loss.py --sweep-mode model_config [--model llama_3_8b] + python scripts/benchmark_kto_loss.py [--sweep-mode token_length] [--bt 2048] ``` 3. Visualize results diff --git a/benchmark/scripts/benchmark_cpo_loss.py b/benchmark/scripts/benchmark_cpo_loss.py index 07501275e..f2d89f5f1 100644 --- a/benchmark/scripts/benchmark_cpo_loss.py +++ b/benchmark/scripts/benchmark_cpo_loss.py @@ -3,17 +3,14 @@ import sys import torch -import triton from benchmark_model_configs import MODEL_REGISTRY -from benchmark_model_configs import compute_model_config_sweep_config -from benchmark_model_configs import compute_seq_len_sweep_config -from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import build_model_config_sweep +from benchmark_model_configs import build_token_length_sweep from benchmark_model_configs import get_benchmark_model_config -from utils import QUANTILES from utils import SingleBenchmarkRunInput -from utils import SingleBenchmarkRunOutput -from utils import _test_memory +from utils import build_memory_bench_fn +from utils import build_speed_bench_fn from utils import parse_benchmark_script_args from utils import run_benchmarks @@ -24,17 +21,24 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) -def _setup_cpo_loss(input: SingleBenchmarkRunInput): +def setup_cpo_loss(input: SingleBenchmarkRunInput): """Create input tensors and CPO loss from benchmark config.""" from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO cfg = input.extra_benchmark_config - H = cfg["hidden_size"] - V = cfg["vocab_size"] - dtype = cfg["dtype"] - B = input.x T = cfg["T"] + if isinstance(input.x, str): + model_cfg = MODEL_REGISTRY[input.x] + H = model_cfg.hidden_size + V = model_cfg.vocab_size + dtype = model_cfg.dtype + B = cfg["bsz"] + else: + B = input.x + H = cfg["hidden_size"] + V = cfg["vocab_size"] + dtype = cfg["dtype"] _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) target = torch.randint(V, (B, T), dtype=torch.long, device=device) @@ -46,253 +50,56 @@ def _setup_cpo_loss(input: SingleBenchmarkRunInput): else: raise ValueError(f"Invalid provider: {input.kernel_provider} for CPOLoss") - fwd_fn = lambda: loss_module(_input, target)[0] + fwd_fn = lambda x: loss_module(x, target)[0] return _input, fwd_fn -def bench_speed_cpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - _input, fwd_fn = _setup_cpo_loss(input) - mode = input.kernel_operation_mode - - if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd_fn, - rep=100, - quantiles=QUANTILES, - ) - elif mode == "backward": - y = fwd_fn() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - lambda: y.backward(retain_graph=True), - grad_to_none=[_input], - rep=100, - quantiles=QUANTILES, - ) - elif mode == "full": - - def full(): - y = fwd_fn() - y.backward() - - ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) - else: - raise ValueError(f"Unsupported mode: {mode}") - - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) - - -def bench_memory_cpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - _input, fwd_fn = _setup_cpo_loss(input) - - def full(): - y = fwd_fn() - y.backward() - - mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) - - -def _resolve_model_config_cpo_loss(input: SingleBenchmarkRunInput): - """Resolve model-config-sweep input into standard setup args.""" - cfg = input.extra_benchmark_config - model_info = cfg["model_configs"][input.x] - return _setup_cpo_loss( - SingleBenchmarkRunInput( - x=cfg["B"], - kernel_provider=input.kernel_provider, - extra_benchmark_config={ - "hidden_size": model_info["hidden_size"], - "vocab_size": model_info["vocab_size"], - "dtype": model_info["dtype"], - "T": cfg["T"], - }, - ) - ) - - -def bench_speed_cpo_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - _input, fwd = _resolve_model_config_cpo_loss(input) - mode = input.kernel_operation_mode - - if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - rep=100, - quantiles=QUANTILES, - ) - elif mode == "backward": - y = fwd() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - lambda: y.backward(retain_graph=True), - grad_to_none=[_input], - rep=100, - quantiles=QUANTILES, - ) - elif mode == "full": - - def full(): - y = fwd() - y.backward() - - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, - rep=100, - quantiles=QUANTILES, - ) - else: - raise ValueError(f"Unsupported mode: {mode}") - - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) - - -def bench_memory_cpo_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - _input, fwd_fn = _resolve_model_config_cpo_loss(input) - - def full(): - y = fwd_fn() - y.backward() - - mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) - - if __name__ == "__main__": args = parse_benchmark_script_args() + T = 1024 if args.sweep_mode == "model_config": - all_model_configs = list(MODEL_REGISTRY.values()) - T = 1024 - - def _probe_factory(model_cfg, probe_bt): - def _probe(): - B = max(1, probe_bt // T) - probe_input = SingleBenchmarkRunInput( - x=B, - kernel_provider="huggingface", - extra_benchmark_config={ - "hidden_size": model_cfg.hidden_size, - "vocab_size": model_cfg.vocab_size, - "dtype": model_cfg.dtype, - "T": T, - }, - ) - _, fwd_fn = _setup_cpo_loss(probe_input) - return fwd_fn() - - return _probe - - sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) - - model_configs_info = { - cfg.name: { - "hidden_size": cfg.hidden_size, - "vocab_size": cfg.vocab_size, - "dtype": cfg.dtype, - } - for cfg in sweep.model_configs - } - - B = max(1, sweep.bt // T) - - common_configs = { - "kernel_name": "fused_linear_cpo_loss", - "x_name": "model_config", - "x_label": "model configuration", - "x_values": [cfg.name for cfg in sweep.model_configs], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "model_configs": model_configs_info, - "B": B, - "T": T, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_cpo_loss_model_config, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_cpo_loss_model_config, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, + common_configs = build_model_config_sweep( + kernel_name="cpo_loss", + setup_fn=setup_cpo_loss, + model_keys=["hidden_size", "vocab_size", "dtype"], + extra_configs={"T": T}, + probe_dim="B", + probe_provider="huggingface", + bt=args.bt, + overwrite=args.overwrite, ) else: model = get_benchmark_model_config(args.model) - T = 1024 - probe_bt = 1024 - - def _probe(): - B = probe_bt // T - probe_input = SingleBenchmarkRunInput( - x=B, - kernel_provider="huggingface", - extra_benchmark_config={ - "hidden_size": model.hidden_size, - "vocab_size": model.vocab_size, - "dtype": model.dtype, - "T": T, - }, - ) - _, fwd_fn = _setup_cpo_loss(probe_input) - return fwd_fn() - - peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) - kernel_bpt = peak_bytes // probe_bt - - config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) - common_configs = { - "kernel_name": "fused_linear_cpo_loss", - "x_name": "B", - "x_label": "Batch Size (B)", - "x_values": [2**i for i in range(1, int(math.log2(max(2, config.batch_size * config.seq_len // T))) + 1)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "hidden_size": model.hidden_size, - "vocab_size": model.vocab_size, - "dtype": model.dtype, - "T": T, - } + common_configs = build_token_length_sweep( + kernel_name="cpo_loss", + probe_x=1, + model=model, + setup_fn=setup_cpo_loss, + model_keys=["hidden_size", "vocab_size", "dtype"], + extra_configs={"T": T}, + scale_dim="B", + probe_provider="huggingface", + x_values_fn=lambda config: [ + 2**i for i in range(1, int(math.log2(max(2, config.batch_size * config.seq_len // T))) + 1) ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_cpo_loss, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_cpo_loss, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, + overwrite=args.overwrite, ) + + common_configs["kernel_providers"] = ["liger", "huggingface"] + + run_benchmarks( + bench_test_fn=build_speed_bench_fn(setup_cpo_loss), + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=build_memory_bench_fn(setup_cpo_loss), + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_layer_norm.py b/benchmark/scripts/benchmark_layer_norm.py index 456e62b69..7e8c5c2a1 100644 --- a/benchmark/scripts/benchmark_layer_norm.py +++ b/benchmark/scripts/benchmark_layer_norm.py @@ -1,18 +1,14 @@ -import math - import torch from benchmark_model_configs import MODEL_REGISTRY -from benchmark_model_configs import compute_model_config_sweep_config -from benchmark_model_configs import compute_seq_len_sweep_config -from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import build_model_config_sweep +from benchmark_model_configs import build_token_length_sweep from benchmark_model_configs import get_benchmark_model_config from utils import SingleBenchmarkRunInput -from utils import SingleBenchmarkRunOutput +from utils import build_memory_bench_fn +from utils import build_speed_bench_fn from utils import parse_benchmark_script_args from utils import run_benchmarks -from utils import run_memory_benchmark -from utils import run_speed_benchmark from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.utils import infer_device @@ -20,175 +16,84 @@ device = infer_device() -def _setup_layer_norm(input: SingleBenchmarkRunInput): +def setup_layer_norm(input: SingleBenchmarkRunInput): """Create input tensor and LayerNorm layer from benchmark config.""" cfg = input.extra_benchmark_config - hidden_size = cfg["hidden_size"] + if isinstance(input.x, str): + model_cfg = MODEL_REGISTRY[input.x] + seq_len = cfg["seq_len"] + hidden_size = model_cfg.hidden_size + dtype = model_cfg.dtype + else: + seq_len = input.x + hidden_size = cfg["hidden_size"] + dtype = cfg["dtype"] + eps = cfg["eps"] x = torch.randn( - input.x, + seq_len, hidden_size, device=device, - dtype=cfg["dtype"], + dtype=dtype, requires_grad=True, ) if input.kernel_provider == "liger": - layer = LigerLayerNorm(hidden_size=hidden_size, eps=eps).to(device) + layer = LigerLayerNorm(hidden_size=hidden_size, eps=eps).to(device).to(dtype) elif input.kernel_provider == "huggingface": - layer = torch.nn.LayerNorm(hidden_size, eps=eps).to(device) + layer = torch.nn.LayerNorm(hidden_size, eps=eps).to(device).to(dtype) else: raise ValueError(f"Invalid provider: {input.kernel_provider} for LayerNorm") return x, layer -def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - x, layer = _setup_layer_norm(input) - return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) - - -def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - x, layer = _setup_layer_norm(input) - return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) - - -def _resolve_model_config_layer_norm(input: SingleBenchmarkRunInput): - """Resolve model-config-sweep input into standard setup args.""" - cfg = input.extra_benchmark_config - model_info = cfg["model_configs"][input.x] - return _setup_layer_norm( - SingleBenchmarkRunInput( - x=cfg["BT"], - kernel_provider=input.kernel_provider, - extra_benchmark_config={ - "hidden_size": model_info["hidden_size"], - "dtype": model_info["dtype"], - "eps": cfg["eps"], - }, - ) - ) - - -def bench_speed_layer_norm_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - x, layer = _resolve_model_config_layer_norm(input) - return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) - - -def bench_memory_layer_norm_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - x, layer = _resolve_model_config_layer_norm(input) - return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) - - if __name__ == "__main__": args = parse_benchmark_script_args() if args.sweep_mode == "model_config": - all_model_configs = list(MODEL_REGISTRY.values()) - - def _probe_factory(model_cfg, probe_bt): - def _probe(): - probe_input = SingleBenchmarkRunInput( - x=probe_bt, - kernel_provider="huggingface", - extra_benchmark_config={ - "hidden_size": model_cfg.hidden_size, - "dtype": model_cfg.dtype, - "eps": 1e-6, - }, - ) - x, layer = _setup_layer_norm(probe_input) - return layer(x) - - return _probe - - sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) - - model_configs_info = { - cfg.name: { - "hidden_size": cfg.hidden_size, - "dtype": cfg.dtype, - } - for cfg in sweep.model_configs - } - - common_configs = { - "kernel_name": "layer_norm", - "x_name": "model_config", - "x_label": "model configuration", - "x_values": [cfg.name for cfg in sweep.model_configs], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "model_configs": model_configs_info, - "BT": sweep.bt, - "eps": 1e-6, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_layer_norm_model_config, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_layer_norm_model_config, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="memory", - metric_unit="MB", - **common_configs, + common_configs = build_model_config_sweep( + kernel_name="layer_norm", + setup_fn=setup_layer_norm, + model_keys=["hidden_size", "dtype"], + extra_configs={ + "eps": 1e-6, + }, + probe_dim="BT", + probe_provider="huggingface", + bt=args.bt, + overwrite=args.overwrite, ) + else: model = get_benchmark_model_config(args.model) - probe_bt = 1024 - - def _probe(): - probe_input = SingleBenchmarkRunInput( - x=probe_bt, - kernel_provider="huggingface", - extra_benchmark_config={ - "hidden_size": model.hidden_size, - "dtype": model.dtype, - "eps": 1e-6, - }, - ) - x, layer = _setup_layer_norm(probe_input) - return layer(x) - - peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) - kernel_bpt = peak_bytes // probe_bt - - config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + probe_seq_len = 1024 + + common_configs = build_token_length_sweep( + kernel_name="layer_norm", + probe_x=probe_seq_len, + model=model, + setup_fn=setup_layer_norm, + model_keys=["hidden_size", "dtype"], + extra_configs={ + "eps": 1e-6, + }, + scale_dim="BT", + probe_provider="huggingface", + overwrite=args.overwrite, + ) - common_configs = { - "kernel_name": "layer_norm", - "x_name": "BT", - "x_label": "B * T", - "x_values": [2**i for i in range(10, int(math.log2(config.batch_size * config.seq_len)) + 1)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "hidden_size": model.hidden_size, - "dtype": model.dtype, - "eps": 1e-6, - } - ], - "overwrite": args.overwrite, - } + common_configs["kernel_providers"] = ["liger", "huggingface"] - run_benchmarks( - bench_test_fn=bench_speed_layer_norm, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_layer_norm, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + run_benchmarks( + bench_test_fn=build_speed_bench_fn(setup_layer_norm), + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=build_memory_bench_fn(setup_layer_norm), + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_model_configs.py b/benchmark/scripts/benchmark_model_configs.py index b7137942b..da6f0fa0b 100644 --- a/benchmark/scripts/benchmark_model_configs.py +++ b/benchmark/scripts/benchmark_model_configs.py @@ -26,14 +26,19 @@ import math from dataclasses import dataclass +from typing import Any from typing import Callable from typing import Dict from typing import List +from typing import Literal from typing import Optional from typing import Tuple import torch +from utils import SingleBenchmarkRunInput +from utils import default_forward_fn + from liger_kernel.utils import get_total_gpu_memory from liger_kernel.utils import infer_device @@ -341,7 +346,7 @@ def compute_seq_len_sweep_config( def compute_model_config_sweep_config( model_configs: List[ModelConfig], probe_fn_factory: Callable[[ModelConfig, int], Callable[[], torch.Tensor]], - bt: int = 2048, + bt: int = 1024, memory_utilization: float = 0.4, ) -> ModelConfigSweepConfig: """Find safe (batch_size, seq_len) that works across all model configs. @@ -351,7 +356,7 @@ def compute_model_config_sweep_config( Args: model_configs: Model configs to benchmark. - probe_fn_factory: Factory ``(model_cfg, probe_seq_len) -> probe_fn``. + probe_fn_factory: Factory ``(model_cfg) -> probe_fn``. The returned probe_fn should perform setup + forward pass and return a tensor suitable for ``.backward()``, same contract as :func:`estimate_kernel_peak_memory`'s *probe_fn*. @@ -365,7 +370,7 @@ def compute_model_config_sweep_config( max_bytes_per_token = 0 for model_cfg in model_configs: - probe_fn = probe_fn_factory(model_cfg, probe_seq_len) + probe_fn = probe_fn_factory(model_cfg) peak_bytes = estimate_kernel_peak_memory(probe_fn) bpt = max(1, peak_bytes // probe_seq_len) max_bytes_per_token = max(max_bytes_per_token, bpt) @@ -383,3 +388,200 @@ def compute_model_config_sweep_config( batch_size=batch_size, seq_len=seq_len, ) + + +def build_extra_config( + model: ModelConfig, + model_keys: List[str], + extra_configs: Optional[Dict] = None, +) -> Dict: + """Construct extra_benchmark_config dict. + + Args: + model: The model configuration object. + model_keys: List of attribute names to read from `model` + (e.g. ["hidden_size", "dtype"]). + extra_configs: Optional dictionary of additional key/value pairs + that override or extend the extracted attributes. + """ + extra_configs = extra_configs or {} + cfg = {k: getattr(model, k) for k in model_keys} + cfg.update(extra_configs) + return cfg + + +def build_model_config_sweep( + kernel_name: str, + all_model_configs: Optional[List[ModelConfig]] = None, + setup_fn: Callable[[SingleBenchmarkRunInput], Tuple[Any, ...]] = None, + model_keys: List[str] = None, + probe_dim: Literal["T", "B", "BT"] = "T", + forward_fn: Callable[..., torch.Tensor] = default_forward_fn, + probe_provider: str = "huggingface", + extra_configs: Optional[Dict] = None, + bt: int = 2048, + overwrite: bool = False, +) -> Dict: + """Build benchmark config dict for model-config sweep. + + Args: + kernel_name: Name of the kernel being benchmarked. + all_model_configs: List of model configurations to sweep over. + setup_fn: Function that prepares inputs and modules given a + `SingleBenchmarkRunInput`. Returns a tuple of objects consumed + by `forward_fn`. + model_keys: List of attributes to extract from each `ModelConfig` + and include in `extra_benchmark_config`. + forward_fn: Function that executes the kernel given the outputs of + `setup_fn`. Defaults to `(x, layer) -> layer(x)`. + probe_provider: Kernel provider used during memory probing. + extra_configs: Optional static overrides merged into the benchmark config. + token_length: Optional token length used for memory probing and sweep config. + bt: Target total tokens (batch_size * seq_len) used to derive sweep. + probe_x: Value of x passed to setup_fn during probing. This should be + specified if the kernel's input.x is not T. + overwrite: Whether to overwrite existing benchmark results. + + Returns: + A dictionary consumable by `run_benchmarks`. + """ + + if all_model_configs is None: + all_model_configs = list(MODEL_REGISTRY.values()) + + def probe_fn_factory(model_cfg): + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=1 if probe_dim == "B" else bt, + kernel_provider=probe_provider, + extra_benchmark_config=build_extra_config( + model_cfg, + model_keys, + extra_configs=extra_configs, + ), + ) + setup_out = setup_fn(probe_input) + + return forward_fn(*setup_out) + + return _probe + + sweep = compute_model_config_sweep_config( + all_model_configs, + probe_fn_factory=probe_fn_factory, + bt=bt, + ) + + base_config = {"bsz": sweep.batch_size, "seq_len": sweep.seq_len} + + if extra_configs: + base_config.update(extra_configs) + + return { + "kernel_name": kernel_name, + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "extra_benchmark_configs": [base_config], + "overwrite": overwrite, + } + + +def build_token_length_sweep( + kernel_name: str, + probe_x: int, + model: ModelConfig, + setup_fn: Callable[[SingleBenchmarkRunInput], Tuple[Any, ...]], + model_keys: List[str], + extra_configs: Optional[Dict] = None, + scale_dim: Literal["T", "B", "BT"] = "T", + forward_fn: Callable[..., torch.Tensor] = default_forward_fn, + probe_provider: str = "huggingface", + x_label: str = "sequence length", + x_values_fn: Optional[Callable[[SeqLenSweepConfig], List[int]]] = None, + overwrite: bool = False, +) -> Dict: + """Build benchmark config dict for token-length sweep. + + Args: + kernel_name: Name of the kernel being benchmarked. + probe_x: Value of x passed to setup_fn during probing. + model: Model configuration used for the sweep. + setup_fn: Function that prepares inputs and modules given a + `SingleBenchmarkRunInput`. Returns a tuple of objects consumed + by `forward_fn`. + model_keys: List of attributes to extract from `model` and include + in `extra_benchmark_config`. + extra_configs: Optional static overrides merged into the config. + forward_fn: Function that executes the kernel given the outputs of + `setup_fn`. Defaults to `(x, layer) -> layer(x)`. + probe_provider: Kernel provider used during memory probing. + scale_dim: Dimension along which to scale the sweep (e.g. "T", "B", or "BT"). + x_label: Label for the x-axis (e.g. "sequence length" or "batch size"). + x_values_fn: Optional function mapping `SeqLenSweepConfig` to a list + of x values. Defaults to powers of 2 up to max seq_len. + overwrite: Whether to overwrite existing benchmark results. + + Returns: + A dictionary consumable by `run_benchmarks`. + """ + extra_configs = extra_configs or {} + + def probe_fn(): + probe_input = SingleBenchmarkRunInput( + x=probe_x, + kernel_provider=probe_provider, + extra_benchmark_config=build_extra_config( + model, + model_keys, + extra_configs=extra_configs, + ), + ) + setup_out = setup_fn(probe_input) + return forward_fn(*setup_out) + + peak_bytes = estimate_kernel_peak_memory(probe_fn=probe_fn) + # --------------------------------------- + # derive total tokens (BT) based on scale_dim + # --------------------------------------- + if scale_dim == "T": + B = extra_configs.get("B", 1) + probe_bt = probe_x * B + + elif scale_dim == "B": + T = extra_configs.get("T") + if T is None: + raise ValueError("For B sweep, extra_configs['T'] must be provided") + probe_bt = probe_x * T + + elif scale_dim == "BT": + probe_bt = probe_x + + else: + raise ValueError(f"Unsupported scale_dim: {scale_dim}") + + kernel_bpt = max(1, peak_bytes // probe_bt) + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + if x_values_fn is None: + if scale_dim == "T": + x_values_fn = lambda cfg: [2**i for i in range(10, int(math.log2(cfg.seq_len)) + 1)] + elif scale_dim == "B": + x_values_fn = lambda cfg: [2**i for i in range(0, int(math.log2(cfg.batch_size)) + 1)] + elif scale_dim == "BT": + x_values_fn = lambda cfg: [2**i for i in range(10, int(math.log2(cfg.seq_len * cfg.batch_size)) + 1)] + + return { + "kernel_name": kernel_name, + "x_name": scale_dim, + "x_label": x_label, + "x_values": x_values_fn(config), + "extra_benchmark_configs": [ + build_extra_config( + model, + model_keys, + extra_configs=extra_configs, + ) + ], + "overwrite": overwrite, + } diff --git a/benchmark/scripts/benchmark_swiglu.py b/benchmark/scripts/benchmark_swiglu.py index dc34fd60d..c54f96a8a 100644 --- a/benchmark/scripts/benchmark_swiglu.py +++ b/benchmark/scripts/benchmark_swiglu.py @@ -1,20 +1,16 @@ -import math - import torch from benchmark_model_configs import MODEL_REGISTRY -from benchmark_model_configs import compute_model_config_sweep_config -from benchmark_model_configs import compute_seq_len_sweep_config -from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import build_model_config_sweep +from benchmark_model_configs import build_token_length_sweep from benchmark_model_configs import get_benchmark_model_config from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaMLP from utils import SingleBenchmarkRunInput -from utils import SingleBenchmarkRunOutput +from utils import build_memory_bench_fn +from utils import build_speed_bench_fn from utils import parse_benchmark_script_args from utils import run_benchmarks -from utils import run_memory_benchmark -from utils import run_speed_benchmark from liger_kernel.transformers.swiglu import LigerSwiGLUMLP from liger_kernel.utils import infer_device @@ -22,189 +18,89 @@ device = infer_device() -def _setup_swiglu(input: SingleBenchmarkRunInput): +def setup_swiglu(input: SingleBenchmarkRunInput): """Create input tensor and SwiGLU layer from benchmark config.""" cfg = input.extra_benchmark_config + if isinstance(input.x, str): + model_cfg = MODEL_REGISTRY[input.x] + seq_len = cfg["seq_len"] + hidden_size = model_cfg.hidden_size + intermediate_size = model_cfg.intermediate_size + dtype = model_cfg.dtype + else: + seq_len = input.x + hidden_size = cfg["hidden_size"] + intermediate_size = cfg["intermediate_size"] + dtype = cfg["dtype"] + llama_config = LlamaConfig( - hidden_size=cfg["hidden_size"], - intermediate_size=cfg["intermediate_size"], + hidden_size=hidden_size, + intermediate_size=intermediate_size, hidden_act=cfg["hidden_act"], ) x = torch.randn( cfg["bsz"], - input.x, - cfg["hidden_size"], + seq_len, + hidden_size, device=device, - dtype=cfg["dtype"], + dtype=dtype, requires_grad=True, ) if input.kernel_provider == "liger": - layer = LigerSwiGLUMLP(config=llama_config).to(device).to(cfg["dtype"]) + layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype) elif input.kernel_provider == "huggingface": - layer = LlamaMLP(config=llama_config).to(device).to(cfg["dtype"]) + layer = LlamaMLP(config=llama_config).to(device).to(dtype) else: raise ValueError(f"Invalid provider: {input.kernel_provider} for SwiGLU") return x, layer -def bench_speed_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - x, layer = _setup_swiglu(input) - return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) - - -def bench_memory_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - x, layer = _setup_swiglu(input) - return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) - - -def _resolve_model_config_swiglu(input: SingleBenchmarkRunInput): - """Resolve model-config-sweep input into standard setup args.""" - cfg = input.extra_benchmark_config - model_info = cfg["model_configs"][input.x] - return _setup_swiglu( - SingleBenchmarkRunInput( - x=cfg["seq_len"], - kernel_provider=input.kernel_provider, - extra_benchmark_config={ - "bsz": cfg["bsz"], - "hidden_size": model_info["hidden_size"], - "intermediate_size": model_info["intermediate_size"], - "hidden_act": cfg["hidden_act"], - "dtype": model_info["dtype"], - }, - ) - ) - - -def bench_speed_swiglu_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - x, layer = _resolve_model_config_swiglu(input) - return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) - - -def bench_memory_swiglu_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - x, layer = _resolve_model_config_swiglu(input) - return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) - - if __name__ == "__main__": args = parse_benchmark_script_args() if args.sweep_mode == "model_config": - all_model_configs = list(MODEL_REGISTRY.values()) - - def _probe_factory(model_cfg, probe_seq_len): - def _probe(): - probe_input = SingleBenchmarkRunInput( - x=probe_seq_len, - kernel_provider="huggingface", - extra_benchmark_config={ - "bsz": 1, - "hidden_size": model_cfg.hidden_size, - "intermediate_size": model_cfg.intermediate_size, - "hidden_act": "silu", - "dtype": model_cfg.dtype, - }, - ) - x, layer = _setup_swiglu(probe_input) - return layer(x) - - return _probe - - sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) - - model_configs_info = { - cfg.name: { - "hidden_size": cfg.hidden_size, - "intermediate_size": cfg.intermediate_size, - "dtype": cfg.dtype, - } - for cfg in sweep.model_configs - } - - common_configs = { - "kernel_name": "swiglu", - "x_name": "model_config", - "x_label": "model configuration", - "x_values": [cfg.name for cfg in sweep.model_configs], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "model_configs": model_configs_info, - "bsz": sweep.batch_size, - "seq_len": sweep.seq_len, - "hidden_act": "silu", - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_swiglu_model_config, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_swiglu_model_config, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="memory", - metric_unit="MB", - **common_configs, + common_configs = build_model_config_sweep( + kernel_name="swiglu", + setup_fn=setup_swiglu, + model_keys=["hidden_size", "intermediate_size", "dtype"], + probe_provider="huggingface", + extra_configs={ + "bsz": 1, + "hidden_act": "silu", + }, + probe_dim="BT", + bt=args.bt, + overwrite=args.overwrite, ) else: model = get_benchmark_model_config(args.model) probe_seq_len = 1024 - def _probe(): - probe_input = SingleBenchmarkRunInput( - x=probe_seq_len, - kernel_provider="huggingface", - extra_benchmark_config={ - "bsz": 1, - "hidden_size": model.hidden_size, - "intermediate_size": model.intermediate_size, - "hidden_act": "silu", - "dtype": model.dtype, - }, - ) - x, layer = _setup_swiglu(probe_input) - return layer(x) - - peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) - kernel_bpt = peak_bytes // probe_seq_len - - config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + common_configs = build_token_length_sweep( + kernel_name="swiglu", + probe_x=probe_seq_len, + model=model, + setup_fn=setup_swiglu, + model_keys=["hidden_size", "intermediate_size", "dtype"], + extra_configs={"hidden_act": "silu", "bsz": 1}, + scale_dim="BT", + probe_provider="huggingface", + overwrite=args.overwrite, + ) - common_configs = { - "kernel_name": "swiglu", - "x_name": "T", - "x_label": "sequence length", - "x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "bsz": config.batch_size, - "hidden_size": model.hidden_size, - "intermediate_size": model.intermediate_size, - "hidden_act": "silu", - "dtype": model.dtype, - } - ], - "overwrite": args.overwrite, - } + common_configs["kernel_providers"] = ["liger", "huggingface"] - run_benchmarks( - bench_test_fn=bench_speed_swiglu, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_swiglu, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + run_benchmarks( + bench_test_fn=build_speed_bench_fn(setup_swiglu), + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=build_memory_bench_fn(setup_swiglu), + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/utils.py b/benchmark/scripts/utils.py index 0cb307d19..f3774a133 100644 --- a/benchmark/scripts/utils.py +++ b/benchmark/scripts/utils.py @@ -191,6 +191,33 @@ def full(): return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) +def default_forward_fn(*setup_out): + x, layer = setup_out[0], setup_out[1] + return layer(x) + + +def build_speed_bench_fn( + setup_fn: Callable[["SingleBenchmarkRunInput"], Any], + forward_fn: Callable[..., torch.Tensor] = default_forward_fn, +) -> Callable: + def bench_speed(input: "SingleBenchmarkRunInput") -> SingleBenchmarkRunOutput: + setup_out = setup_fn(input) + return run_speed_benchmark(lambda: forward_fn(*setup_out), input.kernel_operation_mode, [setup_out[0]]) + + return bench_speed + + +def build_memory_bench_fn( + setup_fn: Callable[["SingleBenchmarkRunInput"], Any], + forward_fn: Callable[..., torch.Tensor] = default_forward_fn, +) -> Callable: + def bench_memory(input: "SingleBenchmarkRunInput") -> SingleBenchmarkRunOutput: + setup_out = setup_fn(input) + return run_memory_benchmark(lambda: forward_fn(*setup_out), input.kernel_operation_mode) + + return bench_memory + + def get_current_file_directory() -> str: """ Returns the directory path of the current Python file.