[feat] Add AnyFlow any-step video distillation (pretrain + on-policy)#1371
[feat] Add AnyFlow any-step video distillation (pretrain + on-policy)#1371Enderfga wants to merge 15 commits into
Conversation
…tity) Adds four fields to WanVideoArchConfig for AnyFlow dual-timestep conditioning: - r_embedder: bool — enables the delta_embedder allocation - r_embedder_fusion: 'additive' (default) or 'gated' - r_embedder_gate_value: float — gate g in gated fusion - r_embedder_deltatime_type: 'r' (default) or 't-r' Also adds two regex entries to param_names_mapping so HF AnyFlow checkpoints load delta_embedder weights into the FastVideo-internal mlp.fc_in/fc_out layout. Both regex are no-ops on plain Wan checkpoints. The defaults preserve bit-identity with every existing Wan-based method (DMD2, Self-Forcing, KD, DFSFT) — no delta_embedder is allocated and no extra computation runs on the embedder forward path.
When r_embedder=True, allocates a delta_embedder (deep-copied from time_embedder for matching initialization, per the AnyFlow reference's setup_flowmap_model() pattern) and registers a non-persistent gate buffer. Forward now accepts an optional r_timestep: - additive (default): temb_t + g * delta_emb - gated: (1 - g) * temb_t + g * delta_emb Both branches are gated by r_embedder=True AND r_timestep is not None, so existing call sites that don't pass r_timestep stay byte-identical (verified by test_embedder_enabled_without_r_timestep_is_bit_identical_to_legacy). The gate is a non-persistent buffer — it's a hyperparameter, not a learned weight, and stays out of state_dict so checkpoints remain portable across gate values. deltatime_type controls whether the delta_embedder consumes r directly (default, matching AnyFlow's deltatime_type='r') or (t - r).
The forward signature now accepts r_timestep as an explicit kwarg (not silently swallowed by **kwargs), gets flattened in parallel with timestep when the input is 2-D (Wan 2.2 ti2v case), and is forwarded to the condition_embedder. The constructor now propagates the four arch_config flags (r_embedder/r_embedder_fusion/r_embedder_gate_value/r_embedder_deltatime_type) into the WanTimeTextImageEmbedding so a YAML can opt into AnyFlow's dual-timestep conditioning through training.dit_config overrides. Defaults preserve byte-identity with the legacy single-timestep path — when r_timestep is omitted the delta_embedder branch is skipped entirely (test_embedder_enabled_without_r_timestep_is_bit_identical_to_legacy).
Standalone any-step Euler scheduler implementing AnyFlow's flow-map step formula x_r = x_t - ((t - r) / num_train_timesteps) * u(x_t, t, r). Also provides the AnyFlow training-time helpers apply_shift (flow-matching shift transform), get_train_weight (per-timestep loss weight with beta08 = t * sqrt(1-t) renormalized to num_train_timesteps total mass), and add_noise (linear flow-matching interpolation). Subclasses BaseScheduler so it slots into the existing FastVideo scheduler discovery surface (timesteps/order/num_train_timesteps attributes plus set_shift/set_timesteps/scale_model_input). Has no diffusers ConfigMixin/SchedulerMixin dependency. set_timesteps accepts custom_timesteps so configs can pin AnyFlow's hand-tuned schedules (e.g. [999, 937, 833, 624, 0] for the 4-step Wan2.1 setting from the paper).
Adds the single-student AnyFlow flow-map pretrain method as a TrainingMethod subclass. The __init__ parses and validates all method config knobs (diffusion_ratio, consistency_ratio, epsilon, weight_type, fuse_guidance_scale, shift), builds a FlowMapEulerDiscreteScheduler, and wires a single optimizer+scheduler over the student. The (t, r) sampling helper _sample_pair_timesteps lives at module scope so it can be unit-tested without instantiating the full method. It matches the AnyFlow paper formulation: - two uniform draws u1, u2; t = max, r = min - first diffusion_ratio * B: r := t (plain flow matching) - next consistency_ratio * B: r := 0 (consistency to clean data) - remainder: free reconstruction range single_train_step raises NotImplementedError for now — filled in by the next commit which adds the central-difference target and the loss assembly (Task 7).
…n_step
Implements AnyFlow's central-difference loss path end-to-end:
1. WanModel gains predict_velocity_with_r(noisy, t, r, batch, ...) — mirrors
predict_noise's forward_context + autocast plumbing but injects
r_timestep into the transformer kwargs.
2. anyflow_pretrain._central_difference_dF_dt — symmetric finite difference
over (t ± delta) with the sample also moved along the flow trajectory by
v_pred * (delta / num_train_timesteps), mirroring AnyFlow's reference
trainer_wan_anyflow_pretrain.py::compute_central_difference. Wrapped in
torch.no_grad so the two extra forwards stay out of the backward graph.
3. AnyFlowPretrainMethod.single_train_step:
- sample (t, r) ∈ [0, 1] via _sample_pair_timesteps
- apply scheduler shift + scale to absolute units
- rebuild noisy with the flow-map scheduler's add_noise
- run conditional + (optional) unconditional student forwards
- apply guidance distillation when fuse_guidance_scale != 1
- compute target = (eps - x0) - (t - r) * dF/dt
- per-sample MSE * per-timestep weight
- stop-grad scale balance so non-diffusion losses match the diffusion
branch's magnitude
- emit metrics for diffusion/consistency fractions + scale weight mean
4. AnyFlowPretrainMethod.backward overrides TrainingMethod.backward to
route the loss through self.student.backward inside the correct
forward_context, matching DMD2Method's pattern.
Subclasses DMD2Method and overrides _student_rollout to run student_sample_steps Euler-flow steps from pure noise. One step is gradient-enabled, with the chosen index broadcast from rank 0 in distributed runs so every worker agrees (matching AnyFlow's WanAnyFlowPipeline.training_rollout). Method config knobs: - student_sample_steps (default 4) - use_mean_velocity (default True; r = t_next vs r = t at each step) - t_list_override (optional pinned schedule, e.g. AnyFlow paper's [999, 937, 833, 624, 0] for 4-step Wan2.1) - dmd_score_r_value (default 0.0; r used by the inherited DMD loss) - real_score_guidance_scale (inherited from DMD2; default 1.0) The inherited _dmd_loss / _critic_flow_matching_loss machinery is reused verbatim — they call predict_x0 / predict_noise on the student, which go through the single-timestep WanModel forward path (since DMD2 scores against r=0 implicitly via the inherited code; if r=t is desired this can be wired in a follow-up by overriding _dmd_loss to call predict_velocity_with_r). CPU-only tests via object.__new__ bypassing the full __init__ wire-up; GPU integration coverage lives in the smoke test (Task 11).
Two reference configs slot into examples/train/configs/distribution_matching/wan/: - anyflow_pretrain_t2v.yaml: single-student pretrain stage. Enables r_embedder under pipeline.dit_config, sets diffusion/consistency ratios, epsilon=5, weight_type=beta08, fuse_guidance_scale=3.0, lr 5e-5, 6k steps, Wan2.1 T2V 1.3B init. - anyflow_onpolicy_t2v.yaml: full DMD2 trio (student + Wan2.1 14B teacher + Wan2.1 1.3B critic). Pins the 4-step rollout schedule [999, 937, 833, 624, 0] from the AnyFlow paper, use_mean_velocity=true, generator_update_interval=5, lr 2e-6, 4k steps. The student loads from a path placeholder which can be either the local pretrain output or the published nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers checkpoint — the latter case relies on the param_names_mapping regex in WanVideoArchConfig to rename delta_embedder weights.
Mirrors test_distill_dmd.py's subprocess-torchrun pattern but points at the new YAML entrypoint (fastvideo.train.entrypoint.train). Auto-skipped when fewer than 2 CUDA devices are visible — the new TrainingMethod framework's distributed barriers don't tolerate single-GPU bring-up. On Buildkite the test fires under /test distillation; on GMI we can run it directly with NUM_GPUS=2 NUM_NODES=1 pytest -v -k anyflow_smoke.
Single-page summary of AnyFlow's two-stage recipe: - Algorithm intuition (u_θ(x_t, t, r) → any-step Euler) - Stage 1 pretrain: central-difference target + (t, r) sampling - Stage 2 on-policy DMD: multi-step rollout with grad-step broadcast - Launch commands for both YAMLs - Note on loading nvidia/AnyFlow-Wan2.1-T2V-* checkpoints (handled by param_names_mapping, no separate adapter) - Note on fuse_guidance_scale parameterization Registered under mkdocs nav so it surfaces in the built doc tree.
…r CPU runs FastVideo's ReplicatedLinear allocates weights via torch.empty and relies on a downstream load_weights pass to populate them. CPU unit tests bypass that pass, so weights start as NaN/Inf and the embedder forward produces NaN everywhere. Add _init_uninitialized_weights that Xavier-init's every >=2D param and zeros the rest before .eval(). Verified on GMI (gpu-h200-06 with 8x H200): all 43 tests pass.
Five-stage end-to-end verification, run via single-rank torchrun-less
srun on a single H200:
(1) Build FastVideo WanTransformer3DModel with r_embedder=True,
r_embedder_fusion=gated, gate=0.25.
(2) Load nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers safetensors and
translate keys via WanVideoArchConfig.param_names_mapping
(0 missing / 0 unexpected — the delta_embedder regex is sufficient).
(3) Build AnyFlow's reference loader (FAR_Wan_Transformer3DModel).
(4) Forward parity on identical inputs — bf16 noise.
(5) 4-step Euler-flow sampling smoke via FlowMapEulerDiscreteScheduler.
(6) Training-step central-difference loss comparison (inline replica
of AnyFlow's train_bidirection).
Measured on Wan2.1-T2V-1.3B + nvidia/AnyFlow checkpoint:
forward rel mean diff : 2.55%
forward max abs diff : 7.81e-2
training loss diff : 1.33% (AnyFlow 0.381619 vs FastVideo 0.386694)
Both within bf16 kernel noise. Compare to the FastGen port at
NVlabs/FastGen#25 which reported 2.8% forward + 4.07% training-loss
on the same checkpoint — FastVideo's tighter result is consistent
with FastVideo's attention/normalization implementation having slightly
lower kernel noise on H200 than FastGen's.
Single-rank demo script that loads nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers into FastVideo's WanTransformer3DModel via param_names_mapping, then samples 81 frames at 480x832 with the new FlowMapEulerDiscreteScheduler at both NFE=4 (~30s) and NFE=50 (~5min) and decodes each via the Wan VAE (tiling enabled). Uses guidance_scale=1.0 since the on-policy distilled checkpoint has fuse_guidance_scale=3.0 baked into the weights, matching the AnyFlow paper's official demo.py default. Memory tactics for single H200 (141 GB HBM): - Encode prompts with UMT5 first, free the text encoder. - Build 14B transformer (~28 GB bf16), load AnyFlow shards. - Sample at both NFEs, free transformer, then VAE decode with tiling. Peak GPU usage ~57 GB on a single H200.
There was a problem hiding this comment.
Welcome to FastVideo! Thanks for your first pull request.
How our CI works:
PRs run a two-tier CI system:
- Pre-commit — formatting (yapf), linting (ruff), type checking (mypy). Runs immediately on every PR.
- Fastcheck — core GPU tests (encoders, VAEs, transformers, kernels, unit tests). Runs automatically via Buildkite on relevant file changes (~10-15 min).
- Full Suite — integration tests, training pipelines, SSIM regression. Runs only when a reviewer adds the
readylabel.
Before your PR is reviewed:
-
pre-commit run --all-filespasses locally - You've added or updated tests for your changes
- The PR description explains what and why
If pre-commit fails, a bot comment will explain how to fix it. Fastcheck and Full Suite results appear in the Checks section below.
Useful links:
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🔴 PR merge requirementsWaiting for
This rule is failing.
|
There was a problem hiding this comment.
Code Review
This pull request implements the AnyFlow any-step video distillation framework, adding dual-timestep conditioning to the Wan model, a new FlowMapEulerDiscreteScheduler, and training methods for pre-training and on-policy distillation. The PR includes comprehensive tests and documentation. A review comment correctly identified and provided a fix for a typo in the arXiv link for the AnyFlow paper.
| @@ -0,0 +1,116 @@ | |||
| # 🌊 AnyFlow Any-Step Video Distillation | |||
|
|
|||
| **AnyFlow** ([paper](https://arxiv.org/abs/2605.13724), [project page](https://nvlabs.github.io/AnyFlow/), [official code](https://github.com/NVlabs/AnyFlow), [model weights](https://huggingface.co/collections/nvidia/anyflow)) is an any-step video diffusion framework built on flow maps. A single distilled checkpoint can be evaluated at NFE ∈ {1, 2, 4, 8, 16, 32} without retraining, and quality scales **monotonically** with steps — unlike consistency-based distillation, which often degrades as NFE grows. | |||
There was a problem hiding this comment.
The link to the AnyFlow paper appears to have a typo. The year-month part 2605 is likely incorrect for a 2024 paper. The correct link should probably point to arxiv.org/abs/2405.13724.
| **AnyFlow** ([paper](https://arxiv.org/abs/2605.13724), [project page](https://nvlabs.github.io/AnyFlow/), [official code](https://github.com/NVlabs/AnyFlow), [model weights](https://huggingface.co/collections/nvidia/anyflow)) is an any-step video diffusion framework built on flow maps. A single distilled checkpoint can be evaluated at NFE ∈ {1, 2, 4, 8, 16, 32} without retraining, and quality scales **monotonically** with steps — unlike consistency-based distillation, which often degrades as NFE grows. | |
| **AnyFlow** ([paper](https://arxiv.org/abs/2405.13724), [project page](https://nvlabs.github.io/AnyFlow/), [official code](https://github.com/NVlabs/AnyFlow), [model weights](https://huggingface.co/collections/nvidia/anyflow)) is an any-step video diffusion framework built on flow maps. A single distilled checkpoint can be evaluated at NFE ∈ {1, 2, 4, 8, 16, 32} without retraining, and quality scales **monotonically** with steps — unlike consistency-based distillation, which often degrades as NFE grows. |
There was a problem hiding this comment.
The arxiv ID is correct as written. AnyFlow was posted to arxiv this month (2026-05), so the prefix is 2605, not 2405. https://arxiv.org/abs/2605.13724 resolves to the right paper; https://arxiv.org/abs/2405.13724 is a different (2024) submission that this suggestion has been pattern-matched to.
|
Hi @Enderfga — this is a code review from one of @SolitaryThinker's AI reviewer agents (Gob). I run these to help triage PRs but @SolitaryThinker hasn't personally verified every finding. If anything below doesn't match what you know about the code, please ping @SolitaryThinker — they'll take a closer look. TL;DRThe PR is in unusually good shape for its size (2,932 LoC). Bit-identity defaults are properly guarded and explicitly tested, numerical parity vs the upstream NVlabs/AnyFlow reference is asserted with real thresholds (2.55% / 1.33% rel-diff on bf16), distributed gradient-step broadcast is correct, scheduler conforms to Verdict: approve-with-followup
What I checkedDefault-path bit-identity (the central claim). The
The four explicit tests pin this down:
Existing Wan-based pipelines (DMD2, Self-Forcing, KD, DFSFT, base T2V/I2V) stay byte-equal to Numerical parity (
The sister FastGen port (NVlabs/FastGen#25) reporting 2.8% / 4.07% on the same checkpoint is a nice external sanity check. Distributed correctness (gradient-step broadcast).
All ranks converge on the same Scheduler conformance. AGENTS.md compliance. New method goes under Commit hygiene. Scanned all 14 commits — no AI co-author trailers, no "Generated with Claude Code" lines, no Claude / Opus / Sonnet / Anthropic / OpenAI / GPT / Codex strings. All subjects use the Test coverage. ~40 CPU unit tests across Three small polish items (all S3 — non-blocking)1. arXiv link typo (
|
SolitaryThinker
left a comment
There was a problem hiding this comment.
Thank you for the contribution!
|
/merge |
Pre-commit checks failedHi @Enderfga, the pre-commit checks have failed. To fix them locally: # Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install
# Run all checks and auto-fix what's possible
pre-commit run --all-filesCommon fixes:
After fixing, commit and push the changes. The checks will re-run automatically. For future commits, |
- scripts/verify_anyflow_fastvideo_parity.py, scripts/demo_anyflow_14b.py: read ANYFLOW_LOCAL / ANYFLOW_REF / ANYFLOW_DEMO_OUT from the environment so the scripts are not tied to a single workstation layout. Drop the bespoke `srun` invocation example from the docstring in favour of a plain `python scripts/...` snippet. - fastvideo/train/methods/distribution_matching/anyflow_pretrain.py: collapse the diff_mean if/else into a ternary so ruff SIM108 is happy.
|
@SolitaryThinker Quick pass on Gob's three S3 items in S3.1 (arxiv link). Keeping S3.2 (hardcoded paths in S3.3 (silent no-op when Also folded a ruff |
|
@SolitaryThinker From my side this is ready to go — the path cleanup and the ruff The only thing blocking the mergify gate is the Thanks for the patience on the back-and-forth. |
Pre-commit checks failedHi @Enderfga, the pre-commit checks have failed. To fix them locally: # Install pre-commit if you haven't already
uv pip install pre-commit
pre-commit install
# Run all checks and auto-fix what's possible
pre-commit run --all-filesCommon fixes:
After fixing, commit and push the changes. The checks will re-run automatically. For future commits, |
Summary
Adds AnyFlow (paper, official code, model weights) — an any-step video diffusion framework built on flow maps — as a first-class distillation method in FastVideo, covering both training stages.
A single distilled checkpoint can be evaluated at NFE ∈ {1, 2, 4, 8, 16, 32, 50} without retraining; quality scales monotonically with steps, unlike consistency-based distillation which often degrades at higher NFE. The student network
u_θ(x_t, t, r)predicts the average velocity from timetback to timer, so one Euler step isTwo new methods on top of FastVideo's
TrainingMethodframework:AnyFlowPretrainMethod(fastvideo/train/methods/distribution_matching/anyflow_pretrain.py) — flow-map pretrain via central-difference target with(t, r)per-batch sampling (50% diffusion / 25% consistency / 25% free).AnyFlowMethod(fastvideo/train/methods/distribution_matching/anyflow.py) — on-policy multi-step Euler-flow rollout subclassingDMD2Method. One randomly-chosen step is gradient-enabled (broadcast from rank 0); the rest run undertorch.no_grad.Supporting changes:
WanTimeTextImageEmbeddinggains an optional dual-timestep branch (r_embedder=True→ allocate adelta_embedderdeep-copied fromtime_embedder+ register a non-persistent gate buffer). Two fusion modes:additive(default):temb_t + g · delta_embgated:(1 - g) · temb_t + g · delta_emb(matches AnyFlow'sWanTwoTimeTextImageEmbedding)r_timesteppassed is byte-identical to the legacy path — every existing config (DMD2, Self-Forcing, KD, DFSFT) stays bit-equal.FlowMapEulerDiscreteScheduler(fastvideo/models/schedulers/) — standalone any-step Euler scheduler withapply_shift,get_train_weight(beta08),step(model_output, sample, t, r),add_noise. No diffusersConfigMixindependency.WanModel.predict_velocity_with_r— adds anr_timestepto the transformer kwargs through the sameset_forward_contextplumbing aspredict_noise.WanVideoArchConfig.param_names_mappinggets two regex entries that renamecondition_embedder.delta_embedder.*→condition_embedder.delta_embedder.mlp.fc_{in,out}.*so the publishednvidia/AnyFlow-Wan2.1-T2V-*-Diffuserscheckpoints load as-is through FastVideo's existing loader pipeline. Both regex are no-ops on plain Wan checkpoints.Two reference YAMLs under
examples/train/configs/distribution_matching/wan/:anyflow_pretrain_t2v.yaml— Wan-T2V-1.3B pretrain,r_embedder: true,gate=0.25,shift=5.0,epsilon=5,weight_type=beta08,fuse_guidance_scale=3.0, lr 5e-5, 6k steps.anyflow_onpolicy_t2v.yaml— Wan-T2V-1.3B on-policy with Wan-14B teacher,dmd_denoising_steps=[999, 937, 833, 624],t_list_override=[999, 937, 833, 624, 0],student_sample_steps=4,use_mean_velocity=true, lr 2e-6.Algorithm + usage write-up at
docs/distillation/anyflow.md(registered inmkdocs.ymlnav).Numerical verification
End-to-end parity verified on a single H200 against NVlabs/AnyFlow's reference loader on the published 1.3B checkpoint (
nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers). Repro:scripts/verify_anyflow_fastvideo_parity.py.For comparison, the sister FastGen port (NVlabs/FastGen#25) reports 2.8% forward / 4.07% training-loss on the same checkpoint — FastVideo's slightly tighter result is consistent with its attention/normalization implementation having marginally lower kernel noise on H200.
Sample videos
Generated end-to-end through FastVideo with the new
FlowMapEulerDiscreteScheduler:nvidia/AnyFlow-Wan2.1-T2V-14B-Diffusers, 81 frames @ 480×832, seed=0,shift=5.0,guidance_scale=1.0(no CFG — the on-policy distilled checkpoint hasfuse_guidance_scale=3.0baked into the weights, matching AnyFlow's officialdemo.pydefault).Same prompt as AnyFlow's
demo.py(the "majestic elephant running towards a herd" prompt). Same model, two NFE settings:fastvideo_anyflow_14b_nfe4_v1.mp4
fastvideo_anyflow_14b_nfe50_v1.mp4
Quality scales monotonically with NFE — NFE=4 is already coherent, NFE=50 produces sharper textures and more stable motion.
Test plan
pytest fastvideo/tests/training/distill/test_anyflow_{pretrain,onpolicy}.py— 43 CPU unit tests (Xavier-init helper forReplicatedLinearso the embedder forward produces finite outputs in CPU mode); covers (t, r) sampling, central-difference target math, scheduler numerics, embedder bit-identity in additive mode, on-policy rollout gradient masking,t_list_overridevalidation, checkpoint key remap.scripts/verify_anyflow_fastvideo_parity.py.scripts/demo_anyflow_14b.py./test distillation— GPU smoke on Buildkite (pretrain + on-policy, 2 iters each viafastvideo/tests/training/distill/test_anyflow_smoke.py).Compatibility
The default
r_embedder=Falsearch config keeps every existing Wan-based method byte-identical tomain— no delta_embedder allocated, no extra forward computation, no state_dict additions. Existing CI (DMD2, Self-Forcing, KD, DFSFT) should be unaffected.Out of scope
Related PRs
AnyFlowPipeline,AnyFlowFARPipeline) +FlowMapEulerDiscreteScheduler.