-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwork3.py
More file actions
124 lines (103 loc) · 3.98 KB
/
work3.py
File metadata and controls
124 lines (103 loc) · 3.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision import datasets
import torch.utils.data as util_data
import torch
#device configuration
device = torch.device("cuda")
#transform used for data
transform = transforms.Compose([transforms.ToTensor()])
#load the datasets from file and into dataloader
batch_size = 64
path = 'C:\\Users\\s4682374\\Downloads\\keras_png_slices_data\\Trainer' # path to dataset
train_dataset = datasets.ImageFolder(path, transform=transform)
train_loader = util_data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Run this to test your data loader
#images, labels = next(iter(train_loader))
#plt.imshow(images[0].permute(1, 2, 0))
#plt.show()
#class defineing methods and initialization for VAE. Inherits from nn module
class VAE(torch.nn.Module):
def __init__(self, input_dim = (256 * 256), hidden_dim = 20000, z_dim = 600):
super().__init__()
self.z_dim = z_dim
self.hidden_dim = hidden_dim
self.input_dim = input_dim
#final dimension has 2 dimensions, the mean and variance
self.mean_layer = torch.nn.Linear(self.z_dim, 2)
self.var_layer = torch.nn.Linear(self.z_dim, 2)
#encoding variables go from images to z
self.img_2hid = torch.nn.Linear(self.input_dim, self.hidden_dim)
self.hid_2z = torch.nn.Linear(self.hidden_dim, self.z_dim)
#decoding variables go from z to images
self.z_2hid = torch.nn.Linear(self.z_dim, self.hidden_dim)
self.hid_2img = torch.nn.Linear(self.hidden_dim, self.input_dim)
#encoding function for image. Takes image and finds mean and variance
def encode(self, x):
encode = torch.nn.Sequential(
self.img_2hid,
torch.nn.ReLU(),
self.hid_2z,
torch.nn.ReLU()
).to(device)
x = encode(x)
mean, var = self.mean_layer(x), self.var_layer(x)
return mean, var
#decoding goes the other way of encoding
def decode(self, x):
decode = torch.nn.Sequential(
torch.nn.Linear(2, self.z_dim),
torch.nn.ReLU(),
self.z_2hid,
torch.nn.ReLU(),
self.hid_2img,
torch.nn.Sigmoid()
).to(device)
x = decode(x)
return x
def forward(self, x):
mean, var = self.encode(x)
#epsilon needed for reperatimization trick.
epsilon = torch.randn_like(var).to(device)
z = mean + (var*epsilon)
rebuil_image = self.decode(z)
return rebuil_image, mean, var
#Define parameters for model
input_dim = (256*256*3)
hidden_dim = 800
z_dim = 80
epochs = 10
learning_rate = 1e-5
#initialise model and optimizer
model = VAE(input_dim, hidden_dim, z_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
#move model to cuda
model.to(device)
#loss function built from two parts generative loss, and latent loss
def loss_function(image, rebuil_image, mean, var):
reproduction_loss = torch.nn.functional.binary_cross_entropy(rebuil_image, image, reduction='sum')
KLD = - 0.5 * torch.sum(1+ var - mean.pow(2) - var.exp())
return reproduction_loss + KLD
# start model training
model.train()
for epoch in range(epochs):
print("epoch", epoch+1)
#note, we dont need the label to train
for count, (image, label) in enumerate(train_loader):
image = image.view(image.shape[0], input_dim).to(device)
rebuilt_image, mean, var = model(image)
loss = loss_function(image, rebuilt_image, mean, var)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Now attempt to generate manifold
#take sample of these parameters
mean = 0.0
var = 1.0
#generate sample taken from above
sample = torch.tensor([[mean, var]], dtype=torch.float).to(device)
x_decode = model.decode(sample)
generated_image = x_decode.detach().cpu().reshape(256, 256, 3) # reshape vector to 2d array
plt.imshow(generated_image)
plt.show()