diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 6cf2bb9ffb..bbb11c2d27 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -65,6 +65,21 @@ steps: - TEST_TYPE=ssim agents: queue: "default" + - path: + - "fastvideo/v1/tests/lora/**" + - "fastvideo/v1/models/loader/**" + - "fastvideo/v1/tests/transformers/**" + - "fastvideo/v1/pipelines/**" + - "fastvideo/v1/layers/lora/**" + - "pyproject.toml" + - "docker/Dockerfile.python3.12" + config: + command: "timeout 15m .buildkite/scripts/pr_test.sh" + label: "LoRA Inference Tests" + env: + - TEST_TYPE=inference_lora + agents: + queue: "default" - path: - "fastvideo/v1/**" - "pyproject.toml" diff --git a/.buildkite/scripts/pr_test.sh b/.buildkite/scripts/pr_test.sh index baf1bc1cce..f43f08711c 100755 --- a/.buildkite/scripts/pr_test.sh +++ b/.buildkite/scripts/pr_test.sh @@ -97,6 +97,10 @@ case "$TEST_TYPE" in log "Running precision VSA tests..." MODAL_COMMAND="$MODAL_ENV python3 -m modal run $MODAL_TEST_FILE::run_precision_tests_VSA" ;; + "inference_lora") + log "Running LoRA tests..." + MODAL_COMMAND="$MODAL_ENV python3 -m modal run $MODAL_TEST_FILE::run_inference_lora_tests" + ;; *) log "Error: Unknown test type: $TEST_TYPE" exit 1 diff --git a/.github/workflows/matchers/mypy.json b/.github/workflows/matchers/mypy.json index f048fce528..7c479786d9 100644 --- a/.github/workflows/matchers/mypy.json +++ b/.github/workflows/matchers/mypy.json @@ -13,4 +13,4 @@ ] } ] -} +} \ No newline at end of file diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 9a2a74f512..532f7e4934 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -372,4 +372,4 @@ jobs: JOB_IDS: '["encoder-test", "vae-test", "transformer-test", "ssim-test-py3.10", "ssim-test-py3.11", "ssim-test-py3.12", "training-test", "training-test-VSA", "inference-test-STA", "precision-test-STA", "precision-test-VSA"]' RUNPOD_API_KEY: ${{ secrets.RUNPOD_API_KEY }} GITHUB_RUN_ID: ${{ github.run_id }} - run: python .github/scripts/runpod_cleanup.py + run: python .github/scripts/runpod_cleanup.py \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d412c2b181..a276ec9785 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -60,7 +60,7 @@ repos: rev: v1.15.0 hooks: - id: mypy - args: [--python-version, '3.10', --follow-imports, "skip", ] + args: [--python-version, '3.10', --follow-imports, "skip" ] additional_dependencies: [types-cachetools, types-setuptools, types-PyYAML, types-requests] - repo: local hooks: @@ -69,7 +69,7 @@ repos: entry: bash args: - -c - - 'git ls-files | grep -v "^fastvideo/v1/tests/ssim/" | grep " " && echo "Filenames should not contain spaces!" && exit 1 || exit 0' + - 'git ls-files | grep -v "^fastvideo/v1/tests/ssim/" | grep -v "^fastvideo/v1/tests/inference/lora/L40S_reference_videos/" | grep " " && echo "Filenames should not contain spaces!" && exit 1 || exit 0' language: system always_run: true pass_filenames: false diff --git a/examples/inference/lora/wan_lora_inference.py b/examples/inference/lora/wan_lora_inference.py index cc3fa8dfce..b96599e300 100644 --- a/examples/inference/lora/wan_lora_inference.py +++ b/examples/inference/lora/wan_lora_inference.py @@ -6,7 +6,7 @@ def main(): # Initialize VideoGenerator with the Wan model generator = VideoGenerator.from_pretrained( "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", - num_gpus=2, + num_gpus=1, lora_path="benjamin-paine/steamboat-willie-1.3b", lora_nickname="steamboat" ) @@ -16,6 +16,7 @@ def main(): "num_frames": 81, "guidance_scale": 5.0, "num_inference_steps": 32, + "seed": 42, } # Generate video with LoRA style prompt = "steamboat willie style, golden era animation, close-up of a short fluffy monster kneeling beside a melting red candle. the mood is one of wonder and curiosity, as the monster gazes at the flame with wide eyes and open mouth. Its pose and expression convey a sense of innocence and playfulness, as if it is exploring the world around it for the first time. The use of warm colors and dramatic lighting further enhances the cozy atmosphere of the image." @@ -29,8 +30,17 @@ def main(): negative_prompt=negative_prompt, **kwargs ) - - generator.set_lora_adapter(lora_nickname="flat_color", lora_path="motimalu/wan-flat-color-1.3b-v2") + del generator + + # Until FSDP resharding bug is fixed, multi-lora requires reloading the model + # see https://github.com/pytorch/pytorch/issues/157209 + generator = VideoGenerator.from_pretrained( + "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + num_gpus=1, + lora_path="motimalu/wan-flat-color-1.3b-v2", + lora_nickname="flat_color" + ) + # generator.set_lora_adapter(lora_nickname="flat_color", lora_path="motimalu/wan-flat-color-1.3b-v2") prompt = "flat color, no lineart, blending, negative space, artist:[john kafka|ponsuke kaikai|hara id 21|yoneyama mai|fuzichoco], 1girl, sakura miko, pink hair, cowboy shot, white shirt, floral print, off shoulder, outdoors, cherry blossom, tree shade, wariza, looking up, falling petals, half-closed eyes, white sky, clouds, live2d animation, upper body, high quality cinematic video of a woman sitting under a sakura tree. Dreamy and lonely, the camera close-ups on the face of the woman as she turns towards the viewer. The Camera is steady, This is a cowboy shot. The animation is smooth and fluid." negative_prompt = "bad quality video,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" video = generator.generate_video( diff --git a/fastvideo/utils/collect_env.py b/fastvideo/utils/collect_env.py index d44bb15b32..a0b4dc91f6 100644 --- a/fastvideo/utils/collect_env.py +++ b/fastvideo/utils/collect_env.py @@ -62,6 +62,7 @@ DEFAULT_CONDA_PATTERNS = { "torch", "numpy", + "mypy" "cudatoolkit", "soumith", "mkl", @@ -80,7 +81,6 @@ DEFAULT_PIP_PATTERNS = { "torch", "numpy", - "mypy", "flake8", "triton", "optree", diff --git a/fastvideo/v1/configs/fasthunyuan_t2v.json b/fastvideo/v1/configs/fasthunyuan_t2v.json index f7dca37fe4..e58aa0d499 100644 --- a/fastvideo/v1/configs/fasthunyuan_t2v.json +++ b/fastvideo/v1/configs/fasthunyuan_t2v.json @@ -4,7 +4,7 @@ "use_cpu_offload": false, "disable_autocast": false, "precision": "bf16", - "vae_precision": "fp16", + "vae_precision": "fp32", "vae_tiling": true, "vae_sp": true, "vae_config": { diff --git a/fastvideo/v1/configs/models/dits/base.py b/fastvideo/v1/configs/models/dits/base.py index 0df4d62c11..f2484e9dc1 100644 --- a/fastvideo/v1/configs/models/dits/base.py +++ b/fastvideo/v1/configs/models/dits/base.py @@ -11,9 +11,9 @@ class DiTArchConfig(ArchConfig): _fsdp_shard_conditions: list = field(default_factory=list) _compile_conditions: list = field(default_factory=list) - _param_names_mapping: dict = field(default_factory=dict) - _reverse_param_names_mapping: dict = field(default_factory=dict) - _lora_param_names_mapping: dict = field(default_factory=dict) + param_names_mapping: dict = field(default_factory=dict) + reverse_param_names_mapping: dict = field(default_factory=dict) + lora_param_names_mapping: dict = field(default_factory=dict) _supported_attention_backends: tuple[AttentionBackendEnum, ...] = ( AttentionBackendEnum.SLIDING_TILE_ATTN, AttentionBackendEnum.SAGE_ATTN, AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, diff --git a/fastvideo/v1/configs/models/dits/hunyuanvideo.py b/fastvideo/v1/configs/models/dits/hunyuanvideo.py index 2709705c78..66f1a74453 100644 --- a/fastvideo/v1/configs/models/dits/hunyuanvideo.py +++ b/fastvideo/v1/configs/models/dits/hunyuanvideo.py @@ -31,7 +31,7 @@ class HunyuanVideoArchConfig(DiTArchConfig): _compile_conditions: list = field( default_factory=lambda: [is_double_block, is_single_block, is_txt_in]) - _param_names_mapping: dict = field( + param_names_mapping: dict = field( default_factory=lambda: { # 1. context_embedder.time_text_embed submodules (specific rules, applied first): r"^context_embedder\.time_text_embed\.timestep_embedder\.linear_1\.(.*)$": @@ -146,8 +146,8 @@ class HunyuanVideoArchConfig(DiTArchConfig): r"final_layer.linear.\1", }) - # Reverse mapping for saving checkpoints: training -> diffusers - _reverse_param_names_mapping: dict = field(default_factory=lambda: {}) + # Reverse mapping for saving checkpoints: custom -> hf + reverse_param_names_mapping: dict = field(default_factory=lambda: {}) patch_size: int = 2 patch_size_t: int = 1 diff --git a/fastvideo/v1/configs/models/dits/stepvideo.py b/fastvideo/v1/configs/models/dits/stepvideo.py index 67065f42e1..254ac3f015 100644 --- a/fastvideo/v1/configs/models/dits/stepvideo.py +++ b/fastvideo/v1/configs/models/dits/stepvideo.py @@ -10,7 +10,7 @@ class StepVideoArchConfig(DiTArchConfig): default_factory=lambda: [lambda n, m: "transformer_blocks" in n and n.split(".")[-1].isdigit()]) - _param_names_mapping: dict = field( + param_names_mapping: dict = field( default_factory=lambda: { # transformer block r"^transformer_blocks\.(\d+)\.norm1\.(weight|bias)$": diff --git a/fastvideo/v1/configs/models/dits/wanvideo.py b/fastvideo/v1/configs/models/dits/wanvideo.py index cbaa38d55c..8884f40c30 100644 --- a/fastvideo/v1/configs/models/dits/wanvideo.py +++ b/fastvideo/v1/configs/models/dits/wanvideo.py @@ -12,7 +12,7 @@ def is_blocks(n: str, m) -> bool: class WanVideoArchConfig(DiTArchConfig): _fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks]) - _param_names_mapping: dict = field( + param_names_mapping: dict = field( default_factory=lambda: { r"^patch_embedding\.(.*)$": r"patch_embedding.proj.\1", @@ -52,12 +52,12 @@ class WanVideoArchConfig(DiTArchConfig): r"blocks.\1.self_attn_residual_norm.norm.\2", }) - # Reverse mapping for saving checkpoints: training -> diffusers - _reverse_param_names_mapping: dict = field(default_factory=lambda: {}) + # Reverse mapping for saving checkpoints: custom -> hf + reverse_param_names_mapping: dict = field(default_factory=lambda: {}) # Some LoRA adapters use the original official layer names instead of hf layer names, # so apply this before the param_names_mapping - _lora_param_names_mapping: dict = field( + lora_param_names_mapping: dict = field( default_factory=lambda: { r"^blocks\.(\d+)\.self_attn\.q\.(.*)$": r"blocks.\1.attn1.to_q.\2", r"^blocks\.(\d+)\.self_attn\.k\.(.*)$": r"blocks.\1.attn1.to_k.\2", diff --git a/fastvideo/v1/configs/pipelines/base.py b/fastvideo/v1/configs/pipelines/base.py index 7ec4027b2c..9fc5ad5e63 100644 --- a/fastvideo/v1/configs/pipelines/base.py +++ b/fastvideo/v1/configs/pipelines/base.py @@ -62,11 +62,11 @@ class PipelineConfig: image_encoder_precision: str = "fp32" # Text encoder configuration - DEFAULT_TEXT_ENCODER_PRECISIONS = ("fp16", ) + DEFAULT_TEXT_ENCODER_PRECISIONS = ("fp32", ) text_encoder_configs: tuple[EncoderConfig, ...] = field( default_factory=lambda: (EncoderConfig(), )) text_encoder_precisions: tuple[str, ...] = field( - default_factory=lambda: ("fp16", )) + default_factory=lambda: ("fp32", )) preprocess_text_funcs: tuple[Callable[[str], str], ...] = field( default_factory=lambda: (preprocess_text, )) postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.tensor], diff --git a/fastvideo/v1/entrypoints/video_generator.py b/fastvideo/v1/entrypoints/video_generator.py index 1a2f475548..5c44adc7ff 100644 --- a/fastvideo/v1/entrypoints/video_generator.py +++ b/fastvideo/v1/entrypoints/video_generator.py @@ -70,7 +70,7 @@ def from_pretrained(cls, """ # If users also provide some kwargs, it will override the FastVideoArgs and PipelineConfig. kwargs['model_path'] = model_path - fastvideo_args = FastVideoArgs.from_kwargs(kwargs) + fastvideo_args = FastVideoArgs.from_kwargs(**kwargs) return cls.from_fastvideo_args(fastvideo_args) @@ -109,6 +109,7 @@ def generate_video( prompt: The prompt to use for generation negative_prompt: The negative prompt to use (overrides the one in fastvideo_args) output_path: Path to save the video (overrides the one in fastvideo_args) + output_video_name: Name of the video file to save. Default is the first 100 characters of the prompt. save_video: Whether to save the video to disk return_frames: Whether to return the raw frames num_inference_steps: Number of denoising steps (overrides fastvideo_args) @@ -228,6 +229,7 @@ def generate_video( n_tokens=n_tokens, VSA_sparsity=fastvideo_args.VSA_sparsity, extra={}, + output_video_name=kwargs.get("output_video_name", prompt[:100]), ) # Run inference @@ -251,7 +253,8 @@ def generate_video( output_path = batch.output_path if output_path: os.makedirs(output_path, exist_ok=True) - video_path = os.path.join(output_path, f"{prompt[:100]}.mp4") + video_path = os.path.join(output_path, + f"{batch.output_video_name}.mp4") imageio.mimsave(video_path, frames, fps=batch.fps, format="mp4") logger.info("Saved video to %s", video_path) else: @@ -267,7 +270,9 @@ def generate_video( "generation_time": gen_time } - def set_lora_adapter(self, lora_nickname: str, lora_path: str) -> None: + def set_lora_adapter(self, + lora_nickname: str, + lora_path: str | None = None) -> None: self.executor.set_lora_adapter(lora_nickname, lora_path) def shutdown(self): diff --git a/fastvideo/v1/fastvideo_args.py b/fastvideo/v1/fastvideo_args.py index 0c9eba8788..c16ba25231 100644 --- a/fastvideo/v1/fastvideo_args.py +++ b/fastvideo/v1/fastvideo_args.py @@ -280,7 +280,7 @@ def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs": return cls(**kwargs) # type: ignore @classmethod - def from_kwargs(cls, kwargs: dict[str, Any]) -> "FastVideoArgs": + def from_kwargs(cls, **kwargs: Any) -> "FastVideoArgs": kwargs['pipeline_config'] = PipelineConfig.from_kwargs(kwargs) return cls(**kwargs) diff --git a/fastvideo/v1/layers/lora/linear.py b/fastvideo/v1/layers/lora/linear.py index 37eba2d994..99b7b6990a 100644 --- a/fastvideo/v1/layers/lora/linear.py +++ b/fastvideo/v1/layers/lora/linear.py @@ -3,9 +3,12 @@ import torch from torch import nn -from torch.distributed.tensor import DTensor, distribute_tensor +from torch.distributed._composable.fsdp import (CPUOffloadPolicy, OffloadPolicy, + fully_shard) +from torch.distributed.tensor import DTensor -from fastvideo.v1.distributed import (get_tp_rank, split_tensor_along_last_dim, +from fastvideo.v1.distributed import (get_local_torch_device, get_tp_rank, + split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from fastvideo.v1.layers.linear import (ColumnParallelLinear, LinearBase, @@ -13,6 +16,7 @@ QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from fastvideo.v1.layers.vocab_parallel_embedding import VocabParallelEmbedding +from fastvideo.v1.utils import get_mixed_precision_state class BaseLayerWithLoRA(nn.Module): @@ -26,12 +30,11 @@ def __init__( self.lora_A: torch.Tensor = None self.lora_B: torch.Tensor = None self.merged: bool = False - self.weight = base_layer.weight self.cpu_weight = base_layer.weight.to("cpu") - self.unmerge_count = 0 # indicates adapter weights don't contain this layer # (which shouldn't normally happen, but we want to separate it from the case of erroneous merging) self.disable_lora: bool = False + self.lora_path: str | None = None def forward(self, x: torch.Tensor) -> torch.Tensor: return self.base_layer.forward(x) @@ -45,12 +48,14 @@ def slice_lora_b_weights(self, B: torch.Tensor) -> torch.Tensor: def set_lora_weights(self, A: torch.Tensor, B: torch.Tensor, - training_mode: bool = False) -> None: + training_mode: bool = False, + lora_path: str | None = None) -> None: self.lora_A = A # share storage with weights in the pipeline self.lora_B = B self.disable_lora = False if not training_mode: self.merge_lora_weights() + self.lora_path = lora_path @torch.no_grad() def merge_lora_weights(self) -> None: @@ -58,27 +63,44 @@ def merge_lora_weights(self) -> None: return if self.merged: - raise ValueError( - "LoRA weights already merged. Please unmerge them first.") + self.unmerge_lora_weights() assert self.lora_A is not None and self.lora_B is not None, "LoRA weights not set. Please set them first." if isinstance(self.base_layer.weight, DTensor): mesh = self.base_layer.weight.data.device_mesh - placements = self.base_layer.weight.data.placements + # Using offload param is on CPU, so current_device is for "CPU -> GPU -> merge -> CPU" current_device = self.base_layer.weight.data.device data = self.base_layer.weight.data.to( - f"cuda:{torch.cuda.current_device()}").full_tensor() - data += (self.slice_lora_b_weights(self.lora_B) - @ self.slice_lora_a_weights(self.lora_A)).to(data) - self.base_layer.weight = nn.Parameter( - distribute_tensor(data, mesh, - placements=placements).to(current_device)) + get_local_torch_device()).full_tensor() + data += (self.slice_lora_b_weights(self.lora_B).to(data) + @ self.slice_lora_a_weights(self.lora_A).to(data)) + + # Must re-register updated weights for FSDP to recognize them + self.base_layer.weight = nn.Parameter(data.to(current_device)) + if isinstance(getattr(self.base_layer, "bias", None), DTensor): + self.base_layer.bias = nn.Parameter( + self.base_layer.bias.to( + get_local_torch_device(), + non_blocking=True).full_tensor().to(current_device)) + + offload_policy = CPUOffloadPolicy() if "cpu" in str( + current_device) else OffloadPolicy() + # see https://github.com/pytorch/torchtune/pull/2714/files#diff-909ee7ef184b0d834c40a1980ca4149afc38612ec7a4b344d8e2fc27641758c9R69-R79 + # After the 1st forward, self.base_layer becomes a FSDP module and needs to be resharded + if hasattr(self.base_layer, "unshard"): + self.base_layer.unshard() + mp_policy = get_mixed_precision_state().mp_policy + fully_shard(self.base_layer, + mesh=mesh, + mp_policy=mp_policy, + offload_policy=offload_policy) else: current_device = self.base_layer.weight.data.device - data = self.base_layer.weight.to( - f"cuda:{torch.cuda.current_device()}") + data = self.base_layer.weight.data.to(get_local_torch_device()) data += \ - (self.slice_lora_b_weights(self.lora_B) @ self.slice_lora_a_weights(self.lora_A)).to(data) - self.base_layer.weight = nn.Parameter(data.to(current_device)) + (self.slice_lora_b_weights(self.lora_B.to(data)) @ self.slice_lora_a_weights(self.lora_A.to(data))) + self.base_layer.weight.data = data.to(current_device, + non_blocking=True) + self.merged = True @torch.no_grad() @@ -90,27 +112,14 @@ def unmerge_lora_weights(self) -> None: raise ValueError( "LoRA weights not merged. Please merge them first before unmerging." ) - self.unmerge_count += 1 - - # Avoid precision loss - if self.unmerge_count % 3 == 0: - self.base_layer.weight.data = self.cpu_weight.data.to( - self.base_layer.weight) + # To avoid precision loss we do not subtract the LoRA weights here if isinstance(self.base_layer.weight, DTensor): - mesh = self.base_layer.weight.data.device_mesh - placement = self.base_layer.weight.data.placements device = self.base_layer.weight.data.device - data = self.base_layer.weight.data.to( - f"cuda:{torch.cuda.current_device()}").full_tensor() - data -= self.slice_lora_b_weights( - self.lora_B) @ self.slice_lora_a_weights(self.lora_A) - self.base_layer.weight = nn.Parameter( - distribute_tensor(data, mesh, placements=placement).to(device)) + self.base_layer.weight = nn.Parameter(self.cpu_weight.to(device)) else: - self.base_layer.weight.data -= \ - self.slice_lora_b_weights(self.lora_B) @\ - self.slice_lora_a_weights(self.lora_A) + self.base_layer.weight.data = self.cpu_weight.data.to( + self.base_layer.weight) self.merged = False diff --git a/fastvideo/v1/models/dits/base.py b/fastvideo/v1/models/dits/base.py index 1b0cbd1815..5022ac16a2 100644 --- a/fastvideo/v1/models/dits/base.py +++ b/fastvideo/v1/models/dits/base.py @@ -13,8 +13,8 @@ class BaseDiT(nn.Module, ABC): _fsdp_shard_conditions: list = [] _compile_conditions: list = [] - _param_names_mapping: dict - _reverse_param_names_mapping: dict + param_names_mapping: dict + reverse_param_names_mapping: dict hidden_size: int num_attention_heads: int num_channels_latents: int @@ -24,7 +24,7 @@ class BaseDiT(nn.Module, ABC): def __init_subclass__(cls) -> None: required_class_attrs = [ - "_fsdp_shard_conditions", "_param_names_mapping", + "_fsdp_shard_conditions", "param_names_mapping", "_compile_conditions" ] super().__init_subclass__() @@ -78,9 +78,9 @@ class CachableDiT(BaseDiT): """ # These are required class attributes that should be overridden by concrete implementations _fsdp_shard_conditions = [] - _param_names_mapping = {} - _reverse_param_names_mapping = {} - _lora_param_names_mapping: dict = {} + param_names_mapping = {} + reverse_param_names_mapping = {} + lora_param_names_mapping: dict = {} # Ensure these instance attributes are properly defined in subclasses hidden_size: int num_attention_heads: int diff --git a/fastvideo/v1/models/dits/hunyuanvideo.py b/fastvideo/v1/models/dits/hunyuanvideo.py index ab9c74f833..05fdef4bae 100644 --- a/fastvideo/v1/models/dits/hunyuanvideo.py +++ b/fastvideo/v1/models/dits/hunyuanvideo.py @@ -441,10 +441,10 @@ class HunyuanVideoTransformer3DModel(CachableDiT): _compile_conditions = HunyuanVideoConfig()._compile_conditions _supported_attention_backends = HunyuanVideoConfig( )._supported_attention_backends - _param_names_mapping = HunyuanVideoConfig()._param_names_mapping - _reverse_param_names_mapping = HunyuanVideoConfig( - )._reverse_param_names_mapping - _lora_param_names_mapping = HunyuanVideoConfig()._lora_param_names_mapping + param_names_mapping = HunyuanVideoConfig().param_names_mapping + reverse_param_names_mapping = HunyuanVideoConfig( + ).reverse_param_names_mapping + lora_param_names_mapping = HunyuanVideoConfig().lora_param_names_mapping def __init__(self, config: HunyuanVideoConfig, hf_config: dict[str, Any]): super().__init__(config=config, hf_config=hf_config) diff --git a/fastvideo/v1/models/dits/stepvideo.py b/fastvideo/v1/models/dits/stepvideo.py index 8a7525fe64..76c1af9b35 100644 --- a/fastvideo/v1/models/dits/stepvideo.py +++ b/fastvideo/v1/models/dits/stepvideo.py @@ -457,11 +457,13 @@ def forward(self, class StepVideoModel(BaseDiT): # (Optional) Keep the same attribute for compatibility with splitting, etc. - _fsdp_shard_conditions = StepVideoConfig()._fsdp_shard_conditions - _param_names_mapping = StepVideoConfig()._param_names_mapping - _reverse_param_names_mapping = StepVideoConfig( - )._reverse_param_names_mapping - _lora_param_names_mapping = StepVideoConfig()._lora_param_names_mapping + _fsdp_shard_conditions = [ + lambda n, m: "transformer_blocks" in n and n.split(".")[-1].isdigit(), + # lambda n, m: "pos_embed" in n # If needed for the patch embedding. + ] + param_names_mapping = StepVideoConfig().param_names_mapping + reverse_param_names_mapping = StepVideoConfig().reverse_param_names_mapping + lora_param_names_mapping = StepVideoConfig().lora_param_names_mapping _supported_attention_backends = StepVideoConfig( )._supported_attention_backends diff --git a/fastvideo/v1/models/dits/wanvideo.py b/fastvideo/v1/models/dits/wanvideo.py index dcb0248f57..b82f3395a8 100644 --- a/fastvideo/v1/models/dits/wanvideo.py +++ b/fastvideo/v1/models/dits/wanvideo.py @@ -515,9 +515,9 @@ class WanTransformer3DModel(CachableDiT): _compile_conditions = WanVideoConfig()._compile_conditions _supported_attention_backends = WanVideoConfig( )._supported_attention_backends - _param_names_mapping = WanVideoConfig()._param_names_mapping - _reverse_param_names_mapping = WanVideoConfig()._reverse_param_names_mapping - _lora_param_names_mapping = WanVideoConfig()._lora_param_names_mapping + param_names_mapping = WanVideoConfig().param_names_mapping + reverse_param_names_mapping = WanVideoConfig().reverse_param_names_mapping + lora_param_names_mapping = WanVideoConfig().lora_param_names_mapping def __init__(self, config: WanVideoConfig, hf_config: dict[str, Any]) -> None: diff --git a/fastvideo/v1/models/loader/component_loader.py b/fastvideo/v1/models/loader/component_loader.py index 96a15bc2ef..426a91cbbf 100644 --- a/fastvideo/v1/models/loader/component_loader.py +++ b/fastvideo/v1/models/loader/component_loader.py @@ -429,7 +429,6 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): hsdp_shard_dim=fastvideo_args.hsdp_shard_dim, cpu_offload=fastvideo_args.use_cpu_offload, fsdp_inference=fastvideo_args.use_fsdp_inference, - default_dtype=default_dtype, # TODO(will): make these configurable param_dtype=torch.bfloat16, reduce_dtype=torch.float32, diff --git a/fastvideo/v1/models/loader/fsdp_load.py b/fastvideo/v1/models/loader/fsdp_load.py index 797f16032a..f0f1c925ff 100644 --- a/fastvideo/v1/models/loader/fsdp_load.py +++ b/fastvideo/v1/models/loader/fsdp_load.py @@ -5,7 +5,6 @@ # Copyright 2025 The FastVideo Authors. import contextlib -from collections import defaultdict from collections.abc import Callable, Generator from itertools import chain from typing import Any @@ -19,7 +18,8 @@ from torch.nn.modules.module import _IncompatibleKeys from fastvideo.v1.logger import init_logger -from fastvideo.v1.models.loader.utils import get_param_names_mapping +from fastvideo.v1.models.loader.utils import (get_param_names_mapping, + hf_to_custom_state_dict) from fastvideo.v1.models.loader.weight_utils import safetensors_weights_iterator from fastvideo.v1.utils import set_mixed_precision_policy @@ -62,7 +62,6 @@ def maybe_load_fsdp_model( device: torch.device, hsdp_replicate_dim: int, hsdp_shard_dim: int, - default_dtype: torch.dtype, param_dtype: torch.dtype, reduce_dtype: torch.dtype, cpu_offload: bool = False, @@ -81,12 +80,14 @@ def maybe_load_fsdp_model( output_dtype, cast_forward_inputs=False) - set_mixed_precision_policy(master_dtype=default_dtype, - param_dtype=param_dtype, - reduce_dtype=reduce_dtype, - output_dtype=output_dtype) + set_mixed_precision_policy( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + output_dtype=output_dtype, + mp_policy=mp_policy, + ) - with set_default_dtype(default_dtype), torch.device("meta"): + with set_default_dtype(param_dtype), torch.device("meta"): model = model_cls(**init_params) world_size = hsdp_replicate_dim * hsdp_shard_dim if not training_mode and not fsdp_inference: @@ -106,9 +107,8 @@ def maybe_load_fsdp_model( fsdp_shard_conditions=model._fsdp_shard_conditions, pin_cpu_memory=pin_cpu_memory) - weight_iterator = safetensors_weights_iterator(weight_dir_list, - to_cpu=cpu_offload) - param_names_mapping_fn = get_param_names_mapping(model._param_names_mapping) + weight_iterator = safetensors_weights_iterator(weight_dir_list) + param_names_mapping_fn = get_param_names_mapping(model.param_names_mapping) load_model_from_full_model_state_dict( model, weight_iterator, @@ -233,36 +233,14 @@ def load_model_from_full_model_state_dict( NotImplementedError: If got FSDP with more than 1D. """ meta_sd = model.state_dict() - # Find new params - used_keys = set() sharded_sd = {} - to_merge_params: defaultdict[str, dict[Any, Any]] = defaultdict(dict) - reverse_param_names_mapping = {} - assert param_names_mapping is not None - for source_param_name, full_tensor in full_sd_iterator: - target_param_name, merge_index, num_params_to_merge = param_names_mapping( - source_param_name) - reverse_param_names_mapping[target_param_name] = (source_param_name, - merge_index, - num_params_to_merge) - used_keys.add(target_param_name) - if merge_index is not None: - to_merge_params[target_param_name][merge_index] = full_tensor - if len(to_merge_params[target_param_name]) == num_params_to_merge: - # cat at output dim according to the merge_index order - sorted_tensors = [ - to_merge_params[target_param_name][i] - for i in range(num_params_to_merge) - ] - full_tensor = torch.cat(sorted_tensors, dim=0) - del to_merge_params[target_param_name] - else: - continue - + custom_param_sd, reverse_param_names_mapping = hf_to_custom_state_dict( + full_sd_iterator, param_names_mapping) # type: ignore + for target_param_name, full_tensor in custom_param_sd.items(): meta_sharded_param = meta_sd.get(target_param_name) if meta_sharded_param is None: raise ValueError( - f"Parameter {source_param_name}-->{target_param_name} not found in meta sharded state dict" + f"Parameter {target_param_name} not found in custom model state dict. The hf to custom mapping may be incorrect." ) if not hasattr(meta_sharded_param, "device_mesh"): full_tensor = full_tensor.to(device=device, dtype=param_dtype) @@ -279,10 +257,10 @@ def load_model_from_full_model_state_dict( sharded_tensor = sharded_tensor.cpu() sharded_sd[target_param_name] = nn.Parameter(sharded_tensor) - model._reverse_param_names_mapping = reverse_param_names_mapping - unused_keys = set(meta_sd.keys()) - used_keys + model.reverse_param_names_mapping = reverse_param_names_mapping + unused_keys = set(meta_sd.keys()) - set(sharded_sd.keys()) if unused_keys: - logger.warning("Found new parameters in meta state dict: %s", + logger.warning("Found unloaded parameters in meta state dict: %s", unused_keys) # List of allowed parameter name patterns diff --git a/fastvideo/v1/models/loader/utils.py b/fastvideo/v1/models/loader/utils.py index dcd3bd2cb3..9294d8759f 100644 --- a/fastvideo/v1/models/loader/utils.py +++ b/fastvideo/v1/models/loader/utils.py @@ -2,7 +2,8 @@ """Utilities for selecting and loading models.""" import contextlib import re -from collections.abc import Callable +from collections import defaultdict +from collections.abc import Callable, Iterator from typing import Any import torch @@ -35,7 +36,6 @@ def get_param_names_mapping( """ def mapping_fn(name: str) -> tuple[str, Any, Any]: - # Try to match and transform the name using the regex patterns in mapping_dict for pattern, replacement in mapping_dict.items(): match = re.match(pattern, name) @@ -52,4 +52,46 @@ def mapping_fn(name: str) -> tuple[str, Any, Any]: # If no pattern matches, return the original name return name, None, None - return mapping_fn \ No newline at end of file + return mapping_fn + + +def hf_to_custom_state_dict( + hf_param_sd: dict[str, torch.Tensor] | Iterator[tuple[str, torch.Tensor]], + param_names_mapping: Callable[[str], tuple[str, Any, Any]] +) -> tuple[dict[str, torch.Tensor], dict[str, tuple[str, Any, Any]]]: + """ + Converts a Hugging Face parameter state dictionary to a custom parameter state dictionary. + + Args: + hf_param_sd (Dict[str, torch.Tensor]): The Hugging Face parameter state dictionary + param_names_mapping (Callable[[str], tuple[str, Any, Any]]): A function that maps parameter names from source to target format + + Returns: + custom_param_sd (Dict[str, torch.Tensor]): The custom formatted parameter state dict + reverse_param_names_mapping (Dict[str, Tuple[str, Any, Any]]): Maps back from custom to hf + """ + custom_param_sd = {} + to_merge_params = defaultdict(dict) # type: ignore + reverse_param_names_mapping = {} + if isinstance(hf_param_sd, dict): + hf_param_sd = hf_param_sd.items() # type: ignore + for source_param_name, full_tensor in hf_param_sd: # type: ignore + target_param_name, merge_index, num_params_to_merge = param_names_mapping( + source_param_name) + reverse_param_names_mapping[target_param_name] = (source_param_name, + merge_index, + num_params_to_merge) + if merge_index is not None: + to_merge_params[target_param_name][merge_index] = full_tensor + if len(to_merge_params[target_param_name]) == num_params_to_merge: + # cat at output dim according to the merge_index order + sorted_tensors = [ + to_merge_params[target_param_name][i] + for i in range(num_params_to_merge) + ] + full_tensor = torch.cat(sorted_tensors, dim=0) + del to_merge_params[target_param_name] + else: + continue + custom_param_sd[target_param_name] = full_tensor + return custom_param_sd, reverse_param_names_mapping diff --git a/fastvideo/v1/pipelines/composed_pipeline_base.py b/fastvideo/v1/pipelines/composed_pipeline_base.py index 6e25ac298a..7444e20311 100644 --- a/fastvideo/v1/pipelines/composed_pipeline_base.py +++ b/fastvideo/v1/pipelines/composed_pipeline_base.py @@ -113,7 +113,7 @@ def from_pretrained(cls, if args is None or args.inference_mode: kwargs['model_path'] = model_path - fastvideo_args = FastVideoArgs.from_kwargs(kwargs) + fastvideo_args = FastVideoArgs.from_kwargs(**kwargs) else: assert args is not None, "args must be provided for training mode" fastvideo_args = TrainingArgs.from_cli_args(args) diff --git a/fastvideo/v1/pipelines/lora_pipeline.py b/fastvideo/v1/pipelines/lora_pipeline.py index e3f4bbf292..38352a6096 100644 --- a/fastvideo/v1/pipelines/lora_pipeline.py +++ b/fastvideo/v1/pipelines/lora_pipeline.py @@ -7,6 +7,7 @@ import torch.distributed as dist from safetensors.torch import load_file +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.layers.lora.linear import (BaseLayerWithLoRA, get_lora_layer, replace_submodule) @@ -29,13 +30,13 @@ class LoRAPipeline(ComposedPipelineBase): lora_layers: dict[str, BaseLayerWithLoRA] = {} fastvideo_args: FastVideoArgs exclude_lora_layers: list[str] = [] - device: torch.device = torch.device(f"cuda:{torch.cuda.current_device()}") + device: torch.device | None = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.exclude_lora_layers = self.modules[ "transformer"].config.arch_config.exclude_lora_layers - + self.device = get_local_torch_device() self.convert_to_lora_layers() if self.fastvideo_args.pipeline_config.lora_path is not None: self.set_lora_adapter( @@ -53,7 +54,6 @@ def convert_to_lora_layers(self) -> None: """ Converts the transformer to a LoRA transformer. """ - for name, layer in self.modules["transformer"].named_modules(): if not self.is_target_layer(name): continue @@ -85,16 +85,18 @@ def set_lora_adapter(self, raise ValueError( f"Adapter {lora_nickname} not found in the pipeline. Please provide lora_path to load it." ) + adapter_updated = False rank = dist.get_rank() if lora_path is not None: lora_local_path = maybe_download_lora(lora_path) - lora_state_dict = load_file(lora_local_path) + lora_state_dict = load_file(lora_local_path, + device=str(self.device)) # Map the hf layer names to our custom layer names param_names_mapping_fn = get_param_names_mapping( - self.modules["transformer"]._param_names_mapping) + self.modules["transformer"].param_names_mapping) lora_param_names_mapping_fn = get_param_names_mapping( - self.modules["transformer"]._lora_param_names_mapping) + self.modules["transformer"].lora_param_names_mapping) to_merge_params: defaultdict[Hashable, dict[Any, Any]] = defaultdict(dict) @@ -119,6 +121,11 @@ def set_lora_adapter(self, del to_merge_params[target_name] else: continue + + if target_name in self.lora_adapters[lora_nickname]: + raise ValueError( + f"Target name {target_name} already exists in lora_adapters[{lora_nickname}]" + ) self.lora_adapters[lora_nickname][target_name] = weight.to( self.device) adapter_updated = True @@ -134,12 +141,11 @@ def set_lora_adapter(self, lora_B_name = name + ".lora_B" if lora_A_name in self.lora_adapters[lora_nickname]\ and lora_B_name in self.lora_adapters[lora_nickname]: - if layer.merged: - layer.unmerge_lora_weights() layer.set_lora_weights( self.lora_adapters[lora_nickname][lora_A_name], self.lora_adapters[lora_nickname][lora_B_name], - training_mode=self.fastvideo_args.training_mode) + training_mode=self.fastvideo_args.training_mode, + lora_path=lora_path) adapted_count += 1 else: if rank == 0: @@ -149,4 +155,5 @@ def set_lora_adapter(self, layer.disable_lora = True logger.info("Rank %d: LoRA adapter %s applied to %d layers", rank, lora_path, adapted_count) + self.cur_adapter_name = lora_nickname diff --git a/fastvideo/v1/pipelines/pipeline_batch_info.py b/fastvideo/v1/pipelines/pipeline_batch_info.py index acf7471aba..d9a01df87c 100644 --- a/fastvideo/v1/pipelines/pipeline_batch_info.py +++ b/fastvideo/v1/pipelines/pipeline_batch_info.py @@ -46,7 +46,7 @@ class ForwardBatch: negative_prompt: str | list[str] | None = None prompt_path: str | None = None output_path: str = "outputs/" - + output_video_name: str | None = None # Primary encoder embeddings prompt_embeds: list[torch.Tensor] = field(default_factory=list) negative_prompt_embeds: list[torch.Tensor] | None = None diff --git a/fastvideo/v1/pipelines/stages/text_encoding.py b/fastvideo/v1/pipelines/stages/text_encoding.py index f896498506..ab32cbdd7d 100644 --- a/fastvideo/v1/pipelines/stages/text_encoding.py +++ b/fastvideo/v1/pipelines/stages/text_encoding.py @@ -81,7 +81,6 @@ def forward( output_hidden_states=True, ) prompt_embeds = postprocess_func(outputs) - batch.prompt_embeds.append(prompt_embeds) if batch.prompt_attention_mask is not None: batch.prompt_attention_mask.append(attention_mask) diff --git a/fastvideo/v1/tests/__init__.py b/fastvideo/v1/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/fastvideo/v1/tests/inference/lora/L40S_reference_videos/Wan2.1-T2V-1.3B-Diffusers/TORCH_SDPA/steamboat-willie-1.3b_steamboat willie style, golden era animation, clos.mp4 b/fastvideo/v1/tests/inference/lora/L40S_reference_videos/Wan2.1-T2V-1.3B-Diffusers/TORCH_SDPA/steamboat-willie-1.3b_steamboat willie style, golden era animation, clos.mp4 new file mode 100644 index 0000000000..0e48e9aa19 Binary files /dev/null and b/fastvideo/v1/tests/inference/lora/L40S_reference_videos/Wan2.1-T2V-1.3B-Diffusers/TORCH_SDPA/steamboat-willie-1.3b_steamboat willie style, golden era animation, clos.mp4 differ diff --git a/fastvideo/v1/tests/inference/lora/L40S_reference_videos/Wan2.1-T2V-1.3B-Diffusers/TORCH_SDPA/wan-flat-color-1.3b-v2_flat color, no lineart, blending, negative space, .mp4 b/fastvideo/v1/tests/inference/lora/L40S_reference_videos/Wan2.1-T2V-1.3B-Diffusers/TORCH_SDPA/wan-flat-color-1.3b-v2_flat color, no lineart, blending, negative space, .mp4 new file mode 100644 index 0000000000..cc1fee2832 Binary files /dev/null and b/fastvideo/v1/tests/inference/lora/L40S_reference_videos/Wan2.1-T2V-1.3B-Diffusers/TORCH_SDPA/wan-flat-color-1.3b-v2_flat color, no lineart, blending, negative space, .mp4 differ diff --git a/fastvideo/v1/tests/inference/lora/test_lora_inference_similarity.py b/fastvideo/v1/tests/inference/lora/test_lora_inference_similarity.py new file mode 100644 index 0000000000..82b7546882 --- /dev/null +++ b/fastvideo/v1/tests/inference/lora/test_lora_inference_similarity.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +import json +import os + +import pytest + +from fastvideo import VideoGenerator +from fastvideo.v1.logger import init_logger +from fastvideo.v1.tests.utils import compute_video_ssim_torchvision, write_ssim_results +from diffusers import DiffusionPipeline +from fastvideo.v1.fastvideo_args import FastVideoArgs +from fastvideo.v1.pipelines import build_pipeline +from fastvideo.v1.models.loader.utils import hf_to_custom_state_dict, get_param_names_mapping +from torch.testing import assert_close +from torch.distributed.tensor import DTensor +from fastvideo.v1.worker import MultiprocExecutor +import torch +logger = init_logger(__name__) +os.environ["MASTER_ADDR"] = "localhost" +os.environ["MASTER_PORT"] = "29500" + +# Base parameters for LoRA inference tests +WAN_LORA_PARAMS = { + "num_gpus": 1, + "model_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + "height": 480, + "width": 832, + "num_frames": 45, + "num_inference_steps": 32, + "guidance_scale": 5.0, + "flow_shift": 3.0, + "seed": 42, + "fps": 24, + "neg_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "text-encoder-precision": ("fp32",), + "use_cpu_offload": True, +} + +# LoRA configurations for testing +LORA_CONFIGS = [ + { + "lora_path": "benjamin-paine/steamboat-willie-1.3b", + "lora_nickname": "steamboat", + "prompt": "steamboat willie style, golden era animation, close-up of a short fluffy monster kneeling beside a melting red candle. the mood is one of wonder and curiosity, as the monster gazes at the flame with wide eyes and open mouth. Its pose and expression convey a sense of innocence and playfulness, as if it is exploring the world around it for the first time. The use of warm colors and dramatic lighting further enhances the cozy atmosphere of the image.", + "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "ssim_threshold": 0.79 + }, + # { + # "lora_path": "motimalu/wan-flat-color-1.3b-v2", + # "lora_nickname": "flat_color", + # "prompt": "flat color, no lineart, blending, negative space, artist:[john kafka|ponsuke kaikai|hara id 21|yoneyama mai|fuzichoco], 1girl, sakura miko, pink hair, cowboy shot, white shirt, floral print, off shoulder, outdoors, cherry blossom, tree shade, wariza, looking up, falling petals, half-closed eyes, white sky, clouds, live2d animation, upper body, high quality cinematic video of a woman sitting under a sakura tree. Dreamy and lonely, the camera close-ups on the face of the woman as she turns towards the viewer. The Camera is steady, This is a cowboy shot. The animation is smooth and fluid.", + # "negative_prompt": "bad quality video,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + # "ssim_threshold": 0.79 + # } +] + +MODEL_TO_PARAMS = { + "Wan-AI/Wan2.1-T2V-1.3B-Diffusers": WAN_LORA_PARAMS, +} + +@pytest.mark.parametrize("model_id", list(MODEL_TO_PARAMS.keys())) +def test_merge_lora_weights(model_id): + lora_config = LORA_CONFIGS[0] # test only one + hf_pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + hf_pipe.enable_model_cpu_offload() + + lora_nickname = lora_config["lora_nickname"] + lora_path = lora_config["lora_path"] + args = FastVideoArgs.from_kwargs( + model_path=model_id, + use_cpu_offload=True, + dit_precision="bf16", + ) + pipe = build_pipeline(args) + pipe.set_lora_adapter(lora_nickname, lora_path) + custom_transformer = pipe.modules["transformer"] + custom_state_dict = custom_transformer.state_dict() + + hf_pipe.load_lora_weights(lora_path, adapter_name=lora_nickname) + for name, layer in hf_pipe.transformer.named_modules(): + if hasattr(layer, "unmerge"): + layer.unmerge() + layer.merge(adapter_names=[lora_nickname]) + + hf_transformer = hf_pipe.transformer + param_names_mapping = get_param_names_mapping(custom_transformer.param_names_mapping) + hf_state_dict, _ = hf_to_custom_state_dict(hf_transformer.state_dict(), param_names_mapping) + for key in hf_state_dict.keys(): + if "base_layer" not in key: + continue + hf_param = hf_state_dict[key] + custom_param = custom_state_dict[key].to_local() if isinstance(custom_state_dict[key], DTensor) else custom_state_dict[key] + assert_close(hf_param, custom_param, atol=7e-4, rtol=7e-4) + +@pytest.mark.parametrize("ATTENTION_BACKEND", ["TORCH_SDPA"]) +@pytest.mark.parametrize("model_id", list(MODEL_TO_PARAMS.keys())) +def test_lora_inference_similarity(ATTENTION_BACKEND, model_id): + """ + Test that runs LoRA inference with LoRA switching and compares the output + to reference videos using SSIM. + """ + os.environ["FASTVIDEO_ATTENTION_BACKEND"] = ATTENTION_BACKEND + + script_dir = os.path.dirname(os.path.abspath(__file__)) + + output_dir = os.path.join(script_dir, 'generated_videos', model_id.split('/')[-1], ATTENTION_BACKEND) + + os.makedirs(output_dir, exist_ok=True) + + BASE_PARAMS = MODEL_TO_PARAMS[model_id] + num_inference_steps = BASE_PARAMS["num_inference_steps"] + + init_kwargs = { + "num_gpus": BASE_PARAMS["num_gpus"], + "flow_shift": BASE_PARAMS["flow_shift"], + "use_cpu_offload": BASE_PARAMS["use_cpu_offload"], + } + if "text-encoder-precision" in BASE_PARAMS: + init_kwargs["text_encoder_precisions"] = BASE_PARAMS["text-encoder-precision"] + + generation_kwargs = { + "num_inference_steps": num_inference_steps, + "output_path": output_dir, + "height": BASE_PARAMS["height"], + "width": BASE_PARAMS["width"], + "num_frames": BASE_PARAMS["num_frames"], + "guidance_scale": BASE_PARAMS["guidance_scale"], + "seed": BASE_PARAMS["seed"], + "fps": BASE_PARAMS["fps"], + "save_video": True, + } + generator = VideoGenerator.from_pretrained(model_path=BASE_PARAMS["model_path"], **init_kwargs) + for lora_config in LORA_CONFIGS: + lora_nickname = lora_config["lora_nickname"] + lora_path = lora_config["lora_path"] + prompt = lora_config["prompt"] + generation_kwargs["negative_prompt"] = lora_config["negative_prompt"] + + generator.set_lora_adapter(lora_nickname=lora_nickname, lora_path=lora_path) + output_video_name = f"{lora_path.split('/')[-1]}_{prompt[:50]}" + generation_kwargs["output_path"] = output_dir + generation_kwargs["output_video_name"] = output_video_name + + generator.generate_video(prompt, **generation_kwargs) + + assert os.path.exists( + output_dir), f"Output video was not generated at {output_dir}" + + reference_folder = os.path.join(script_dir, 'L40S_reference_videos', model_id.split('/')[-1], ATTENTION_BACKEND) + + if not os.path.exists(reference_folder): + logger.error("Reference folder missing") + raise FileNotFoundError( + f"Reference video folder does not exist: {reference_folder}") + + # Find the matching reference video for the switched LoRA + reference_video_name = None + + for filename in os.listdir(reference_folder): + # Check if the filename starts with the expected output_video_name and ends with .mp4 + if filename.startswith(output_video_name) and filename.endswith('.mp4'): + reference_video_name = filename # Remove .mp4 extension to match the logic below + break + + if not reference_video_name: + logger.error(f"Reference video not found for adapter: {lora_path} with prompt: {prompt[:50]} and backend: {ATTENTION_BACKEND}") + raise FileNotFoundError(f"Reference video missing for adapter {lora_path}") + + reference_video_path = os.path.join(reference_folder, reference_video_name) + generated_video_path = os.path.join(output_dir, output_video_name + ".mp4") + + logger.info( + f"Computing SSIM between {reference_video_path} and {generated_video_path}" + ) + ssim_values = compute_video_ssim_torchvision(reference_video_path, + generated_video_path, + use_ms_ssim=True) + + mean_ssim = ssim_values[0] + logger.info(f"SSIM mean value: {mean_ssim}") + logger.info(f"Writing SSIM results to directory: {output_dir}") + + success = write_ssim_results(output_dir, ssim_values, reference_video_path, + generated_video_path, num_inference_steps, + prompt) + + if not success: + logger.error("Failed to write SSIM results to file") + + min_acceptable_ssim = lora_config["ssim_threshold"] + assert mean_ssim >= min_acceptable_ssim, f"SSIM value {mean_ssim} is below threshold {min_acceptable_ssim} for adapter {lora_config['lora_path']}" + + + + diff --git a/fastvideo/v1/tests/modal/pr_test.py b/fastvideo/v1/tests/modal/pr_test.py index 08cc962bcd..75422d1575 100644 --- a/fastvideo/v1/tests/modal/pr_test.py +++ b/fastvideo/v1/tests/modal/pr_test.py @@ -97,3 +97,8 @@ def run_precision_tests_STA(): @app.function(gpu="H100:1", image=image, timeout=900) def run_precision_tests_VSA(): run_test("python csrc/attn/tests/test_block_sparse.py") + + +@app.function(gpu="L40S:1", image=image, timeout=3600) +def run_inference_lora_tests(): + run_test("pytest ./fastvideo/v1/tests/inference/lora/test_lora_inference_similarity.py -vs") \ No newline at end of file diff --git a/fastvideo/v1/tests/ssim/test_inference_similarity.py b/fastvideo/v1/tests/ssim/test_inference_similarity.py index 41076378ea..2913e4f27b 100644 --- a/fastvideo/v1/tests/ssim/test_inference_similarity.py +++ b/fastvideo/v1/tests/ssim/test_inference_similarity.py @@ -7,7 +7,7 @@ from fastvideo import VideoGenerator from fastvideo.v1.logger import init_logger -from fastvideo.v1.tests.ssim.compute_ssim import compute_video_ssim_torchvision +from fastvideo.v1.tests.utils import compute_video_ssim_torchvision, write_ssim_results from fastvideo.v1.worker.multiproc_executor import MultiprocExecutor logger = init_logger(__name__) @@ -99,44 +99,6 @@ ] -def write_ssim_results(output_dir, ssim_values, reference_path, generated_path, - num_inference_steps, prompt): - """ - Write SSIM results to a JSON file in the same directory as the generated videos. - """ - try: - logger.info( - f"Attempting to write SSIM results to directory: {output_dir}") - - if not os.path.exists(output_dir): - os.makedirs(output_dir, exist_ok=True) - - mean_ssim, min_ssim, max_ssim = ssim_values - - result = { - "mean_ssim": mean_ssim, - "min_ssim": min_ssim, - "max_ssim": max_ssim, - "reference_video": reference_path, - "generated_video": generated_path, - "parameters": { - "num_inference_steps": num_inference_steps, - "prompt": prompt - } - } - - test_name = f"steps{num_inference_steps}_{prompt[:100]}" - result_file = os.path.join(output_dir, f"{test_name}_ssim.json") - logger.info(f"Writing JSON results to: {result_file}") - with open(result_file, 'w') as f: - json.dump(result, f, indent=2) - - logger.info(f"SSIM results written to {result_file}") - return True - except Exception as e: - logger.error(f"ERROR writing SSIM results: {str(e)}") - return False - @pytest.mark.parametrize("prompt", I2V_TEST_PROMPTS) @pytest.mark.parametrize("ATTENTION_BACKEND", ["FLASH_ATTN", "TORCH_SDPA"]) diff --git a/fastvideo/v1/tests/ssim/compute_ssim.py b/fastvideo/v1/tests/utils.py similarity index 53% rename from fastvideo/v1/tests/ssim/compute_ssim.py rename to fastvideo/v1/tests/utils.py index 4910363e64..aa2f3b388b 100644 --- a/fastvideo/v1/tests/ssim/compute_ssim.py +++ b/fastvideo/v1/tests/utils.py @@ -1,15 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 import argparse import os +import json +from fastvideo.v1.logger import init_logger import numpy as np import torch from pytorch_msssim import ms_ssim, ssim from torchvision.io import read_video +logger = init_logger(__name__) + def compute_video_ssim_torchvision(video1_path, video2_path, use_ms_ssim=True): + """ + Compute SSIM between two videos. + + Args: + video1_path: Path to the first video. + video2_path: Path to the second video. + use_ms_ssim: Whether to use Multi-Scale Structural Similarity(MS-SSIM) instead of SSIM. + """ print(f"Computing SSIM between {video1_path} and {video2_path}...") + if not os.path.exists(video1_path): + raise FileNotFoundError(f"Video1 not found: {video1_path}") + if not os.path.exists(video2_path): + raise FileNotFoundError(f"Video2 not found: {video2_path}") frames1, _, _ = read_video(video1_path, pts_unit='sec', @@ -65,7 +81,26 @@ def compute_video_ssim_torchvision(video1_path, video2_path, use_ms_ssim=True): def compare_folders(reference_folder, generated_folder, use_ms_ssim=True): """ Compare videos with the same filename between reference_folder and generated_folder + + Example usage: + results = compare_folders(reference_folder, generated_folder, + args.use_ms_ssim) + for video_name, ssim_value in results.items(): + if ssim_value is not None: + print( + f"{video_name}: {ssim_value[0]:.4f}, Min SSIM: {ssim_value[1]:.4f}, Max SSIM: {ssim_value[2]:.4f}" + ) + else: + print(f"{video_name}: Error during comparison") + + valid_ssims = [v for v in results.values() if v is not None] + if valid_ssims: + avg_ssim = np.mean([v[0] for v in valid_ssims]) + print(f"\nAverage SSIM across all videos: {avg_ssim:.4f}") + else: + print("\nNo valid SSIM values to average") """ + reference_videos = [ f for f in os.listdir(reference_folder) if f.endswith('.mp4') ] @@ -92,54 +127,40 @@ def compare_folders(reference_folder, generated_folder, use_ms_ssim=True): return results - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description='Compare videos using SSIM/MS-SSIM metrics') - parser.add_argument('--reference', - '-r', - type=str, - help='Path to reference videos directory') - parser.add_argument('--generated', - '-g', - type=str, - help='Path to generated videos directory') - parser.add_argument('--use-ms-ssim', - action='store_true', - help='Use MS-SSIM instead of SSIM') - args = parser.parse_args() - - script_dir = os.path.dirname(os.path.abspath(__file__)) - - reference_folder = args.reference if args.reference else os.path.join( - script_dir, 'reference_videos') - generated_folder = args.generated if args.generated else os.path.join( - script_dir, 'generated_videos') - - if not os.path.exists(reference_folder): - print(f"ERROR: Reference folder {reference_folder} does not exist!") - exit(1) - - if not os.path.exists(generated_folder): - print(f"ERROR: Generated folder {generated_folder} does not exist!") - exit(1) - - print(f"Comparing videos between {reference_folder} and {generated_folder}") - results = compare_folders(reference_folder, generated_folder, - args.use_ms_ssim) - - print("\n===== SSIM Results Summary =====") - for video_name, ssim_value in results.items(): - if ssim_value is not None: - print( - f"{video_name}: {ssim_value[0]:.4f}, Min SSIM: {ssim_value[1]:.4f}, Max SSIM: {ssim_value[2]:.4f}" - ) - else: - print(f"{video_name}: Error during comparison") - - valid_ssims = [v for v in results.values() if v is not None] - if valid_ssims: - avg_ssim = np.mean([v[0] for v in valid_ssims]) - print(f"\nAverage SSIM across all videos: {avg_ssim:.4f}") - else: - print("\nNo valid SSIM values to average") +def write_ssim_results(output_dir, ssim_values, reference_path, generated_path, + num_inference_steps, prompt): + """ + Write SSIM results to a JSON file in the same directory as the generated videos. + """ + try: + logger.info( + f"Attempting to write SSIM results to directory: {output_dir}") + + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + mean_ssim, min_ssim, max_ssim = ssim_values + + result = { + "mean_ssim": mean_ssim, + "min_ssim": min_ssim, + "max_ssim": max_ssim, + "reference_video": reference_path, + "generated_video": generated_path, + "parameters": { + "num_inference_steps": num_inference_steps, + "prompt": prompt + } + } + + test_name = f"steps{num_inference_steps}_{prompt[:100]}" + result_file = os.path.join(output_dir, f"{test_name}_ssim.json") + logger.info(f"Writing JSON results to: {result_file}") + with open(result_file, 'w') as f: + json.dump(result, f, indent=2) + + logger.info(f"SSIM results written to {result_file}") + return True + except Exception as e: + logger.error(f"ERROR writing SSIM results: {str(e)}") + return False \ No newline at end of file diff --git a/fastvideo/v1/training/training_utils.py b/fastvideo/v1/training/training_utils.py index 0bf1dbc5d0..93eefee420 100644 --- a/fastvideo/v1/training/training_utils.py +++ b/fastvideo/v1/training/training_utils.py @@ -3,6 +3,7 @@ import math import os import time +from collections.abc import Iterator from typing import Any import torch @@ -162,9 +163,9 @@ def save_checkpoint(transformer, weight_path, local_main_process_only=False) - # Convert fastvideo custom format to diffusers format and save - diffusers_state_dict = convert_custom_format_to_diffusers_format( - cpu_state, transformer) + # Convert training format to diffusers format and save + diffusers_state_dict = custom_to_hf_state_dict( + cpu_state, transformer.reverse_param_names_mapping) save_file(diffusers_state_dict, weight_path) logger.info("rank: %s, consolidated checkpoint saved to %s", @@ -487,24 +488,25 @@ def _has_foreach_support(tensors: list[torch.Tensor], t is None or type(t) in [torch.Tensor] for t in tensors) -def convert_custom_format_to_diffusers_format(state_dict: dict[str, Any], - transformer) -> dict[str, Any]: +def custom_to_hf_state_dict( + state_dict: dict[str, Any] | Iterator[tuple[str, torch.Tensor]], + reverse_param_names_mapping: dict[str, tuple[str, int, + int]]) -> dict[str, Any]: """ - Convert fastvideo custom format state dict to diffusers format using reverse_param_names_mapping. + Convert fastvideo's custom model format to diffusers format using reverse_param_names_mapping. Args: - state_dict: State dict in training format - transformer: Transformer model object with _reverse_param_names_mapping + state_dict: State dict in fastvideo's custom format + reverse_param_names_mapping: Reverse mapping from fastvideo's custom format to diffusers format Returns: State dict in diffusers format """ + assert len( + reverse_param_names_mapping) > 0, "reverse_param_names_mapping is empty" + if isinstance(state_dict, Iterator): + state_dict = dict(state_dict) new_state_dict = {} - - # Get the reverse mapping from the transformer - reverse_param_names_mapping = transformer._reverse_param_names_mapping - assert reverse_param_names_mapping != {}, "reverse_param_names_mapping is empty" - # Group parameters that need to be split (merged parameters) merge_groups: dict[str, list[tuple[str, int, int]]] = {} diff --git a/fastvideo/v1/utils.py b/fastvideo/v1/utils.py index 9cbf6ae5f9..1591d9cbee 100644 --- a/fastvideo/v1/utils.py +++ b/fastvideo/v1/utils.py @@ -29,6 +29,7 @@ _best_guess_weight_name) # watch out for potetential removal from diffusers from huggingface_hub import snapshot_download from remote_pdb import RemotePdb +from torch.distributed.fsdp import MixedPrecisionPolicy import fastvideo.v1.envs as envs from fastvideo.v1.logger import init_logger @@ -684,11 +685,11 @@ def remote_breakpoint() -> None: @dataclass class MixedPrecisionState: - master_dtype: torch.dtype | None = None param_dtype: torch.dtype | None = None reduce_dtype: torch.dtype | None = None output_dtype: torch.dtype | None = None compute_dtype: torch.dtype | None = None + mp_policy: MixedPrecisionPolicy | None = None # Thread-local storage for mixed precision state @@ -702,10 +703,12 @@ def get_mixed_precision_state() -> MixedPrecisionState: return cast(MixedPrecisionState, _mixed_precision_state.state) -def set_mixed_precision_policy(master_dtype: torch.dtype, - param_dtype: torch.dtype, - reduce_dtype: torch.dtype, - output_dtype: torch.dtype | None = None): +def set_mixed_precision_policy( + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + output_dtype: torch.dtype | None = None, + mp_policy: MixedPrecisionPolicy | None = None, +): """Set mixed precision policy globally. Args: @@ -714,10 +717,10 @@ def set_mixed_precision_policy(master_dtype: torch.dtype, output_dtype: Optional output dtype """ state = MixedPrecisionState( - master_dtype=master_dtype, param_dtype=param_dtype, reduce_dtype=reduce_dtype, output_dtype=output_dtype, + mp_policy=mp_policy, ) _mixed_precision_state.state = state diff --git a/fastvideo/v1/worker/__init__.py b/fastvideo/v1/worker/__init__.py new file mode 100644 index 0000000000..85413d3ca5 --- /dev/null +++ b/fastvideo/v1/worker/__init__.py @@ -0,0 +1,5 @@ +from .executor import Executor +from .gpu_worker import run_worker_process +from .multiproc_executor import MultiprocExecutor + +__all__ = ["Executor", "run_worker_process", "MultiprocExecutor"] \ No newline at end of file diff --git a/fastvideo/v1/worker/executor.py b/fastvideo/v1/worker/executor.py index 0a79fe5625..276dfe8d0a 100644 --- a/fastvideo/v1/worker/executor.py +++ b/fastvideo/v1/worker/executor.py @@ -49,7 +49,9 @@ def execute_forward( return cast(ForwardBatch, outputs[0]["output_batch"]) @abstractmethod - def set_lora_adapter(self, lora_nickname: str, lora_path: str) -> None: + def set_lora_adapter(self, + lora_nickname: str, + lora_path: str | None = None) -> None: """ Set the LoRA adapter for the workers. """ diff --git a/fastvideo/v1/worker/gpu_worker.py b/fastvideo/v1/worker/gpu_worker.py index 797cc6d48d..0490e57584 100644 --- a/fastvideo/v1/worker/gpu_worker.py +++ b/fastvideo/v1/worker/gpu_worker.py @@ -87,7 +87,9 @@ def execute_forward(self, forward_batch: ForwardBatch, output_batch = self.pipeline.forward(forward_batch, self.fastvideo_args) return cast(ForwardBatch, output_batch) - def set_lora_adapter(self, lora_nickname: str, lora_path: str) -> None: + def set_lora_adapter(self, + lora_nickname: str, + lora_path: str | None = None) -> None: self.pipeline.set_lora_adapter(lora_nickname, lora_path) def shutdown(self) -> dict[str, Any]: @@ -132,6 +134,13 @@ def event_loop(self) -> None: output_batch = self.execute_forward(forward_batch, fastvideo_args) self.pipe.send({"output_batch": output_batch.output.cpu()}) + elif method_name == 'set_lora_adapter': + lora_nickname = recv_rpc['kwargs']['lora_nickname'] + lora_path = recv_rpc['kwargs']['lora_path'] + self.set_lora_adapter(lora_nickname, lora_path) + logger.info("Worker %d set LoRA adapter %s with path %s", + self.rank, lora_nickname, lora_path) + self.pipe.send({"status": "lora_adapter_set"}) else: # Handle other methods dynamically if needed args = recv_rpc.get('args', ()) diff --git a/fastvideo/v1/worker/multiproc_executor.py b/fastvideo/v1/worker/multiproc_executor.py index 48df08bf50..2a2d488563 100644 --- a/fastvideo/v1/worker/multiproc_executor.py +++ b/fastvideo/v1/worker/multiproc_executor.py @@ -75,12 +75,18 @@ def execute_forward(self, forward_batch: ForwardBatch, }) return cast(ForwardBatch, responses[0]["output_batch"]) - def set_lora_adapter(self, lora_nickname: str, lora_path: str) -> None: - self.collective_rpc("set_lora_adapter", - kwargs={ - "lora_nickname": lora_nickname, - "lora_path": lora_path - }) + def set_lora_adapter(self, + lora_nickname: str, + lora_path: str | None = None) -> None: + responses = self.collective_rpc("set_lora_adapter", + kwargs={ + "lora_nickname": lora_nickname, + "lora_path": lora_path + }) + for i, response in enumerate(responses): + if response["status"] != "lora_adapter_set": + raise RuntimeError( + f"Worker {i} failed to set LoRA adapter to {lora_path}") def collective_rpc(self, method: str | Callable, diff --git a/pyproject.toml b/pyproject.toml index 9bcdd741de..2768492a7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ exclude = ["assets*", "docker*", "docs", "scripts*"] [tool.wheel] exclude = ["assets*", "docker*", "docs", "scripts*"] + [tool.mypy] warn_unused_configs = true ignore_missing_imports = true diff --git a/tests/test_data_preprocess.py b/tests/test_data_preprocess.py deleted file mode 100644 index 3fdf33b60a..0000000000 --- a/tests/test_data_preprocess.py +++ /dev/null @@ -1,113 +0,0 @@ -import os -import unittest - -import torch -from transformers import AutoTokenizer, T5EncoderModel - -from fastvideo.models.hunyuan.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D - - -class TestAutoencoderKLCausal3D(unittest.TestCase): - - @classmethod - def setUpClass(cls): - """ - setUpClass is called once, before any test is run. - We can set environment variables or load heavy resources here. - """ - os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" - - # Load tokenizer/model that can be reused across all tests - cls.tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") - cls.text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") - - def setUp(self): - """ - setUp is called before each test method to prepare fresh state. - """ - self.batch_size = 1 - self.init_time_len = 9 - self.init_height = 16 - self.init_width = 16 - self.latent_channels = 4 - self.spatial_compression_ratio = 8 - self.time_compression_ratio = 4 - - # Model initialization config - self.init_dict = { - "in_channels": - 3, - "out_channels": - 3, - "latent_channels": - self.latent_channels, - "down_block_types": ( - "DownEncoderBlockCausal3D", - "DownEncoderBlockCausal3D", - "DownEncoderBlockCausal3D", - "DownEncoderBlockCausal3D", - ), - "up_block_types": ( - "UpDecoderBlockCausal3D", - "UpDecoderBlockCausal3D", - "UpDecoderBlockCausal3D", - "UpDecoderBlockCausal3D", - ), - "block_out_channels": (8, 8, 8, 8), - "layers_per_block": - 1, - "act_fn": - "silu", - "norm_num_groups": - 4, - "scaling_factor": - 0.476986, - "spatial_compression_ratio": - self.spatial_compression_ratio, - "time_compression_ratio": - self.time_compression_ratio, - "mid_block_add_attention": - True, - } - - # Instantiate the model - self.model = AutoencoderKLCausal3D(**self.init_dict) - - # Create a random input tensor - self.input_tensor = torch.rand(self.batch_size, 3, self.init_time_len, self.init_height, self.init_width) - - def test_encode_shape(self): - """ - Check that the shape of the encoded output matches expectations. - """ - vae_encoder_output = self.model.encode(self.input_tensor) - - # The distribution from the VAE has a .sample() method - # so we verify the shape of that sample. - sample_shape = vae_encoder_output["latent_dist"].sample().shape - - # We expect shape: [batch_size, latent_channels, - # (init_time_len // time_compression_ratio) + 1, - # init_height // spatial_compression_ratio, - # init_width // spatial_compression_ratio] - expected_shape = ( - self.batch_size, - self.latent_channels, - (self.init_time_len // self.time_compression_ratio) + 1, - self.init_height // self.spatial_compression_ratio, - self.init_width // self.spatial_compression_ratio, - ) - - # (Optional) Print them if you like, or just rely on assertions: - print(f"sample_shape: {sample_shape}") - print(f"expected_shape: {expected_shape}") - - self.assertEqual( - sample_shape, - expected_shape, - f"Encoded sample shape {sample_shape} does not match {expected_shape}.", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_save_checkpoint.py b/tests/test_save_checkpoint.py deleted file mode 100644 index 49b47d18ad..0000000000 --- a/tests/test_save_checkpoint.py +++ /dev/null @@ -1,39 +0,0 @@ -import os -import shutil - -import pytest -import torch -import torch.distributed as dist -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - -@pytest.fixture(scope="module", autouse=True) -def setup_distributed(): - os.environ["RANK"] = "0" - os.environ["WORLD_SIZE"] = "1" - os.environ["LOCAL_RANK"] = "0" - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = "12345" - - dist.init_process_group("nccl") - yield - dist.destroy_process_group() - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires at least 2 GPUs to run NCCL tests") -def test_save_and_remove_checkpoint(): - from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel - from fastvideo.utils.checkpoint import save_checkpoint - from fastvideo.utils.fsdp_util import get_dit_fsdp_kwargs - - transformer = MochiTransformer3DModel(num_layers=0) - fsdp_kwargs, _ = get_dit_fsdp_kwargs(transformer, "none") - transformer = FSDP(transformer, **fsdp_kwargs) - - test_folder = "./test_checkpoint" - save_checkpoint(transformer, 0, test_folder, 0) - - assert os.path.exists(test_folder), "Checkpoint folder was not created." - - shutil.rmtree(test_folder) - assert not os.path.exists(test_folder), "Checkpoint folder still exists." diff --git a/tests/test_sequence_parallel.py b/tests/test_sequence_parallel.py deleted file mode 100644 index 3f83782901..0000000000 --- a/tests/test_sequence_parallel.py +++ /dev/null @@ -1,111 +0,0 @@ -from functools import partial -from multiprocessing import Manager - -import pytest -import torch -import torch.distributed as dist -import torch.multiprocessing as mp - -from fastvideo.utils.communications import nccl_info, prepare_sequence_parallel_data - - -def _init_distributed_test_gpu(rank, world_size, backend, port, data, results): - dist.init_process_group( - backend=backend, - init_method=f"tcp://127.0.0.1:{port}", - world_size=world_size, - rank=rank, - ) - - device = torch.device(f"cuda:{rank}") - - nccl_info.sp_size = world_size - nccl_info.rank_within_group = rank - nccl_info.group_id = 0 - - seq_group = dist.new_group(ranks=list(range(world_size))) - nccl_info.group = seq_group - - hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask = data - hidden_states = hidden_states[rank].unsqueeze(dim=0).to(device) - encoder_hidden_states = encoder_hidden_states.to(device) - attention_mask = attention_mask.to(device) - encoder_attention_mask = encoder_attention_mask.to(device) - print(f"Rank {rank} input hidden_states:\n", hidden_states) - print(f"Rank {rank} input hidden_states shape:\n", hidden_states.shape) - out_hidden, out_encoder, out_attn_mask, out_encoder_mask = prepare_sequence_parallel_data( - hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask) - print(f"Rank {rank} output out_hidden:\n", out_hidden) - - shapes = ( - out_hidden.shape, - out_encoder.shape, - out_attn_mask.shape, - out_encoder_mask.shape, - ) - shape_tensor = torch.tensor([*shapes[0], *shapes[1], *shapes[2], *shapes[3]], dtype=torch.int32, device=device) - shape_list = [torch.zeros_like(shape_tensor) for _ in range(world_size)] - dist.all_gather(shape_list, shape_tensor, group=seq_group) - gathered_shapes = [tuple(s.tolist()) for s in shape_list] - out_hidden_cpu = out_hidden.to("cpu") - - results[rank] = { - "shapes": gathered_shapes, - "out_hidden": out_hidden_cpu, - } - - dist.barrier() - dist.destroy_process_group() - - -@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, - reason="Requires at least 2 GPUs to run NCCL tests") -def test_prepare_sequence_parallel_data_gpu(): - world_size = 2 - backend = "nccl" - port = 12355 # or use a random free port if collisions occur - - # Create test tensors on CPU; the dimension at index=2 should be divisible by world_size=2 (if applicable). - hidden_states = torch.randn(2, 1, 2, 1, 1) - encoder_hidden_states = torch.randn(2, 2) - attention_mask = torch.randn(2, 2) - encoder_attention_mask = torch.randn(2, 2) - - print("init hidden states", hidden_states) - - manager = Manager() - results_dict = manager.dict() - - # Wrap our helper function with partial - mp_func = partial(_init_distributed_test_gpu, - world_size=world_size, - backend=backend, - port=port, - data=(hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask), - results=results_dict) - - # Spawn two GPU processes (rank=0, rank=1) - mp.spawn(mp_func, nprocs=world_size) - - first_rank_shapes = None - - overall_hidden_out = [] - - for rank in sorted(results_dict.keys()): - rank_data = results_dict[rank] - rank_shapes = rank_data["shapes"] - if first_rank_shapes is None: - first_rank_shapes = rank_shapes - assert rank_shapes == first_rank_shapes, ( - f"Mismatch in shapes across ranks: {rank_shapes} != {first_rank_shapes}") - overall_hidden_out.append(rank_data["out_hidden"]) - - overall_hidden_out = torch.cat(overall_hidden_out, dim=2) - print("overall_hidden_out", overall_hidden_out) - print("overall_hidden_out_shape", overall_hidden_out.shape) - - assert torch.allclose(hidden_states, torch.tensor(overall_hidden_out), rtol=1e-7, atol=1e-6) - - -if __name__ == "__main__": - test_prepare_sequence_parallel_data_gpu()