diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 316e79da4fd6..c841cc522d81 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -165,11 +165,20 @@ def __init__(self, model_dim, time_dim, max_period=10000.0): self.activation = nn.SiLU() self.out_layer = nn.Linear(time_dim, time_dim, bias=True) - @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, time): - args = torch.outer(time, self.freqs.to(device=time.device)) + args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device)) time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) + time_embed = F.linear( + self.activation( + F.linear( + time_embed, + self.in_layer.weight.to(torch.float32), + self.in_layer.bias.to(torch.float32), + ) + ), + self.out_layer.weight.to(torch.float32), + self.out_layer.bias.to(torch.float32), + ) return time_embed @@ -269,9 +278,12 @@ def __init__(self, time_dim, model_dim, num_params): self.out_layer.weight.data.zero_() self.out_layer.bias.data.zero_() - @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, x): - return self.out_layer(self.activation(x)) + return F.linear( + self.activation(x.to(torch.float32)), + self.out_layer.weight.to(torch.float32), + self.out_layer.bias.to(torch.float32), + ) class Kandinsky5AttnProcessor: