-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathbase.py
More file actions
118 lines (104 loc) · 4.03 KB
/
base.py
File metadata and controls
118 lines (104 loc) · 4.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import random
import numpy as np
from PIL import Image, ImageOps, ImageFilter
import torch
import torch.utils.data as data
__all__ = ['BaseDataset', 'test_batchify_fn']
class BaseDataset(data.Dataset):
def __init__(self, root, split, mode=None, transform=None,
target_transform=None, base_size=520, crop_size=480,
logger=None, scale=True):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.split = split
self.mode = mode if mode is not None else split
self.base_size = base_size
self.crop_size = crop_size
self.logger = logger
self.scale = scale
if self.mode == 'train':
print('BaseDataset: base_size {}, crop_size {}'. \
format(base_size, crop_size))
if not self.scale:
if self.logger is not None:
self.logger.info('single scale training!!!')
def __getitem__(self, index):
raise NotImplemented
@property
def num_class(self):
return self.NUM_CLASS
@property
def pred_offset(self):
raise NotImplemented
def _val_sync_transform(self, img, mask):
outsize = self.crop_size
short_size = outsize
w, h = img.size
if w > h:
oh = short_size
ow = int(1.0 * w * oh / h)
else:
ow = short_size
oh = int(1.0 * h * ow / w)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# center crop
w, h = img.size
x1 = int(round((w - outsize) / 2.))
y1 = int(round((h - outsize) / 2.))
img = img.crop((x1, y1, x1+outsize, y1+outsize))
mask = mask.crop((x1, y1, x1+outsize, y1+outsize))
# final transform
return img, self._mask_transform(mask)
def _sync_transform(self, img, mask):
# random mirror
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
crop_size = self.crop_size
if self.scale:
short_size = random.randint(int(self.base_size*0.75), int(self.base_size*2.0))
else:
short_size = self.base_size
w, h = img.size
if h > w:
ow = short_size
oh = int(1.0 * h * ow / w)
else:
oh = short_size
ow = int(1.0 * w * oh / h)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# random rotate -10~10, mask using NN rotate
# deg = random.uniform(-10, 10)
# img = img.rotate(deg, resample=Image.BILINEAR)
# mask = mask.rotate(deg, resample=Image.NEAREST)
# pad crop
if short_size < crop_size:
padh = crop_size - oh if oh < crop_size else 0
padw = crop_size - ow if ow < crop_size else 0
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=255)#pad 255 for cityscapes
# random crop crop_size
w, h = img.size
x1 = random.randint(0, w - crop_size)
y1 = random.randint(0, h - crop_size)
img = img.crop((x1, y1, x1+crop_size, y1+crop_size))
mask = mask.crop((x1, y1, x1+crop_size, y1+crop_size))
# gaussian blur as in PSP
# if random.random() < 0.5:
# img = img.filter(ImageFilter.GaussianBlur(
# radius=random.random()))
# final transform
return img, self._mask_transform(mask)
def _mask_transform(self, mask):
return torch.from_numpy(np.array(mask)).long()
def test_batchify_fn(data):
error_msg = "batch must contain tensors, tuples or lists; found {}"
if isinstance(data[0], (str, torch.Tensor)):
return list(data)
elif isinstance(data[0], (tuple, list)):
data = zip(*data)
return [test_batchify_fn(i) for i in data]
raise TypeError((error_msg.format(type(batch[0]))))