Skip to content

How to correctly normalize VAE encoder latents before feeding them into SiT (inverse transform of denormalize_latents)? #27

@Zhou-Weichen

Description

@Zhou-Weichen

Hello, and thank you for releasing REPA-E. I’m currently building an image reconstruction / img2img pipeline using sit-repae-sdvae.

The pipeline is:

  1. Encode image using the model’s VAE encoder → latent z_real
  2. Feed z_real into SiT (as the initial latent)
  3. Decode the output latent using the model’s VAE decoder

From the released code, I understand that decoding uses this function:

samples = euler_maruyama_sampler(**sampling_kwargs).to(torch.float32)
images = vae.decode(denormalize_latents(samples, latents_scale, latents_bias)).sample

def denormalize_latents(latents, latents_scale, latents_bias):
    return latents / latents_scale + latents_bias

latents_scale = state_dict["ema"]["bn.running_var"].rsqrt().view(1, in_channels, 1, 1).to(device)
latents_bias = state_dict["ema"]["bn.running_mean"].view(1, in_channels, 1, 1).to(device)

My questions are about the inverse transformation, because I want the encoder latents to match the distribution used by the pre-trained SiT .

If I feed the raw VAE encoder latents directly into SiT (without applying any normalization), the model produces out-of-distribution results or diverges numerically. What preprocessing or normalization should be applied to encoder latents before passing them into SiT so that they match the latent distribution SiT was trained on?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions