diff --git a/README.md b/README.md index ed40abb..4da9897 100755 --- a/README.md +++ b/README.md @@ -89,6 +89,51 @@ torchrun --nnodes 1 --nproc_per_node=8 --master_port 17154 \ Set `--nproc_per_node` to the number of GPUs you use. Logs and checkpoints go under `experiments//` (the `name` field in the YAML). +## 🤗 Using with Hugging Face `diffusers` + +The `nvidia/AnyFlow-*-Diffusers` checkpoints can be loaded through the standard `diffusers` API: + +```python +import torch +from diffusers import AnyFlowPipeline +from diffusers.utils import export_to_video + +pipe = AnyFlowPipeline.from_pretrained( + "nvidia/AnyFlow-Wan2.1-T2V-1.3B-Diffusers", + torch_dtype=torch.bfloat16, +).to("cuda") + +video = pipe( + prompt="A red panda eating bamboo in a forest, cinematic lighting", + num_inference_steps=4, + num_frames=33, +).frames[0] +export_to_video(video, "anyflow_t2v.mp4", fps=16) +``` + +For the FAR variant (T2V / I2V / V2V via `context_sequence`): + +```python +import torch +from diffusers import AnyFlowFARPipeline +from diffusers.utils import export_to_video + +pipe = AnyFlowFARPipeline.from_pretrained( + "nvidia/AnyFlow-FAR-Wan2.1-1.3B-Diffusers", + torch_dtype=torch.bfloat16, +).to("cuda") + +video = pipe( + prompt="A red panda eating bamboo in a forest, cinematic lighting", + num_inference_steps=4, + num_frames=81, +).frames[0] +export_to_video(video, "anyflow_far_t2v.mp4", fps=16) +``` + +The same checkpoints also work with the `demo.py` and training entry points in this repository. See the [diffusers AnyFlow docs](https://huggingface.co/docs/diffusers/api/pipelines/anyflow) for the full reference. + + ## 📊 Evaluation Evaluation uses **`mode: eval`** configs under `options/test/anyflow/`. diff --git a/far/models/transformer_far_wan_model.py b/far/models/transformer_far_wan_model.py index 1be2ede..121979f 100644 --- a/far/models/transformer_far_wan_model.py +++ b/far/models/transformer_far_wan_model.py @@ -584,6 +584,16 @@ def forward( return hidden_states +# Bind this class under the AnyFlow* names that `model_index.json` resolves via +# `getattr(diffusers, ...)`. Idempotent: if the diffusers AnyFlow classes are +# already importable, the existing bindings win. +def _register_diffusers_aliases(cls): + import diffusers as _diffusers + for name in ('AnyFlowTransformer3DModel', 'AnyFlowFARTransformer3DModel'): + if not hasattr(_diffusers, name): + setattr(_diffusers, name, cls) + + @MODEL_REGISTRY.register() class FAR_Wan_Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" @@ -695,6 +705,47 @@ def __init__( if init_flowmap_model: self.setup_flowmap_model(gate_value=self.config.gate_value, deltatime_type=self.config.deltatime_type) + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + """Load checkpoints whose `transformer/config.json` omits `init_*_model`. + + When the config does not specify which submodules to build, derive the + flags from `_class_name`: + + AnyFlowTransformer3DModel -> flow-map embedder + AnyFlowFARTransformer3DModel -> flow-map embedder + FAR patch embedding + + default `chunk_partition` for 81-frame inference + + Configs that already set these fields are passed through unchanged; user + kwargs always win. + """ + load_kwargs = { + k: kwargs[k] + for k in ( + 'subfolder', 'cache_dir', 'force_download', 'proxies', + 'local_files_only', 'token', 'revision', 'variant' + ) + if k in kwargs + } + try: + config_dict = cls.load_config(pretrained_model_name_or_path, **load_kwargs) + except Exception: + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + + if ( + 'init_flowmap_model' not in config_dict + and 'init_flowmap_model' not in kwargs + ): + cls_name = config_dict.get('_class_name', '') or '' + is_far = 'FAR' in cls_name + kwargs.setdefault('init_flowmap_model', True) + kwargs.setdefault('init_far_model', is_far) + # The pipeline in this repository reads `chunk_partition` from the + # transformer config; fall back to the 81-frame schedule when absent. + if is_far and 'chunk_partition' not in config_dict and 'chunk_partition' not in kwargs: + kwargs.setdefault('chunk_partition', [1, 3, 3, 3, 3, 3, 3, 2]) + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + def setup_flowmap_model(self, gate_value=0, deltatime_type='r'): inner_dim = self.config.num_attention_heads * self.config.attention_head_dim @@ -1214,3 +1265,7 @@ def _forward_bidirection( return (output,) return Transformer2DModelOutput(sample=output) + + +# See _register_diffusers_aliases above. +_register_diffusers_aliases(FAR_Wan_Transformer3DModel) diff --git a/far/pipelines/pipeline_far_wan_anyflow.py b/far/pipelines/pipeline_far_wan_anyflow.py index e2f152a..2b82815 100644 --- a/far/pipelines/pipeline_far_wan_anyflow.py +++ b/far/pipelines/pipeline_far_wan_anyflow.py @@ -112,6 +112,39 @@ def __init__( self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.use_mean_velocity = use_mean_velocity + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + """Load checkpoints whose `model_index.json` references the diffusers AnyFlow + class names. + + Pre-instantiates the transformer and scheduler with the classes defined in + this repository and passes them as kwargs, so `DiffusionPipeline.from_pretrained` + skips its module class lookup for those entries. text_encoder / tokenizer / vae + still load normally. + """ + load_kwargs = { + k: kwargs[k] + for k in ( + 'cache_dir', 'force_download', 'proxies', 'local_files_only', + 'token', 'revision', 'variant' + ) + if k in kwargs + } + if 'transformer' not in kwargs: + kwargs['transformer'] = FAR_Wan_Transformer3DModel.from_pretrained( + pretrained_model_name_or_path, + subfolder='transformer', + torch_dtype=kwargs.get('torch_dtype'), + **load_kwargs, + ) + if 'scheduler' not in kwargs: + kwargs['scheduler'] = FlowMapDiscreteScheduler.from_pretrained( + pretrained_model_name_or_path, + subfolder='scheduler', + **load_kwargs, + ) + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, diff --git a/far/pipelines/pipeline_wan_anyflow.py b/far/pipelines/pipeline_wan_anyflow.py index ab84098..5803c63 100644 --- a/far/pipelines/pipeline_wan_anyflow.py +++ b/far/pipelines/pipeline_wan_anyflow.py @@ -111,6 +111,39 @@ def __init__( self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.use_mean_velocity = use_mean_velocity + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + """Load checkpoints whose `model_index.json` references the diffusers AnyFlow + class names. + + Pre-instantiates the transformer and scheduler with the classes defined in + this repository and passes them as kwargs, so `DiffusionPipeline.from_pretrained` + skips its module class lookup for those entries. text_encoder / tokenizer / vae + still load normally. + """ + load_kwargs = { + k: kwargs[k] + for k in ( + 'cache_dir', 'force_download', 'proxies', 'local_files_only', + 'token', 'revision', 'variant' + ) + if k in kwargs + } + if 'transformer' not in kwargs: + kwargs['transformer'] = FAR_Wan_Transformer3DModel.from_pretrained( + pretrained_model_name_or_path, + subfolder='transformer', + torch_dtype=kwargs.get('torch_dtype'), + **load_kwargs, + ) + if 'scheduler' not in kwargs: + kwargs['scheduler'] = FlowMapDiscreteScheduler.from_pretrained( + pretrained_model_name_or_path, + subfolder='scheduler', + **load_kwargs, + ) + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, diff --git a/far/schedulers/scheduling_flowmap_euler_discrete.py b/far/schedulers/scheduling_flowmap_euler_discrete.py index 8d18f60..9290980 100755 --- a/far/schedulers/scheduling_flowmap_euler_discrete.py +++ b/far/schedulers/scheduling_flowmap_euler_discrete.py @@ -104,3 +104,15 @@ def step( r_timestep = r_timestep.view(*r_timestep.shape, *([1] * (model_output.ndim - r_timestep.ndim))) prev_sample = sample - (timestep - r_timestep) * model_output return prev_sample.to(model_output.dtype) + + +# Expose this scheduler under the name used by the diffusers AnyFlow pipeline. +FlowMapEulerDiscreteScheduler = FlowMapDiscreteScheduler + +# Bind the same alias on the `diffusers` package so it can be resolved via +# `getattr`. Idempotent: if diffusers already provides this class, the existing +# binding wins. +import diffusers as _diffusers # noqa: E402 + +if not hasattr(_diffusers, 'FlowMapEulerDiscreteScheduler'): + _diffusers.FlowMapEulerDiscreteScheduler = FlowMapEulerDiscreteScheduler