flash-mHC is a standalone, production-focused extraction of the mHC-lite Torch+Triton path.
It provides:
- Triton fused kernels for mHC-lite K1/K3/K4 steps
torch.librarycustom ops compatible withtorch.compile- A reusable
MHCLiteBlockwrapper module - Kernel microbenchmark + one-shot launch-parameter grid search
- CUDA correctness tests against PyTorch references
- Current fused kernels are specialized for
n_streams=4. - Triton path is used for CUDA + bf16 inputs with shape
(T, n, C). - Fallback path remains pure PyTorch in
MHCLiteBlock. - Runtime Triton autotune is enabled by default for K1/K3/K4 kernels (capped config sets + Triton disk cache).
MHCLiteBlocksupports optional activation checkpointing and disables autotune during eval/inference by default.
pip install -e .
# optional test tooling
pip install -e .[dev]import torch
import torch.nn as nn
from flash_mhc import MHCLiteBlock
layer = nn.Linear(1024, 1024, bias=False).cuda().bfloat16()
block = MHCLiteBlock(
n_streams=4,
hidden_size=1024,
layer=layer,
triton_fused=True,
activation_checkpointing=True,
).cuda().bfloat16()
x = torch.randn(65536, 4, 1024, device="cuda", dtype=torch.bfloat16)
out = block(x)
print(out.shape) # (65536, 4, 1024)Kernel microbenchmark:
python scripts/benchmark_kernels.py --T 65536 --C 1024 --n 4 --peak-gbps 1792Grid search for launch parameters:
python scripts/gridsearch.py --T 65536 --C 1024 --n 4 --json-out output/gridsearch_sm120.jsonBy default, flash-mHC uses Triton's runtime autotune wrappers for fused kernels.
FLASH_MHC_TRITON_AUTOTUNE=0: disable runtime autotune and use legacy hardcoded launch params fromops.py.FLASH_MHC_TRITON_AUTOTUNE_STATUS=0: silence one-time autotune status logs.TRITON_CACHE_DIR=/path/to/cache: override Triton disk cache location (default is Triton's normal cache path).
pytest -q tests/test_correctness.pysrc/flash_mhc/kernels.py: Triton kernelssrc/flash_mhc/ops.py:torch.libraryop registration + autogradsrc/flash_mhc/block.py: standaloneMHCLiteBlockscripts/benchmark_kernels.py: per-kernel timing and efficiencyscripts/gridsearch.py: one-shot launch-parameter sweeptests/test_correctness.py: CUDA correctness testsdocs/ARCHITECTURE.md: kernel/dataflow detailsdocs/TUNING_SM120.md: SM120 tuning summary and selected paramsdocs/PUBLISHING.md: packaging and release checklist