diff --git a/fastgen/networks/cosmos_predict2/network.py b/fastgen/networks/cosmos_predict2/network.py index ff85360..65b1536 100644 --- a/fastgen/networks/cosmos_predict2/network.py +++ b/fastgen/networks/cosmos_predict2/network.py @@ -1352,6 +1352,13 @@ def forward( conditioning_latents_full = conditioning_latents # Replace conditioning frames model_input = conditioning_latents_full * condition_mask_C + x_t * (1 - condition_mask_C) + # Per-frame timesteps: assign timestep 0 to conditioning (clean) frames so the + # model knows they are noise-free and should not be denoised. Without this, the + # model treats the clean first frame as a fully-noised frame, causing incoherent + # video2world (image2world) generation. + t_expanded = t.unsqueeze(1).expand(B, T) + mask_B_T = condition_mask[:, 0, :, 0, 0] # (B, T) + t = t_expanded * (1 - mask_B_T) model_outputs = self.transformer( x_B_C_T_H_W=model_input,