-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathdataset.py
148 lines (129 loc) · 4.7 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import numbers
import os
import queue as Queue
import threading
import mxnet as mx
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
class BackgroundGenerator(threading.Thread):
def __init__(self, generator, local_rank, max_prefetch=6):
super(BackgroundGenerator, self).__init__()
self.queue = Queue.Queue(max_prefetch)
self.generator = generator
self.local_rank = local_rank
self.daemon = True
self.start()
def run(self):
torch.cuda.set_device(self.local_rank)
for item in self.generator:
self.queue.put(item)
self.queue.put(None)
def next(self):
next_item = self.queue.get()
if next_item is None:
raise StopIteration
return next_item
def __next__(self):
return self.next()
def __iter__(self):
return self
class DataLoaderX(DataLoader):
def __init__(self, local_rank, **kwargs):
super(DataLoaderX, self).__init__(**kwargs)
self.stream = torch.cuda.Stream(local_rank)
self.local_rank = local_rank
def __iter__(self):
self.iter = super(DataLoaderX, self).__iter__()
self.iter = BackgroundGenerator(self.iter, self.local_rank)
self.preload()
return self
def preload(self):
self.batch = next(self.iter, None)
if self.batch is None:
return None
with torch.cuda.stream(self.stream):
for k in range(len(self.batch)):
self.batch[k] = self.batch[k].to(device=self.local_rank,
non_blocking=True)
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
batch = self.batch
if batch is None:
raise StopIteration
self.preload()
return batch
class MXFaceDataset(Dataset):
def __init__(self, root_dir, local_rank):
super(MXFaceDataset, self).__init__()
self.transform = transforms.Compose(
[transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
self.root_dir = root_dir
self.local_rank = local_rank
path_imgrec = os.path.join(root_dir, 'train.rec')
path_imgidx = os.path.join(root_dir, 'train.idx')
self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
s = self.imgrec.read_idx(0)
header, _ = mx.recordio.unpack(s)
if header.flag > 0:
self.header0 = (int(header.label[0]), int(header.label[1]))
self.imgidx = np.array(range(1, int(header.label[0])))
else:
self.imgidx = np.array(list(self.imgrec.keys))
def __getitem__(self, index):
idx = self.imgidx[index]
s = self.imgrec.read_idx(idx)
header, img = mx.recordio.unpack(s)
label = header.label
if not isinstance(label, numbers.Number):
label = label[0]
label = torch.tensor(label, dtype=torch.long)
sample = mx.image.imdecode(img).asnumpy()
if self.transform is not None:
sample = self.transform(sample)
return sample, label
def __len__(self):
return len(self.imgidx)
class FaceDatasetFolder(Dataset):
def __init__(self, root_dir, local_rank):
super(FaceDatasetFolder, self).__init__()
self.transform = transforms.Compose(
[transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
self.root_dir = root_dir
self.local_rank = local_rank
self.imgidx, self.labels=self.scan(root_dir)
def scan(self,root):
imgidex=[]
labels=[]
lb=-1
list_dir=os.listdir(root)
list_dir.sort()
for l in list_dir:
images=os.listdir(os.path.join(root,l))
lb += 1
for img in images:
imgidex.append(os.path.join(l,img))
labels.append(lb)
return imgidex,labels
def readImage(self,path):
return cv2.imread(os.path.join(self.root_dir,path))
def __getitem__(self, index):
path = self.imgidx[index]
img=self.readImage(path)
label = self.labels[index]
label = torch.tensor(label, dtype=torch.long)
sample = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
if self.transform is not None:
sample = self.transform(sample)
return sample, label
def __len__(self):
return len(self.imgidx)