-
Notifications
You must be signed in to change notification settings - Fork 52
Expand file tree
/
Copy pathcolorize.py
More file actions
73 lines (62 loc) · 2.57 KB
/
colorize.py
File metadata and controls
73 lines (62 loc) · 2.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse # to parse script arguments
from statistics import mean # to compute the mean of a list
from tqdm import tqdm #used to generate progress bar during training
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid #to generate image grids, will be used in tensorboard
from data_utils import get_colorized_dataset_loader # dataloarder
from unet import UNet
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def train(net, optimizer, loader, epochs=5, writer=None):
criterion = ...
for epoch in range(epochs):
running_loss = []
t = tqdm(loader)
for x, y in t: # x: black and white image, y: colored image
...
...
...
...
...
...
...
...
if writer is not None:
#Logging loss in tensorboard
writer.add_scalar('training loss', mean(running_loss), epoch)
# Logging a sample of inputs in tensorboard
input_grid = make_grid(x[:16].detach().cpu())
writer.add_image('Input', input_grid, epoch)
# Logging a sample of predicted outputs in tensorboard
colorized_grid = make_grid(outputs[:16].detach().cpu())
writer.add_image('Predicted', colorized_grid, epoch)
# Logging a sample of ground truth in tensorboard
original_grid = make_grid(y[:16].detach().cpu())
writer.add_image('Ground truth', original_grid, epoch)
return mean(running_loss)
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--exp_name', type=str, default = 'Colorize', help='experiment name')
parser.add_argument('--data_path', ...)
parser.add_argument('--batch_size'...)
parser.add_argument('--epochs'...)
parser.add_argument('--lr'...)
exp_name = ...
args = ...
data_path = ...
batch_size = ...
epochs = ...
lr = ...
unet = UNet().to(device)
loader = get_colorized_dataset_loader(path=data_path,
batch_size=batch_size,
shuffle=True,
num_workers=0)
optimizer = optim.Adam(unet.parameters(), lr=lr)
writer = SummaryWriter(f'runs/{exp_name}')
train(unet, optimizer, loader, epochs=epochs, writer=writer)
writer.add_graph(unet)
# Save model weights
torch.save(unet.state_dict(), 'unet.pth')