Opt-in inference path for advantage-conditioned VLA policies trained
with the RECAP recipe (π*0.6, arXiv:2511.14759).
At test time the model runs twice per denoising step — once with the
"Advantage: positive" indicator appended to the prompt, once
without — and the two velocity predictions are combined to sharpen
the action distribution toward the high-advantage subset:
v_guided = v_uncond + beta * (v_cond - v_uncond)
This document covers the public API, measured performance on RTX and Thor, and the numerical contract the implementation meets.
| RTX | Thor (Jetson AGX, SM110) | |
|---|---|---|
| Model | Pi0.5 | Pi0.5 |
| Hardware | RTX 5090 / 4090 (SM89, SM120) | Jetson AGX Thor (SM110, aarch64) |
| Frontends | Pi05TorchFrontendRtx (safetensors) / Pi05JaxFrontendRtx (Orbax) |
Pi05TorchFrontendThor (safetensors) / Pi05JaxFrontendThor (Orbax) |
| Serial CFG | ✅ 37 ms (β=1.5) | ✅ 88 ms (torch) / 96 ms (JAX) |
| Fused CFG (B=2, paper-correct per-step) | ✅ 25.9 ms (β=1.5) | ✅ ~67 ms (torch / JAX, with autotune≥3) |
| Generic B>2 batched (RL rollout) | not yet | not yet |
Conditioned-prompt strings are byte-equal across the four frontends
(shared builder in flash_rt/core/rl/),
so the same merged LoRA checkpoint serves all four backends.
CFG is opt-in. The default (no set_rl_mode call) inference path is
bit-for-bit unchanged.
from flash_rt.frontends.torch.pi05_thor import Pi05TorchFrontendThor
# Construct. autotune>0 enables the B=2 outer-graph autotuner —
# recommended for production, see "Performance" below.
pipe = Pi05TorchFrontendThor(
"/path/to/pi05_libero_pytorch", num_views=2, autotune=3)
# Recommended: enable the fused B=2 CFG path BEFORE set_prompt.
pipe.set_batched_mode(enable=True, batch_size=2)
# Configure CFG: β must be >= 1.0; π*0.6 paper recommends [1.5, 2.5].
pipe.set_rl_mode(cfg_enable=True, cfg_beta=1.5, advantage_positive=True)
pipe.set_prompt("fold the t-shirt")
# First infer call lazy-recalibrates FP8 scales against the cond
# prompt and (when batched) recaptures the B=2 graph.
actions = pipe.infer(obs)["actions"] # shape: (chunk_size, action_dim)
# Revert to the standard non-CFG path.
pipe.set_rl_mode(cfg_enable=False)
pipe.set_prompt("fold the t-shirt")The JAX frontend has the same call surface. The four arg-compatible
frontends are Pi05TorchFrontendRtx, Pi05JaxFrontendRtx,
Pi05TorchFrontendThor, Pi05JaxFrontendThor.
cfg_enable(bool): activate CFG inference.Falseclears any previous configuration; the nextset_promptrebuilds the standard pipeline.cfg_beta(float, default1.5): guidance strength. Must be>= 1.0.1.0mathematically reduces to cond-only output (combine collapses tov_cond) — useful as a correctness gate but wasteful in production; prefer the default non-CFG path for unconditioned inference.advantage_positive(bool, defaultTrue): conditioned prompt uses the positive advantage tag. SetFalseonly for debugging the guidance direction.
autotune=N runs N capture+benchmark trials per CUDA-Graph build
(both the B=1 enc+ae graph and, when set_batched_mode is enabled,
the B=2 outer fused-CFG graph). Each trial lets cuBLASLt re-query
its heuristic; the fastest captured schedule is kept.
autotune=0(default): one capture, whatever cuBLASLt picks first.autotune=3: recommended for RL CFG deployment — eliminates the cuBLASLt-tactic variance between Python frameworks (see Performance).- Higher values cost ~0.5–1 s startup per additional trial.
Pi0.5 action expert is a 10-step flow-matching diffusion. Standard
single-forward inference integrates one velocity per step:
for k in 0..9:
v = action_head(x_k, prompt, image)
x_{k+1} = x_k + v
CFG runs the action head twice per step with two prompts and combines
the velocities per step (paper Eq. 2; matches Pi05CFGBatchedPipeline
on RTX and decoder_forward_b2(cfg_beta=...) on Thor):
for k in 0..9:
v_cond = action_head(x_k, "task\nAdvantage: positive", image)
v_uncond = action_head(x_k, "task", image)
x_{k+1} = x_k + v_uncond + beta * (v_cond - v_uncond)
Both branches must enter step k+1 from the same x_{k+1}
(otherwise the trajectories drift apart and combining their final
velocities is no longer the paper's CFG). The fused B=2 path
enforces this by writing the guided update into the cond slot via
the cfg_combine_into_residual kernel and mirroring it into the
uncond slot via a cudaMemcpyAsync — both inside the captured graph.
RTX: Pi05Pipeline → Pi05BatchedPipeline → Pi05CFGBatchedPipeline (B=2 fused CFG)
→ Pi05CFGPipeline (serial CFG)
Thor: Pi05ThorPipeline → Pi05ThorBatchedPipeline → Pi05ThorCFGBatchedPipeline
→ Pi05ThorCFGPipeline (serial CFG)
Each *BatchedPipeline runs the encoder + 10-step decoder once at
B=2. Slot 0 is the conditioned context, slot 1 the unconditioned;
the per-step cfg_combine_into_residual kernel (single fused
elementwise call, FP16/BF16 packed-2) writes the guided velocity into
slot 0 and a D2D copy mirrors it into slot 1.
The RTX backend captures the entire B=2 forward (vision encoder,
text encoder, per-step decoder, cfg_combine, mirror) as one
torch.cuda.CUDAGraph. forward() is a single graph.replay().
The Thor backend captures the same shape — outer graph wraps two
B=1 SigLIP runs (one per language slot, lang-emb swap is a graph-
internal D2D from a pre-staged device buffer), one B=2 enc_ae graph,
and the per-step CFG combine + noise mirror inside decoder_forward_b2.
Pi05ThorCFGBatchedPipeline.forward() calls outer_graph.replay()
- stream sync.
Median over 20 infer invocations after 5 warmup calls.
| path | β | median (ms) | vs baseline |
|---|---|---|---|
| baseline (no CFG) | — | 19.0 | 1.00× |
| serial CFG | 1.5 | 37.1 | 1.96× |
| fused CFG batched | 1.5 | 25.9 | 1.36× |
β does not affect latency — it is a multiplier inside the combine
kernel only. Fused batched is faster than the equivalent generic
B=2 path (27.5 ms) because the cfg_combine kernel replaces (does not
add to) the cond-slot per-step residual_add the generic batched path
performs.
The 25.9 ms median fits inside the 20 ms budget that 50 Hz real-robot
control demands once typical 3 ms control-loop overhead outside
infer() is accounted for.
Median over 50 timed iters per back-to-back A/B subprocess pair, 3
cycles. Both backends use autotune=3.
| backend | path | β | median (ms) |
|---|---|---|---|
| torch | baseline (no CFG) | — | 44.6 |
| torch | serial CFG | 1.5 | 88 |
| torch | fused CFG batched | 1.5 | ~67 |
| JAX | baseline (no CFG) | — | 44.9 |
| JAX | serial CFG | 1.5 | 96 |
| JAX | fused CFG batched | 1.5 | ~67 |
Without autotune, the JAX frontend's fused-CFG p50 lands ~3–4 ms
above torch's. Root cause is process-state-dependent cuBLASLt
heuristic divergence — the two Python frameworks load different
libcublas.so versions (system 13.2.0 for torch, pip-bundled 13.2.1
for JAX) and start cuBLASLt with different internal cache states.
Given the same (M, N, K), cuBLASLt can return a tactic that
launches ~36 extra cutlass::Kernel2 sub-launches per inference
in the JAX process.
autotune=N recaptures the outer graph N times and keeps the
fastest. Each capture lets cuBLASLt re-query the heuristic; with
N≥3 the JAX backend converges on the same fast tactic torch picks
on the first try. This keeps the heuristic-first design (we never
pin a specific algo, which would brittle-break on cuBLAS upgrades
or hardware revisions) while erasing the cross-backend gap.
Thor's qkv_split_rope_kvcache_fp16 and attention_qkv_fp16
launches run as a per-sample inline Python loop (no batch-aware
fused-attention kernel for SM110 yet); these account for ~20 ms of
the fused-CFG path. The dense FP8 GEMMs amortise across the two
slots correctly (M = B*Seq) — only the per-token-indexed kernels
pay the per-sample cost. A future SM110 batch-aware attention
kernel would close most of the Thor↔RTX gap.
Default path (no set_rl_mode):
- bit-identical to the pre-RL implementation on all four frontends.
CFG path:
cfg_combine_into_residualkernel vs FP32 reference on random inputs at the production size (chunk_size * action_dim = 320):max abs diff = 0,cos = 1.0.cfg_beta=1.0collapse:cos(CFG, cond_only) >= 0.999on all serial and fused paths, both backends, both hardware platforms (mathematical identity:v_uncond + 1*(v_cond - v_uncond) = v_cond).- B=2 slot symmetry: same observation in both slots, identical noise
R →
cos(slot 0, slot 1) = 1.000000,maxdiff = 0on torch and JAX.
| β | regime | batched vs serial | batched vs FP32 ref |
|---|---|---|---|
| 1.0 | paper default | 0.9997 | 0.9958 |
| 1.5 | moderate (lower) | 0.9991 | 0.9919 |
| 2.0 | mid-moderate | 0.9982 | 0.9854 |
| 2.5 | moderate (upper) | 0.9971 | 0.9756 |
The fused batched path tracks both serial and the FP32 reference
within the FP8 quantisation budget across the paper's full
[1.0, 2.5] recommended β range.
Same numpy-seeded R fed to both backends:
| β | torch vs JAX cos |
|---|---|
| 1.0 | ≥ 0.9997 |
| 1.5 | ≥ 0.9986 |
| 2.5 | ≥ 0.9979 |
The residual gap (~0.001–0.002) is per-frontend FP8 calibration noise amplified by the CFG combine; it is not a correctness issue (well inside the deployment cosine floor of 0.99 vs PyTorch FP32 reference).
| test | what it validates |
|---|---|
tests/test_rl_cfg_inference.py |
RTX serial + batched CFG, all βs, validation gates |
tests/test_thor_rl_cfg_inference.py --backends torch,jax |
Thor serial CFG: validation, β=1.0 collapse, β=1.5 finite |
tests/test_cfg_correctness_oracle.py |
per-step C1–C5 contract (RTX) vs frozen reference |
Calibration warning about scale ceiling during RL mode — the conditioned prompt has slightly different token statistics than pure task text. If the ratio is within ~25× the median, output is correct; the warning flags calibration-set diversity, not a bug.
RuntimeError: cfg_beta must be >= 1.0 — pass a value in
[1.0, …]. < 1.0 would invert guidance, which the frontend
rejects to prevent silent sign bugs.
RuntimeError: set_prompt must be called before calibrate —
RL mode rebuilds the pipeline at the next set_prompt. Order is
always set_rl_mode → set_prompt → calibrate.
Two Pi05TorchFrontendRtx instances in the same process segfault —
pre-existing single-instance constraint of the calibration path,
unrelated to RL mode. Use one frontend per process (the test suite
does this).
JAX fused-CFG is consistently 3–4 ms slower than torch — pass
autotune=3 (or higher) to the frontend constructor. See the
"Performance" → "Why autotune matters on Thor" section.
- π*0.6 paper — arXiv:2511.14759, Appendix E for the CFG derivation from the flow-matching likelihood gradient.
flash_rt/core/rl/— framework-agnostic combine math, ACP-tag prompt builder.csrc/kernels/elementwise.cu—cfg_combine_into_residualkernel (packed-2 vectorised, FP32 internally for numerical stability at β > 1).- Pipeline classes: