-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset_creater.py
More file actions
62 lines (53 loc) · 2.93 KB
/
dataset_creater.py
File metadata and controls
62 lines (53 loc) · 2.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import torchvision
import torchvision.transforms as transforms
import random
from imshow import visualize
def dataset_creater(batchsize, bool_visualize):
"""creates a new dataset
Train set: 45000 samples
Validation set: 5000 samples generated from cifar dataset with equal number of samples from each class selected random.
Test set: The default cifar test set.
if bool_visualize is True: prints random samples from the dataset out.
"""
# it is adjusted to determine the transformation process when creating dataset
transform = transforms.Compose([
torchvision.transforms.ToTensor(), # transforms input imagess to tensor type to compute gradients easier.
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), # adjust the distribution of the dataset statistically
torchvision.transforms.Grayscale() # convert the RGB images to greyscale
])
# download train dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
# number of samples in validation set from each class
valid_from_class = len(trainset) // 100
# initialize indice recording:
indice_dict = {}
# initialize index list:
idxs = []
for i in range(10):
indice_dict[i] = [] # create an index list for each class
indice = 0
# record the indices at which each class appeares:
for data in trainset:
indice_dict[data[1]].append(indice)
indice += 1
# random sampling indexes from the indice lists to determine who are going to be selected as validation data
for i in range(10):
idxs = idxs + random.sample(indice_dict[i], valid_from_class) # validation data indexes in train set
# removing validation indexes from the train set indexes:
train_list = list(range(len(trainset)))
for idx in idxs: train_list.remove(idx)
trainset_new = torch.utils.data.Subset(trainset, train_list) # validation removed version of trainset
valid_set = torch.utils.data.Subset(trainset, idxs) # validation set is created with random samples
# download train dataset
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
transform=transform)
# creating the batches of train, validation and test sets as iterables:
train_generator = torch.utils.data.DataLoader(trainset_new, batch_size=batchsize, shuffle=True)
validation_generator = torch.utils.data.DataLoader(valid_set, batch_size=4096, shuffle=False)
test_generator = torch.utils.data.DataLoader(testset, batch_size=batchsize, shuffle=False)
# if bool_visualize is True: prints random samples from the dataset out:
if bool_visualize is True:
visualize.imshow(train_generator)
return train_generator, validation_generator, test_generator