-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathxray_data.py
More file actions
95 lines (77 loc) · 3.06 KB
/
xray_data.py
File metadata and controls
95 lines (77 loc) · 3.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import pandas
import torch
import SimpleITK as sitk
import numpy as np
import os
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from torchvision import transforms, datasets
from torch.utils import data
DATA_PATH = '/media/user/disk/'
def read_data():
path = DATA_PATH + 'rsna-pneumonia-detection-challenge/'
sub_path = 'stage_2_train_images/'
csv = pandas.read_csv(path + 'stage_2_detailed_class_info.csv')
class_label = {'Normal': 0,
'No Lung Opacity / Not Normal': 1, 'Lung Opacity': 2}
patient_dict = {}
for i in range(csv.shape[0]):
name = csv['patientId'][i]
target = csv['class'][i]
target = class_label[target]
patient_dict[name] = target
for file_name in tqdm(os.listdir(path + sub_path)):
patient = file_name.split('.')[0]
target = patient_dict[patient]
if target == 0:
os.system('cp {}{}{} {}{}'.format(
path, sub_path, file_name, path, 'normal/'))
if target == 1:
os.system('cp {}{}{} {}{}'.format(
path, sub_path, file_name, path, 'not_normal/'))
if target == 2:
os.system('cp {}{}{} {}{}'.format(
path, sub_path, file_name, path, 'lung_opacity/'))
class Xray(data.Dataset):
def __init__(self, main_path, img_size=64, transform=None):
super(Xray, self).__init__()
self.transform = transform
self.file_path = []
self.labels = []
self.slices = []
self.transform = transform if transform is not None else lambda x: x
for label in os.listdir(main_path):
if label not in ['0', '1']:
continue
for file_name in tqdm(os.listdir(main_path+'/'+label)):
data = sitk.ReadImage(main_path+'/'+label + '/' + file_name)
data = sitk.GetArrayFromImage(data).squeeze()
img = Image.fromarray(data).convert('L').resize((img_size,img_size), resample=Image.BILINEAR)
self.slices.append(img)
self.labels.append(int(label))
def __getitem__(self, index):
img = self.slices[index]
label = self.labels[index]
img = self.transform(img)
return img, label
def __len__(self):
return len(self.slices)
def get_xray_dataloader(bs, workers, dtype='train', img_size=64, dataset='rsna'):
transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
if dataset == 'rsna':
path = DATA_PATH + 'rsna-pneumonia-detection-challenge/'
elif dataset == 'pedia':
path = DATA_PATH + 'pediatric/'
path += dtype
dset = Xray(main_path=path, transform=transform, img_size=img_size)
train_flag = True if dtype == 'train' else False
dataloader = data.DataLoader(dset, bs, shuffle=train_flag,
drop_last=train_flag, num_workers=workers, pin_memory=True)
return dataloader
if __name__ == '__main__':
read_data()