Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,51 @@ torchrun --nnodes 1 --nproc_per_node=8 --master_port 17154 \
Set `--nproc_per_node` to the number of GPUs you use. Logs and checkpoints go under `experiments/<run_name>/` (the `name` field in the YAML).


## 🤗 Using with Hugging Face `diffusers`

The `nvidia/AnyFlow-*-Diffusers` checkpoints can be loaded through the standard `diffusers` API:

```python
import torch
from diffusers import AnyFlowPipeline
from diffusers.utils import export_to_video

pipe = AnyFlowPipeline.from_pretrained(
"nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers",
torch_dtype=torch.bfloat16,
).to("cuda")

video = pipe(
prompt="A red panda eating bamboo in a forest, cinematic lighting",
num_inference_steps=4,
num_frames=33,
).frames[0]
export_to_video(video, "anyflow_t2v.mp4", fps=16)
```

For the FAR variant (T2V / I2V / V2V via `context_sequence`):

```python
import torch
from diffusers import AnyFlowFARPipeline
from diffusers.utils import export_to_video

pipe = AnyFlowFARPipeline.from_pretrained(
"nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers",
torch_dtype=torch.bfloat16,
).to("cuda")

video = pipe(
prompt="A red panda eating bamboo in a forest, cinematic lighting",
num_inference_steps=4,
num_frames=81,
).frames[0]
export_to_video(video, "anyflow_far_t2v.mp4", fps=16)
```

The same checkpoints also work with the `demo.py` and training entry points in this repository. See the [diffusers AnyFlow docs](https://huggingface.co/docs/diffusers/api/pipelines/anyflow) for the full reference.


## 📊 Evaluation

Evaluation uses **`mode: eval`** configs under `options/test/anyflow/`.
Expand Down
55 changes: 55 additions & 0 deletions far/models/transformer_far_wan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,16 @@ def forward(
return hidden_states


# Bind this class under the AnyFlow* names that `model_index.json` resolves via
# `getattr(diffusers, ...)`. Idempotent: if the diffusers AnyFlow classes are
# already importable, the existing bindings win.
def _register_diffusers_aliases(cls):
import diffusers as _diffusers
for name in ('AnyFlowTransformer3DModel', 'AnyFlowFARTransformer3DModel'):
if not hasattr(_diffusers, name):
setattr(_diffusers, name, cls)


@MODEL_REGISTRY.register()
class FAR_Wan_Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
Expand Down Expand Up @@ -695,6 +705,47 @@ def __init__(
if init_flowmap_model:
self.setup_flowmap_model(gate_value=self.config.gate_value, deltatime_type=self.config.deltatime_type)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""Load checkpoints whose `transformer/config.json` omits `init_*_model`.

When the config does not specify which submodules to build, derive the
flags from `_class_name`:

AnyFlowTransformer3DModel -> flow-map embedder
AnyFlowFARTransformer3DModel -> flow-map embedder + FAR patch embedding
+ default `chunk_partition` for 81-frame inference

Configs that already set these fields are passed through unchanged; user
kwargs always win.
"""
load_kwargs = {
k: kwargs[k]
for k in (
'subfolder', 'cache_dir', 'force_download', 'proxies',
'local_files_only', 'token', 'revision', 'variant'
)
if k in kwargs
}
try:
config_dict = cls.load_config(pretrained_model_name_or_path, **load_kwargs)
except Exception:
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)

if (
'init_flowmap_model' not in config_dict
and 'init_flowmap_model' not in kwargs
):
cls_name = config_dict.get('_class_name', '') or ''
is_far = 'FAR' in cls_name
kwargs.setdefault('init_flowmap_model', True)
kwargs.setdefault('init_far_model', is_far)
# The pipeline in this repository reads `chunk_partition` from the
# transformer config; fall back to the 81-frame schedule when absent.
if is_far and 'chunk_partition' not in config_dict and 'chunk_partition' not in kwargs:
kwargs.setdefault('chunk_partition', [1, 3, 3, 3, 3, 3, 3, 2])
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)

def setup_flowmap_model(self, gate_value=0, deltatime_type='r'):
inner_dim = self.config.num_attention_heads * self.config.attention_head_dim

Expand Down Expand Up @@ -1214,3 +1265,7 @@ def _forward_bidirection(
return (output,)

return Transformer2DModelOutput(sample=output)


# See _register_diffusers_aliases above.
_register_diffusers_aliases(FAR_Wan_Transformer3DModel)
33 changes: 33 additions & 0 deletions far/pipelines/pipeline_far_wan_anyflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,39 @@ def __init__(
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.use_mean_velocity = use_mean_velocity

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""Load checkpoints whose `model_index.json` references the diffusers AnyFlow
class names.

Pre-instantiates the transformer and scheduler with the classes defined in
this repository and passes them as kwargs, so `DiffusionPipeline.from_pretrained`
skips its module class lookup for those entries. text_encoder / tokenizer / vae
still load normally.
"""
load_kwargs = {
k: kwargs[k]
for k in (
'cache_dir', 'force_download', 'proxies', 'local_files_only',
'token', 'revision', 'variant'
)
if k in kwargs
}
if 'transformer' not in kwargs:
kwargs['transformer'] = FAR_Wan_Transformer3DModel.from_pretrained(
pretrained_model_name_or_path,
subfolder='transformer',
torch_dtype=kwargs.get('torch_dtype'),
**load_kwargs,
)
if 'scheduler' not in kwargs:
kwargs['scheduler'] = FlowMapDiscreteScheduler.from_pretrained(
pretrained_model_name_or_path,
subfolder='scheduler',
**load_kwargs,
)
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)

def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
Expand Down
33 changes: 33 additions & 0 deletions far/pipelines/pipeline_wan_anyflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,39 @@ def __init__(
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.use_mean_velocity = use_mean_velocity

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""Load checkpoints whose `model_index.json` references the diffusers AnyFlow
class names.

Pre-instantiates the transformer and scheduler with the classes defined in
this repository and passes them as kwargs, so `DiffusionPipeline.from_pretrained`
skips its module class lookup for those entries. text_encoder / tokenizer / vae
still load normally.
"""
load_kwargs = {
k: kwargs[k]
for k in (
'cache_dir', 'force_download', 'proxies', 'local_files_only',
'token', 'revision', 'variant'
)
if k in kwargs
}
if 'transformer' not in kwargs:
kwargs['transformer'] = FAR_Wan_Transformer3DModel.from_pretrained(
pretrained_model_name_or_path,
subfolder='transformer',
torch_dtype=kwargs.get('torch_dtype'),
**load_kwargs,
)
if 'scheduler' not in kwargs:
kwargs['scheduler'] = FlowMapDiscreteScheduler.from_pretrained(
pretrained_model_name_or_path,
subfolder='scheduler',
**load_kwargs,
)
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)

def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
Expand Down
12 changes: 12 additions & 0 deletions far/schedulers/scheduling_flowmap_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,15 @@ def step(
r_timestep = r_timestep.view(*r_timestep.shape, *([1] * (model_output.ndim - r_timestep.ndim)))
prev_sample = sample - (timestep - r_timestep) * model_output
return prev_sample.to(model_output.dtype)


# Expose this scheduler under the name used by the diffusers AnyFlow pipeline.
FlowMapEulerDiscreteScheduler = FlowMapDiscreteScheduler

# Bind the same alias on the `diffusers` package so it can be resolved via
# `getattr`. Idempotent: if diffusers already provides this class, the existing
# binding wins.
import diffusers as _diffusers # noqa: E402

if not hasattr(_diffusers, 'FlowMapEulerDiscreteScheduler'):
_diffusers.FlowMapEulerDiscreteScheduler = FlowMapEulerDiscreteScheduler